From 949c3b55dbdfaca4e62d67cc4de1f4e3958813c4 Mon Sep 17 00:00:00 2001 From: Dylan Baker Date: Tue, 1 Nov 2022 12:59:06 -0700 Subject: [PATCH] util/glsl2spirv: add type annotations Which are all clean Reviewed-by: Luis Felipe Strano Moraes Part-of: --- src/util/glsl2spirv.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/util/glsl2spirv.py b/src/util/glsl2spirv.py index 214bfed667d..5e216a0b96c 100644 --- a/src/util/glsl2spirv.py +++ b/src/util/glsl2spirv.py @@ -21,15 +21,29 @@ # Converts GLSL shader to SPIR-V library +from __future__ import annotations import argparse import subprocess import os +import typing as T + +if T.TYPE_CHECKING: + + class Arguments(T.Protocol): + input: str + output: str + create_entry: T.Optional[str] + glsl_ver: T.Optional[str] + Olib: bool + extra: T.Optional[str] + vn: str + stage: str class ShaderCompileError(RuntimeError): def __init__(self, *args): super(ShaderCompileError, self).__init__(*args) -def get_args(): +def get_args() -> Arguments: parser = argparse.ArgumentParser() parser.add_argument('input', help="Name of input file.") parser.add_argument('output', help="Name of output file.") @@ -64,7 +78,7 @@ def get_args(): return args -def create_include_guard(lines, filename): +def create_include_guard(lines: T.List[str], filename: str) -> T.List[str]: filename = filename.replace('.', '_') upper_name = filename.upper() @@ -81,7 +95,7 @@ def create_include_guard(lines, filename): return guard_head + lines + guard_tail -def convert_to_static_variable(lines, varname): +def convert_to_static_variable(lines: T.List[str], varname: str) -> T.List[str]: for idx, l in enumerate(lines): if l.find(varname) != -1: lines[idx] = "static " + lines[idx] @@ -89,7 +103,7 @@ def convert_to_static_variable(lines, varname): raise RuntimeError(f'Did not find {varname}, this is unexpected') -def override_version(lines, glsl_version): +def override_version(lines: T.List[str], glsl_version: str) -> T.List[str]: for idx, l in enumerate(lines): if l.find('#version ') != -1: lines[idx] = "#version {}\n".format(glsl_version) @@ -97,7 +111,7 @@ def override_version(lines, glsl_version): raise RuntimeError('Did not find #version directive, this is unexpected') -def postprocess_file(args): +def postprocess_file(args: Arguments) -> None: with open(args.output, "r") as r: lines = r.readlines() @@ -111,7 +125,7 @@ def postprocess_file(args): w.writelines(lines) -def preprocess_file(args, origin_file): +def preprocess_file(args: Arguments, origin_file: T.TextIO) -> str: with open(origin_file.name + ".copy", "w") as copy_file: lines = origin_file.readlines() @@ -126,7 +140,7 @@ def preprocess_file(args, origin_file): return copy_file.name -def process_file(args): +def process_file(args: Arguments) -> None: with open(args.input, "r") as infile: copy_file = preprocess_file(args, infile) @@ -169,7 +183,7 @@ def process_file(args): os.remove(copy_file) -def main(): +def main() -> None: args = get_args() process_file(args)