diff --git a/src/imagination/csbgen/gen_pack_header.py b/src/imagination/csbgen/gen_pack_header.py index 6058459bbd1..e7616948fe6 100644 --- a/src/imagination/csbgen/gen_pack_header.py +++ b/src/imagination/csbgen/gen_pack_header.py @@ -27,11 +27,15 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from __future__ import annotations + import copy import os import sys import textwrap +import typing as t import xml.parsers.expat as expat +from abc import ABC from ast import literal_eval @@ -73,14 +77,14 @@ PACK_FILE_HEADER = """%(license)s """ -def safe_name(name): +def safe_name(name: str) -> str: if not name[0].isalpha(): name = "_" + name return name -def num_from_str(num_str): +def num_from_str(num_str: str) -> int: if num_str.lower().startswith("0x"): return int(num_str, base=16) @@ -90,8 +94,13 @@ def num_from_str(num_str): return int(num_str) -class Node: - def __init__(self, parent, name, name_is_safe=False): +class Node(ABC): + __slots__ = ["parent", "name"] + + parent: Node + name: str + + def __init__(self, parent: Node, name: str, *, name_is_safe: bool = False) -> None: self.parent = parent if name_is_safe: self.name = name @@ -99,19 +108,31 @@ class Node: self.name = safe_name(name) @property - def full_name(self): + def full_name(self) -> str: if self.name[0] == "_": return self.parent.prefix + self.name.upper() return self.parent.prefix + "_" + self.name.upper() @property - def prefix(self): + def prefix(self) -> str: return self.parent.prefix + def add(self, element: Node) -> None: + raise RuntimeError("Element cannot be nested in %s. Element Type: %s" + % (type(self).__name__.lower(), type(element).__name__)) + class Csbgen(Node): - def __init__(self, name, prefix, filename): + __slots__ = ["prefix_field", "filename", "_defines", "_enums", "_structs"] + + prefix_field: str + filename: str + _defines: t.List[Define] + _enums: t.Dict[str, Enum] + _structs: t.Dict[str, Struct] + + def __init__(self, name: str, prefix: str, filename: str) -> None: super().__init__(None, name.upper()) self.prefix_field = safe_name(prefix.upper()) self.filename = filename @@ -121,14 +142,14 @@ class Csbgen(Node): self._structs = {} @property - def full_name(self): + def full_name(self) -> str: return self.name + "_" + self.prefix_field @property - def prefix(self): + def prefix(self) -> str: return self.full_name - def add(self, element): + def add(self, element: Node) -> None: if isinstance(element, Enum): if element.name in self._enums: raise RuntimeError("Enum redefined. Enum: %s" % element.name) @@ -146,12 +167,12 @@ class Csbgen(Node): self._defines.append(element) else: - raise RuntimeError("Element '%s' cannot be nested in csbgen." % type(element).__name__) + super().add(element) - def _gen_guard(self): + def _gen_guard(self) -> str: return os.path.basename(self.filename).replace(".xml", "_h").upper() - def emit(self): + def emit(self) -> None: print(PACK_FILE_HEADER % { "license": MIT_LICENSE_COMMENT % {"copyright": "2022 Imagination Technologies Ltd."}, "platform": self.name, @@ -171,18 +192,22 @@ class Csbgen(Node): print("#endif /* %s */" % self._gen_guard()) - def is_known_struct(self, struct_name): + def is_known_struct(self, struct_name: str) -> bool: return struct_name in self._structs.keys() - def is_known_enum(self, enum_name): + def is_known_enum(self, enum_name: str) -> bool: return enum_name in self._enums.keys() - def get_enum(self, enum_name): + def get_enum(self, enum_name: str) -> Enum: return self._enums[enum_name] class Enum(Node): - def __init__(self, parent, name): + __slots__ = ["_values"] + + _values: t.Dict[str, Value] + + def __init__(self, parent: Node, name: str) -> None: super().__init__(parent, name) self._values = {} @@ -191,23 +216,22 @@ class Enum(Node): # We override prefix so that the values will contain the enum's name too. @property - def prefix(self): + def prefix(self) -> str: return self.full_name - def get_value(self, value_name): + def get_value(self, value_name: str) -> Value: return self._values[value_name] - def add(self, element): + def add(self, element: Node) -> None: if not isinstance(element, Value): - raise RuntimeError("Element cannot be nested in enum. Element Type: %s, Enum: %s" - % (type(element).__name__, self.full_name)) + super().add(element) if element.name in self._values: raise RuntimeError("Value is being redefined. Value: '%s'" % element.name) self._values[element.name] = element - def emit(self): + def emit(self) -> None: # This check is invalid if tags other than Value can be nested within an enum. if not self._values.values(): raise RuntimeError("Enum definition is empty. Enum: '%s'" % self.full_name) @@ -219,10 +243,14 @@ class Enum(Node): class Value(Node): - def __init__(self, parent, name, value): + __slots__ = ["value"] + + value: int + + def __init__(self, parent: Node, name: str, value: int) -> None: super().__init__(parent, name) - self.value = int(value) + self.value = value self.parent.add(self) @@ -231,10 +259,16 @@ class Value(Node): class Struct(Node): - def __init__(self, parent, name, length): + __slots__ = ["length", "size", "_children"] + + length: int + size: int + _children: t.Dict[str, t.Union[Condition, Field]] + + def __init__(self, parent: Node, name: str, length: int) -> None: super().__init__(parent, name) - self.length = int(length) + self.length = length self.size = self.length * 32 if self.length <= 0: @@ -245,7 +279,7 @@ class Struct(Node): self.parent.add(self) @property - def fields(self): + def fields(self) -> t.List[Field]: # TODO: Should we cache? See TODO in equivalent Condition getter. fields = [] @@ -258,10 +292,10 @@ class Struct(Node): return fields @property - def prefix(self): + def prefix(self) -> str: return self.full_name - def add(self, element): + def add(self, element: Node) -> None: # We don't support conditions and field having the same name. if isinstance(element, Field): if element.name in self._children.keys(): @@ -281,10 +315,9 @@ class Struct(Node): raise RuntimeError("Unknown condition: '%s'" % element.name) else: - raise RuntimeError("Element cannot be nested in struct. Element Type: %s, Struct: %s" - % (type(element).__name__, self.full_name)) + super().add(element) - def _emit_header(self, root): + def _emit_header(self, root: Csbgen) -> None: fields = filter(lambda f: hasattr(f, "default"), self.fields) default_fields = [] @@ -313,7 +346,7 @@ class Struct(Node): print(", \\\n".join(default_fields)) print("") - def _emit_helper_macros(self): + def _emit_helper_macros(self) -> None: fields_with_defines = filter(lambda f: f.defines, self.fields) for field in fields_with_defines: @@ -324,7 +357,7 @@ class Struct(Node): print() - def _emit_pack_function(self, root): + def _emit_pack_function(self, root: Csbgen) -> None: print(textwrap.dedent("""\ static inline __attribute__((always_inline)) void %s_pack(__attribute__((unused)) void * restrict dst, @@ -341,7 +374,7 @@ class Struct(Node): print("}\n") - def emit(self, root): + def emit(self, root: Csbgen) -> None: print("#define %-33s %6d" % (self.full_name + "_length", self.length)) self._emit_header(root) @@ -357,11 +390,21 @@ class Struct(Node): class Field(Node): - def __init__(self, parent, name, start, end, ty, default=None, shift=None): + __slots__ = ["start", "end", "type", "default", "shift", "_defines"] + + start: int + end: int + type: str + default: t.Union[str, int] + shift: t.Optional[int] + _defines: t.Dict[str, Define] + + def __init__(self, parent: Node, name: str, start: int, end: int, ty: str, *, + default: str = None, shift: int = None) -> None: super().__init__(parent, name) - self.start = int(start) - self.end = int(end) + self.start = start + self.end = end self.type = ty self._defines = {} @@ -393,21 +436,23 @@ class Field(Node): if self.type == "address": raise RuntimeError("Field of address type requires a shift attribute. Field '%s'" % self.name) + self.shift = None + @property - def defines(self): + def defines(self) -> t.Iterator[Define]: return self._defines.values() # We override prefix so that the defines will contain the field's name too. @property - def prefix(self): + def prefix(self) -> str: return self.full_name @property - def is_builtin_type(self): + def is_builtin_type(self) -> bool: builtins = {"address", "bool", "float", "mbo", "offset", "int", "uint"} return self.type in builtins - def _get_c_type(self, root): + def _get_c_type(self, root: Csbgen) -> str: if self.type == "address": return "__pvr_address_type" elif self.type == "bool": @@ -432,7 +477,7 @@ class Field(Node): return "enum " + root.get_enum(self.type).full_name raise RuntimeError("Unknown type. Type: '%s', Field: '%s'" % (self.type, self.name)) - def add(self, element): + def add(self, element: Node) -> None: if self.type == "mbo": raise RuntimeError("No element can be nested in an mbo field. Element Type: %s, Field: %s" % (type(element).__name__, self.name)) @@ -443,10 +488,9 @@ class Field(Node): self._defines[element.name] = element else: - raise RuntimeError("Element cannot be nested in a field. Element Type: %s, Field: %s" - % (type(element).__name__, self.name)) + super().add(element) - def emit(self, root): + def emit(self, root: Csbgen) -> None: if self.type == "mbo": return @@ -454,19 +498,29 @@ class Field(Node): class Define(Node): - def __init__(self, parent, name, value): + __slots__ = ["value"] + + value: int + + def __init__(self, parent: Node, name: str, value: int) -> None: super().__init__(parent, name) self.value = value self.parent.add(self) - def emit(self): + def emit(self) -> None: print("#define %-40s %d" % (self.full_name, self.value)) class Condition(Node): - def __init__(self, parent, name, ty): + __slots__ = ["type", "_children", "_child_branch"] + + type: str + _children: t.Dict[str, t.Union[Condition, Field]] + _child_branch: t.Optional[Condition] + + def __init__(self, parent: Node, name: str, ty: str) -> None: super().__init__(parent, name, name_is_safe=True) self.type = ty @@ -483,7 +537,7 @@ class Condition(Node): self.parent.add(self) @property - def fields(self): + def fields(self) -> t.List[Field]: # TODO: Should we use some kind of state to indicate the all of the # child nodes have been added and then cache the fields in here on the # first call so that we don't have to traverse them again per each call? @@ -504,7 +558,7 @@ class Condition(Node): return fields @staticmethod - def _is_valid_type(ty): + def _is_valid_type(ty: str) -> bool: types = {"if", "elif", "else", "endif"} return ty in types @@ -514,7 +568,7 @@ class Condition(Node): return (branch.type in types[idx + 1:] or self.type == "elif" and branch.type == "elif") - def _add_branch(self, branch): + def _add_branch(self, branch: Condition) -> None: if branch.type == "elif" and branch.name == self.name: raise RuntimeError("Elif branch cannot have same check as previous branch. Check: '%s'" % branch.name) @@ -533,13 +587,15 @@ class Condition(Node): # TODO: Redo this to improve speed? Would caching this be helpful? We could # just save the name of the if instead of having to walk towards it whenever # a new condition is being added. - def _top_branch_name(self): + def _top_branch_name(self) -> str: if self.type == "if": return self.name + # If we're not an 'if' condition, our parent must be another condition. + assert isinstance(self.parent, Condition) return self.parent._top_branch_name() - def add(self, element): + def add(self, element: Node) -> None: if isinstance(element, Field): if element.name in self._children.keys(): raise ValueError("Duplicate field. Field: '%s'" % element.name) @@ -578,10 +634,9 @@ class Condition(Node): self._children[element.name] = element else: - raise RuntimeError("Element cannot be nested in a condition. Element Type: %s, Check: %s" - % (type(element).__name__, self.name)) + super().add(element) - def emit(self, root): + def emit(self, root: Csbgen) -> None: if self.type == "if": print("/* if %s is supported use: */" % self.name) elif self.type == "elif": @@ -600,20 +655,33 @@ class Condition(Node): self._child_branch.emit(root) -class Group(object): - def __init__(self, start, count, size, fields): +class Group: + __slots__ = ["start", "count", "size", "fields"] + + start: int + count: int + size: int + fields: t.List[Field] + + def __init__(self, start: int, count: int, size: int, fields) -> None: self.start = start self.count = count self.size = size self.fields = fields class DWord: - def __init__(self): + __slots__ = ["size", "fields", "addresses"] + + size: int + fields: t.List[Field] + addresses: t.List[Field] + + def __init__(self) -> None: self.size = 32 self.fields = [] self.addresses = [] - def collect_dwords(self, dwords, start): + def collect_dwords(self, dwords: t.Dict[int, Group.DWord], start: int) -> None: for field in self.fields: index = (start + field.start) // 32 if index not in dwords: @@ -640,9 +708,9 @@ class Group(object): dwords[index + 1] = dwords[index] index = index + 1 - def collect_dwords_and_length(self): + def collect_dwords_and_length(self) -> t.Tuple[t.Dict[int, Group.DWord], int]: dwords = {} - self.collect_dwords(dwords, 0, "") + self.collect_dwords(dwords, 0) # Determine number of dwords in this group. If we have a size, use # that, since that'll account for MBZ dwords at the end of a group @@ -657,7 +725,7 @@ class Group(object): return dwords, length - def emit_pack_function(self, root, dwords, length): + def emit_pack_function(self, root: Csbgen, dwords: t.Dict[int, Group.DWord], length: int) -> None: for index in range(length): # Handle MBZ dwords if index not in dwords: @@ -688,7 +756,7 @@ class Group(object): # to the dword for those fields. field_index = 0 for field in dw.fields: - if isinstance(field, Field) and root.is_known_struct(field.type): + if root.is_known_struct(field.type): print("") print(" uint32_t v%d_%d;" % (index, field_index)) print(" %s_pack(data, &v%d_%d, &values->%s);" @@ -775,15 +843,22 @@ class Group(object): print(" dw[%d] = %s >> 32;" % (index + 1, v)) -class Parser(object): - def __init__(self): +class Parser: + __slots__ = ["parser", "context", "filename"] + + parser: expat.XMLParserType + context: t.List[Node] + filename: str + + def __init__(self) -> None: self.parser = expat.ParserCreate() self.parser.StartElementHandler = self.start_element self.parser.EndElementHandler = self.end_element self.context = [] + self.filename = "" - def start_element(self, name, attrs): + def start_element(self, name: str, attrs: t.Dict[str, str]) -> None: if not name == "csbgen": parent = self.context[-1] @@ -798,7 +873,7 @@ class Parser(object): self.context.append(csbgen) elif name == "struct": - struct = Struct(parent, attrs["name"], attrs["length"]) + struct = Struct(parent, attrs["name"], int(attrs["length"])) self.context.append(struct) elif name == "field": @@ -819,11 +894,11 @@ class Parser(object): self.context.append(enum) elif name == "value": - value = Value(parent, attrs["name"], literal_eval(attrs["value"])) + value = Value(parent, attrs["name"], int(literal_eval(attrs["value"]))) self.context.append(value) elif name == "define": - define = Define(parent, attrs["name"], literal_eval(attrs["value"])) + define = Define(parent, attrs["name"], int(literal_eval(attrs["value"]))) self.context.append(define) elif name == "condition": @@ -851,7 +926,7 @@ class Parser(object): else: raise RuntimeError("Unknown tag: '%s'" % name) - def end_element(self, name): + def end_element(self, name: str) -> None: if name == "condition": element = self.context[-1] if not isinstance(element, Condition) and not isinstance(element, Struct): @@ -884,7 +959,7 @@ class Parser(object): else: raise RuntimeError("Unknown closing element: '%s'" % name) - def parse(self, filename): + def parse(self, filename: str) -> None: file = open(filename, "rb") self.filename = filename self.parser.ParseFile(file)