zink/spirv: implement loops

Acked-by: Jordan Justen <jordan.l.justen@intel.com>
This commit is contained in:
Erik Faye-Lund
2019-03-18 20:29:49 +01:00
parent acdd12dae3
commit b458863c1e
3 changed files with 70 additions and 3 deletions

View File

@@ -54,6 +54,7 @@ struct ntv_context {
const SpvId *block_ids;
size_t num_blocks;
bool block_started;
SpvId loop_break, loop_cont;
};
static SpvId
@@ -1183,6 +1184,25 @@ branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
ctx->block_started = false;
}
static void
emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
{
switch (jump->type) {
case nir_jump_break:
assert(ctx->loop_break);
branch(ctx, ctx->loop_break);
break;
case nir_jump_continue:
assert(ctx->loop_cont);
branch(ctx, ctx->loop_cont);
break;
default:
unreachable("Unsupported jump type\n");
}
}
static void
emit_block(struct ntv_context *ctx, struct nir_block *block)
{
@@ -1208,7 +1228,7 @@ emit_block(struct ntv_context *ctx, struct nir_block *block)
unreachable("nir_instr_type_phi not supported");
break;
case nir_instr_type_jump:
unreachable("nir_instr_type_jump not supported");
emit_jump(ctx, nir_instr_as_jump(instr));
break;
case nir_instr_type_call:
unreachable("nir_instr_type_call not supported");
@@ -1260,13 +1280,45 @@ emit_if(struct ntv_context *ctx, nir_if *if_stmt)
emit_cf_list(ctx, &if_stmt->then_list);
if (has_else) {
branch(ctx, endif_id);
if (ctx->block_started)
branch(ctx, endif_id);
emit_cf_list(ctx, &if_stmt->else_list);
}
start_block(ctx, endif_id);
}
static void
emit_loop(struct ntv_context *ctx, nir_loop *loop)
{
SpvId header_id = spirv_builder_new_id(&ctx->builder);
SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
SpvId break_id = spirv_builder_new_id(&ctx->builder);
SpvId cont_id = spirv_builder_new_id(&ctx->builder);
/* create a header-block */
start_block(ctx, header_id);
spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
branch(ctx, begin_id);
SpvId save_break = ctx->loop_break;
SpvId save_cont = ctx->loop_cont;
ctx->loop_break = break_id;
ctx->loop_cont = cont_id;
emit_cf_list(ctx, &loop->body);
ctx->loop_break = save_break;
ctx->loop_cont = save_cont;
branch(ctx, cont_id);
start_block(ctx, cont_id);
branch(ctx, header_id);
start_block(ctx, break_id);
}
static void
emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
{
@@ -1281,7 +1333,7 @@ emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
break;
case nir_cf_node_loop:
unreachable("nir_cf_node_loop not supported");
emit_loop(ctx, nir_cf_node_as_loop(node));
break;
case nir_cf_node_function:

View File

@@ -450,6 +450,17 @@ spirv_builder_emit_selection_merge(struct spirv_builder *b, SpvId merge_block,
spirv_buffer_emit_word(&b->instructions, selection_control);
}
void
spirv_builder_loop_merge(struct spirv_builder *b, SpvId merge_block,
SpvId cont_target, SpvLoopControlMask loop_control)
{
spirv_buffer_prepare(&b->instructions, 4);
spirv_buffer_emit_word(&b->instructions, SpvOpLoopMerge | (4 << 16));
spirv_buffer_emit_word(&b->instructions, merge_block);
spirv_buffer_emit_word(&b->instructions, cont_target);
spirv_buffer_emit_word(&b->instructions, loop_control);
}
void
spirv_builder_emit_branch_conditional(struct spirv_builder *b, SpvId condition,
SpvId true_label, SpvId false_label)

View File

@@ -182,6 +182,10 @@ void
spirv_builder_emit_selection_merge(struct spirv_builder *b, SpvId merge_block,
SpvSelectionControlMask selection_control);
void
spirv_builder_loop_merge(struct spirv_builder *b, SpvId merge_block,
SpvId cont_target, SpvLoopControlMask loop_control);
void
spirv_builder_emit_branch_conditional(struct spirv_builder *b, SpvId condition,
SpvId true_label, SpvId false_label);