rusticl/program: implement clCreateProgramWithBinary
Signed-off-by: Karol Herbst <kherbst@redhat.com> Acked-by: Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15439>
This commit is contained in:
@@ -49,7 +49,7 @@ pub static DISPATCH: cl_icd_dispatch = cl_icd_dispatch {
|
||||
clReleaseSampler: Some(cl_release_sampler),
|
||||
clGetSamplerInfo: Some(cl_get_sampler_info),
|
||||
clCreateProgramWithSource: Some(cl_create_program_with_source),
|
||||
clCreateProgramWithBinary: None,
|
||||
clCreateProgramWithBinary: Some(cl_create_program_with_binary),
|
||||
clRetainProgram: Some(cl_retain_program),
|
||||
clReleaseProgram: Some(cl_release_program),
|
||||
clBuildProgram: Some(cl_build_program),
|
||||
@@ -699,6 +699,28 @@ extern "C" fn cl_create_program_with_source(
|
||||
)
|
||||
}
|
||||
|
||||
extern "C" fn cl_create_program_with_binary(
|
||||
context: cl_context,
|
||||
num_devices: cl_uint,
|
||||
device_list: *const cl_device_id,
|
||||
lengths: *const usize,
|
||||
binaries: *mut *const ::std::os::raw::c_uchar,
|
||||
binary_status: *mut cl_int,
|
||||
errcode_ret: *mut cl_int,
|
||||
) -> cl_program {
|
||||
match_obj!(
|
||||
create_program_with_binary(
|
||||
context,
|
||||
num_devices,
|
||||
device_list,
|
||||
lengths,
|
||||
binaries,
|
||||
binary_status
|
||||
),
|
||||
errcode_ret
|
||||
)
|
||||
}
|
||||
|
||||
extern "C" fn cl_retain_program(program: cl_program) -> cl_int {
|
||||
match_err!(program.retain())
|
||||
}
|
||||
|
@@ -20,9 +20,11 @@ use std::slice;
|
||||
use std::sync::Arc;
|
||||
|
||||
impl CLInfo<cl_program_info> for cl_program {
|
||||
fn query(&self, q: cl_program_info, _: &[u8]) -> CLResult<Vec<u8>> {
|
||||
fn query(&self, q: cl_program_info, vals: &[u8]) -> CLResult<Vec<u8>> {
|
||||
let prog = self.get_ref()?;
|
||||
Ok(match q {
|
||||
CL_PROGRAM_BINARIES => cl_prop::<Vec<*mut u8>>(prog.binaries(vals)),
|
||||
CL_PROGRAM_BINARY_SIZES => cl_prop::<Vec<usize>>(prog.bin_sizes()),
|
||||
CL_PROGRAM_CONTEXT => {
|
||||
// Note we use as_ptr here which doesn't increase the reference count.
|
||||
let ptr = Arc::as_ptr(&prog.context);
|
||||
@@ -134,6 +136,64 @@ pub fn create_program_with_source(
|
||||
)))
|
||||
}
|
||||
|
||||
pub fn create_program_with_binary(
|
||||
context: cl_context,
|
||||
num_devices: cl_uint,
|
||||
device_list: *const cl_device_id,
|
||||
lengths: *const usize,
|
||||
binaries: *mut *const ::std::os::raw::c_uchar,
|
||||
binary_status: *mut cl_int,
|
||||
) -> CLResult<cl_program> {
|
||||
let c = context.get_arc()?;
|
||||
let devs = cl_device_id::get_arc_vec_from_arr(device_list, num_devices)?;
|
||||
|
||||
// CL_INVALID_VALUE if device_list is NULL or num_devices is zero.
|
||||
if devs.is_empty() {
|
||||
return Err(CL_INVALID_VALUE);
|
||||
}
|
||||
|
||||
// CL_INVALID_VALUE if lengths or binaries is NULL
|
||||
if lengths.is_null() || binaries.is_null() {
|
||||
return Err(CL_INVALID_VALUE);
|
||||
}
|
||||
|
||||
// CL_INVALID_DEVICE if any device in device_list is not in the list of devices associated with
|
||||
// context.
|
||||
if !devs.iter().all(|d| c.devs.contains(d)) {
|
||||
return Err(CL_INVALID_DEVICE);
|
||||
}
|
||||
|
||||
let lengths = unsafe { slice::from_raw_parts(lengths, num_devices as usize) };
|
||||
let binaries = unsafe { slice::from_raw_parts(binaries, num_devices as usize) };
|
||||
|
||||
// now device specific stuff
|
||||
let mut err = 0;
|
||||
let mut bins: Vec<&[u8]> = vec![&[]; num_devices as usize];
|
||||
for i in 0..num_devices as usize {
|
||||
let mut dev_err = 0;
|
||||
|
||||
// CL_INVALID_VALUE if lengths[i] is zero or if binaries[i] is a NULL value
|
||||
if lengths[i] == 0 || binaries[i].is_null() {
|
||||
dev_err = CL_INVALID_VALUE;
|
||||
}
|
||||
|
||||
if !binary_status.is_null() {
|
||||
unsafe { binary_status.add(i).write(dev_err) };
|
||||
}
|
||||
|
||||
// just return the last one
|
||||
err = dev_err;
|
||||
bins[i] = unsafe { slice::from_raw_parts(binaries[i], lengths[i] as usize) };
|
||||
}
|
||||
|
||||
if err != 0 {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
Ok(cl_program::from_arc(Program::from_bins(c, devs, &bins)))
|
||||
//• CL_INVALID_BINARY if an invalid program binary was encountered for any device. binary_status will return specific status for each device.
|
||||
}
|
||||
|
||||
pub fn build_program(
|
||||
program: cl_program,
|
||||
num_devices: cl_uint,
|
||||
|
@@ -15,10 +15,23 @@ use self::rusticl_opencl_gen::*;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::ffi::CString;
|
||||
use std::mem::size_of;
|
||||
use std::ptr;
|
||||
use std::slice;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::MutexGuard;
|
||||
|
||||
const BIN_HEADER_SIZE_V1: usize =
|
||||
// 1. format version
|
||||
size_of::<u32>() +
|
||||
// 2. spirv len
|
||||
size_of::<u32>() +
|
||||
// 3. binary_type
|
||||
size_of::<cl_program_binary_type>();
|
||||
|
||||
const BIN_HEADER_SIZE: usize = BIN_HEADER_SIZE_V1;
|
||||
|
||||
#[repr(C)]
|
||||
pub struct Program {
|
||||
pub base: CLObjectBase<CL_INVALID_PROGRAM>,
|
||||
@@ -89,6 +102,76 @@ impl Program {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_bins(
|
||||
context: Arc<Context>,
|
||||
devs: Vec<Arc<Device>>,
|
||||
bins: &[&[u8]],
|
||||
) -> Arc<Program> {
|
||||
let mut builds = HashMap::new();
|
||||
let mut kernels = HashSet::new();
|
||||
|
||||
for (d, b) in devs.iter().zip(bins) {
|
||||
let mut ptr = b.as_ptr();
|
||||
let bin_type;
|
||||
let spirv;
|
||||
|
||||
unsafe {
|
||||
// 1. version
|
||||
let version = ptr.cast::<u32>().read();
|
||||
ptr = ptr.add(size_of::<u32>());
|
||||
|
||||
match version {
|
||||
1 => {
|
||||
// 2. size of the spirv
|
||||
let spirv_size = ptr.cast::<u32>().read();
|
||||
ptr = ptr.add(size_of::<u32>());
|
||||
|
||||
// 3. binary_type
|
||||
bin_type = ptr.cast::<cl_program_binary_type>().read();
|
||||
ptr = ptr.add(size_of::<cl_program_binary_type>());
|
||||
|
||||
// 4. the spirv
|
||||
assert!(b.as_ptr().add(BIN_HEADER_SIZE_V1) == ptr);
|
||||
assert!(b.len() == BIN_HEADER_SIZE_V1 + spirv_size as usize);
|
||||
spirv = Some(spirv::SPIRVBin::from_bin(
|
||||
slice::from_raw_parts(ptr, spirv_size as usize),
|
||||
bin_type == CL_PROGRAM_BINARY_TYPE_EXECUTABLE,
|
||||
));
|
||||
}
|
||||
_ => panic!("unknown version"),
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(spirv) = &spirv {
|
||||
for k in spirv.kernels() {
|
||||
kernels.insert(k);
|
||||
}
|
||||
}
|
||||
|
||||
builds.insert(
|
||||
d.clone(),
|
||||
ProgramDevBuild {
|
||||
spirv: spirv,
|
||||
status: CL_BUILD_SUCCESS as cl_build_status,
|
||||
log: String::from(""),
|
||||
options: String::from(""),
|
||||
bin_type: bin_type,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Arc::new(Self {
|
||||
base: CLObjectBase::new(),
|
||||
context: context,
|
||||
devs: devs,
|
||||
src: CString::new("").unwrap(),
|
||||
build: Mutex::new(ProgramBuild {
|
||||
builds: builds,
|
||||
kernels: kernels.into_iter().collect(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
fn build_info(&self) -> MutexGuard<ProgramBuild> {
|
||||
self.build.lock().unwrap()
|
||||
}
|
||||
@@ -120,6 +203,65 @@ impl Program {
|
||||
.clone()
|
||||
}
|
||||
|
||||
// we need to precalculate the size
|
||||
pub fn bin_sizes(&self) -> Vec<usize> {
|
||||
let mut lock = self.build_info();
|
||||
let mut res = Vec::new();
|
||||
for d in &self.devs {
|
||||
let info = Self::dev_build_info(&mut lock, d);
|
||||
|
||||
res.push(
|
||||
info.spirv
|
||||
.as_ref()
|
||||
.map_or(0, |s| s.to_bin().len() + BIN_HEADER_SIZE),
|
||||
);
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
pub fn binaries(&self, vals: &[u8]) -> Vec<*mut u8> {
|
||||
// if the application didn't provide any pointers, just return the length of devices
|
||||
if vals.is_empty() {
|
||||
return vec![std::ptr::null_mut(); self.devs.len()];
|
||||
}
|
||||
|
||||
// vals is an array of pointers where we should write the device binaries into
|
||||
if vals.len() != self.devs.len() * size_of::<*const u8>() {
|
||||
panic!("wrong size")
|
||||
}
|
||||
|
||||
let ptrs: &[*mut u8] = unsafe {
|
||||
slice::from_raw_parts(vals.as_ptr().cast(), vals.len() / size_of::<*mut u8>())
|
||||
};
|
||||
|
||||
let mut lock = self.build_info();
|
||||
for (i, d) in self.devs.iter().enumerate() {
|
||||
let mut ptr = ptrs[i];
|
||||
let info = Self::dev_build_info(&mut lock, d);
|
||||
let spirv = info.spirv.as_ref().unwrap().to_bin();
|
||||
|
||||
unsafe {
|
||||
// 1. binary format version
|
||||
ptr.cast::<u32>().write(1);
|
||||
ptr = ptr.add(size_of::<u32>());
|
||||
|
||||
// 2. size of the spirv
|
||||
ptr.cast::<u32>().write(spirv.len() as u32);
|
||||
ptr = ptr.add(size_of::<u32>());
|
||||
|
||||
// 3. binary_type
|
||||
ptr.cast::<cl_program_binary_type>().write(info.bin_type);
|
||||
ptr = ptr.add(size_of::<cl_program_binary_type>());
|
||||
|
||||
// 4. the spirv
|
||||
assert!(ptrs[i].add(BIN_HEADER_SIZE) == ptr);
|
||||
ptr::copy_nonoverlapping(spirv.as_ptr(), ptr, spirv.len());
|
||||
}
|
||||
}
|
||||
|
||||
ptrs.to_vec()
|
||||
}
|
||||
|
||||
pub fn args(&self, dev: &Arc<Device>, kernel: &str) -> Vec<spirv::SPIRVKernelArg> {
|
||||
Self::dev_build_info(&mut self.build_info(), dev)
|
||||
.spirv
|
||||
@@ -133,6 +275,11 @@ impl Program {
|
||||
}
|
||||
|
||||
pub fn build(&self, dev: &Arc<Device>, options: String) -> bool {
|
||||
// program binary
|
||||
if self.src.as_bytes().is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut info = self.build_info();
|
||||
let d = Self::dev_build_info(&mut info, dev);
|
||||
let lib = options.contains("-create-library");
|
||||
@@ -175,6 +322,11 @@ impl Program {
|
||||
options: String,
|
||||
headers: &[spirv::CLCHeader],
|
||||
) -> bool {
|
||||
// program binary
|
||||
if self.src.as_bytes().is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut info = self.build_info();
|
||||
let d = Self::dev_build_info(&mut info, dev);
|
||||
let args = prepare_options(&options, dev);
|
||||
|
@@ -234,6 +234,33 @@ impl SPIRVBin {
|
||||
nir_load_libclc_shader(64, shader_cache, &spirv_options, nir_options)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_bin(&self) -> &[u8] {
|
||||
unsafe { slice::from_raw_parts(self.spirv.data.cast(), self.spirv.size) }
|
||||
}
|
||||
|
||||
pub fn from_bin(bin: &[u8], executable: bool) -> Self {
|
||||
unsafe {
|
||||
let ptr = malloc(bin.len());
|
||||
ptr::copy_nonoverlapping(bin.as_ptr(), ptr.cast(), bin.len());
|
||||
let spirv = clc_binary {
|
||||
data: ptr,
|
||||
size: bin.len(),
|
||||
};
|
||||
let info = if executable {
|
||||
let mut pspirv = clc_parsed_spirv::default();
|
||||
clc_parse_spirv(&spirv, ptr::null(), &mut pspirv);
|
||||
Some(pspirv)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
SPIRVBin {
|
||||
spirv: spirv,
|
||||
info: info,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SPIRVBin {
|
||||
|
@@ -101,6 +101,10 @@ rusticl_bindgen_args = [
|
||||
'--anon-fields-prefix', 'anon_',
|
||||
]
|
||||
|
||||
rusticl_bindgen_c_args = [
|
||||
'-fno-builtin-malloc',
|
||||
]
|
||||
|
||||
rusticl_opencl_bindings_rs = rust.bindgen(
|
||||
input : [
|
||||
'rusticl_opencl_bindings.h',
|
||||
@@ -111,6 +115,7 @@ rusticl_opencl_bindings_rs = rust.bindgen(
|
||||
inc_include,
|
||||
],
|
||||
c_args : [
|
||||
rusticl_bindgen_c_args,
|
||||
'-DCL_USE_DEPRECATED_OPENCL_1_0_APIS',
|
||||
'-DCL_USE_DEPRECATED_OPENCL_1_1_APIS',
|
||||
'-DCL_USE_DEPRECATED_OPENCL_1_2_APIS',
|
||||
@@ -179,7 +184,10 @@ rusticl_mesa_bindings_rs = rust.bindgen(
|
||||
inc_nir,
|
||||
inc_src,
|
||||
],
|
||||
c_args : pre_args,
|
||||
c_args : [
|
||||
rusticl_bindgen_c_args,
|
||||
pre_args,
|
||||
],
|
||||
args : [
|
||||
rusticl_bindgen_args,
|
||||
'--whitelist-function', 'clc_.*',
|
||||
@@ -189,6 +197,7 @@ rusticl_mesa_bindings_rs = rust.bindgen(
|
||||
'--whitelist-function', 'rusticl_.*',
|
||||
'--whitelist-function', 'rz?alloc_.*',
|
||||
'--whitelist-function', 'spirv_.*',
|
||||
'--whitelist-function', 'malloc',
|
||||
'--whitelist-type', 'pipe_endian',
|
||||
'--whitelist-type', 'clc_kernel_arg_access_qualifier',
|
||||
'--bitfield-enum', 'clc_kernel_arg_access_qualifier',
|
||||
|
Reference in New Issue
Block a user