diff --git a/src/api/l_graphics_pass.c b/src/api/l_graphics_pass.c index 984d6136..a171ff6c 100644 --- a/src/api/l_graphics_pass.c +++ b/src/api/l_graphics_pass.c @@ -442,6 +442,21 @@ static int l_lovrPassBox(lua_State* L) { return 0; } +static int l_lovrPassCompute(lua_State* L) { + Pass* pass = luax_checktype(L, 1, Pass); + Buffer* buffer = luax_totype(L, 2, Buffer); + if (buffer) { + uint32_t offset = lua_tointeger(L, 3); + lovrPassCompute(pass, 0, 0, 0, buffer, offset); + } else { + uint32_t x = luax_optu32(L, 2, 1); + uint32_t y = luax_optu32(L, 3, 1); + uint32_t z = luax_optu32(L, 4, 1); + lovrPassCompute(pass, x, y, z, NULL, 0); + } + return 0; +} + static int l_lovrPassClear(lua_State* L) { Pass* pass = luax_checktype(L, 1, Pass); @@ -623,6 +638,8 @@ const luaL_Reg lovrPass[] = { { "cube", l_lovrPassCube }, { "box", l_lovrPassBox }, + { "compute", l_lovrPassCompute }, + { "clear", l_lovrPassClear }, { "copy", l_lovrPassCopy }, { "blit", l_lovrPassBlit }, diff --git a/src/modules/graphics/graphics.c b/src/modules/graphics/graphics.c index 427cb217..d1c84d18 100644 --- a/src/modules/graphics/graphics.c +++ b/src/modules/graphics/graphics.c @@ -492,6 +492,8 @@ void lovrGraphicsSubmit(Pass** passes, uint32_t count) { streams[extraPassCount + i] = passes[i]->stream; if (passes[i]->info.type == PASS_RENDER) { gpu_render_end(passes[i]->stream); + } else if (passes[i]->info.type == PASS_COMPUTE) { + gpu_compute_end(passes[i]->stream); } } @@ -1261,7 +1263,23 @@ Pass* lovrGraphicsGetPass(PassInfo* info) { pass->info = *info; pass->stream = gpu_stream_begin(info->label); - if (info->type != PASS_RENDER) { + if (info->type == PASS_TRANSFER) { + return pass; + } + + if (info->type == PASS_COMPUTE) { + memset(pass->constants, 0, sizeof(pass->constants)); + pass->constantsDirty = true; + + pass->bindingMask = 0; + pass->bindingsDirty = true; + + pass->pipeline = &pass->pipelines[0]; + pass->pipeline->shader = NULL; + pass->pipeline->dirty = true; + + gpu_compute_begin(pass->stream); + return pass; } @@ -1861,6 +1879,7 @@ static void flushBindings(Pass* pass, Shader* shader) { return; } + uint32_t set = pass->info.type == PASS_RENDER ? 2 : 0; gpu_binding* bindings = tempAlloc(shader->resourceCount * sizeof(gpu_binding)); for (uint32_t i = 0; i < shader->resourceCount; i++) { @@ -1875,7 +1894,7 @@ static void flushBindings(Pass* pass, Shader* shader) { gpu_bundle* bundle = getBundle(shader->layout); gpu_bundle_write(&bundle, &info, 1); - gpu_bind_bundle(pass->stream, shader->gpu, 2, bundle, NULL, 0); + gpu_bind_bundle(pass->stream, shader->gpu, set, bundle, NULL, 0); } static void flushBuiltins(Pass* pass, Draw* draw, Shader* shader) { @@ -2136,6 +2155,34 @@ void lovrPassBox(Pass* pass, float* transform) { memcpy(indices, indexData, sizeof(indexData)); } +void lovrPassCompute(Pass* pass, uint32_t x, uint32_t y, uint32_t z, Buffer* indirect, uint32_t offset) { + lovrCheck(pass->info.type == PASS_COMPUTE, "This function can only be called on a compute pass"); + + Shader* shader = pass->pipeline->shader; + lovrCheck(shader && shader->info.type == SHADER_COMPUTE, "Tried to run a compute shader, but no compute shader is bound"); + lovrCheck(x <= state.limits.computeDispatchCount[0], "Compute %s count exceeds computeDispatchCount limit", "x"); + lovrCheck(y <= state.limits.computeDispatchCount[1], "Compute %s count exceeds computeDispatchCount limit", "y"); + lovrCheck(z <= state.limits.computeDispatchCount[2], "Compute %s count exceeds computeDispatchCount limit", "z"); + + gpu_pipeline* pipeline = state.pipelines.data[shader->computePipeline]; + + if (pass->pipeline->dirty) { + gpu_bind_pipeline(pass->stream, pipeline, true); + pass->pipeline->dirty = false; + } + + flushConstants(pass, shader); + flushBindings(pass, shader); + + if (indirect) { + lovrCheck(offset % 4 == 0, "Indirect compute offset must be a multiple of 4"); + lovrCheck(offset <= indirect->size - 12, "Indirect compute offset overflows the Buffer"); + gpu_compute_indirect(pass->stream, indirect->gpu, offset); + } else { + gpu_compute(pass->stream, x, y, z); + } +} + void lovrPassClearBuffer(Pass* pass, Buffer* buffer, uint32_t offset, uint32_t extent) { if (extent == 0) return; if (extent == ~0u) extent = buffer->size - offset; diff --git a/src/modules/graphics/graphics.h b/src/modules/graphics/graphics.h index 7edd2d02..2f88c7ca 100644 --- a/src/modules/graphics/graphics.h +++ b/src/modules/graphics/graphics.h @@ -392,6 +392,7 @@ void lovrPassPoints(Pass* pass, uint32_t count, float** vertices); void lovrPassLine(Pass* pass, uint32_t count, float** vertices); void lovrPassPlane(Pass* pass, float* transform, uint32_t cols, uint32_t rows); void lovrPassBox(Pass* pass, float* transform); +void lovrPassCompute(Pass* pass, uint32_t x, uint32_t y, uint32_t z, Buffer* indirect, uint32_t offset); void lovrPassClearBuffer(Pass* pass, Buffer* buffer, uint32_t offset, uint32_t extent); void lovrPassClearTexture(Pass* pass, Texture* texture, float value[4], uint32_t layer, uint32_t layerCount, uint32_t level, uint32_t levelCount); void lovrPassCopyDataToBuffer(Pass* pass, void* data, Buffer* buffer, uint32_t offset, uint32_t size);