rusticl/program: implement clCreateProgramWithBinary

Signed-off-by: Karol Herbst <kherbst@redhat.com>
Acked-by: Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15439>
This commit is contained in:
Karol Herbst
2022-03-17 11:44:22 +01:00
committed by Marge Bot
parent 84d16045d0
commit e028baa177
5 changed files with 273 additions and 3 deletions

View File

@@ -49,7 +49,7 @@ pub static DISPATCH: cl_icd_dispatch = cl_icd_dispatch {
clReleaseSampler: Some(cl_release_sampler),
clGetSamplerInfo: Some(cl_get_sampler_info),
clCreateProgramWithSource: Some(cl_create_program_with_source),
clCreateProgramWithBinary: None,
clCreateProgramWithBinary: Some(cl_create_program_with_binary),
clRetainProgram: Some(cl_retain_program),
clReleaseProgram: Some(cl_release_program),
clBuildProgram: Some(cl_build_program),
@@ -699,6 +699,28 @@ extern "C" fn cl_create_program_with_source(
)
}
extern "C" fn cl_create_program_with_binary(
context: cl_context,
num_devices: cl_uint,
device_list: *const cl_device_id,
lengths: *const usize,
binaries: *mut *const ::std::os::raw::c_uchar,
binary_status: *mut cl_int,
errcode_ret: *mut cl_int,
) -> cl_program {
match_obj!(
create_program_with_binary(
context,
num_devices,
device_list,
lengths,
binaries,
binary_status
),
errcode_ret
)
}
extern "C" fn cl_retain_program(program: cl_program) -> cl_int {
match_err!(program.retain())
}

View File

@@ -20,9 +20,11 @@ use std::slice;
use std::sync::Arc;
impl CLInfo<cl_program_info> for cl_program {
fn query(&self, q: cl_program_info, _: &[u8]) -> CLResult<Vec<u8>> {
fn query(&self, q: cl_program_info, vals: &[u8]) -> CLResult<Vec<u8>> {
let prog = self.get_ref()?;
Ok(match q {
CL_PROGRAM_BINARIES => cl_prop::<Vec<*mut u8>>(prog.binaries(vals)),
CL_PROGRAM_BINARY_SIZES => cl_prop::<Vec<usize>>(prog.bin_sizes()),
CL_PROGRAM_CONTEXT => {
// Note we use as_ptr here which doesn't increase the reference count.
let ptr = Arc::as_ptr(&prog.context);
@@ -134,6 +136,64 @@ pub fn create_program_with_source(
)))
}
pub fn create_program_with_binary(
context: cl_context,
num_devices: cl_uint,
device_list: *const cl_device_id,
lengths: *const usize,
binaries: *mut *const ::std::os::raw::c_uchar,
binary_status: *mut cl_int,
) -> CLResult<cl_program> {
let c = context.get_arc()?;
let devs = cl_device_id::get_arc_vec_from_arr(device_list, num_devices)?;
// CL_INVALID_VALUE if device_list is NULL or num_devices is zero.
if devs.is_empty() {
return Err(CL_INVALID_VALUE);
}
// CL_INVALID_VALUE if lengths or binaries is NULL
if lengths.is_null() || binaries.is_null() {
return Err(CL_INVALID_VALUE);
}
// CL_INVALID_DEVICE if any device in device_list is not in the list of devices associated with
// context.
if !devs.iter().all(|d| c.devs.contains(d)) {
return Err(CL_INVALID_DEVICE);
}
let lengths = unsafe { slice::from_raw_parts(lengths, num_devices as usize) };
let binaries = unsafe { slice::from_raw_parts(binaries, num_devices as usize) };
// now device specific stuff
let mut err = 0;
let mut bins: Vec<&[u8]> = vec![&[]; num_devices as usize];
for i in 0..num_devices as usize {
let mut dev_err = 0;
// CL_INVALID_VALUE if lengths[i] is zero or if binaries[i] is a NULL value
if lengths[i] == 0 || binaries[i].is_null() {
dev_err = CL_INVALID_VALUE;
}
if !binary_status.is_null() {
unsafe { binary_status.add(i).write(dev_err) };
}
// just return the last one
err = dev_err;
bins[i] = unsafe { slice::from_raw_parts(binaries[i], lengths[i] as usize) };
}
if err != 0 {
return Err(err);
}
Ok(cl_program::from_arc(Program::from_bins(c, devs, &bins)))
//• CL_INVALID_BINARY if an invalid program binary was encountered for any device. binary_status will return specific status for each device.
}
pub fn build_program(
program: cl_program,
num_devices: cl_uint,

View File

@@ -15,10 +15,23 @@ use self::rusticl_opencl_gen::*;
use std::collections::HashMap;
use std::collections::HashSet;
use std::ffi::CString;
use std::mem::size_of;
use std::ptr;
use std::slice;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::MutexGuard;
const BIN_HEADER_SIZE_V1: usize =
// 1. format version
size_of::<u32>() +
// 2. spirv len
size_of::<u32>() +
// 3. binary_type
size_of::<cl_program_binary_type>();
const BIN_HEADER_SIZE: usize = BIN_HEADER_SIZE_V1;
#[repr(C)]
pub struct Program {
pub base: CLObjectBase<CL_INVALID_PROGRAM>,
@@ -89,6 +102,76 @@ impl Program {
})
}
pub fn from_bins(
context: Arc<Context>,
devs: Vec<Arc<Device>>,
bins: &[&[u8]],
) -> Arc<Program> {
let mut builds = HashMap::new();
let mut kernels = HashSet::new();
for (d, b) in devs.iter().zip(bins) {
let mut ptr = b.as_ptr();
let bin_type;
let spirv;
unsafe {
// 1. version
let version = ptr.cast::<u32>().read();
ptr = ptr.add(size_of::<u32>());
match version {
1 => {
// 2. size of the spirv
let spirv_size = ptr.cast::<u32>().read();
ptr = ptr.add(size_of::<u32>());
// 3. binary_type
bin_type = ptr.cast::<cl_program_binary_type>().read();
ptr = ptr.add(size_of::<cl_program_binary_type>());
// 4. the spirv
assert!(b.as_ptr().add(BIN_HEADER_SIZE_V1) == ptr);
assert!(b.len() == BIN_HEADER_SIZE_V1 + spirv_size as usize);
spirv = Some(spirv::SPIRVBin::from_bin(
slice::from_raw_parts(ptr, spirv_size as usize),
bin_type == CL_PROGRAM_BINARY_TYPE_EXECUTABLE,
));
}
_ => panic!("unknown version"),
}
}
if let Some(spirv) = &spirv {
for k in spirv.kernels() {
kernels.insert(k);
}
}
builds.insert(
d.clone(),
ProgramDevBuild {
spirv: spirv,
status: CL_BUILD_SUCCESS as cl_build_status,
log: String::from(""),
options: String::from(""),
bin_type: bin_type,
},
);
}
Arc::new(Self {
base: CLObjectBase::new(),
context: context,
devs: devs,
src: CString::new("").unwrap(),
build: Mutex::new(ProgramBuild {
builds: builds,
kernels: kernels.into_iter().collect(),
}),
})
}
fn build_info(&self) -> MutexGuard<ProgramBuild> {
self.build.lock().unwrap()
}
@@ -120,6 +203,65 @@ impl Program {
.clone()
}
// we need to precalculate the size
pub fn bin_sizes(&self) -> Vec<usize> {
let mut lock = self.build_info();
let mut res = Vec::new();
for d in &self.devs {
let info = Self::dev_build_info(&mut lock, d);
res.push(
info.spirv
.as_ref()
.map_or(0, |s| s.to_bin().len() + BIN_HEADER_SIZE),
);
}
res
}
pub fn binaries(&self, vals: &[u8]) -> Vec<*mut u8> {
// if the application didn't provide any pointers, just return the length of devices
if vals.is_empty() {
return vec![std::ptr::null_mut(); self.devs.len()];
}
// vals is an array of pointers where we should write the device binaries into
if vals.len() != self.devs.len() * size_of::<*const u8>() {
panic!("wrong size")
}
let ptrs: &[*mut u8] = unsafe {
slice::from_raw_parts(vals.as_ptr().cast(), vals.len() / size_of::<*mut u8>())
};
let mut lock = self.build_info();
for (i, d) in self.devs.iter().enumerate() {
let mut ptr = ptrs[i];
let info = Self::dev_build_info(&mut lock, d);
let spirv = info.spirv.as_ref().unwrap().to_bin();
unsafe {
// 1. binary format version
ptr.cast::<u32>().write(1);
ptr = ptr.add(size_of::<u32>());
// 2. size of the spirv
ptr.cast::<u32>().write(spirv.len() as u32);
ptr = ptr.add(size_of::<u32>());
// 3. binary_type
ptr.cast::<cl_program_binary_type>().write(info.bin_type);
ptr = ptr.add(size_of::<cl_program_binary_type>());
// 4. the spirv
assert!(ptrs[i].add(BIN_HEADER_SIZE) == ptr);
ptr::copy_nonoverlapping(spirv.as_ptr(), ptr, spirv.len());
}
}
ptrs.to_vec()
}
pub fn args(&self, dev: &Arc<Device>, kernel: &str) -> Vec<spirv::SPIRVKernelArg> {
Self::dev_build_info(&mut self.build_info(), dev)
.spirv
@@ -133,6 +275,11 @@ impl Program {
}
pub fn build(&self, dev: &Arc<Device>, options: String) -> bool {
// program binary
if self.src.as_bytes().is_empty() {
return true;
}
let mut info = self.build_info();
let d = Self::dev_build_info(&mut info, dev);
let lib = options.contains("-create-library");
@@ -175,6 +322,11 @@ impl Program {
options: String,
headers: &[spirv::CLCHeader],
) -> bool {
// program binary
if self.src.as_bytes().is_empty() {
return true;
}
let mut info = self.build_info();
let d = Self::dev_build_info(&mut info, dev);
let args = prepare_options(&options, dev);

View File

@@ -234,6 +234,33 @@ impl SPIRVBin {
nir_load_libclc_shader(64, shader_cache, &spirv_options, nir_options)
})
}
pub fn to_bin(&self) -> &[u8] {
unsafe { slice::from_raw_parts(self.spirv.data.cast(), self.spirv.size) }
}
pub fn from_bin(bin: &[u8], executable: bool) -> Self {
unsafe {
let ptr = malloc(bin.len());
ptr::copy_nonoverlapping(bin.as_ptr(), ptr.cast(), bin.len());
let spirv = clc_binary {
data: ptr,
size: bin.len(),
};
let info = if executable {
let mut pspirv = clc_parsed_spirv::default();
clc_parse_spirv(&spirv, ptr::null(), &mut pspirv);
Some(pspirv)
} else {
None
};
SPIRVBin {
spirv: spirv,
info: info,
}
}
}
}
impl Drop for SPIRVBin {

View File

@@ -101,6 +101,10 @@ rusticl_bindgen_args = [
'--anon-fields-prefix', 'anon_',
]
rusticl_bindgen_c_args = [
'-fno-builtin-malloc',
]
rusticl_opencl_bindings_rs = rust.bindgen(
input : [
'rusticl_opencl_bindings.h',
@@ -111,6 +115,7 @@ rusticl_opencl_bindings_rs = rust.bindgen(
inc_include,
],
c_args : [
rusticl_bindgen_c_args,
'-DCL_USE_DEPRECATED_OPENCL_1_0_APIS',
'-DCL_USE_DEPRECATED_OPENCL_1_1_APIS',
'-DCL_USE_DEPRECATED_OPENCL_1_2_APIS',
@@ -179,7 +184,10 @@ rusticl_mesa_bindings_rs = rust.bindgen(
inc_nir,
inc_src,
],
c_args : pre_args,
c_args : [
rusticl_bindgen_c_args,
pre_args,
],
args : [
rusticl_bindgen_args,
'--whitelist-function', 'clc_.*',
@@ -189,6 +197,7 @@ rusticl_mesa_bindings_rs = rust.bindgen(
'--whitelist-function', 'rusticl_.*',
'--whitelist-function', 'rz?alloc_.*',
'--whitelist-function', 'spirv_.*',
'--whitelist-function', 'malloc',
'--whitelist-type', 'pipe_endian',
'--whitelist-type', 'clc_kernel_arg_access_qualifier',
'--bitfield-enum', 'clc_kernel_arg_access_qualifier',