diff --git a/src/gallium/frontends/rusticl/api/icd.rs b/src/gallium/frontends/rusticl/api/icd.rs index 0615ad4daa2..40bdd8cc452 100644 --- a/src/gallium/frontends/rusticl/api/icd.rs +++ b/src/gallium/frontends/rusticl/api/icd.rs @@ -49,7 +49,7 @@ pub static DISPATCH: cl_icd_dispatch = cl_icd_dispatch { clReleaseSampler: Some(cl_release_sampler), clGetSamplerInfo: Some(cl_get_sampler_info), clCreateProgramWithSource: Some(cl_create_program_with_source), - clCreateProgramWithBinary: None, + clCreateProgramWithBinary: Some(cl_create_program_with_binary), clRetainProgram: Some(cl_retain_program), clReleaseProgram: Some(cl_release_program), clBuildProgram: Some(cl_build_program), @@ -699,6 +699,28 @@ extern "C" fn cl_create_program_with_source( ) } +extern "C" fn cl_create_program_with_binary( + context: cl_context, + num_devices: cl_uint, + device_list: *const cl_device_id, + lengths: *const usize, + binaries: *mut *const ::std::os::raw::c_uchar, + binary_status: *mut cl_int, + errcode_ret: *mut cl_int, +) -> cl_program { + match_obj!( + create_program_with_binary( + context, + num_devices, + device_list, + lengths, + binaries, + binary_status + ), + errcode_ret + ) +} + extern "C" fn cl_retain_program(program: cl_program) -> cl_int { match_err!(program.retain()) } diff --git a/src/gallium/frontends/rusticl/api/program.rs b/src/gallium/frontends/rusticl/api/program.rs index c7c08da3581..68bb1505bb7 100644 --- a/src/gallium/frontends/rusticl/api/program.rs +++ b/src/gallium/frontends/rusticl/api/program.rs @@ -20,9 +20,11 @@ use std::slice; use std::sync::Arc; impl CLInfo for cl_program { - fn query(&self, q: cl_program_info, _: &[u8]) -> CLResult> { + fn query(&self, q: cl_program_info, vals: &[u8]) -> CLResult> { let prog = self.get_ref()?; Ok(match q { + CL_PROGRAM_BINARIES => cl_prop::>(prog.binaries(vals)), + CL_PROGRAM_BINARY_SIZES => cl_prop::>(prog.bin_sizes()), CL_PROGRAM_CONTEXT => { // Note we use as_ptr here which doesn't increase the reference count. let ptr = Arc::as_ptr(&prog.context); @@ -134,6 +136,64 @@ pub fn create_program_with_source( ))) } +pub fn create_program_with_binary( + context: cl_context, + num_devices: cl_uint, + device_list: *const cl_device_id, + lengths: *const usize, + binaries: *mut *const ::std::os::raw::c_uchar, + binary_status: *mut cl_int, +) -> CLResult { + let c = context.get_arc()?; + let devs = cl_device_id::get_arc_vec_from_arr(device_list, num_devices)?; + + // CL_INVALID_VALUE if device_list is NULL or num_devices is zero. + if devs.is_empty() { + return Err(CL_INVALID_VALUE); + } + + // CL_INVALID_VALUE if lengths or binaries is NULL + if lengths.is_null() || binaries.is_null() { + return Err(CL_INVALID_VALUE); + } + + // CL_INVALID_DEVICE if any device in device_list is not in the list of devices associated with + // context. + if !devs.iter().all(|d| c.devs.contains(d)) { + return Err(CL_INVALID_DEVICE); + } + + let lengths = unsafe { slice::from_raw_parts(lengths, num_devices as usize) }; + let binaries = unsafe { slice::from_raw_parts(binaries, num_devices as usize) }; + + // now device specific stuff + let mut err = 0; + let mut bins: Vec<&[u8]> = vec![&[]; num_devices as usize]; + for i in 0..num_devices as usize { + let mut dev_err = 0; + + // CL_INVALID_VALUE if lengths[i] is zero or if binaries[i] is a NULL value + if lengths[i] == 0 || binaries[i].is_null() { + dev_err = CL_INVALID_VALUE; + } + + if !binary_status.is_null() { + unsafe { binary_status.add(i).write(dev_err) }; + } + + // just return the last one + err = dev_err; + bins[i] = unsafe { slice::from_raw_parts(binaries[i], lengths[i] as usize) }; + } + + if err != 0 { + return Err(err); + } + + Ok(cl_program::from_arc(Program::from_bins(c, devs, &bins))) + //• CL_INVALID_BINARY if an invalid program binary was encountered for any device. binary_status will return specific status for each device. +} + pub fn build_program( program: cl_program, num_devices: cl_uint, diff --git a/src/gallium/frontends/rusticl/core/program.rs b/src/gallium/frontends/rusticl/core/program.rs index 2c203632357..0c123a405a9 100644 --- a/src/gallium/frontends/rusticl/core/program.rs +++ b/src/gallium/frontends/rusticl/core/program.rs @@ -15,10 +15,23 @@ use self::rusticl_opencl_gen::*; use std::collections::HashMap; use std::collections::HashSet; use std::ffi::CString; +use std::mem::size_of; +use std::ptr; +use std::slice; use std::sync::Arc; use std::sync::Mutex; use std::sync::MutexGuard; +const BIN_HEADER_SIZE_V1: usize = + // 1. format version + size_of::() + + // 2. spirv len + size_of::() + + // 3. binary_type + size_of::(); + +const BIN_HEADER_SIZE: usize = BIN_HEADER_SIZE_V1; + #[repr(C)] pub struct Program { pub base: CLObjectBase, @@ -89,6 +102,76 @@ impl Program { }) } + pub fn from_bins( + context: Arc, + devs: Vec>, + bins: &[&[u8]], + ) -> Arc { + let mut builds = HashMap::new(); + let mut kernels = HashSet::new(); + + for (d, b) in devs.iter().zip(bins) { + let mut ptr = b.as_ptr(); + let bin_type; + let spirv; + + unsafe { + // 1. version + let version = ptr.cast::().read(); + ptr = ptr.add(size_of::()); + + match version { + 1 => { + // 2. size of the spirv + let spirv_size = ptr.cast::().read(); + ptr = ptr.add(size_of::()); + + // 3. binary_type + bin_type = ptr.cast::().read(); + ptr = ptr.add(size_of::()); + + // 4. the spirv + assert!(b.as_ptr().add(BIN_HEADER_SIZE_V1) == ptr); + assert!(b.len() == BIN_HEADER_SIZE_V1 + spirv_size as usize); + spirv = Some(spirv::SPIRVBin::from_bin( + slice::from_raw_parts(ptr, spirv_size as usize), + bin_type == CL_PROGRAM_BINARY_TYPE_EXECUTABLE, + )); + } + _ => panic!("unknown version"), + } + } + + if let Some(spirv) = &spirv { + for k in spirv.kernels() { + kernels.insert(k); + } + } + + builds.insert( + d.clone(), + ProgramDevBuild { + spirv: spirv, + status: CL_BUILD_SUCCESS as cl_build_status, + log: String::from(""), + options: String::from(""), + bin_type: bin_type, + }, + ); + } + + Arc::new(Self { + base: CLObjectBase::new(), + context: context, + devs: devs, + src: CString::new("").unwrap(), + build: Mutex::new(ProgramBuild { + builds: builds, + kernels: kernels.into_iter().collect(), + }), + }) + } + fn build_info(&self) -> MutexGuard { self.build.lock().unwrap() } @@ -120,6 +203,65 @@ impl Program { .clone() } + // we need to precalculate the size + pub fn bin_sizes(&self) -> Vec { + let mut lock = self.build_info(); + let mut res = Vec::new(); + for d in &self.devs { + let info = Self::dev_build_info(&mut lock, d); + + res.push( + info.spirv + .as_ref() + .map_or(0, |s| s.to_bin().len() + BIN_HEADER_SIZE), + ); + } + res + } + + pub fn binaries(&self, vals: &[u8]) -> Vec<*mut u8> { + // if the application didn't provide any pointers, just return the length of devices + if vals.is_empty() { + return vec![std::ptr::null_mut(); self.devs.len()]; + } + + // vals is an array of pointers where we should write the device binaries into + if vals.len() != self.devs.len() * size_of::<*const u8>() { + panic!("wrong size") + } + + let ptrs: &[*mut u8] = unsafe { + slice::from_raw_parts(vals.as_ptr().cast(), vals.len() / size_of::<*mut u8>()) + }; + + let mut lock = self.build_info(); + for (i, d) in self.devs.iter().enumerate() { + let mut ptr = ptrs[i]; + let info = Self::dev_build_info(&mut lock, d); + let spirv = info.spirv.as_ref().unwrap().to_bin(); + + unsafe { + // 1. binary format version + ptr.cast::().write(1); + ptr = ptr.add(size_of::()); + + // 2. size of the spirv + ptr.cast::().write(spirv.len() as u32); + ptr = ptr.add(size_of::()); + + // 3. binary_type + ptr.cast::().write(info.bin_type); + ptr = ptr.add(size_of::()); + + // 4. the spirv + assert!(ptrs[i].add(BIN_HEADER_SIZE) == ptr); + ptr::copy_nonoverlapping(spirv.as_ptr(), ptr, spirv.len()); + } + } + + ptrs.to_vec() + } + pub fn args(&self, dev: &Arc, kernel: &str) -> Vec { Self::dev_build_info(&mut self.build_info(), dev) .spirv @@ -133,6 +275,11 @@ impl Program { } pub fn build(&self, dev: &Arc, options: String) -> bool { + // program binary + if self.src.as_bytes().is_empty() { + return true; + } + let mut info = self.build_info(); let d = Self::dev_build_info(&mut info, dev); let lib = options.contains("-create-library"); @@ -175,6 +322,11 @@ impl Program { options: String, headers: &[spirv::CLCHeader], ) -> bool { + // program binary + if self.src.as_bytes().is_empty() { + return true; + } + let mut info = self.build_info(); let d = Self::dev_build_info(&mut info, dev); let args = prepare_options(&options, dev); diff --git a/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs b/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs index f98bb9061d3..7bf82587bdb 100644 --- a/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs +++ b/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs @@ -234,6 +234,33 @@ impl SPIRVBin { nir_load_libclc_shader(64, shader_cache, &spirv_options, nir_options) }) } + + pub fn to_bin(&self) -> &[u8] { + unsafe { slice::from_raw_parts(self.spirv.data.cast(), self.spirv.size) } + } + + pub fn from_bin(bin: &[u8], executable: bool) -> Self { + unsafe { + let ptr = malloc(bin.len()); + ptr::copy_nonoverlapping(bin.as_ptr(), ptr.cast(), bin.len()); + let spirv = clc_binary { + data: ptr, + size: bin.len(), + }; + let info = if executable { + let mut pspirv = clc_parsed_spirv::default(); + clc_parse_spirv(&spirv, ptr::null(), &mut pspirv); + Some(pspirv) + } else { + None + }; + + SPIRVBin { + spirv: spirv, + info: info, + } + } + } } impl Drop for SPIRVBin { diff --git a/src/gallium/frontends/rusticl/meson.build b/src/gallium/frontends/rusticl/meson.build index 8e3b1e85b1d..d892c5b2318 100644 --- a/src/gallium/frontends/rusticl/meson.build +++ b/src/gallium/frontends/rusticl/meson.build @@ -101,6 +101,10 @@ rusticl_bindgen_args = [ '--anon-fields-prefix', 'anon_', ] +rusticl_bindgen_c_args = [ + '-fno-builtin-malloc', +] + rusticl_opencl_bindings_rs = rust.bindgen( input : [ 'rusticl_opencl_bindings.h', @@ -111,6 +115,7 @@ rusticl_opencl_bindings_rs = rust.bindgen( inc_include, ], c_args : [ + rusticl_bindgen_c_args, '-DCL_USE_DEPRECATED_OPENCL_1_0_APIS', '-DCL_USE_DEPRECATED_OPENCL_1_1_APIS', '-DCL_USE_DEPRECATED_OPENCL_1_2_APIS', @@ -179,7 +184,10 @@ rusticl_mesa_bindings_rs = rust.bindgen( inc_nir, inc_src, ], - c_args : pre_args, + c_args : [ + rusticl_bindgen_c_args, + pre_args, + ], args : [ rusticl_bindgen_args, '--whitelist-function', 'clc_.*', @@ -189,6 +197,7 @@ rusticl_mesa_bindings_rs = rust.bindgen( '--whitelist-function', 'rusticl_.*', '--whitelist-function', 'rz?alloc_.*', '--whitelist-function', 'spirv_.*', + '--whitelist-function', 'malloc', '--whitelist-type', 'pipe_endian', '--whitelist-type', 'clc_kernel_arg_access_qualifier', '--bitfield-enum', 'clc_kernel_arg_access_qualifier',