diff --git a/src/intel/compiler/brw_nir_lower_shader_calls.c b/src/intel/compiler/brw_nir_lower_shader_calls.c index 4f88f10ee0a..38c4e0a3345 100644 --- a/src/intel/compiler/brw_nir_lower_shader_calls.c +++ b/src/intel/compiler/brw_nir_lower_shader_calls.c @@ -124,143 +124,124 @@ store_resume_addr(nir_builder *b, nir_intrinsic_instr *call) nir_btd_stack_push_intel(b, offset); } +static bool +lower_shader_calls_instr(struct nir_builder *b, nir_instr *instr, void *data) +{ + if (instr->type != nir_instr_type_intrinsic) + return false; + + /* Leave nir_intrinsic_rt_resume to be lowered by + * brw_nir_lower_rt_intrinsics() + */ + nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr); + + switch (call->intrinsic) { + case nir_intrinsic_rt_trace_ray: { + store_resume_addr(b, call); + + nir_ssa_def *as_addr = call->src[0].ssa; + nir_ssa_def *ray_flags = call->src[1].ssa; + /* From the SPIR-V spec: + * + * "Only the 8 least-significant bits of Cull Mask are used by this + * instruction - other bits are ignored. + * + * Only the 4 least-significant bits of SBT Offset and SBT Stride are + * used by this instruction - other bits are ignored. + * + * Only the 16 least-significant bits of Miss Index are used by this + * instruction - other bits are ignored." + */ + nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff); + nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf); + nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf); + nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff); + nir_ssa_def *ray_orig = call->src[6].ssa; + nir_ssa_def *ray_t_min = call->src[7].ssa; + nir_ssa_def *ray_dir = call->src[8].ssa; + nir_ssa_def *ray_t_max = call->src[9].ssa; + + /* The hardware packet takes the address to the root node in the + * acceleration structure, not the acceleration structure itself. To + * find that, we have to read the root node offset from the acceleration + * structure which is the first QWord. + */ + nir_ssa_def *root_node_ptr = + nir_iadd(b, as_addr, nir_load_global(b, as_addr, 256, 1, 64)); + + /* The hardware packet requires an address to the first element of the + * hit SBT. + * + * In order to calculate this, we must multiply the "SBT Offset" + * provided to OpTraceRay by the SBT stride provided for the hit SBT in + * the call to vkCmdTraceRay() and add that to the base address of the + * hit SBT. This stride is not to be confused with the "SBT Stride" + * provided to OpTraceRay which is in units of this stride. It's a + * rather terrible overload of the word "stride". The hardware docs + * calls the SPIR-V stride value the "shader index multiplier" which is + * a much more sane name. + */ + nir_ssa_def *hit_sbt_stride_B = + nir_load_ray_hit_sbt_stride_intel(b); + nir_ssa_def *hit_sbt_offset_B = + nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B)); + nir_ssa_def *hit_sbt_addr = + nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b), + nir_u2u64(b, hit_sbt_offset_B)); + + /* The hardware packet takes an address to the miss BSR. */ + nir_ssa_def *miss_sbt_stride_B = + nir_load_ray_miss_sbt_stride_intel(b); + nir_ssa_def *miss_sbt_offset_B = + nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B)); + nir_ssa_def *miss_sbt_addr = + nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b), + nir_u2u64(b, miss_sbt_offset_B)); + + struct brw_nir_rt_mem_ray_defs ray_defs = { + .root_node_ptr = root_node_ptr, + .ray_flags = nir_u2u16(b, ray_flags), + .ray_mask = cull_mask, + .hit_group_sr_base_ptr = hit_sbt_addr, + .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B), + .miss_sr_ptr = miss_sbt_addr, + .orig = ray_orig, + .t_near = ray_t_min, + .dir = ray_dir, + .t_far = ray_t_max, + .shader_index_multiplier = sbt_stride, + }; + brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD); + nir_trace_ray_initial_intel(b); + return true; + } + + case nir_intrinsic_rt_execute_callable: { + store_resume_addr(b, call); + + nir_ssa_def *sbt_offset32 = + nir_imul(b, call->src[0].ssa, + nir_u2u32(b, nir_load_callable_sbt_stride_intel(b))); + nir_ssa_def *sbt_addr = + nir_iadd(b, nir_load_callable_sbt_addr_intel(b), + nir_u2u64(b, sbt_offset32)); + brw_nir_btd_spawn(b, sbt_addr); + return true; + } + + default: + return false; + } +} + bool brw_nir_lower_shader_calls(nir_shader *shader) { - nir_function_impl *impl = nir_shader_get_entrypoint(shader); - bool progress = false; - - nir_builder _b, *b = &_b; - nir_builder_init(&_b, impl); - - nir_foreach_block_safe(block, impl) { - nir_foreach_instr_safe(instr, block) { - if (instr->type != nir_instr_type_intrinsic) - continue; - - nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr); - if (call->intrinsic != nir_intrinsic_rt_trace_ray && - call->intrinsic != nir_intrinsic_rt_execute_callable && - call->intrinsic != nir_intrinsic_rt_resume) - continue; - - b->cursor = nir_before_instr(instr); - - progress = true; - - switch (call->intrinsic) { - case nir_intrinsic_rt_trace_ray: { - store_resume_addr(b, call); - - nir_ssa_def *as_addr = call->src[0].ssa; - nir_ssa_def *ray_flags = call->src[1].ssa; - /* From the SPIR-V spec: - * - * "Only the 8 least-significant bits of Cull Mask are used by - * this instruction - other bits are ignored. - * - * Only the 4 least-significant bits of SBT Offset and SBT - * Stride are used by this instruction - other bits are - * ignored. - * - * Only the 16 least-significant bits of Miss Index are used by - * this instruction - other bits are ignored." - */ - nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff); - nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf); - nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf); - nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff); - nir_ssa_def *ray_orig = call->src[6].ssa; - nir_ssa_def *ray_t_min = call->src[7].ssa; - nir_ssa_def *ray_dir = call->src[8].ssa; - nir_ssa_def *ray_t_max = call->src[9].ssa; - - /* The hardware packet takes the address to the root node in the - * acceleration structure, not the acceleration structure itself. - * To find that, we have to read the root node offset from the - * acceleration structure which is the first QWord. - */ - nir_ssa_def *root_node_ptr = - nir_iadd(b, as_addr, nir_load_global(b, as_addr, 256, 1, 64)); - - /* The hardware packet requires an address to the first element of - * the hit SBT. - * - * In order to calculate this, we must multiply the "SBT Offset" - * provided to OpTraceRay by the SBT stride provided for the hit - * SBT in the call to vkCmdTraceRay() and add that to the base - * address of the hit SBT. This stride is not to be confused with - * the "SBT Stride" provided to OpTraceRay which is in units of - * this stride. It's a rather terrible overload of the word - * "stride". The hardware docs calls the SPIR-V stride value the - * "shader index multiplier" which is a much more sane name. - */ - nir_ssa_def *hit_sbt_stride_B = - nir_load_ray_hit_sbt_stride_intel(b); - nir_ssa_def *hit_sbt_offset_B = - nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B)); - nir_ssa_def *hit_sbt_addr = - nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b), - nir_u2u64(b, hit_sbt_offset_B)); - - /* The hardware packet takes an address to the miss BSR. */ - nir_ssa_def *miss_sbt_stride_B = - nir_load_ray_miss_sbt_stride_intel(b); - nir_ssa_def *miss_sbt_offset_B = - nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B)); - nir_ssa_def *miss_sbt_addr = - nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b), - nir_u2u64(b, miss_sbt_offset_B)); - - struct brw_nir_rt_mem_ray_defs ray_defs = { - .root_node_ptr = root_node_ptr, - .ray_flags = nir_u2u16(b, ray_flags), - .ray_mask = cull_mask, - .hit_group_sr_base_ptr = hit_sbt_addr, - .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B), - .miss_sr_ptr = miss_sbt_addr, - .orig = ray_orig, - .t_near = ray_t_min, - .dir = ray_dir, - .t_far = ray_t_max, - .shader_index_multiplier = sbt_stride, - }; - brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD); - nir_trace_ray_initial_intel(b); - break; - } - - case nir_intrinsic_rt_execute_callable: { - store_resume_addr(b, call); - - nir_ssa_def *sbt_offset32 = - nir_imul(b, call->src[0].ssa, - nir_u2u32(b, nir_load_callable_sbt_stride_intel(b))); - nir_ssa_def *sbt_addr = - nir_iadd(b, nir_load_callable_sbt_addr_intel(b), - nir_u2u64(b, sbt_offset32)); - brw_nir_btd_spawn(b, sbt_addr); - break; - } - - default: - unreachable("Invalid intrinsic"); - } - - nir_instr_remove(&call->instr); - } - } - - nir_foreach_block_safe(block, impl) { - nir_foreach_instr_safe(instr, block) { - if (instr->type != nir_instr_type_intrinsic) - continue; - - - } - } - - return progress; + return nir_shader_instructions_pass(shader, + lower_shader_calls_instr, + nir_metadata_block_index | + nir_metadata_dominance, + NULL); } /** Creates a trivial return shader