rusticl: Drop some Kernel data and have a NirKernelBuild ref instead

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23999>
This commit is contained in:
Antonio Gomes
2023-07-04 21:33:41 -03:00
committed by Marge Bot
parent 005b41fd39
commit 3dde5c231e
3 changed files with 32 additions and 37 deletions

View File

@@ -21,13 +21,13 @@ impl CLInfo<cl_kernel_info> for cl_kernel {
fn query(&self, q: cl_kernel_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let kernel = self.get_ref()?;
Ok(match q {
CL_KERNEL_ATTRIBUTES => cl_prop::<&str>(&kernel.attributes_string),
CL_KERNEL_ATTRIBUTES => cl_prop::<&str>(&kernel.build.attributes_string),
CL_KERNEL_CONTEXT => {
let ptr = Arc::as_ptr(&kernel.prog.context);
cl_prop::<cl_context>(cl_context::from_ptr(ptr))
}
CL_KERNEL_FUNCTION_NAME => cl_prop::<&str>(&kernel.name),
CL_KERNEL_NUM_ARGS => cl_prop::<cl_uint>(kernel.args.len() as cl_uint),
CL_KERNEL_NUM_ARGS => cl_prop::<cl_uint>(kernel.build.args.len() as cl_uint),
CL_KERNEL_PROGRAM => {
let ptr = Arc::as_ptr(&kernel.prog);
cl_prop::<cl_program>(cl_program::from_ptr(ptr))
@@ -45,7 +45,7 @@ impl CLInfoObj<cl_kernel_arg_info, cl_uint> for cl_kernel {
let kernel = self.get_ref()?;
// CL_INVALID_ARG_INDEX if arg_index is not a valid argument index.
if idx as usize >= kernel.args.len() {
if idx as usize >= kernel.build.args.len() {
return Err(CL_INVALID_ARG_INDEX);
}
@@ -229,7 +229,7 @@ fn set_kernel_arg(
let k = kernel.get_arc()?;
// CL_INVALID_ARG_INDEX if arg_index is not a valid argument index.
if let Some(arg) = k.args.get(arg_index as usize) {
if let Some(arg) = k.build.args.get(arg_index as usize) {
// CL_INVALID_ARG_SIZE if arg_size does not match the size of the data type for an argument
// that is not a memory object or if the argument is a memory object and
// arg_size != sizeof(cl_mem) or if arg_size is zero and the argument is declared with the
@@ -329,7 +329,7 @@ fn set_kernel_arg_svm_pointer(
return Err(CL_INVALID_OPERATION);
}
if let Some(arg) = kernel.args.get(arg_index) {
if let Some(arg) = kernel.build.args.get(arg_index) {
if !matches!(
arg.kind,
KernelArgType::MemConstant | KernelArgType::MemGlobal

View File

@@ -273,15 +273,15 @@ impl Drop for KernelDevState {
}
impl KernelDevState {
fn new(nirs: HashMap<Arc<Device>, Arc<NirShader>>) -> Arc<Self> {
fn new(nirs: &HashMap<Arc<Device>, Arc<NirShader>>) -> Arc<Self> {
let states = nirs
.into_iter()
.iter()
.map(|(dev, nir)| {
let mut cso = dev
.helper_ctx()
.create_compute_state(&nir, nir.shared_size());
.create_compute_state(nir, nir.shared_size());
let info = dev.helper_ctx().compute_state_info(cso);
let cb = Self::create_nir_constant_buffer(&dev, &nir);
let cb = Self::create_nir_constant_buffer(dev, nir);
// if we can't share the cso between threads, destroy it now.
if !dev.shareable_shaders() {
@@ -290,9 +290,9 @@ impl KernelDevState {
};
(
dev,
dev.clone(),
KernelDevStateInner {
nir: nir,
nir: nir.clone(),
constant_buffer: cb,
cso: cso,
info: info,
@@ -333,11 +333,9 @@ pub struct Kernel {
pub base: CLObjectBase<CL_INVALID_KERNEL>,
pub prog: Arc<Program>,
pub name: String,
pub args: Vec<KernelArg>,
pub values: Vec<RefCell<Option<KernelArgValue>>>,
pub work_group_size: [usize; 3],
pub attributes_string: String,
internal_args: Vec<InternalKernelArg>,
pub build: Arc<NirKernelBuild>,
dev_state: Arc<KernelDevState>,
}
@@ -798,17 +796,18 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
impl Kernel {
pub fn new(name: String, prog: Arc<Program>) -> Arc<Kernel> {
let nir_kernel_build = prog.get_nir_kernel_build(&name);
let mut nirs = nir_kernel_build.nirs;
let args = nir_kernel_build.args;
let internal_args = nir_kernel_build.internal_args;
let attributes_string = nir_kernel_build.attributes_string;
let nirs = &nir_kernel_build.nirs;
let nir = nirs.values_mut().next().unwrap();
let nir = nirs.values().next().unwrap();
let wgs = nir.workgroup_size();
let work_group_size = [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize];
// can't use vec!...
let values = args.iter().map(|_| RefCell::new(None)).collect();
let values = nir_kernel_build
.args
.iter()
.map(|_| RefCell::new(None))
.collect();
// increase ref
prog.kernel_count.fetch_add(1, Ordering::Relaxed);
@@ -817,12 +816,10 @@ impl Kernel {
base: CLObjectBase::new(),
prog: prog,
name: name,
args: args,
work_group_size: work_group_size,
attributes_string: attributes_string,
values: values,
internal_args: internal_args,
dev_state: KernelDevState::new(nirs),
build: nir_kernel_build,
})
}
@@ -899,7 +896,7 @@ impl Kernel {
self.optimize_local_size(&q.device, &mut grid, &mut block);
for (arg, val) in self.args.iter().zip(&self.values) {
for (arg, val) in self.build.args.iter().zip(&self.values) {
if arg.dead {
continue;
}
@@ -986,7 +983,7 @@ impl Kernel {
variable_local_size -= dev_state.nir.shared_size() as u64;
let mut printf_buf = None;
for arg in &self.internal_args {
for arg in &self.build.internal_args {
if arg.offset > input.len() {
input.resize(arg.offset, 0);
}
@@ -1134,7 +1131,7 @@ impl Kernel {
}
pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier {
let aq = self.args[idx as usize].spirv.access_qualifier;
let aq = self.build.args[idx as usize].spirv.access_qualifier;
if aq
== clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
@@ -1151,7 +1148,7 @@ impl Kernel {
}
pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier {
match self.args[idx as usize].spirv.address_qualifier {
match self.build.args[idx as usize].spirv.address_qualifier {
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
CL_KERNEL_ARG_ADDRESS_PRIVATE
}
@@ -1168,7 +1165,7 @@ impl Kernel {
}
pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier {
let tq = self.args[idx as usize].spirv.type_qualifier;
let tq = self.build.args[idx as usize].spirv.type_qualifier;
let zero = clc_kernel_arg_type_qualifier(0);
let mut res = CL_KERNEL_ARG_TYPE_NONE;
@@ -1188,11 +1185,11 @@ impl Kernel {
}
pub fn arg_name(&self, idx: cl_uint) -> &String {
&self.args[idx as usize].spirv.name
&self.build.args[idx as usize].spirv.name
}
pub fn arg_type_name(&self, idx: cl_uint) -> &String {
&self.args[idx as usize].spirv.type_name
&self.build.args[idx as usize].spirv.type_name
}
pub fn priv_mem_size(&self, dev: &Arc<Device>) -> cl_ulong {
@@ -1223,11 +1220,9 @@ impl Clone for Kernel {
base: CLObjectBase::new(),
prog: self.prog.clone(),
name: self.name.clone(),
args: self.args.clone(),
values: self.values.clone(),
work_group_size: self.work_group_size,
attributes_string: self.attributes_string.clone(),
internal_args: self.internal_args.clone(),
build: self.build.clone(),
dev_state: self.dev_state.clone(),
}
}

View File

@@ -82,7 +82,7 @@ pub(super) struct ProgramBuild {
builds: HashMap<Arc<Device>, ProgramDevBuild>,
spec_constants: HashMap<u32, nir_const_value>,
kernels: Vec<String>,
kernel_builds: HashMap<String, NirKernelBuild>,
kernel_builds: HashMap<String, Arc<NirKernelBuild>>,
}
impl ProgramBuild {
@@ -148,12 +148,12 @@ impl ProgramBuild {
self.kernel_builds.insert(
kernel_name.clone(),
NirKernelBuild {
Arc::new(NirKernelBuild {
nirs: nirs,
args: args,
internal_args: internal_args,
attributes_string: attributes_string,
},
}),
);
}
}
@@ -418,7 +418,7 @@ impl Program {
self.build.lock().unwrap()
}
pub fn get_nir_kernel_build(&self, name: &str) -> NirKernelBuild {
pub fn get_nir_kernel_build(&self, name: &str) -> Arc<NirKernelBuild> {
let info = self.build_info();
info.kernel_builds.get(name).unwrap().clone()
}