507 lines
19 KiB
C++
507 lines
19 KiB
C++
#include "wiGaussianSplatModel.h"
|
|
#include "shaders/ShaderInterop_GaussianSplat.h"
|
|
#include "wiRenderer.h"
|
|
#include "wiBacklog.h"
|
|
#include "wiTimer.h"
|
|
#include "wiEventHandler.h"
|
|
#include "wiGPUSortLib.h"
|
|
#include "wiScene_Components.h"
|
|
#include "wiProfiler.h"
|
|
|
|
using namespace wi::math;
|
|
using namespace wi::graphics;
|
|
using namespace wi::primitive;
|
|
|
|
namespace wi
|
|
{
|
|
static Shader computeShader;
|
|
static Shader computeShader_indirect;
|
|
static Shader vertexShader;
|
|
static Shader pixelShader;
|
|
static BlendState blendState;
|
|
static RasterizerState rasterizerState;
|
|
static DepthStencilState depthStencilState;
|
|
static PipelineState pipelineState;
|
|
|
|
void GaussianSplatModel::CreateRenderData()
|
|
{
|
|
GraphicsDevice* device = GetDevice();
|
|
|
|
aabb_rest = AABB();
|
|
for (size_t splatIdx = 0; splatIdx < positions.size(); ++splatIdx)
|
|
{
|
|
aabb_rest.AddPoint(positions[splatIdx]);
|
|
}
|
|
|
|
GPUBufferDesc desc;
|
|
desc.bind_flags = BindFlag::SHADER_RESOURCE;
|
|
desc.misc_flags = ResourceMiscFlag::TYPED_FORMAT_CASTING | ResourceMiscFlag::NO_DEFAULT_DESCRIPTORS;
|
|
|
|
const uint64_t alignment = device->GetMinOffsetAlignment(&desc);
|
|
desc.size =
|
|
align(uint64_t(positions.size() * sizeof(GaussianSplat)), alignment) +
|
|
align(uint64_t(f_rest.size() * sizeof(uint16_t)), alignment);
|
|
|
|
const uint64_t sh_aligned_offset = align(uint64_t(positions.size() * sizeof(GaussianSplat)), alignment);
|
|
|
|
auto fill_gpu = [&](void* dest) {
|
|
const uint32_t totalSphericalHarmonicsComponentCount = uint32_t(f_rest.size() / positions.size());
|
|
const uint32_t sphericalHarmonicsCoefficientsPerChannel = totalSphericalHarmonicsComponentCount / 3;
|
|
int sphericalHarmonicsDegree = 0;
|
|
int splatStride = 0;
|
|
if (sphericalHarmonicsCoefficientsPerChannel >= 3)
|
|
{
|
|
sphericalHarmonicsDegree = 1;
|
|
splatStride += 3 * 3;
|
|
}
|
|
if (sphericalHarmonicsCoefficientsPerChannel >= 8)
|
|
{
|
|
sphericalHarmonicsDegree = 2;
|
|
splatStride += 5 * 3;
|
|
}
|
|
if (sphericalHarmonicsCoefficientsPerChannel == 15)
|
|
{
|
|
sphericalHarmonicsDegree = 3;
|
|
splatStride += 7 * 3;
|
|
}
|
|
GaussianSplat* splat_dest = (GaussianSplat*)dest;
|
|
uint16_t* sh_dest = (uint16_t*)((uint8_t*)dest + sh_aligned_offset);
|
|
for (size_t splatIdx = 0; splatIdx < positions.size(); ++splatIdx)
|
|
{
|
|
GaussianSplat splat = {};
|
|
|
|
// position remap and quantize to 16 bit UNORM:
|
|
const XMFLOAT3 pos = wi::math::InverseLerp(aabb_rest._min, aabb_rest._max, positions[splatIdx]);
|
|
splat.position_radius.x = uint32_t(saturate(pos.x) * 65535.0f);
|
|
splat.position_radius.x |= uint32_t(saturate(pos.y) * 65535.0f) << 16u;
|
|
splat.position_radius.y = uint32_t(saturate(pos.z) * 65535.0f);
|
|
|
|
const XMFLOAT3 scale = XMFLOAT3(std::exp(scales[splatIdx].x), std::exp(scales[splatIdx].y), std::exp(scales[splatIdx].z));
|
|
const float radius = std::max(scale.x, std::max(scale.y, scale.z)); // culling
|
|
splat.position_radius.y |= uint32_t(XMConvertFloatToHalf(radius)) << 16u; // radius in half precision
|
|
|
|
// f_dc is L0 spherical harmonics (not view dependent), so it's converted to rgb color here
|
|
// https://github.com/nvpro-samples/vk_gaussian_splatting/blob/main/src/splat_set_vk.cpp
|
|
// I also remove SRGB curve here with pow(rgb, 2.2)
|
|
static constexpr float SH_C0 = 0.28209479177387814f;
|
|
float4 color;
|
|
color.x = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].x), 2.2f);
|
|
color.y = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].y), 2.2f);
|
|
color.z = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].z), 2.2f);
|
|
color.w = saturate(1.0f / (1.0f + std::exp(-opacities[splatIdx])));
|
|
splat.color = pack_half4(color);
|
|
|
|
// covariance from: https://github.com/nvpro-samples/vk_gaussian_splatting/blob/main/src/splat_set_vk.cpp
|
|
// changed from glm to DirectXMath (column->row major, matrix mul order changed)
|
|
const XMMATRIX scaleMatrix = XMMatrixScaling(scale.x, scale.y, scale.z);
|
|
const XMMATRIX rotationMatrix = XMMatrixRotationQuaternion(XMLoadFloat4(&rotations[splatIdx]));
|
|
const XMMATRIX covarianceMatrix = XMMatrixMultiply(scaleMatrix, rotationMatrix);
|
|
const XMMATRIX transformedCovarianceMatrix = XMMatrixMultiply(XMMatrixTranspose(covarianceMatrix), covarianceMatrix);
|
|
XMFLOAT3X3 transformedCovariance;
|
|
XMStoreFloat3x3(&transformedCovariance, transformedCovarianceMatrix);
|
|
float3 cov3D_M11_M12_M13;
|
|
float3 cov3D_M22_M23_M33;
|
|
cov3D_M11_M12_M13.x = transformedCovariance._11;
|
|
cov3D_M11_M12_M13.y = transformedCovariance._12;
|
|
cov3D_M11_M12_M13.z = transformedCovariance._13;
|
|
cov3D_M22_M23_M33.x = transformedCovariance._22;
|
|
cov3D_M22_M23_M33.y = transformedCovariance._23;
|
|
cov3D_M22_M23_M33.z = transformedCovariance._33;
|
|
splat.cov3D_M11_M12_M13 = pack_half3(cov3D_M11_M12_M13);
|
|
splat.cov3D_M22_M23_M33 = pack_half3(cov3D_M22_M23_M33);
|
|
|
|
std::memcpy(splat_dest + splatIdx, &splat, sizeof(splat)); // memcpy into uncached
|
|
|
|
|
|
// View dependent SH data is deinterleaved, now I interleave it into 16x (rgb) vectors per splat:
|
|
uint16_t* dst = sh_dest + splatStride * splatIdx;
|
|
const auto srcBase = splatStride * splatIdx;
|
|
int dstOffset = 0;
|
|
// degree 1, three coefs per component
|
|
for (auto i = 0; i < 3; i++)
|
|
{
|
|
for (auto rgb = 0; rgb < 3; rgb++)
|
|
{
|
|
const auto srcIndex = srcBase + (sphericalHarmonicsCoefficientsPerChannel * rgb + i);
|
|
dst[dstOffset++] = XMConvertFloatToHalf(f_rest[srcIndex]);
|
|
}
|
|
}
|
|
// degree 2, five coefs per component
|
|
for (auto i = 0; i < 5; i++)
|
|
{
|
|
for (auto rgb = 0; rgb < 3; rgb++)
|
|
{
|
|
const auto srcIndex = srcBase + (sphericalHarmonicsCoefficientsPerChannel * rgb + 3 + i);
|
|
dst[dstOffset++] = XMConvertFloatToHalf(f_rest[srcIndex]);
|
|
}
|
|
}
|
|
// degree 3, seven coefs per component
|
|
for (auto i = 0; i < 7; i++)
|
|
{
|
|
for (auto rgb = 0; rgb < 3; rgb++)
|
|
{
|
|
const auto srcIndex = srcBase + (sphericalHarmonicsCoefficientsPerChannel * rgb + 3 + 5 + i);
|
|
dst[dstOffset++] = XMConvertFloatToHalf(f_rest[srcIndex]);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
bool success = device->CreateBuffer2(&desc, fill_gpu, &buffer);
|
|
assert(success);
|
|
device->SetName(&buffer, "GaussianSplatModel::buffer");
|
|
|
|
static constexpr uint32_t structured_stride = sizeof(GaussianSplat);
|
|
subresource_splatBuffer = device->CreateSubresource(&buffer, SubresourceType::SRV, 0, positions.size() * sizeof(GaussianSplat), nullptr, &structured_stride);
|
|
|
|
static constexpr Format sh_format = Format::R16_FLOAT;
|
|
subresource_shBuffer = device->CreateSubresource(&buffer, SubresourceType::SRV, sh_aligned_offset, f_rest.size() * sizeof(uint16_t), &sh_format);
|
|
}
|
|
|
|
void GaussianSplatModel::Update(const XMFLOAT4X4& matrix)
|
|
{
|
|
const XMMATRIX W = XMLoadFloat4x4(&matrix);
|
|
aabb = aabb_rest.transform(W);
|
|
|
|
static const float sqrt8 = std::sqrt(8.0f);
|
|
XMVECTOR scale = XMVectorSet(1, 1, 1, 1);
|
|
scale = XMVector3TransformNormal(scale, W);
|
|
maxScale = std::max(XMVectorGetX(scale), std::max(XMVectorGetY(scale), XMVectorGetZ(scale))) * sqrt8;
|
|
|
|
XMStoreFloat4x4(&transform, W);
|
|
XMStoreFloat4x4(&transform_inverse, XMMatrixInverse(nullptr, W));
|
|
}
|
|
|
|
void GaussianSplatModel::Serialize(wi::Archive& archive, wi::ecs::EntitySerializer& seri)
|
|
{
|
|
if (archive.IsReadMode())
|
|
{
|
|
archive >> positions;
|
|
archive >> rotations;
|
|
archive >> scales;
|
|
archive >> opacities;
|
|
archive >> f_dc;
|
|
archive >> f_rest;
|
|
|
|
wi::jobsystem::Execute(seri.ctx, [this](wi::jobsystem::JobArgs args) {
|
|
CreateRenderData();
|
|
});
|
|
}
|
|
else
|
|
{
|
|
archive << positions;
|
|
archive << rotations;
|
|
archive << scales;
|
|
archive << opacities;
|
|
archive << f_dc;
|
|
archive << f_rest;
|
|
}
|
|
}
|
|
|
|
void GaussianSplatModel::Initialize()
|
|
{
|
|
wi::Timer timer;
|
|
|
|
static auto LoadShaders = []{
|
|
wi::renderer::LoadShader(ShaderStage::CS, computeShader, "gaussian_splatCS.cso");
|
|
wi::renderer::LoadShader(ShaderStage::CS, computeShader_indirect, "gaussian_splat_indirectCS.cso");
|
|
wi::renderer::LoadShader(ShaderStage::VS, vertexShader, "gaussian_splatVS.cso");
|
|
wi::renderer::LoadShader(ShaderStage::PS, pixelShader, "gaussian_splatPS.cso");
|
|
};
|
|
static wi::eventhandler::Handle handle = wi::eventhandler::Subscribe(wi::eventhandler::EVENT_RELOAD_SHADERS, [](uint64_t userdata) { LoadShaders(); });
|
|
LoadShaders();
|
|
|
|
GraphicsDevice* device = GetDevice();
|
|
|
|
depthStencilState.depth_enable = true;
|
|
depthStencilState.depth_write_mask = DepthWriteMask::ZERO;
|
|
depthStencilState.depth_func = ComparisonFunc::GREATER;
|
|
|
|
rasterizerState.cull_mode = CullMode::BACK;
|
|
rasterizerState.fill_mode = FillMode::SOLID;
|
|
rasterizerState.depth_clip_enable = true;
|
|
|
|
blendState.render_target[0].blend_enable = true;
|
|
blendState.render_target[0].src_blend = Blend::ONE;
|
|
blendState.render_target[0].dest_blend = Blend::INV_SRC_ALPHA;
|
|
blendState.render_target[0].blend_op = BlendOp::ADD;
|
|
blendState.render_target[0].src_blend_alpha = Blend::ONE;
|
|
blendState.render_target[0].dest_blend_alpha = Blend::INV_SRC_ALPHA;
|
|
blendState.render_target[0].blend_op_alpha = BlendOp::ADD;
|
|
blendState.render_target[0].render_target_write_mask = ColorWrite::ENABLE_ALL;
|
|
|
|
PipelineStateDesc desc;
|
|
desc.vs = &vertexShader;
|
|
desc.ps = &pixelShader;
|
|
desc.rs = &rasterizerState;
|
|
desc.bs = &blendState;
|
|
desc.dss = &depthStencilState;
|
|
desc.pt = PrimitiveTopology::TRIANGLESTRIP;
|
|
bool success = device->CreatePipelineState(&desc, &pipelineState);
|
|
assert(success);
|
|
|
|
wilog("wi::GaussianSplatModel Initialized (%d ms)", (int)std::round(timer.elapsed()));
|
|
}
|
|
|
|
int GaussianSplatModel::GetSphericalHarmonicsDegree() const
|
|
{
|
|
const uint32_t totalSphericalHarmonicsComponentCount = uint32_t(f_rest.size() / positions.size());
|
|
const uint32_t sphericalHarmonicsCoefficientsPerChannel = totalSphericalHarmonicsComponentCount / 3;
|
|
int sphericalHarmonicsDegree = 0;
|
|
if (sphericalHarmonicsCoefficientsPerChannel >= 3)
|
|
{
|
|
sphericalHarmonicsDegree = 1;
|
|
}
|
|
if (sphericalHarmonicsCoefficientsPerChannel >= 8)
|
|
{
|
|
sphericalHarmonicsDegree = 2;
|
|
}
|
|
if (sphericalHarmonicsCoefficientsPerChannel == 15)
|
|
{
|
|
sphericalHarmonicsDegree = 3;
|
|
}
|
|
return sphericalHarmonicsDegree;
|
|
}
|
|
size_t GaussianSplatModel::GetMemorySizeCPU() const
|
|
{
|
|
size_t ret = 0;
|
|
ret += positions.size() * sizeof(XMFLOAT3);
|
|
ret += rotations.size() * sizeof(XMFLOAT4);
|
|
ret += scales.size() * sizeof(XMFLOAT3);
|
|
ret += opacities.size() * sizeof(float);
|
|
ret += f_dc.size() * sizeof(XMFLOAT3);
|
|
ret += f_rest.size() * sizeof(float);
|
|
return ret;
|
|
}
|
|
size_t GaussianSplatModel::GetMemorySizeGPU() const
|
|
{
|
|
return buffer.desc.size;
|
|
}
|
|
|
|
void GaussianSplatScene::MakeReservations(const GaussianSplatModel* models, size_t model_count)
|
|
{
|
|
GraphicsDevice* device = GetDevice();
|
|
size_t global_splat_count = 0;
|
|
for (size_t model_index = 0; model_index < model_count; ++model_index)
|
|
{
|
|
global_splat_count += models[model_index].GetSplatCount();
|
|
}
|
|
|
|
if (!indirectBuffer.IsValid())
|
|
{
|
|
GPUBufferDesc desc;
|
|
desc.stride = sizeof(IndirectDrawArgsInstanced);
|
|
desc.size = desc.stride;
|
|
desc.bind_flags = BindFlag::SHADER_RESOURCE | BindFlag::UNORDERED_ACCESS;
|
|
desc.misc_flags = ResourceMiscFlag::BUFFER_STRUCTURED | ResourceMiscFlag::INDIRECT_ARGS;
|
|
bool success = device->CreateBuffer(&desc, nullptr, &indirectBuffer);
|
|
assert(success);
|
|
device->SetName(&indirectBuffer, "GaussianSplatScene::indirectBuffer");
|
|
}
|
|
|
|
if (global_splat_count > splat_capacity)
|
|
{
|
|
splat_capacity = global_splat_count;
|
|
|
|
GPUBufferDesc desc;
|
|
desc.stride = sizeof(uint32_t);
|
|
desc.size = splat_capacity * desc.stride;
|
|
desc.bind_flags = BindFlag::SHADER_RESOURCE | BindFlag::UNORDERED_ACCESS;
|
|
desc.misc_flags = ResourceMiscFlag::BUFFER_STRUCTURED;
|
|
bool success = device->CreateBuffer(&desc, nullptr, &sortedIndexBuffer);
|
|
assert(success);
|
|
device->SetName(&sortedIndexBuffer, "GaussianSplatScene::sortedIndexBuffer");
|
|
|
|
desc.stride = sizeof(float);
|
|
desc.size = splat_capacity * desc.stride;
|
|
desc.bind_flags = BindFlag::SHADER_RESOURCE | BindFlag::UNORDERED_ACCESS;
|
|
desc.misc_flags = ResourceMiscFlag::BUFFER_STRUCTURED;
|
|
success = device->CreateBuffer(&desc, nullptr, &distanceBuffer);
|
|
assert(success);
|
|
device->SetName(&distanceBuffer, "GaussianSplatScene::distanceBuffer");
|
|
|
|
desc.stride = sizeof(uint2);
|
|
desc.size = splat_capacity * desc.stride;
|
|
desc.bind_flags = BindFlag::SHADER_RESOURCE | BindFlag::UNORDERED_ACCESS;
|
|
desc.misc_flags = ResourceMiscFlag::BUFFER_STRUCTURED;
|
|
success = device->CreateBuffer(&desc, nullptr, &splatLookupBuffer);
|
|
assert(success);
|
|
device->SetName(&splatLookupBuffer, "GaussianSplatScene::splatLookupBuffer");
|
|
}
|
|
|
|
if (model_count > model_capacity)
|
|
{
|
|
model_capacity = model_count;
|
|
|
|
GPUBufferDesc desc;
|
|
desc.stride = sizeof(ShaderGaussianSplatModel);
|
|
desc.size = model_capacity * desc.stride;
|
|
desc.bind_flags = BindFlag::SHADER_RESOURCE;
|
|
desc.misc_flags = ResourceMiscFlag::BUFFER_STRUCTURED;
|
|
bool success = device->CreateBuffer(&desc, nullptr, &modelBuffer);
|
|
assert(success);
|
|
device->SetName(&modelBuffer, "GaussianSplatScene::modelBuffer");
|
|
}
|
|
}
|
|
|
|
void GaussianSplatScene::UpdateGPU(const GaussianSplatModel** models, size_t model_count, CommandList cmd, const XMFLOAT4X4* viewmatrices, uint32_t camera_count) const
|
|
{
|
|
ScopedGPUProfiling("Gaussian splat culling and sorting", cmd);
|
|
GraphicsDevice* device = GetDevice();
|
|
device->EventBegin("Gaussian Splat Update", cmd);
|
|
|
|
size_t global_splat_count = 0;
|
|
const size_t shmodeldatasize = sizeof(ShaderGaussianSplatModel) * model_count;
|
|
auto alloc = device->AllocateGPU(shmodeldatasize, cmd);
|
|
ShaderGaussianSplatModel* dest = (ShaderGaussianSplatModel*)alloc.data;
|
|
for (size_t model_index = 0; model_index < model_count; ++model_index)
|
|
{
|
|
const GaussianSplatModel& model = *models[model_index];
|
|
ShaderGaussianSplatModel shmodel = {};
|
|
shmodel.descriptor_splatBuffer = device->GetDescriptorIndex(&model.buffer, SubresourceType::SRV, model.subresource_splatBuffer);
|
|
shmodel.descriptor_shBuffer = device->GetDescriptorIndex(&model.buffer, SubresourceType::SRV, model.subresource_shBuffer);
|
|
shmodel.transform.Create(model.transform);
|
|
shmodel.transform_inverse.Create(model.transform_inverse);
|
|
shmodel.aabb_min = model.aabb_rest._min;
|
|
shmodel.aabb_max = model.aabb_rest._max;
|
|
shmodel.maxScale = model.maxScale;
|
|
const XMMATRIX modelMatrix = XMLoadFloat4x4(&model.transform);
|
|
for (uint32_t i = 0; i < camera_count; ++i)
|
|
{
|
|
XMStoreFloat4x4(shmodel.modelViewMatrices + i, XMMatrixMultiply(modelMatrix, XMLoadFloat4x4(viewmatrices + i)));
|
|
}
|
|
|
|
const uint32_t totalSphericalHarmonicsComponentCount = uint32_t(model.f_rest.size() / model.positions.size());
|
|
const uint32_t sphericalHarmonicsCoefficientsPerChannel = totalSphericalHarmonicsComponentCount / 3;
|
|
int sphericalHarmonicsDegree = 0;
|
|
int splatStride = 0;
|
|
if (sphericalHarmonicsCoefficientsPerChannel >= 3)
|
|
{
|
|
sphericalHarmonicsDegree = 1;
|
|
splatStride += 3 * 3;
|
|
}
|
|
if (sphericalHarmonicsCoefficientsPerChannel >= 8)
|
|
{
|
|
sphericalHarmonicsDegree = 2;
|
|
splatStride += 5 * 3;
|
|
}
|
|
if (sphericalHarmonicsCoefficientsPerChannel == 15)
|
|
{
|
|
sphericalHarmonicsDegree = 3;
|
|
splatStride += 7 * 3;
|
|
}
|
|
shmodel.sphericalHarmonicsDegree = sphericalHarmonicsDegree;
|
|
shmodel.splatStride = splatStride;
|
|
|
|
std::memcpy(dest + model_index, &shmodel, sizeof(shmodel)); // memcpy into uncached
|
|
|
|
global_splat_count += model.GetSplatCount();
|
|
}
|
|
device->CopyBuffer(&modelBuffer, 0, &alloc.buffer, alloc.offset, shmodeldatasize, cmd);
|
|
|
|
IndirectDrawArgsInstanced args = {};
|
|
args.VertexCountPerInstance = 4;
|
|
args.InstanceCount = 0; // shader atomic dest
|
|
args.StartVertexLocation = 0;
|
|
args.StartInstanceLocation = 0;
|
|
device->UpdateBuffer(&indirectBuffer, &args, cmd);
|
|
|
|
{
|
|
GPUBarrier barriers[] = {
|
|
GPUBarrier::Buffer(&modelBuffer, ResourceState::COPY_DST, ResourceState::SHADER_RESOURCE),
|
|
GPUBarrier::Buffer(&indirectBuffer, ResourceState::COPY_DST, ResourceState::UNORDERED_ACCESS),
|
|
};
|
|
device->Barrier(barriers, arraysize(barriers), cmd);
|
|
}
|
|
|
|
// All models dispatch themselves into global buffer for sorting:
|
|
for (size_t model_index = 0; model_index < model_count; ++model_index)
|
|
{
|
|
const GaussianSplatModel& model = *models[model_index];
|
|
device->BindComputeShader(&computeShader, cmd);
|
|
device->BindResource(&modelBuffer, 0, cmd);
|
|
device->BindUAV(&indirectBuffer, 0, cmd);
|
|
device->BindUAV(&sortedIndexBuffer, 1, cmd);
|
|
device->BindUAV(&distanceBuffer, 2, cmd);
|
|
device->BindUAV(&splatLookupBuffer, 3, cmd);
|
|
|
|
struct Push
|
|
{
|
|
uint model_index;
|
|
uint camera_count;
|
|
uint dispatch_offset;
|
|
} push = {};
|
|
push.model_index = (uint)model_index;
|
|
push.camera_count = camera_count;
|
|
|
|
// Some GPU can't exceed dispatch group count of 65535 in single dimension (DX12 validation), so I do multiple dispatches with max 65535 group count each:
|
|
int remaining_threadgroups = ((int)model.GetSplatCount() + GAUSSIAN_COMPUTE_THREADSIZE - 1) / GAUSSIAN_COMPUTE_THREADSIZE;
|
|
uint32_t group_offset = 0;
|
|
while (remaining_threadgroups > 0)
|
|
{
|
|
push.dispatch_offset = group_offset * GAUSSIAN_COMPUTE_THREADSIZE;
|
|
device->PushConstants(&push, sizeof(push), cmd);
|
|
const uint32_t threadgroups = (uint32_t)std::min(remaining_threadgroups, 65535);
|
|
device->Dispatch(threadgroups, 1, 1, cmd);
|
|
remaining_threadgroups -= threadgroups;
|
|
group_offset += threadgroups;
|
|
}
|
|
}
|
|
|
|
{
|
|
GPUBarrier barriers[] = {
|
|
GPUBarrier::Buffer(&indirectBuffer, ResourceState::UNORDERED_ACCESS, ResourceState::SHADER_RESOURCE),
|
|
GPUBarrier::Buffer(&sortedIndexBuffer, ResourceState::UNORDERED_ACCESS, ResourceState::SHADER_RESOURCE),
|
|
GPUBarrier::Buffer(&distanceBuffer, ResourceState::UNORDERED_ACCESS, ResourceState::SHADER_RESOURCE),
|
|
GPUBarrier::Buffer(&splatLookupBuffer, ResourceState::UNORDERED_ACCESS, ResourceState::SHADER_RESOURCE),
|
|
};
|
|
device->Barrier(barriers, arraysize(barriers), cmd);
|
|
}
|
|
|
|
// Sorting is done globally for buffer containing all models:
|
|
wi::gpusortlib::Sort((uint32_t)global_splat_count, distanceBuffer, indirectBuffer, offsetof(IndirectDrawArgsInstanced, InstanceCount), sortedIndexBuffer, cmd);
|
|
|
|
if (camera_count > 1)
|
|
{
|
|
// InstanceCount multiplied by cameraCount on GPU after sorting:
|
|
device->BindComputeShader(&computeShader_indirect, cmd);
|
|
device->BindUAV(&indirectBuffer, 0, cmd);
|
|
device->PushConstants(&camera_count, sizeof(camera_count), cmd);
|
|
device->Barrier(GPUBarrier::Buffer(&indirectBuffer, ResourceState::SHADER_RESOURCE, ResourceState::UNORDERED_ACCESS), cmd);
|
|
device->Dispatch(1, 1, 1, cmd);
|
|
device->Barrier(GPUBarrier::Buffer(&indirectBuffer, ResourceState::UNORDERED_ACCESS, ResourceState::INDIRECT_ARGUMENT), cmd);
|
|
}
|
|
else
|
|
{
|
|
device->Barrier(GPUBarrier::Buffer(&indirectBuffer, ResourceState::SHADER_RESOURCE, ResourceState::INDIRECT_ARGUMENT), cmd);
|
|
}
|
|
|
|
device->EventEnd(cmd);
|
|
}
|
|
|
|
void GaussianSplatScene::Draw(CommandList cmd, uint32_t camera_count) const
|
|
{
|
|
ScopedGPUProfiling("Gaussian splat drawing", cmd);
|
|
GraphicsDevice* device = GetDevice();
|
|
device->EventBegin("Gaussian Splat Render", cmd);
|
|
device->BindPipelineState(&pipelineState, cmd);
|
|
device->BindResource(&sortedIndexBuffer, 0, cmd);
|
|
device->BindResource(&splatLookupBuffer, 1, cmd);
|
|
device->BindResource(&modelBuffer, 2, cmd);
|
|
device->PushConstants(&camera_count, sizeof(camera_count), cmd);
|
|
device->DrawInstancedIndirect(&indirectBuffer, 0, cmd);
|
|
device->EventEnd(cmd);
|
|
}
|
|
|
|
void GaussianSplatScene::Clear()
|
|
{
|
|
splat_capacity = 0;
|
|
model_capacity = 0;
|
|
modelBuffer = {};
|
|
indirectBuffer = {};
|
|
sortedIndexBuffer = {};
|
|
distanceBuffer = {};
|
|
splatLookupBuffer = {};
|
|
}
|
|
}
|