#include "wiGaussianSplatModel.h" #include "shaders/ShaderInterop_GaussianSplat.h" #include "wiRenderer.h" #include "wiBacklog.h" #include "wiTimer.h" #include "wiEventHandler.h" #include "wiGPUSortLib.h" #include "wiScene_Components.h" #include "wiProfiler.h" using namespace wi::math; using namespace wi::graphics; using namespace wi::primitive; namespace wi { static Shader computeShader; static Shader computeShader_indirect; static Shader vertexShader; static Shader pixelShader; static BlendState blendState; static RasterizerState rasterizerState; static DepthStencilState depthStencilState; static PipelineState pipelineState; void GaussianSplatModel::CreateRenderData() { GraphicsDevice* device = GetDevice(); aabb_rest = AABB(); for (size_t splatIdx = 0; splatIdx < positions.size(); ++splatIdx) { aabb_rest.AddPoint(positions[splatIdx]); } GPUBufferDesc desc; desc.bind_flags = BindFlag::SHADER_RESOURCE; desc.misc_flags = ResourceMiscFlag::TYPED_FORMAT_CASTING | ResourceMiscFlag::NO_DEFAULT_DESCRIPTORS; const uint64_t alignment = device->GetMinOffsetAlignment(&desc); desc.size = align(uint64_t(positions.size() * sizeof(GaussianSplat)), alignment) + align(uint64_t(f_rest.size() * sizeof(uint16_t)), alignment); const uint64_t sh_aligned_offset = align(uint64_t(positions.size() * sizeof(GaussianSplat)), alignment); auto fill_gpu = [&](void* dest) { const uint32_t totalSphericalHarmonicsComponentCount = uint32_t(f_rest.size() / positions.size()); const uint32_t sphericalHarmonicsCoefficientsPerChannel = totalSphericalHarmonicsComponentCount / 3; int sphericalHarmonicsDegree = 0; int splatStride = 0; if (sphericalHarmonicsCoefficientsPerChannel >= 3) { sphericalHarmonicsDegree = 1; splatStride += 3 * 3; } if (sphericalHarmonicsCoefficientsPerChannel >= 8) { sphericalHarmonicsDegree = 2; splatStride += 5 * 3; } if (sphericalHarmonicsCoefficientsPerChannel == 15) { sphericalHarmonicsDegree = 3; splatStride += 7 * 3; } GaussianSplat* splat_dest = (GaussianSplat*)dest; uint16_t* sh_dest = (uint16_t*)((uint8_t*)dest + sh_aligned_offset); for (size_t splatIdx = 0; splatIdx < positions.size(); ++splatIdx) { GaussianSplat splat = {}; // position remap and quantize to 16 bit UNORM: const XMFLOAT3 pos = wi::math::InverseLerp(aabb_rest._min, aabb_rest._max, positions[splatIdx]); splat.position_radius.x = uint32_t(saturate(pos.x) * 65535.0f); splat.position_radius.x |= uint32_t(saturate(pos.y) * 65535.0f) << 16u; splat.position_radius.y = uint32_t(saturate(pos.z) * 65535.0f); const XMFLOAT3 scale = XMFLOAT3(std::exp(scales[splatIdx].x), std::exp(scales[splatIdx].y), std::exp(scales[splatIdx].z)); const float radius = std::max(scale.x, std::max(scale.y, scale.z)); // culling splat.position_radius.y |= uint32_t(XMConvertFloatToHalf(radius)) << 16u; // radius in half precision // f_dc is L0 spherical harmonics (not view dependent), so it's converted to rgb color here // https://github.com/nvpro-samples/vk_gaussian_splatting/blob/main/src/splat_set_vk.cpp // I also remove SRGB curve here with pow(rgb, 2.2) static constexpr float SH_C0 = 0.28209479177387814f; float4 color; color.x = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].x), 2.2f); color.y = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].y), 2.2f); color.z = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].z), 2.2f); color.w = saturate(1.0f / (1.0f + std::exp(-opacities[splatIdx]))); splat.color = pack_half4(color); // covariance from: https://github.com/nvpro-samples/vk_gaussian_splatting/blob/main/src/splat_set_vk.cpp // changed from glm to DirectXMath (column->row major, matrix mul order changed) const XMMATRIX scaleMatrix = XMMatrixScaling(scale.x, scale.y, scale.z); const XMMATRIX rotationMatrix = XMMatrixRotationQuaternion(XMLoadFloat4(&rotations[splatIdx])); const XMMATRIX covarianceMatrix = XMMatrixMultiply(scaleMatrix, rotationMatrix); const XMMATRIX transformedCovarianceMatrix = XMMatrixMultiply(XMMatrixTranspose(covarianceMatrix), covarianceMatrix); XMFLOAT3X3 transformedCovariance; XMStoreFloat3x3(&transformedCovariance, transformedCovarianceMatrix); float3 cov3D_M11_M12_M13; float3 cov3D_M22_M23_M33; cov3D_M11_M12_M13.x = transformedCovariance._11; cov3D_M11_M12_M13.y = transformedCovariance._12; cov3D_M11_M12_M13.z = transformedCovariance._13; cov3D_M22_M23_M33.x = transformedCovariance._22; cov3D_M22_M23_M33.y = transformedCovariance._23; cov3D_M22_M23_M33.z = transformedCovariance._33; splat.cov3D_M11_M12_M13 = pack_half3(cov3D_M11_M12_M13); splat.cov3D_M22_M23_M33 = pack_half3(cov3D_M22_M23_M33); std::memcpy(splat_dest + splatIdx, &splat, sizeof(splat)); // memcpy into uncached // View dependent SH data is deinterleaved, now I interleave it into 16x (rgb) vectors per splat: uint16_t* dst = sh_dest + splatStride * splatIdx; const auto srcBase = splatStride * splatIdx; int dstOffset = 0; // degree 1, three coefs per component for (auto i = 0; i < 3; i++) { for (auto rgb = 0; rgb < 3; rgb++) { const auto srcIndex = srcBase + (sphericalHarmonicsCoefficientsPerChannel * rgb + i); dst[dstOffset++] = XMConvertFloatToHalf(f_rest[srcIndex]); } } // degree 2, five coefs per component for (auto i = 0; i < 5; i++) { for (auto rgb = 0; rgb < 3; rgb++) { const auto srcIndex = srcBase + (sphericalHarmonicsCoefficientsPerChannel * rgb + 3 + i); dst[dstOffset++] = XMConvertFloatToHalf(f_rest[srcIndex]); } } // degree 3, seven coefs per component for (auto i = 0; i < 7; i++) { for (auto rgb = 0; rgb < 3; rgb++) { const auto srcIndex = srcBase + (sphericalHarmonicsCoefficientsPerChannel * rgb + 3 + 5 + i); dst[dstOffset++] = XMConvertFloatToHalf(f_rest[srcIndex]); } } } }; bool success = device->CreateBuffer2(&desc, fill_gpu, &buffer); assert(success); device->SetName(&buffer, "GaussianSplatModel::buffer"); static constexpr uint32_t structured_stride = sizeof(GaussianSplat); subresource_splatBuffer = device->CreateSubresource(&buffer, SubresourceType::SRV, 0, positions.size() * sizeof(GaussianSplat), nullptr, &structured_stride); static constexpr Format sh_format = Format::R16_FLOAT; subresource_shBuffer = device->CreateSubresource(&buffer, SubresourceType::SRV, sh_aligned_offset, f_rest.size() * sizeof(uint16_t), &sh_format); } void GaussianSplatModel::Update(const XMFLOAT4X4& matrix) { const XMMATRIX W = XMLoadFloat4x4(&matrix); aabb = aabb_rest.transform(W); static const float sqrt8 = std::sqrt(8.0f); XMVECTOR scale = XMVectorSet(1, 1, 1, 1); scale = XMVector3TransformNormal(scale, W); maxScale = std::max(XMVectorGetX(scale), std::max(XMVectorGetY(scale), XMVectorGetZ(scale))) * sqrt8; XMStoreFloat4x4(&transform, W); XMStoreFloat4x4(&transform_inverse, XMMatrixInverse(nullptr, W)); } void GaussianSplatModel::Serialize(wi::Archive& archive, wi::ecs::EntitySerializer& seri) { if (archive.IsReadMode()) { archive >> positions; archive >> rotations; archive >> scales; archive >> opacities; archive >> f_dc; archive >> f_rest; wi::jobsystem::Execute(seri.ctx, [this](wi::jobsystem::JobArgs args) { CreateRenderData(); }); } else { archive << positions; archive << rotations; archive << scales; archive << opacities; archive << f_dc; archive << f_rest; } } void GaussianSplatModel::Initialize() { wi::Timer timer; static auto LoadShaders = []{ wi::renderer::LoadShader(ShaderStage::CS, computeShader, "gaussian_splatCS.cso"); wi::renderer::LoadShader(ShaderStage::CS, computeShader_indirect, "gaussian_splat_indirectCS.cso"); wi::renderer::LoadShader(ShaderStage::VS, vertexShader, "gaussian_splatVS.cso"); wi::renderer::LoadShader(ShaderStage::PS, pixelShader, "gaussian_splatPS.cso"); }; static wi::eventhandler::Handle handle = wi::eventhandler::Subscribe(wi::eventhandler::EVENT_RELOAD_SHADERS, [](uint64_t userdata) { LoadShaders(); }); LoadShaders(); GraphicsDevice* device = GetDevice(); depthStencilState.depth_enable = true; depthStencilState.depth_write_mask = DepthWriteMask::ZERO; depthStencilState.depth_func = ComparisonFunc::GREATER; rasterizerState.cull_mode = CullMode::BACK; rasterizerState.fill_mode = FillMode::SOLID; rasterizerState.depth_clip_enable = true; blendState.render_target[0].blend_enable = true; blendState.render_target[0].src_blend = Blend::ONE; blendState.render_target[0].dest_blend = Blend::INV_SRC_ALPHA; blendState.render_target[0].blend_op = BlendOp::ADD; blendState.render_target[0].src_blend_alpha = Blend::ONE; blendState.render_target[0].dest_blend_alpha = Blend::INV_SRC_ALPHA; blendState.render_target[0].blend_op_alpha = BlendOp::ADD; blendState.render_target[0].render_target_write_mask = ColorWrite::ENABLE_ALL; PipelineStateDesc desc; desc.vs = &vertexShader; desc.ps = &pixelShader; desc.rs = &rasterizerState; desc.bs = &blendState; desc.dss = &depthStencilState; desc.pt = PrimitiveTopology::TRIANGLESTRIP; bool success = device->CreatePipelineState(&desc, &pipelineState); assert(success); wilog("wi::GaussianSplatModel Initialized (%d ms)", (int)std::round(timer.elapsed())); } int GaussianSplatModel::GetSphericalHarmonicsDegree() const { const uint32_t totalSphericalHarmonicsComponentCount = uint32_t(f_rest.size() / positions.size()); const uint32_t sphericalHarmonicsCoefficientsPerChannel = totalSphericalHarmonicsComponentCount / 3; int sphericalHarmonicsDegree = 0; if (sphericalHarmonicsCoefficientsPerChannel >= 3) { sphericalHarmonicsDegree = 1; } if (sphericalHarmonicsCoefficientsPerChannel >= 8) { sphericalHarmonicsDegree = 2; } if (sphericalHarmonicsCoefficientsPerChannel == 15) { sphericalHarmonicsDegree = 3; } return sphericalHarmonicsDegree; } size_t GaussianSplatModel::GetMemorySizeCPU() const { size_t ret = 0; ret += positions.size() * sizeof(XMFLOAT3); ret += rotations.size() * sizeof(XMFLOAT4); ret += scales.size() * sizeof(XMFLOAT3); ret += opacities.size() * sizeof(float); ret += f_dc.size() * sizeof(XMFLOAT3); ret += f_rest.size() * sizeof(float); return ret; } size_t GaussianSplatModel::GetMemorySizeGPU() const { return buffer.desc.size; } void GaussianSplatScene::MakeReservations(const GaussianSplatModel* models, size_t model_count) { GraphicsDevice* device = GetDevice(); size_t global_splat_count = 0; for (size_t model_index = 0; model_index < model_count; ++model_index) { global_splat_count += models[model_index].GetSplatCount(); } if (!indirectBuffer.IsValid()) { GPUBufferDesc desc; desc.stride = sizeof(IndirectDrawArgsInstanced); desc.size = desc.stride; desc.bind_flags = BindFlag::SHADER_RESOURCE | BindFlag::UNORDERED_ACCESS; desc.misc_flags = ResourceMiscFlag::BUFFER_STRUCTURED | ResourceMiscFlag::INDIRECT_ARGS; bool success = device->CreateBuffer(&desc, nullptr, &indirectBuffer); assert(success); device->SetName(&indirectBuffer, "GaussianSplatScene::indirectBuffer"); } if (global_splat_count > splat_capacity) { splat_capacity = global_splat_count; GPUBufferDesc desc; desc.stride = sizeof(uint32_t); desc.size = splat_capacity * desc.stride; desc.bind_flags = BindFlag::SHADER_RESOURCE | BindFlag::UNORDERED_ACCESS; desc.misc_flags = ResourceMiscFlag::BUFFER_STRUCTURED; bool success = device->CreateBuffer(&desc, nullptr, &sortedIndexBuffer); assert(success); device->SetName(&sortedIndexBuffer, "GaussianSplatScene::sortedIndexBuffer"); desc.stride = sizeof(float); desc.size = splat_capacity * desc.stride; desc.bind_flags = BindFlag::SHADER_RESOURCE | BindFlag::UNORDERED_ACCESS; desc.misc_flags = ResourceMiscFlag::BUFFER_STRUCTURED; success = device->CreateBuffer(&desc, nullptr, &distanceBuffer); assert(success); device->SetName(&distanceBuffer, "GaussianSplatScene::distanceBuffer"); desc.stride = sizeof(uint2); desc.size = splat_capacity * desc.stride; desc.bind_flags = BindFlag::SHADER_RESOURCE | BindFlag::UNORDERED_ACCESS; desc.misc_flags = ResourceMiscFlag::BUFFER_STRUCTURED; success = device->CreateBuffer(&desc, nullptr, &splatLookupBuffer); assert(success); device->SetName(&splatLookupBuffer, "GaussianSplatScene::splatLookupBuffer"); } if (model_count > model_capacity) { model_capacity = model_count; GPUBufferDesc desc; desc.stride = sizeof(ShaderGaussianSplatModel); desc.size = model_capacity * desc.stride; desc.bind_flags = BindFlag::SHADER_RESOURCE; desc.misc_flags = ResourceMiscFlag::BUFFER_STRUCTURED; bool success = device->CreateBuffer(&desc, nullptr, &modelBuffer); assert(success); device->SetName(&modelBuffer, "GaussianSplatScene::modelBuffer"); } } void GaussianSplatScene::UpdateGPU(const GaussianSplatModel** models, size_t model_count, CommandList cmd, const XMFLOAT4X4* viewmatrices, uint32_t camera_count) const { ScopedGPUProfiling("Gaussian splat culling and sorting", cmd); GraphicsDevice* device = GetDevice(); device->EventBegin("Gaussian Splat Update", cmd); size_t global_splat_count = 0; const size_t shmodeldatasize = sizeof(ShaderGaussianSplatModel) * model_count; auto alloc = device->AllocateGPU(shmodeldatasize, cmd); ShaderGaussianSplatModel* dest = (ShaderGaussianSplatModel*)alloc.data; for (size_t model_index = 0; model_index < model_count; ++model_index) { const GaussianSplatModel& model = *models[model_index]; ShaderGaussianSplatModel shmodel = {}; shmodel.descriptor_splatBuffer = device->GetDescriptorIndex(&model.buffer, SubresourceType::SRV, model.subresource_splatBuffer); shmodel.descriptor_shBuffer = device->GetDescriptorIndex(&model.buffer, SubresourceType::SRV, model.subresource_shBuffer); shmodel.transform.Create(model.transform); shmodel.transform_inverse.Create(model.transform_inverse); shmodel.aabb_min = model.aabb_rest._min; shmodel.aabb_max = model.aabb_rest._max; shmodel.maxScale = model.maxScale; const XMMATRIX modelMatrix = XMLoadFloat4x4(&model.transform); for (uint32_t i = 0; i < camera_count; ++i) { XMStoreFloat4x4(shmodel.modelViewMatrices + i, XMMatrixMultiply(modelMatrix, XMLoadFloat4x4(viewmatrices + i))); } const uint32_t totalSphericalHarmonicsComponentCount = uint32_t(model.f_rest.size() / model.positions.size()); const uint32_t sphericalHarmonicsCoefficientsPerChannel = totalSphericalHarmonicsComponentCount / 3; int sphericalHarmonicsDegree = 0; int splatStride = 0; if (sphericalHarmonicsCoefficientsPerChannel >= 3) { sphericalHarmonicsDegree = 1; splatStride += 3 * 3; } if (sphericalHarmonicsCoefficientsPerChannel >= 8) { sphericalHarmonicsDegree = 2; splatStride += 5 * 3; } if (sphericalHarmonicsCoefficientsPerChannel == 15) { sphericalHarmonicsDegree = 3; splatStride += 7 * 3; } shmodel.sphericalHarmonicsDegree = sphericalHarmonicsDegree; shmodel.splatStride = splatStride; std::memcpy(dest + model_index, &shmodel, sizeof(shmodel)); // memcpy into uncached global_splat_count += model.GetSplatCount(); } device->CopyBuffer(&modelBuffer, 0, &alloc.buffer, alloc.offset, shmodeldatasize, cmd); IndirectDrawArgsInstanced args = {}; args.VertexCountPerInstance = 4; args.InstanceCount = 0; // shader atomic dest args.StartVertexLocation = 0; args.StartInstanceLocation = 0; device->UpdateBuffer(&indirectBuffer, &args, cmd); { GPUBarrier barriers[] = { GPUBarrier::Buffer(&modelBuffer, ResourceState::COPY_DST, ResourceState::SHADER_RESOURCE), GPUBarrier::Buffer(&indirectBuffer, ResourceState::COPY_DST, ResourceState::UNORDERED_ACCESS), }; device->Barrier(barriers, arraysize(barriers), cmd); } // All models dispatch themselves into global buffer for sorting: for (size_t model_index = 0; model_index < model_count; ++model_index) { const GaussianSplatModel& model = *models[model_index]; device->BindComputeShader(&computeShader, cmd); device->BindResource(&modelBuffer, 0, cmd); device->BindUAV(&indirectBuffer, 0, cmd); device->BindUAV(&sortedIndexBuffer, 1, cmd); device->BindUAV(&distanceBuffer, 2, cmd); device->BindUAV(&splatLookupBuffer, 3, cmd); struct Push { uint model_index; uint camera_count; uint dispatch_offset; } push = {}; push.model_index = (uint)model_index; push.camera_count = camera_count; // Some GPU can't exceed dispatch group count of 65535 in single dimension (DX12 validation), so I do multiple dispatches with max 65535 group count each: int remaining_threadgroups = ((int)model.GetSplatCount() + GAUSSIAN_COMPUTE_THREADSIZE - 1) / GAUSSIAN_COMPUTE_THREADSIZE; uint32_t group_offset = 0; while (remaining_threadgroups > 0) { push.dispatch_offset = group_offset * GAUSSIAN_COMPUTE_THREADSIZE; device->PushConstants(&push, sizeof(push), cmd); const uint32_t threadgroups = (uint32_t)std::min(remaining_threadgroups, 65535); device->Dispatch(threadgroups, 1, 1, cmd); remaining_threadgroups -= threadgroups; group_offset += threadgroups; } } { GPUBarrier barriers[] = { GPUBarrier::Buffer(&indirectBuffer, ResourceState::UNORDERED_ACCESS, ResourceState::SHADER_RESOURCE), GPUBarrier::Buffer(&sortedIndexBuffer, ResourceState::UNORDERED_ACCESS, ResourceState::SHADER_RESOURCE), GPUBarrier::Buffer(&distanceBuffer, ResourceState::UNORDERED_ACCESS, ResourceState::SHADER_RESOURCE), GPUBarrier::Buffer(&splatLookupBuffer, ResourceState::UNORDERED_ACCESS, ResourceState::SHADER_RESOURCE), }; device->Barrier(barriers, arraysize(barriers), cmd); } // Sorting is done globally for buffer containing all models: wi::gpusortlib::Sort((uint32_t)global_splat_count, distanceBuffer, indirectBuffer, offsetof(IndirectDrawArgsInstanced, InstanceCount), sortedIndexBuffer, cmd); if (camera_count > 1) { // InstanceCount multiplied by cameraCount on GPU after sorting: device->BindComputeShader(&computeShader_indirect, cmd); device->BindUAV(&indirectBuffer, 0, cmd); device->PushConstants(&camera_count, sizeof(camera_count), cmd); device->Barrier(GPUBarrier::Buffer(&indirectBuffer, ResourceState::SHADER_RESOURCE, ResourceState::UNORDERED_ACCESS), cmd); device->Dispatch(1, 1, 1, cmd); device->Barrier(GPUBarrier::Buffer(&indirectBuffer, ResourceState::UNORDERED_ACCESS, ResourceState::INDIRECT_ARGUMENT), cmd); } else { device->Barrier(GPUBarrier::Buffer(&indirectBuffer, ResourceState::SHADER_RESOURCE, ResourceState::INDIRECT_ARGUMENT), cmd); } device->EventEnd(cmd); } void GaussianSplatScene::Draw(CommandList cmd, uint32_t camera_count) const { ScopedGPUProfiling("Gaussian splat drawing", cmd); GraphicsDevice* device = GetDevice(); device->EventBegin("Gaussian Splat Render", cmd); device->BindPipelineState(&pipelineState, cmd); device->BindResource(&sortedIndexBuffer, 0, cmd); device->BindResource(&splatLookupBuffer, 1, cmd); device->BindResource(&modelBuffer, 2, cmd); device->PushConstants(&camera_count, sizeof(camera_count), cmd); device->DrawInstancedIndirect(&indirectBuffer, 0, cmd); device->EventEnd(cmd); } void GaussianSplatScene::Clear() { splat_capacity = 0; model_capacity = 0; modelBuffer = {}; indirectBuffer = {}; sortedIndexBuffer = {}; distanceBuffer = {}; splatLookupBuffer = {}; } }