#!/usr/bin/env python3
# pylint: disable=C0301
# SPDX-License-Identifier: (GPL-2.0 OR MIT)
#
# Copyright (C) 2023 Intel Corporation

import abc
import argparse
import logging
import hashlib
import sys
import typing

# struct definition below should match the one from i915
# (drivers/gpu/drm/i915/gt/uc/intel_uc_fw_abi.h) and xe
# (drivers/gpu/drm/xe/xe_uc_fw_abi.h).
#
# For compatibility reasons with dissect.cstruct python module, the following
# things are changed from the original kernel definition:
#
#       1) #define in the middle of the struct removed: comment them out
#       2) No anonymous union - not compatible with the
#          dumpstruct(): give it a name

CDEF = """
struct uc_css_header {
	u32 module_type;
	/*
	 * header_size includes all non-uCode bits, including css_header, rsa
	 * key, modulus key and exponent data.
	 */
	u32 header_size_dw;
	u32 header_version;
	u32 module_id;
	u32 module_vendor;
	u32 date;
// #define CSS_DATE_DAY				(0xFF << 0)
// #define CSS_DATE_MONTH			(0xFF << 8)
// #define CSS_DATE_YEAR			(0xFFFF << 16)
	u32 size_dw; /* uCode plus header_size_dw */
	u32 key_size_dw;
	u32 modulus_size_dw;
	u32 exponent_size_dw;
	u32 time;
// #define CSS_TIME_HOUR			(0xFF << 0)
// #define CSS_DATE_MIN				(0xFF << 8)
// #define CSS_DATE_SEC				(0xFFFF << 16)
	char username[8];
	char buildnumber[12];
	u32 sw_version;
// #define CSS_SW_VERSION_UC_MAJOR		(0xFF << 16)
// #define CSS_SW_VERSION_UC_MINOR		(0xFF << 8)
// #define CSS_SW_VERSION_UC_PATCH		(0xFF << 0)
	u32 vf_version;
	u32 reserved0[12];
	union {
		u32 private_data_size; /* only applies to GuC */
		u32 reserved1;
	} rsvd;
	u32 header_info;
};

/* Code partition directory (CPD) structures */
#define GSC_CPD_HEADER_MARKER 0x44504324
struct gsc_cpd_header_v2 {
	u32 header_marker;

	u32 num_of_entries;
	u8 header_version;
	u8 entry_version;
	u8 header_length; /* in bytes */
	u8 flags;
	u32 partition_name;
	u32 crc32;
};

// Added as python code
// #define GSC_CPD_ENTRY_OFFSET_MASK GENMASK(24, 0)

struct gsc_version {
	u16 major;
	u16 minor;
	u16 hotfix;
	u16 build;
};

struct gsc_cpd_entry {
	u8 name[12];

	/*
	 * Bits 0-24: offset from the beginning of the code partition
	 * Bit 25: huffman compressed
	 * Bits 26-31: reserved
	 */
	u32 offset;

	/*
	 * Module/Item length, in bytes. For Huffman-compressed modules, this
	 * refers to the uncompressed size. For software-compressed modules,
	 * this refers to the compressed size.
	 */
	u32 length;

	u8 reserved[4];
};

struct gsc_manifest_header {
	u32 header_type; /* 0x4 for manifest type */
	u32 header_length; /* in dwords */
	u32 header_version;
	u32 flags;
	u32 vendor;
	u32 date;
	u32 size; /* In dwords, size of entire manifest (header + extensions) */
	u32 header_id;
	u32 internal_data;
	struct gsc_version fw_version;
	u32 security_version;
	struct gsc_version meu_kit_version;
	u32 meu_manifest_version;
	u8 general_data[4];
	u8 reserved3[56];
	u32 modulus_size; /* in dwords */
	u32 exponent_size; /* in dwords */
};

struct magic {
	char data[4];
};
"""

logging.basicConfig(format="%(levelname)s: %(message)s")

try:
    from dissect import cstruct
except ImportError as e:
    logging.critical(
        "Could not import dissect.cstruct module. See https://github.com/fox-it/dissect.cstruct for installation options"
    )
    raise SystemExit(1) from e


GSC_CPD_ENTRY_OFFSET_MASK = 0x1FFFFFF


def ffs(x: int) -> int:
    """Returns the index, counting from 0, of the
    least significant set bit in `x`.
    """
    return (x & -x).bit_length() - 1


def FIELD_GET(mask: int, value: int) -> int:
    return (value & mask) >> ffs(mask)


class Fw(abc.ABC):
    def __init__(self, cparser, f):
        self.f = f
        self.cparser = cparser
        self.fw = None

    def checksum(self):
        self.f.seek(0, 0)
        return hashlib.sha256(self.f.read()).hexdigest()

    def dump(self, **kw):
        cstruct.dumpstruct(self.fw, **kw)

    @classmethod
    def create(cls, f):
        cparser = cstruct.cstruct()
        # cstruct defines some "convenience types" that are not very convenient
        # when parsing a kernel header - e.g. we don't want u16 to be parsed as
        # uint128
        cparser.add_type("u8", "uint8_t", replace=True)
        cparser.add_type("u16", "uint16_t", replace=True)
        cparser.add_type("u32", "uint32_t", replace=True)
        cparser.add_type("u64", "uint64_t", replace=True)
        cparser.load(CDEF)

        magic = cparser.magic(f).data

        constructor = None

        for s in cls.__subclasses__():
            if s.MAGIC is None:
                constructor = s
            elif s.MAGIC == magic:
                constructor = s
                break

        if constructor:
            return constructor(cparser, f)

        return None

    @abc.abstractmethod
    def decode(self):
        raise NotImplementedError()


class FwCss(Fw):
    MAGIC = None

    def decode(self):
        self.f.seek(0, 0)
        self.fw = self.cparser.uc_css_header(self.f)

        data = ["header-type: CSS"]

        CSS_SW_VERSION_UC_MAJOR = 0xFF << 16
        CSS_SW_VERSION_UC_MINOR = 0xFF << 8
        CSS_SW_VERSION_UC_PATCH = 0xFF
        major = FIELD_GET(CSS_SW_VERSION_UC_MAJOR, self.fw.sw_version)
        minor = FIELD_GET(CSS_SW_VERSION_UC_MINOR, self.fw.sw_version)
        patch = FIELD_GET(CSS_SW_VERSION_UC_PATCH, self.fw.sw_version)
        data += [f"version: {major}.{minor}.{patch}"]

        CSS_DATE_DAY = 0xFF
        CSS_DATE_MONTH = 0xFF << 8
        CSS_DATE_YEAR = 0xFFFF << 16
        day = FIELD_GET(CSS_DATE_DAY, self.fw.date)
        month = FIELD_GET(CSS_DATE_MONTH, self.fw.date)
        year = FIELD_GET(CSS_DATE_YEAR, self.fw.date)
        data += [f"date: {year:02x}-{month:02x}-{day:02x}"]

        return data


class FwGsc(Fw):
    MAGIC = b"$CPD"

    def decode(self):
        self.f.seek(0, 0)
        cpd_hdr = self.cparser.gsc_cpd_header_v2(self.f)

        data = ["header-type: GSC"]

        # find manifest entry
        self.f.seek(cpd_hdr.header_length, 0)
        for _ in range(cpd_hdr.num_of_entries):
            entry = self.cparser.gsc_cpd_entry(self.f)
            s = "".join(map(chr, entry.name)).rstrip("\x00")
            if s == "HUCP.man":
                offset = entry.offset & GSC_CPD_ENTRY_OFFSET_MASK
                self.f.seek(offset, 0)
                break
        else:
            logging.fatal("Unknown/corrupted firmware - no manifest available")
            sys.exit(1)

        self.fw = self.cparser.gsc_manifest_header(self.f)
        major = self.fw.fw_version.major
        minor = self.fw.fw_version.minor
        hotfix = self.fw.fw_version.hotfix
        data += [f"version: {major}.{minor}.{hotfix}"]

        return data


def parse_args(argv: typing.List[str]) -> argparse.Namespace:
    description = "Dump GuC/HuC firmware information"
    parser = argparse.ArgumentParser(prog="intel-gfx-fw-info", description=description)

    parser.add_argument(
        "-x", "--raw", action="store_true", help="Also print raw header content"
    )
    parser.add_argument(
        "-c", "--checksum", action="store_true", help="Also print checksum"
    )

    parser.add_argument("filename", help="GuC/HuC firmware file")

    return parser.parse_args(argv)


def main(argv: typing.List[str]) -> int:
    args = parse_args(argv)

    checksum = None

    try:
        with open(args.filename, mode="rb") as f:
            fw = Fw.create(f)
            if not fw:
                logging.fatal(f"Unknown firmware type in '{args.filename}'")
                sys.exit(1)

            decoded_data = fw.decode()
            if args.checksum:
                checksum = fw.checksum()
    except FileNotFoundError as e:
        logging.fatal(e)
        return 1

    print(*decoded_data, sep="\n")

    if args.raw:
        print("raw dump:", end="")
        fw.dump(color=sys.stdout.isatty())

    if checksum:
        print(f"checksum: {checksum}")

    return 0


if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))
