util/glsl2spirv: add type annotations
Which are all clean Reviewed-by: Luis Felipe Strano Moraes <luis.strano@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19449>
This commit is contained in:
@@ -21,15 +21,29 @@
|
||||
|
||||
# Converts GLSL shader to SPIR-V library
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import subprocess
|
||||
import os
|
||||
import typing as T
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
|
||||
class Arguments(T.Protocol):
|
||||
input: str
|
||||
output: str
|
||||
create_entry: T.Optional[str]
|
||||
glsl_ver: T.Optional[str]
|
||||
Olib: bool
|
||||
extra: T.Optional[str]
|
||||
vn: str
|
||||
stage: str
|
||||
|
||||
class ShaderCompileError(RuntimeError):
|
||||
def __init__(self, *args):
|
||||
super(ShaderCompileError, self).__init__(*args)
|
||||
|
||||
def get_args():
|
||||
def get_args() -> Arguments:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('input', help="Name of input file.")
|
||||
parser.add_argument('output', help="Name of output file.")
|
||||
@@ -64,7 +78,7 @@ def get_args():
|
||||
return args
|
||||
|
||||
|
||||
def create_include_guard(lines, filename):
|
||||
def create_include_guard(lines: T.List[str], filename: str) -> T.List[str]:
|
||||
filename = filename.replace('.', '_')
|
||||
upper_name = filename.upper()
|
||||
|
||||
@@ -81,7 +95,7 @@ def create_include_guard(lines, filename):
|
||||
return guard_head + lines + guard_tail
|
||||
|
||||
|
||||
def convert_to_static_variable(lines, varname):
|
||||
def convert_to_static_variable(lines: T.List[str], varname: str) -> T.List[str]:
|
||||
for idx, l in enumerate(lines):
|
||||
if l.find(varname) != -1:
|
||||
lines[idx] = "static " + lines[idx]
|
||||
@@ -89,7 +103,7 @@ def convert_to_static_variable(lines, varname):
|
||||
raise RuntimeError(f'Did not find {varname}, this is unexpected')
|
||||
|
||||
|
||||
def override_version(lines, glsl_version):
|
||||
def override_version(lines: T.List[str], glsl_version: str) -> T.List[str]:
|
||||
for idx, l in enumerate(lines):
|
||||
if l.find('#version ') != -1:
|
||||
lines[idx] = "#version {}\n".format(glsl_version)
|
||||
@@ -97,7 +111,7 @@ def override_version(lines, glsl_version):
|
||||
raise RuntimeError('Did not find #version directive, this is unexpected')
|
||||
|
||||
|
||||
def postprocess_file(args):
|
||||
def postprocess_file(args: Arguments) -> None:
|
||||
with open(args.output, "r") as r:
|
||||
lines = r.readlines()
|
||||
|
||||
@@ -111,7 +125,7 @@ def postprocess_file(args):
|
||||
w.writelines(lines)
|
||||
|
||||
|
||||
def preprocess_file(args, origin_file):
|
||||
def preprocess_file(args: Arguments, origin_file: T.TextIO) -> str:
|
||||
with open(origin_file.name + ".copy", "w") as copy_file:
|
||||
lines = origin_file.readlines()
|
||||
|
||||
@@ -126,7 +140,7 @@ def preprocess_file(args, origin_file):
|
||||
return copy_file.name
|
||||
|
||||
|
||||
def process_file(args):
|
||||
def process_file(args: Arguments) -> None:
|
||||
with open(args.input, "r") as infile:
|
||||
copy_file = preprocess_file(args, infile)
|
||||
|
||||
@@ -169,7 +183,7 @@ def process_file(args):
|
||||
os.remove(copy_file)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
args = get_args()
|
||||
process_file(args)
|
||||
|
||||
|
Reference in New Issue
Block a user