diff --git a/.pick_status.json b/.pick_status.json index fb7e1ef51f0..bbbe9196b8b 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -954,7 +954,7 @@ "description": "rusticl: check for overrun status when deserializing", "nominated": true, "nomination_type": 1, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": null, "notes": null diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index 0e5d66f6c77..34b1f916523 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -25,6 +25,7 @@ use std::convert::TryInto; use std::fmt::Debug; use std::fmt::Display; use std::ops::Index; +use std::ops::Not; use std::os::raw::c_void; use std::ptr; use std::slice; @@ -58,8 +59,10 @@ pub enum KernelArgType { impl KernelArgType { fn deserialize(blob: &mut blob_reader) -> Option { - Some(match unsafe { blob_read_uint8(blob) } { + // SAFETY: we get 0 on an overrun, but we verify that later and act accordingly. + let res = match unsafe { blob_read_uint8(blob) } { 0 => { + // SAFETY: same here let size = unsafe { blob_read_uint16(blob) }; KernelArgType::Constant(size) } @@ -71,7 +74,9 @@ impl KernelArgType { 6 => KernelArgType::MemConstant, 7 => KernelArgType::MemLocal, _ => return None, - }) + }; + + blob.overrun.not().then_some(res) } fn serialize(&self, blob: &mut blob) { @@ -192,24 +197,24 @@ impl KernelArg { } fn deserialize(blob: &mut blob_reader) -> Option> { - unsafe { - let len = blob_read_uint16(blob) as usize; - let mut res = Vec::with_capacity(len); + // SAFETY: we check the overrun status, blob_read returns 0 in such a case. + let len = unsafe { blob_read_uint16(blob) } as usize; + let mut res = Vec::with_capacity(len); - for _ in 0..len { - let spirv = spirv::SPIRVKernelArg::deserialize(blob)?; - let dead = blob_read_uint8(blob) != 0; - let kind = KernelArgType::deserialize(blob)?; + for _ in 0..len { + let spirv = spirv::SPIRVKernelArg::deserialize(blob)?; + // SAFETY: we check the overrun status + let dead = unsafe { blob_read_uint8(blob) } != 0; + let kind = KernelArgType::deserialize(blob)?; - res.push(Self { - spirv: spirv, - kind: kind, - dead: dead, - }); - } - - Some(res) + res.push(Self { + spirv: spirv, + kind: kind, + dead: dead, + }); } + + blob.overrun.not().then_some(res) } } @@ -1080,18 +1085,16 @@ impl SPIRVToNirResult { let args = KernelArg::deserialize(&mut reader)?; let default_build = CompilationResult::deserialize(&mut reader, d)?; + // SAFETY: on overrun this returns 0 let optimized = match unsafe { blob_read_uint8(&mut reader) } { 0 => None, _ => Some(CompilationResult::deserialize(&mut reader, d)?), }; - Some(SPIRVToNirResult::new( - d, - kernel_info, - args, - default_build, - optimized, - )) + reader + .overrun + .not() + .then(|| SPIRVToNirResult::new(d, kernel_info, args, default_build, optimized)) } // we can't use Self here as the nir shader might be compiled to a cso already and we can't diff --git a/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs b/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs index a25c5189ded..89e2e30abae 100644 --- a/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs +++ b/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs @@ -9,6 +9,7 @@ use mesa_rust_util::string::*; use std::ffi::CString; use std::fmt::Debug; +use std::ops::Not; use std::os::raw::c_char; use std::os::raw::c_void; use std::ptr; @@ -484,7 +485,8 @@ impl SPIRVKernelArg { _ => return None, }; - Some(Self { + // check overrun to ensure nothing went wrong + blob.overrun.not().then(|| Self { name: String::from_utf8_unchecked(name.to_owned()), type_name: String::from_utf8_unchecked(type_name.to_owned()), access_qualifier: clc_kernel_arg_access_qualifier(access_qualifier), diff --git a/src/gallium/frontends/rusticl/mesa/compiler/nir.rs b/src/gallium/frontends/rusticl/mesa/compiler/nir.rs index e8e9f078fd3..36c632a4720 100644 --- a/src/gallium/frontends/rusticl/mesa/compiler/nir.rs +++ b/src/gallium/frontends/rusticl/mesa/compiler/nir.rs @@ -5,6 +5,7 @@ use mesa_rust_util::offset_of; use std::convert::TryInto; use std::ffi::CString; use std::marker::PhantomData; +use std::ops::Not; use std::ptr; use std::ptr::NonNull; use std::slice; @@ -158,7 +159,9 @@ impl NirShader { blob: &mut blob_reader, options: *const nir_shader_compiler_options, ) -> Option { - unsafe { Self::new(nir_deserialize(ptr::null_mut(), options, blob)) } + // we already create the NirShader here so it gets automatically deallocated on overrun. + let nir = Self::new(unsafe { nir_deserialize(ptr::null_mut(), options, blob) })?; + blob.overrun.not().then_some(nir) } pub fn serialize(&self, blob: &mut blob) {