rusticl: Move NirKernelBuild to ProgramDevBuild

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23898>
This commit is contained in:
Antonio Gomes
2023-07-23 13:02:21 -03:00
committed by Marge Bot
parent 7ec9b9cd07
commit 323dcbb4b5
3 changed files with 133 additions and 114 deletions

View File

@@ -22,13 +22,13 @@ impl CLInfo<cl_kernel_info> for cl_kernel {
fn query(&self, q: cl_kernel_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let kernel = self.get_ref()?;
Ok(match q {
CL_KERNEL_ATTRIBUTES => cl_prop::<&str>(&kernel.build.attributes_string),
CL_KERNEL_ATTRIBUTES => cl_prop::<&str>(&kernel.kernel_info.attributes_string),
CL_KERNEL_CONTEXT => {
let ptr = Arc::as_ptr(&kernel.prog.context);
cl_prop::<cl_context>(cl_context::from_ptr(ptr))
}
CL_KERNEL_FUNCTION_NAME => cl_prop::<&str>(&kernel.name),
CL_KERNEL_NUM_ARGS => cl_prop::<cl_uint>(kernel.build.args.len() as cl_uint),
CL_KERNEL_NUM_ARGS => cl_prop::<cl_uint>(kernel.kernel_info.args.len() as cl_uint),
CL_KERNEL_PROGRAM => {
let ptr = Arc::as_ptr(&kernel.prog);
cl_prop::<cl_program>(cl_program::from_ptr(ptr))
@@ -46,7 +46,7 @@ impl CLInfoObj<cl_kernel_arg_info, cl_uint> for cl_kernel {
let kernel = self.get_ref()?;
// CL_INVALID_ARG_INDEX if arg_index is not a valid argument index.
if idx as usize >= kernel.build.args.len() {
if idx as usize >= kernel.kernel_info.args.len() {
return Err(CL_INVALID_ARG_INDEX);
}
@@ -329,7 +329,7 @@ fn set_kernel_arg(
let k = kernel.get_arc()?;
// CL_INVALID_ARG_INDEX if arg_index is not a valid argument index.
if let Some(arg) = k.build.args.get(arg_index as usize) {
if let Some(arg) = k.kernel_info.args.get(arg_index as usize) {
// CL_INVALID_ARG_SIZE if arg_size does not match the size of the data type for an argument
// that is not a memory object or if the argument is a memory object and
// arg_size != sizeof(cl_mem) or if arg_size is zero and the argument is declared with the
@@ -429,7 +429,7 @@ fn set_kernel_arg_svm_pointer(
return Err(CL_INVALID_OPERATION);
}
if let Some(arg) = kernel.build.args.get(arg_index) {
if let Some(arg) = kernel.kernel_info.args.get(arg_index) {
if !matches!(
arg.kind,
KernelArgType::MemConstant | KernelArgType::MemGlobal

View File

@@ -250,7 +250,17 @@ impl InternalKernelArg {
}
}
struct CSOWrapper {
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct KernelInfo {
pub args: Vec<KernelArg>,
pub internal_args: Vec<InternalKernelArg>,
pub attributes_string: String,
pub work_group_size: [usize; 3],
pub subgroup_size: usize,
pub num_subgroups: usize,
}
pub struct CSOWrapper {
pub cso_ptr: *mut c_void,
dev: &'static Device,
}
@@ -286,7 +296,7 @@ impl Drop for CSOWrapper {
}
}
enum KernelDevStateVariant {
pub enum KernelDevStateVariant {
Cso(Arc<CSOWrapper>),
Nir(Arc<NirShader>),
}
@@ -341,7 +351,7 @@ impl KernelDevState {
Arc::new(Self { states: states })
}
fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>> {
pub fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>> {
let buf = nir.get_constant_buffer();
let len = buf.len() as u32;
@@ -371,7 +381,8 @@ pub struct Kernel {
pub prog: Arc<Program>,
pub name: String,
pub values: Vec<RefCell<Option<KernelArgValue>>>,
pub build: Arc<NirKernelBuild>,
pub builds: HashMap<&'static Device, Arc<NirKernelBuild>>,
pub kernel_info: KernelInfo,
}
impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL);
@@ -830,10 +841,16 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
impl Kernel {
pub fn new(name: String, prog: Arc<Program>) -> Arc<Kernel> {
let nir_kernel_build = prog.get_nir_kernel_build(&name);
let prog_build = prog.build_info();
let kernel_info = prog_build.kernel_info.get(&name).unwrap().clone();
let builds = prog_build
.builds
.iter()
.map(|(k, v)| (*k, v.kernels.get(&name).unwrap().clone()))
.collect();
// can't use vec!...
let values = nir_kernel_build
let values = kernel_info
.args
.iter()
.map(|_| RefCell::new(None))
@@ -841,10 +858,11 @@ impl Kernel {
Arc::new(Self {
base: CLObjectBase::new(),
prog: prog,
prog: prog.clone(),
name: name,
values: values,
build: nir_kernel_build,
builds: builds,
kernel_info: kernel_info,
})
}
@@ -896,14 +914,14 @@ impl Kernel {
grid: &[usize],
offsets: &[usize],
) -> CLResult<EventSig> {
let dev_state = self.build.dev_state.get(q.device);
let nir_kernel_build = self.builds.get(q.device).unwrap().clone();
let mut block = create_kernel_arr::<u32>(block, 1);
let mut grid = create_kernel_arr::<u32>(grid, 1);
let offsets = create_kernel_arr::<u64>(offsets, 0);
let mut input: Vec<u8> = Vec::new();
let mut resource_info = Vec::new();
// Set it once so we get the alignment padding right
let static_local_size: u64 = dev_state.nir_internal_info.shared_size;
let static_local_size: u64 = nir_kernel_build.shared_size;
let mut variable_local_size: u64 = static_local_size;
let printf_size = q.device.printf_buffer_size() as u32;
let mut samplers = Vec::new();
@@ -921,7 +939,7 @@ impl Kernel {
self.optimize_local_size(q.device, &mut grid, &mut block);
for (arg, val) in self.build.args.iter().zip(&self.values) {
for (arg, val) in self.kernel_info.args.iter().zip(&self.values) {
if arg.dead {
continue;
}
@@ -1005,18 +1023,21 @@ impl Kernel {
}
// subtract the shader local_size as we only request something on top of that.
variable_local_size -= dev_state.nir_internal_info.shared_size;
variable_local_size -= static_local_size;
let mut printf_buf = None;
for arg in &self.build.internal_args {
for arg in &self.kernel_info.internal_args {
if arg.offset > input.len() {
input.resize(arg.offset, 0);
}
match arg.kind {
InternalKernelArgType::ConstantBuffer => {
assert!(dev_state.constant_buffer.is_some());
assert!(nir_kernel_build.constant_buffer.is_some());
input.extend_from_slice(null_ptr);
resource_info.push((dev_state.constant_buffer.clone().unwrap(), arg.offset));
resource_info.push((
nir_kernel_build.constant_buffer.clone().unwrap(),
arg.offset,
));
}
InternalKernelArgType::GlobalWorkOffsets => {
if q.device.address_bits() == 64 {
@@ -1061,13 +1082,11 @@ impl Kernel {
}
}
let k = Arc::clone(self);
Ok(Box::new(move |q, ctx| {
let dev_state = k.build.dev_state.get(q.device);
let mut input = input.clone();
let mut resources = Vec::with_capacity(resource_info.len());
let mut globals: Vec<*mut u32> = Vec::new();
let printf_format = &dev_state.nir_internal_info.printf_info;
let printf_format = &nir_kernel_build.printf_info;
let mut sviews: Vec<_> = sviews
.iter()
@@ -1093,7 +1112,7 @@ impl Kernel {
);
}
let cso = match &dev_state.nir_or_cso {
let cso = match &nir_kernel_build.nir_or_cso {
KernelDevStateVariant::Cso(cso) => cso.clone(),
KernelDevStateVariant::Nir(nir) => CSOWrapper::new(q.device, nir),
};
@@ -1145,7 +1164,7 @@ impl Kernel {
}
pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier {
let aq = self.build.args[idx as usize].spirv.access_qualifier;
let aq = self.kernel_info.args[idx as usize].spirv.access_qualifier;
if aq
== clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
@@ -1162,7 +1181,7 @@ impl Kernel {
}
pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier {
match self.build.args[idx as usize].spirv.address_qualifier {
match self.kernel_info.args[idx as usize].spirv.address_qualifier {
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
CL_KERNEL_ARG_ADDRESS_PRIVATE
}
@@ -1179,7 +1198,7 @@ impl Kernel {
}
pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier {
let tq = self.build.args[idx as usize].spirv.type_qualifier;
let tq = self.kernel_info.args[idx as usize].spirv.type_qualifier;
let zero = clc_kernel_arg_type_qualifier(0);
let mut res = CL_KERNEL_ARG_TYPE_NONE;
@@ -1199,61 +1218,40 @@ impl Kernel {
}
pub fn work_group_size(&self) -> [usize; 3] {
self.build
.dev_state
.states
.values()
.next()
.unwrap()
.nir_internal_info
.work_group_size
self.kernel_info.work_group_size
}
pub fn num_subgroups(&self) -> usize {
self.build
.dev_state
.states
.values()
.next()
.unwrap()
.nir_internal_info
.num_subgroups
self.kernel_info.num_subgroups
}
pub fn subgroup_size(&self) -> usize {
self.build
.dev_state
.states
.values()
.next()
.unwrap()
.nir_internal_info
.subgroup_size
self.kernel_info.subgroup_size
}
pub fn arg_name(&self, idx: cl_uint) -> &String {
&self.build.args[idx as usize].spirv.name
&self.kernel_info.args[idx as usize].spirv.name
}
pub fn arg_type_name(&self, idx: cl_uint) -> &String {
&self.build.args[idx as usize].spirv.type_name
&self.kernel_info.args[idx as usize].spirv.type_name
}
pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
self.build.dev_state.get(dev).info.private_memory.into()
self.builds.get(dev).unwrap().info.private_memory as cl_ulong
}
pub fn max_threads_per_block(&self, dev: &Device) -> usize {
self.build.dev_state.get(dev).info.max_threads as usize
self.builds.get(dev).unwrap().info.max_threads as usize
}
pub fn preferred_simd_size(&self, dev: &Device) -> usize {
self.build.dev_state.get(dev).info.preferred_simd_size as usize
self.builds.get(dev).unwrap().info.preferred_simd_size as usize
}
pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
// TODO include args
self.build.dev_state.get(dev).nir_internal_info.shared_size as cl_ulong
self.builds.get(dev).unwrap().shared_size as cl_ulong
}
pub fn has_svm_devs(&self) -> bool {
@@ -1261,7 +1259,7 @@ impl Kernel {
}
pub fn subgroup_sizes(&self, dev: &Device) -> Vec<usize> {
SetBitIndices::from_msb(self.build.dev_state.get(dev).info.simd_sizes)
SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes)
.map(|bit| 1 << bit)
.collect()
}
@@ -1292,7 +1290,7 @@ impl Kernel {
*block.get(2).unwrap_or(&1) as u32,
];
match &self.build.dev_state.get(dev).nir_or_cso {
match &self.builds.get(dev).unwrap().nir_or_cso {
KernelDevStateVariant::Cso(cso) => {
dev.helper_ctx()
.compute_state_subgroup_size(cso.cso_ptr, &block) as usize
@@ -1311,7 +1309,8 @@ impl Clone for Kernel {
prog: self.prog.clone(),
name: self.name.clone(),
values: self.values.clone(),
build: self.build.clone(),
builds: self.builds.clone(),
kernel_info: self.kernel_info.clone(),
}
}
}

View File

@@ -8,6 +8,7 @@ use crate::impl_cl_type_trait;
use mesa_rust::compiler::clc::spirv::SPIRVBin;
use mesa_rust::compiler::clc::*;
use mesa_rust::compiler::nir::*;
use mesa_rust::pipe::resource::*;
use mesa_rust::util::disk_cache::*;
use mesa_rust_gen::*;
use rusticl_opencl_gen::*;
@@ -68,17 +69,18 @@ pub struct Program {
impl_cl_type_trait!(cl_program, Program, CL_INVALID_PROGRAM);
pub struct NirKernelBuild {
pub dev_state: Arc<KernelDevState>,
pub args: Vec<KernelArg>,
pub internal_args: Vec<InternalKernelArg>,
pub attributes_string: String,
pub nir_or_cso: KernelDevStateVariant,
pub constant_buffer: Option<Arc<PipeResource>>,
pub info: pipe_compute_state_object_info,
pub shared_size: u64,
pub printf_info: Option<NirPrintfInfo>,
}
pub(super) struct ProgramBuild {
builds: HashMap<&'static Device, ProgramDevBuild>,
pub struct ProgramBuild {
pub builds: HashMap<&'static Device, ProgramDevBuild>,
pub kernel_info: HashMap<String, KernelInfo>,
spec_constants: HashMap<u32, nir_const_value>,
kernels: Vec<String>,
kernel_builds: HashMap<String, Arc<NirKernelBuild>>,
}
impl ProgramBuild {
@@ -104,7 +106,7 @@ impl ProgramBuild {
}
fn build_nirs(&mut self, is_src: bool) {
for kernel_name in &self.kernels {
for kernel_name in &self.kernels.clone() {
let kernel_args: HashSet<_> = self
.devs_with_build()
.iter()
@@ -112,45 +114,64 @@ impl ProgramBuild {
.collect();
let args = kernel_args.into_iter().next().unwrap();
let mut nirs = HashMap::new();
let mut args_set = HashSet::new();
let mut internal_args_set = HashSet::new();
let mut attributes_string_set = HashSet::new();
let mut kernel_info_set = HashSet::new();
// TODO: we could run this in parallel?
for d in self.devs_with_build() {
let (nir, args, internal_args) = convert_spirv_to_nir(self, kernel_name, &args, d);
let attributes_string = self.attribute_str(kernel_name, d);
nirs.insert(d, nir);
args_set.insert(args);
internal_args_set.insert(internal_args);
attributes_string_set.insert(attributes_string);
}
for dev in self.devs_with_build() {
let (mut nir, args, internal_args) =
convert_spirv_to_nir(self, kernel_name, &args, dev);
let attributes_string = self.attribute_str(kernel_name, dev);
let wgs = nir.workgroup_size();
let shared_size = nir.shared_size() as u64;
let printf_info = nir.take_printf_info();
// we want the same (internal) args for every compiled kernel, for now
assert!(args_set.len() == 1);
assert!(internal_args_set.len() == 1);
assert!(attributes_string_set.len() == 1);
let args = args_set.into_iter().next().unwrap();
let internal_args = internal_args_set.into_iter().next().unwrap();
// spec: For kernels not created from OpenCL C source and the clCreateProgramWithSource
// API call the string returned from this query [CL_KERNEL_ATTRIBUTES] will be empty.
let attributes_string = if is_src {
attributes_string_set.into_iter().next().unwrap()
} else {
String::new()
};
self.kernel_builds.insert(
kernel_name.clone(),
Arc::new(NirKernelBuild {
dev_state: KernelDevState::new(nirs),
let kernel_info = KernelInfo {
args: args,
internal_args: internal_args,
attributes_string: attributes_string,
}),
);
work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize],
subgroup_size: nir.subgroup_size() as usize,
num_subgroups: nir.num_subgroups() as usize,
};
kernel_info_set.insert(kernel_info);
let cso = CSOWrapper::new(dev, &nir);
let info = cso.get_cso_info();
let cb = KernelDevState::create_nir_constant_buffer(dev, &nir);
let nir_or_cso = if !dev.shareable_shaders() {
KernelDevStateVariant::Nir(Arc::new(nir))
} else {
KernelDevStateVariant::Cso(cso)
};
let nir_kernel_build = NirKernelBuild {
nir_or_cso: nir_or_cso,
constant_buffer: cb,
info: info,
shared_size: shared_size,
printf_info: printf_info,
};
self.builds
.get_mut(dev)
.unwrap()
.kernels
.insert(kernel_name.clone(), Arc::new(nir_kernel_build));
}
// we want the same (internal) args for every compiled kernel, for now
assert!(kernel_info_set.len() == 1);
let mut kernel_info = kernel_info_set.into_iter().next().unwrap();
// spec: For kernels not created from OpenCL C source and the clCreateProgramWithSource
// API call the string returned from this query [CL_KERNEL_ATTRIBUTES] will be empty.
if !is_src {
kernel_info.attributes_string = String::new();
}
self.kernel_info.insert(kernel_name.clone(), kernel_info);
}
}
@@ -228,12 +249,13 @@ impl ProgramBuild {
}
}
struct ProgramDevBuild {
pub struct ProgramDevBuild {
spirv: Option<spirv::SPIRVBin>,
status: cl_build_status,
options: String,
log: String,
bin_type: cl_program_binary_type,
pub kernels: HashMap<String, Arc<NirKernelBuild>>,
}
fn prepare_options(options: &str, dev: &Device) -> Vec<CString> {
@@ -297,6 +319,7 @@ impl Program {
log: String::from(""),
options: String::from(""),
bin_type: CL_PROGRAM_BINARY_TYPE_NONE,
kernels: HashMap::new(),
},
)
})
@@ -313,7 +336,7 @@ impl Program {
builds: Self::create_default_builds(devs),
spec_constants: HashMap::new(),
kernels: Vec::new(),
kernel_builds: HashMap::new(),
kernel_info: HashMap::new(),
}),
})
}
@@ -372,6 +395,7 @@ impl Program {
log: String::from(""),
options: String::from(""),
bin_type: bin_type,
kernels: HashMap::new(),
},
);
}
@@ -380,7 +404,7 @@ impl Program {
builds: builds,
spec_constants: HashMap::new(),
kernels: kernels.into_iter().collect(),
kernel_builds: HashMap::new(),
kernel_info: HashMap::new(),
};
build.build_nirs(false);
@@ -404,20 +428,15 @@ impl Program {
builds: builds,
spec_constants: HashMap::new(),
kernels: Vec::new(),
kernel_builds: HashMap::new(),
kernel_info: HashMap::new(),
}),
})
}
fn build_info(&self) -> MutexGuard<ProgramBuild> {
pub fn build_info(&self) -> MutexGuard<ProgramBuild> {
self.build.lock().unwrap()
}
pub fn get_nir_kernel_build(&self, name: &str) -> Arc<NirKernelBuild> {
let info = self.build_info();
info.kernel_builds.get(name).unwrap().clone()
}
pub fn status(&self, dev: &Device) -> cl_build_status {
self.build_info().dev_build(dev).status
}
@@ -510,9 +529,9 @@ impl Program {
pub fn active_kernels(&self) -> bool {
self.build_info()
.kernel_builds
.builds
.values()
.any(|b| Arc::strong_count(b) > 1)
.any(|b| b.kernels.values().any(|b| Arc::strong_count(b) > 1))
}
pub fn build(&self, dev: &Device, options: String) -> bool {
@@ -668,6 +687,7 @@ impl Program {
log: log,
options: String::from(""),
bin_type: bin_type,
kernels: HashMap::new(),
},
);
}
@@ -676,7 +696,7 @@ impl Program {
builds: builds,
spec_constants: HashMap::new(),
kernels: kernels.into_iter().collect(),
kernel_builds: HashMap::new(),
kernel_info: HashMap::new(),
};
// Pre build nir kernels