diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 5d8de704520..d70e91b980f 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -718,6 +718,12 @@ typedef struct nir_register { /* The bit-size of each channel; must be one of 8, 16, 32, or 64 */ uint8_t bit_size; + /** + * True if this register may have different values in different SIMD + * invocations of the shader. + */ + bool divergent; + /** generic register index. */ unsigned index; @@ -967,8 +973,7 @@ nir_src_is_const(nir_src src) static inline bool nir_src_is_divergent(nir_src src) { - assert(src.is_ssa); - return src.ssa->divergent; + return src.is_ssa ? src.ssa->divergent : src.reg.reg->divergent; } static inline unsigned @@ -986,8 +991,7 @@ nir_dest_num_components(nir_dest dest) static inline bool nir_dest_is_divergent(nir_dest dest) { - assert(dest.is_ssa); - return dest.ssa.divergent; + return dest.is_ssa ? dest.ssa.divergent : dest.reg.reg->divergent; } /* Are all components the same, ie. .xxxx */ diff --git a/src/compiler/nir/nir_from_ssa.c b/src/compiler/nir/nir_from_ssa.c index fcaf156d218..3e98f317140 100644 --- a/src/compiler/nir/nir_from_ssa.c +++ b/src/compiler/nir/nir_from_ssa.c @@ -106,6 +106,7 @@ typedef struct { typedef struct merge_set { struct exec_list nodes; unsigned size; + bool divergent; nir_register *reg; } merge_set; @@ -144,6 +145,7 @@ get_merge_node(nir_ssa_def *def, struct from_ssa_state *state) merge_set *set = ralloc(state->dead_ctx, merge_set); exec_list_make_empty(&set->nodes); set->size = 1; + set->divergent = def->divergent; set->reg = NULL; merge_node *node = ralloc(state->dead_ctx, merge_node); @@ -186,6 +188,7 @@ merge_merge_sets(merge_set *a, merge_set *b) a->size += b->size; b->size = 0; + a->divergent |= b->divergent; return a; } @@ -358,6 +361,7 @@ isolate_phi_nodes_block(nir_block *block, void *dead_ctx) nir_ssa_dest_init(&pcopy->instr, &entry->dest, phi->dest.ssa.num_components, phi->dest.ssa.bit_size, src->src.ssa->name); + entry->dest.ssa.divergent = nir_src_is_divergent(src->src); exec_list_push_tail(&pcopy->entries, &entry->node); assert(src->src.is_ssa); @@ -372,6 +376,7 @@ isolate_phi_nodes_block(nir_block *block, void *dead_ctx) nir_ssa_dest_init(&block_pcopy->instr, &entry->dest, phi->dest.ssa.num_components, phi->dest.ssa.bit_size, phi->dest.ssa.name); + entry->dest.ssa.divergent = phi->dest.ssa.divergent; exec_list_push_tail(&block_pcopy->entries, &entry->node); nir_ssa_def_rewrite_uses(&phi->dest.ssa, @@ -432,6 +437,12 @@ aggressive_coalesce_parallel_copy(nir_parallel_copy_instr *pcopy, if (src_node->set == dest_node->set) continue; + /* TODO: We can probably do better here but for now we should be safe if + * we just don't coalesce things with different divergence. + */ + if (dest_node->set->divergent != src_node->set->divergent) + continue; + if (!merge_sets_interfere(src_node->set, dest_node->set)) merge_merge_sets(src_node->set, dest_node->set); } @@ -493,8 +504,10 @@ rewrite_ssa_def(nir_ssa_def *def, void *void_state) * the things in the merge set should be the same so it doesn't * matter which node's definition we use. */ - if (node->set->reg == NULL) + if (node->set->reg == NULL) { node->set->reg = create_reg_for_ssa_def(def, state->builder.impl); + node->set->reg->divergent = node->set->divergent; + } reg = node->set->reg; } else { @@ -562,6 +575,8 @@ emit_copy(nir_builder *b, nir_src src, nir_src dest_src) dest_src.reg.indirect == NULL && dest_src.reg.base_offset == 0); + assert(!nir_src_is_divergent(src) || nir_src_is_divergent(dest_src)); + if (src.is_ssa) assert(src.ssa->num_components >= dest_src.reg.reg->num_components); else @@ -699,13 +714,26 @@ resolve_parallel_copy(nir_parallel_copy_instr *pcopy, /* b has been filled, mark it as not needing to be copied */ pred[b] = -1; - /* If a needs to be filled... */ - if (pred[a] != -1) { - /* If any other copies want a they can find it at b */ + /* The next bit only applies if the source and destination have the + * same divergence. If they differ (it must be convergent -> + * divergent), then we can't guarantee we won't need the convergent + * version of again. + */ + if (nir_src_is_divergent(values[a]) == + nir_src_is_divergent(values[b])) { + /* If any other copies want a they can find it at b but only if the + * two have the same divergence. + */ loc[a] = b; - /* It's ready for copying now */ - ready[++ready_idx] = a; + /* If a needs to be filled... */ + if (pred[a] != -1) { + /* If any other copies want a they can find it at b */ + loc[a] = b; + + /* It's ready for copying now */ + ready[++ready_idx] = a; + } } } int b = to_do[to_do_idx--]; @@ -732,6 +760,7 @@ resolve_parallel_copy(nir_parallel_copy_instr *pcopy, reg->num_components = values[b].reg.reg->num_components; reg->bit_size = values[b].reg.reg->bit_size; } + reg->divergent = nir_src_is_divergent(values[b]); values[num_vals].is_ssa = false; values[num_vals].reg.reg = reg;