util/glsl2spirv: add type annotations
Which are all clean Reviewed-by: Luis Felipe Strano Moraes <luis.strano@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19449>
This commit is contained in:
@@ -21,15 +21,29 @@
|
|||||||
|
|
||||||
# Converts GLSL shader to SPIR-V library
|
# Converts GLSL shader to SPIR-V library
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
|
import typing as T
|
||||||
|
|
||||||
|
if T.TYPE_CHECKING:
|
||||||
|
|
||||||
|
class Arguments(T.Protocol):
|
||||||
|
input: str
|
||||||
|
output: str
|
||||||
|
create_entry: T.Optional[str]
|
||||||
|
glsl_ver: T.Optional[str]
|
||||||
|
Olib: bool
|
||||||
|
extra: T.Optional[str]
|
||||||
|
vn: str
|
||||||
|
stage: str
|
||||||
|
|
||||||
class ShaderCompileError(RuntimeError):
|
class ShaderCompileError(RuntimeError):
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
super(ShaderCompileError, self).__init__(*args)
|
super(ShaderCompileError, self).__init__(*args)
|
||||||
|
|
||||||
def get_args():
|
def get_args() -> Arguments:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('input', help="Name of input file.")
|
parser.add_argument('input', help="Name of input file.")
|
||||||
parser.add_argument('output', help="Name of output file.")
|
parser.add_argument('output', help="Name of output file.")
|
||||||
@@ -64,7 +78,7 @@ def get_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def create_include_guard(lines, filename):
|
def create_include_guard(lines: T.List[str], filename: str) -> T.List[str]:
|
||||||
filename = filename.replace('.', '_')
|
filename = filename.replace('.', '_')
|
||||||
upper_name = filename.upper()
|
upper_name = filename.upper()
|
||||||
|
|
||||||
@@ -81,7 +95,7 @@ def create_include_guard(lines, filename):
|
|||||||
return guard_head + lines + guard_tail
|
return guard_head + lines + guard_tail
|
||||||
|
|
||||||
|
|
||||||
def convert_to_static_variable(lines, varname):
|
def convert_to_static_variable(lines: T.List[str], varname: str) -> T.List[str]:
|
||||||
for idx, l in enumerate(lines):
|
for idx, l in enumerate(lines):
|
||||||
if l.find(varname) != -1:
|
if l.find(varname) != -1:
|
||||||
lines[idx] = "static " + lines[idx]
|
lines[idx] = "static " + lines[idx]
|
||||||
@@ -89,7 +103,7 @@ def convert_to_static_variable(lines, varname):
|
|||||||
raise RuntimeError(f'Did not find {varname}, this is unexpected')
|
raise RuntimeError(f'Did not find {varname}, this is unexpected')
|
||||||
|
|
||||||
|
|
||||||
def override_version(lines, glsl_version):
|
def override_version(lines: T.List[str], glsl_version: str) -> T.List[str]:
|
||||||
for idx, l in enumerate(lines):
|
for idx, l in enumerate(lines):
|
||||||
if l.find('#version ') != -1:
|
if l.find('#version ') != -1:
|
||||||
lines[idx] = "#version {}\n".format(glsl_version)
|
lines[idx] = "#version {}\n".format(glsl_version)
|
||||||
@@ -97,7 +111,7 @@ def override_version(lines, glsl_version):
|
|||||||
raise RuntimeError('Did not find #version directive, this is unexpected')
|
raise RuntimeError('Did not find #version directive, this is unexpected')
|
||||||
|
|
||||||
|
|
||||||
def postprocess_file(args):
|
def postprocess_file(args: Arguments) -> None:
|
||||||
with open(args.output, "r") as r:
|
with open(args.output, "r") as r:
|
||||||
lines = r.readlines()
|
lines = r.readlines()
|
||||||
|
|
||||||
@@ -111,7 +125,7 @@ def postprocess_file(args):
|
|||||||
w.writelines(lines)
|
w.writelines(lines)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_file(args, origin_file):
|
def preprocess_file(args: Arguments, origin_file: T.TextIO) -> str:
|
||||||
with open(origin_file.name + ".copy", "w") as copy_file:
|
with open(origin_file.name + ".copy", "w") as copy_file:
|
||||||
lines = origin_file.readlines()
|
lines = origin_file.readlines()
|
||||||
|
|
||||||
@@ -126,7 +140,7 @@ def preprocess_file(args, origin_file):
|
|||||||
return copy_file.name
|
return copy_file.name
|
||||||
|
|
||||||
|
|
||||||
def process_file(args):
|
def process_file(args: Arguments) -> None:
|
||||||
with open(args.input, "r") as infile:
|
with open(args.input, "r") as infile:
|
||||||
copy_file = preprocess_file(args, infile)
|
copy_file = preprocess_file(args, infile)
|
||||||
|
|
||||||
@@ -169,7 +183,7 @@ def process_file(args):
|
|||||||
os.remove(copy_file)
|
os.remove(copy_file)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
args = get_args()
|
args = get_args()
|
||||||
process_file(args)
|
process_file(args)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user