diff --git a/src/gallium/frontends/rusticl/api/kernel.rs b/src/gallium/frontends/rusticl/api/kernel.rs index f518c43e08c..9ae240e7c0f 100644 --- a/src/gallium/frontends/rusticl/api/kernel.rs +++ b/src/gallium/frontends/rusticl/api/kernel.rs @@ -22,13 +22,13 @@ impl CLInfo for cl_kernel { fn query(&self, q: cl_kernel_info, _: &[u8]) -> CLResult>> { let kernel = self.get_ref()?; Ok(match q { - CL_KERNEL_ATTRIBUTES => cl_prop::<&str>(&kernel.build.attributes_string), + CL_KERNEL_ATTRIBUTES => cl_prop::<&str>(&kernel.kernel_info.attributes_string), CL_KERNEL_CONTEXT => { let ptr = Arc::as_ptr(&kernel.prog.context); cl_prop::(cl_context::from_ptr(ptr)) } CL_KERNEL_FUNCTION_NAME => cl_prop::<&str>(&kernel.name), - CL_KERNEL_NUM_ARGS => cl_prop::(kernel.build.args.len() as cl_uint), + CL_KERNEL_NUM_ARGS => cl_prop::(kernel.kernel_info.args.len() as cl_uint), CL_KERNEL_PROGRAM => { let ptr = Arc::as_ptr(&kernel.prog); cl_prop::(cl_program::from_ptr(ptr)) @@ -46,7 +46,7 @@ impl CLInfoObj for cl_kernel { let kernel = self.get_ref()?; // CL_INVALID_ARG_INDEX if arg_index is not a valid argument index. - if idx as usize >= kernel.build.args.len() { + if idx as usize >= kernel.kernel_info.args.len() { return Err(CL_INVALID_ARG_INDEX); } @@ -329,7 +329,7 @@ fn set_kernel_arg( let k = kernel.get_arc()?; // CL_INVALID_ARG_INDEX if arg_index is not a valid argument index. - if let Some(arg) = k.build.args.get(arg_index as usize) { + if let Some(arg) = k.kernel_info.args.get(arg_index as usize) { // CL_INVALID_ARG_SIZE if arg_size does not match the size of the data type for an argument // that is not a memory object or if the argument is a memory object and // arg_size != sizeof(cl_mem) or if arg_size is zero and the argument is declared with the @@ -429,7 +429,7 @@ fn set_kernel_arg_svm_pointer( return Err(CL_INVALID_OPERATION); } - if let Some(arg) = kernel.build.args.get(arg_index) { + if let Some(arg) = kernel.kernel_info.args.get(arg_index) { if !matches!( arg.kind, KernelArgType::MemConstant | KernelArgType::MemGlobal diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index 65e46ce8d4a..f2934e15811 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -250,7 +250,17 @@ impl InternalKernelArg { } } -struct CSOWrapper { +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct KernelInfo { + pub args: Vec, + pub internal_args: Vec, + pub attributes_string: String, + pub work_group_size: [usize; 3], + pub subgroup_size: usize, + pub num_subgroups: usize, +} + +pub struct CSOWrapper { pub cso_ptr: *mut c_void, dev: &'static Device, } @@ -286,7 +296,7 @@ impl Drop for CSOWrapper { } } -enum KernelDevStateVariant { +pub enum KernelDevStateVariant { Cso(Arc), Nir(Arc), } @@ -341,7 +351,7 @@ impl KernelDevState { Arc::new(Self { states: states }) } - fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option> { + pub fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option> { let buf = nir.get_constant_buffer(); let len = buf.len() as u32; @@ -371,7 +381,8 @@ pub struct Kernel { pub prog: Arc, pub name: String, pub values: Vec>>, - pub build: Arc, + pub builds: HashMap<&'static Device, Arc>, + pub kernel_info: KernelInfo, } impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL); @@ -830,10 +841,16 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] { impl Kernel { pub fn new(name: String, prog: Arc) -> Arc { - let nir_kernel_build = prog.get_nir_kernel_build(&name); + let prog_build = prog.build_info(); + let kernel_info = prog_build.kernel_info.get(&name).unwrap().clone(); + let builds = prog_build + .builds + .iter() + .map(|(k, v)| (*k, v.kernels.get(&name).unwrap().clone())) + .collect(); // can't use vec!... - let values = nir_kernel_build + let values = kernel_info .args .iter() .map(|_| RefCell::new(None)) @@ -841,10 +858,11 @@ impl Kernel { Arc::new(Self { base: CLObjectBase::new(), - prog: prog, + prog: prog.clone(), name: name, values: values, - build: nir_kernel_build, + builds: builds, + kernel_info: kernel_info, }) } @@ -896,14 +914,14 @@ impl Kernel { grid: &[usize], offsets: &[usize], ) -> CLResult { - let dev_state = self.build.dev_state.get(q.device); + let nir_kernel_build = self.builds.get(q.device).unwrap().clone(); let mut block = create_kernel_arr::(block, 1); let mut grid = create_kernel_arr::(grid, 1); let offsets = create_kernel_arr::(offsets, 0); let mut input: Vec = Vec::new(); let mut resource_info = Vec::new(); // Set it once so we get the alignment padding right - let static_local_size: u64 = dev_state.nir_internal_info.shared_size; + let static_local_size: u64 = nir_kernel_build.shared_size; let mut variable_local_size: u64 = static_local_size; let printf_size = q.device.printf_buffer_size() as u32; let mut samplers = Vec::new(); @@ -921,7 +939,7 @@ impl Kernel { self.optimize_local_size(q.device, &mut grid, &mut block); - for (arg, val) in self.build.args.iter().zip(&self.values) { + for (arg, val) in self.kernel_info.args.iter().zip(&self.values) { if arg.dead { continue; } @@ -1005,18 +1023,21 @@ impl Kernel { } // subtract the shader local_size as we only request something on top of that. - variable_local_size -= dev_state.nir_internal_info.shared_size; + variable_local_size -= static_local_size; let mut printf_buf = None; - for arg in &self.build.internal_args { + for arg in &self.kernel_info.internal_args { if arg.offset > input.len() { input.resize(arg.offset, 0); } match arg.kind { InternalKernelArgType::ConstantBuffer => { - assert!(dev_state.constant_buffer.is_some()); + assert!(nir_kernel_build.constant_buffer.is_some()); input.extend_from_slice(null_ptr); - resource_info.push((dev_state.constant_buffer.clone().unwrap(), arg.offset)); + resource_info.push(( + nir_kernel_build.constant_buffer.clone().unwrap(), + arg.offset, + )); } InternalKernelArgType::GlobalWorkOffsets => { if q.device.address_bits() == 64 { @@ -1061,13 +1082,11 @@ impl Kernel { } } - let k = Arc::clone(self); Ok(Box::new(move |q, ctx| { - let dev_state = k.build.dev_state.get(q.device); let mut input = input.clone(); let mut resources = Vec::with_capacity(resource_info.len()); let mut globals: Vec<*mut u32> = Vec::new(); - let printf_format = &dev_state.nir_internal_info.printf_info; + let printf_format = &nir_kernel_build.printf_info; let mut sviews: Vec<_> = sviews .iter() @@ -1093,7 +1112,7 @@ impl Kernel { ); } - let cso = match &dev_state.nir_or_cso { + let cso = match &nir_kernel_build.nir_or_cso { KernelDevStateVariant::Cso(cso) => cso.clone(), KernelDevStateVariant::Nir(nir) => CSOWrapper::new(q.device, nir), }; @@ -1145,7 +1164,7 @@ impl Kernel { } pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier { - let aq = self.build.args[idx as usize].spirv.access_qualifier; + let aq = self.kernel_info.args[idx as usize].spirv.access_qualifier; if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ @@ -1162,7 +1181,7 @@ impl Kernel { } pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier { - match self.build.args[idx as usize].spirv.address_qualifier { + match self.kernel_info.args[idx as usize].spirv.address_qualifier { clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => { CL_KERNEL_ARG_ADDRESS_PRIVATE } @@ -1179,7 +1198,7 @@ impl Kernel { } pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier { - let tq = self.build.args[idx as usize].spirv.type_qualifier; + let tq = self.kernel_info.args[idx as usize].spirv.type_qualifier; let zero = clc_kernel_arg_type_qualifier(0); let mut res = CL_KERNEL_ARG_TYPE_NONE; @@ -1199,61 +1218,40 @@ impl Kernel { } pub fn work_group_size(&self) -> [usize; 3] { - self.build - .dev_state - .states - .values() - .next() - .unwrap() - .nir_internal_info - .work_group_size + self.kernel_info.work_group_size } pub fn num_subgroups(&self) -> usize { - self.build - .dev_state - .states - .values() - .next() - .unwrap() - .nir_internal_info - .num_subgroups + self.kernel_info.num_subgroups } pub fn subgroup_size(&self) -> usize { - self.build - .dev_state - .states - .values() - .next() - .unwrap() - .nir_internal_info - .subgroup_size + self.kernel_info.subgroup_size } pub fn arg_name(&self, idx: cl_uint) -> &String { - &self.build.args[idx as usize].spirv.name + &self.kernel_info.args[idx as usize].spirv.name } pub fn arg_type_name(&self, idx: cl_uint) -> &String { - &self.build.args[idx as usize].spirv.type_name + &self.kernel_info.args[idx as usize].spirv.type_name } pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong { - self.build.dev_state.get(dev).info.private_memory.into() + self.builds.get(dev).unwrap().info.private_memory as cl_ulong } pub fn max_threads_per_block(&self, dev: &Device) -> usize { - self.build.dev_state.get(dev).info.max_threads as usize + self.builds.get(dev).unwrap().info.max_threads as usize } pub fn preferred_simd_size(&self, dev: &Device) -> usize { - self.build.dev_state.get(dev).info.preferred_simd_size as usize + self.builds.get(dev).unwrap().info.preferred_simd_size as usize } pub fn local_mem_size(&self, dev: &Device) -> cl_ulong { // TODO include args - self.build.dev_state.get(dev).nir_internal_info.shared_size as cl_ulong + self.builds.get(dev).unwrap().shared_size as cl_ulong } pub fn has_svm_devs(&self) -> bool { @@ -1261,7 +1259,7 @@ impl Kernel { } pub fn subgroup_sizes(&self, dev: &Device) -> Vec { - SetBitIndices::from_msb(self.build.dev_state.get(dev).info.simd_sizes) + SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes) .map(|bit| 1 << bit) .collect() } @@ -1292,7 +1290,7 @@ impl Kernel { *block.get(2).unwrap_or(&1) as u32, ]; - match &self.build.dev_state.get(dev).nir_or_cso { + match &self.builds.get(dev).unwrap().nir_or_cso { KernelDevStateVariant::Cso(cso) => { dev.helper_ctx() .compute_state_subgroup_size(cso.cso_ptr, &block) as usize @@ -1311,7 +1309,8 @@ impl Clone for Kernel { prog: self.prog.clone(), name: self.name.clone(), values: self.values.clone(), - build: self.build.clone(), + builds: self.builds.clone(), + kernel_info: self.kernel_info.clone(), } } } diff --git a/src/gallium/frontends/rusticl/core/program.rs b/src/gallium/frontends/rusticl/core/program.rs index bcf2731a305..a2f4d98dbdb 100644 --- a/src/gallium/frontends/rusticl/core/program.rs +++ b/src/gallium/frontends/rusticl/core/program.rs @@ -8,6 +8,7 @@ use crate::impl_cl_type_trait; use mesa_rust::compiler::clc::spirv::SPIRVBin; use mesa_rust::compiler::clc::*; use mesa_rust::compiler::nir::*; +use mesa_rust::pipe::resource::*; use mesa_rust::util::disk_cache::*; use mesa_rust_gen::*; use rusticl_opencl_gen::*; @@ -68,17 +69,18 @@ pub struct Program { impl_cl_type_trait!(cl_program, Program, CL_INVALID_PROGRAM); pub struct NirKernelBuild { - pub dev_state: Arc, - pub args: Vec, - pub internal_args: Vec, - pub attributes_string: String, + pub nir_or_cso: KernelDevStateVariant, + pub constant_buffer: Option>, + pub info: pipe_compute_state_object_info, + pub shared_size: u64, + pub printf_info: Option, } -pub(super) struct ProgramBuild { - builds: HashMap<&'static Device, ProgramDevBuild>, +pub struct ProgramBuild { + pub builds: HashMap<&'static Device, ProgramDevBuild>, + pub kernel_info: HashMap, spec_constants: HashMap, kernels: Vec, - kernel_builds: HashMap>, } impl ProgramBuild { @@ -104,7 +106,7 @@ impl ProgramBuild { } fn build_nirs(&mut self, is_src: bool) { - for kernel_name in &self.kernels { + for kernel_name in &self.kernels.clone() { let kernel_args: HashSet<_> = self .devs_with_build() .iter() @@ -112,45 +114,64 @@ impl ProgramBuild { .collect(); let args = kernel_args.into_iter().next().unwrap(); - let mut nirs = HashMap::new(); - let mut args_set = HashSet::new(); - let mut internal_args_set = HashSet::new(); - let mut attributes_string_set = HashSet::new(); + let mut kernel_info_set = HashSet::new(); // TODO: we could run this in parallel? - for d in self.devs_with_build() { - let (nir, args, internal_args) = convert_spirv_to_nir(self, kernel_name, &args, d); - let attributes_string = self.attribute_str(kernel_name, d); - nirs.insert(d, nir); - args_set.insert(args); - internal_args_set.insert(internal_args); - attributes_string_set.insert(attributes_string); - } + for dev in self.devs_with_build() { + let (mut nir, args, internal_args) = + convert_spirv_to_nir(self, kernel_name, &args, dev); + let attributes_string = self.attribute_str(kernel_name, dev); + let wgs = nir.workgroup_size(); + let shared_size = nir.shared_size() as u64; + let printf_info = nir.take_printf_info(); - // we want the same (internal) args for every compiled kernel, for now - assert!(args_set.len() == 1); - assert!(internal_args_set.len() == 1); - assert!(attributes_string_set.len() == 1); - let args = args_set.into_iter().next().unwrap(); - let internal_args = internal_args_set.into_iter().next().unwrap(); - - // spec: For kernels not created from OpenCL C source and the clCreateProgramWithSource - // API call the string returned from this query [CL_KERNEL_ATTRIBUTES] will be empty. - let attributes_string = if is_src { - attributes_string_set.into_iter().next().unwrap() - } else { - String::new() - }; - - self.kernel_builds.insert( - kernel_name.clone(), - Arc::new(NirKernelBuild { - dev_state: KernelDevState::new(nirs), + let kernel_info = KernelInfo { args: args, internal_args: internal_args, attributes_string: attributes_string, - }), - ); + work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize], + subgroup_size: nir.subgroup_size() as usize, + num_subgroups: nir.num_subgroups() as usize, + }; + + kernel_info_set.insert(kernel_info); + + let cso = CSOWrapper::new(dev, &nir); + let info = cso.get_cso_info(); + let cb = KernelDevState::create_nir_constant_buffer(dev, &nir); + + let nir_or_cso = if !dev.shareable_shaders() { + KernelDevStateVariant::Nir(Arc::new(nir)) + } else { + KernelDevStateVariant::Cso(cso) + }; + + let nir_kernel_build = NirKernelBuild { + nir_or_cso: nir_or_cso, + constant_buffer: cb, + info: info, + shared_size: shared_size, + printf_info: printf_info, + }; + + self.builds + .get_mut(dev) + .unwrap() + .kernels + .insert(kernel_name.clone(), Arc::new(nir_kernel_build)); + } + + // we want the same (internal) args for every compiled kernel, for now + assert!(kernel_info_set.len() == 1); + let mut kernel_info = kernel_info_set.into_iter().next().unwrap(); + + // spec: For kernels not created from OpenCL C source and the clCreateProgramWithSource + // API call the string returned from this query [CL_KERNEL_ATTRIBUTES] will be empty. + if !is_src { + kernel_info.attributes_string = String::new(); + } + + self.kernel_info.insert(kernel_name.clone(), kernel_info); } } @@ -228,12 +249,13 @@ impl ProgramBuild { } } -struct ProgramDevBuild { +pub struct ProgramDevBuild { spirv: Option, status: cl_build_status, options: String, log: String, bin_type: cl_program_binary_type, + pub kernels: HashMap>, } fn prepare_options(options: &str, dev: &Device) -> Vec { @@ -297,6 +319,7 @@ impl Program { log: String::from(""), options: String::from(""), bin_type: CL_PROGRAM_BINARY_TYPE_NONE, + kernels: HashMap::new(), }, ) }) @@ -313,7 +336,7 @@ impl Program { builds: Self::create_default_builds(devs), spec_constants: HashMap::new(), kernels: Vec::new(), - kernel_builds: HashMap::new(), + kernel_info: HashMap::new(), }), }) } @@ -372,6 +395,7 @@ impl Program { log: String::from(""), options: String::from(""), bin_type: bin_type, + kernels: HashMap::new(), }, ); } @@ -380,7 +404,7 @@ impl Program { builds: builds, spec_constants: HashMap::new(), kernels: kernels.into_iter().collect(), - kernel_builds: HashMap::new(), + kernel_info: HashMap::new(), }; build.build_nirs(false); @@ -404,20 +428,15 @@ impl Program { builds: builds, spec_constants: HashMap::new(), kernels: Vec::new(), - kernel_builds: HashMap::new(), + kernel_info: HashMap::new(), }), }) } - fn build_info(&self) -> MutexGuard { + pub fn build_info(&self) -> MutexGuard { self.build.lock().unwrap() } - pub fn get_nir_kernel_build(&self, name: &str) -> Arc { - let info = self.build_info(); - info.kernel_builds.get(name).unwrap().clone() - } - pub fn status(&self, dev: &Device) -> cl_build_status { self.build_info().dev_build(dev).status } @@ -510,9 +529,9 @@ impl Program { pub fn active_kernels(&self) -> bool { self.build_info() - .kernel_builds + .builds .values() - .any(|b| Arc::strong_count(b) > 1) + .any(|b| b.kernels.values().any(|b| Arc::strong_count(b) > 1)) } pub fn build(&self, dev: &Device, options: String) -> bool { @@ -668,6 +687,7 @@ impl Program { log: log, options: String::from(""), bin_type: bin_type, + kernels: HashMap::new(), }, ); } @@ -676,7 +696,7 @@ impl Program { builds: builds, spec_constants: HashMap::new(), kernels: kernels.into_iter().collect(), - kernel_builds: HashMap::new(), + kernel_info: HashMap::new(), }; // Pre build nir kernels