gaussian splat improvements

This commit is contained in:
Turánszki János
2026-03-03 08:49:35 +01:00
parent 9f9f166a18
commit 4d2f91afc7
7 changed files with 44 additions and 26 deletions
@@ -15,6 +15,7 @@ struct GaussianSplatCB
{
ShaderTransform transform;
ShaderTransform transform_inverse;
float4x4 modelViewMatrix;
int sphericalHarmonicsDegree;
int splatStride;
int padding0;
+6 -7
View File
@@ -228,19 +228,18 @@ half3 fetchViewDependentRadiance(in uint splatIndex, in float3 worldViewDir)
void main(in uint vertexID : SV_VertexID, in uint instanceID : SV_InstanceID, out float4 pos : SV_Position, out half4 color : COLOR, out half2 localPos : LOCALPOS)
{
const uint splatIndex = sortedIndexBuffer[instanceID];
ShaderCamera camera = GetCamera();
const float3 splatCenter = float4(splats[splatIndex].position, 1).xyz;
color = unpack_half4(splats[splatIndex].color);
float3 viewDir = normalize(mul(cb.transform_inverse.GetMatrix(), float4(GetCamera().position, 1)).xyz - splatCenter);
float3 viewDir = normalize(mul(cb.transform_inverse.GetMatrix(), float4(camera.position, 1)).xyz - splatCenter);
color.rgb += fetchViewDependentRadiance(splatIndex, viewDir);
color.rgb = RemoveSRGBCurve_Fast(color.rgb); // Checked against SuperSplat Editor, result is more similar with gamma remove
float4x4 modelViewMatrix = mul(GetCamera().view, cb.transform.GetMatrix());
const float4 viewCenter = mul(modelViewMatrix, float4(splatCenter, 1.0));
const float4 clipCenter = mul(GetCamera().projection, viewCenter);
const float4 viewCenter = mul(cb.modelViewMatrix, float4(splatCenter, 1.0));
const float4 clipCenter = mul(camera.projection, viewCenter);
const float3x3 cov3Dm = fetchCovariance(splatIndex);
const float3 cov2Dv = threedgsCovarianceProjection(cov3Dm, viewCenter, GetCamera().focal, modelViewMatrix); // computes the basis vectors of the extent of the projected covariance
const float3 cov2Dv = threedgsCovarianceProjection(cov3Dm, viewCenter, camera.focal, cb.modelViewMatrix); // computes the basis vectors of the extent of the projected covariance
const float2 fragPos = BILLBOARD[vertexID].xy;
// We use sqrt(8) standard deviations instead of 3 to eliminate more of the splat with a very low opacity.
@@ -253,7 +252,7 @@ void main(in uint vertexID : SV_VertexID, in uint instanceID : SV_InstanceID, ou
else
{
const float3 ndcCenter = clipCenter.xyz / clipCenter.w;
const float2 ndcOffset = float2(fragPos.x * basisVector1 + fragPos.y * basisVector2) * /*frameInfo.basisViewport*/GetCamera().internal_resolution_rcp * 2.0 * /*frameInfo.inverseFocalAdjustment*/1.0;
const float2 ndcOffset = float2(fragPos.x * basisVector1 + fragPos.y * basisVector2) * /*frameInfo.basisViewport*/camera.internal_resolution_rcp * 2.0 * /*frameInfo.inverseFocalAdjustment*/1.0;
const float4 quadPos = float4(ndcCenter.xy + ndcOffset, ndcCenter.z, 1.0);
pos = quadPos;
}
+13 -6
View File
@@ -5,6 +5,7 @@
#include "wiTimer.h"
#include "wiEventHandler.h"
#include "wiGPUSortLib.h"
#include "wiScene_Components.h"
using namespace wi::math;
using namespace wi::graphics;
@@ -24,13 +25,13 @@ namespace wi
{
GraphicsDevice* device = GetDevice();
aabb = AABB();
aabb_rest = AABB();
auto fill_gpu = [&](void* dest) {
GaussianSplat* splat_dest = (GaussianSplat*)dest;
for (size_t splatIdx = 0; splatIdx < positions.size(); ++splatIdx)
{
aabb.AddPoint(positions[splatIdx]);
aabb_rest.AddPoint(positions[splatIdx]);
GaussianSplat splat = {};
splat.position = positions[splatIdx];
@@ -177,14 +178,19 @@ namespace wi
device->SetName(&constantBuffer, "GaussianSplatModel::constantBuffer");
}
void GaussianSplatModel::Update(const XMFLOAT4X4& transform, wi::graphics::CommandList cmd)
void GaussianSplatModel::Update(const XMFLOAT4X4& matrix)
{
transform = matrix;
XMFLOAT4X4 transform_inverse;
XMStoreFloat4x4(&transform_inverse, XMMatrixInverse(nullptr, XMLoadFloat4x4(&matrix)));
aabb = aabb_rest.transform(matrix);
}
void GaussianSplatModel::UpdateGPU(const wi::scene::CameraComponent& camera, wi::graphics::CommandList cmd)
{
GraphicsDevice* device = GetDevice();
device->EventBegin("Gaussian Splat Update", cmd);
XMFLOAT4X4 transform_inverse;
XMStoreFloat4x4(&transform_inverse, XMMatrixInverse(nullptr, XMLoadFloat4x4(&transform)));
const uint32_t totalSphericalHarmonicsComponentCount = uint32_t(f_rest.size() / positions.size());
const uint32_t sphericalHarmonicsCoefficientsPerChannel = totalSphericalHarmonicsComponentCount / 3;
int sphericalHarmonicsDegree = 0;
@@ -208,6 +214,7 @@ namespace wi
GaussianSplatCB cb = {};
cb.transform.Create(transform);
cb.transform_inverse.Create(transform_inverse);
XMStoreFloat4x4(&cb.modelViewMatrix, XMMatrixMultiply(XMLoadFloat4x4(&transform), camera.GetView()));
cb.sphericalHarmonicsDegree = sphericalHarmonicsDegree;
cb.splatStride = splatStride;
device->UpdateBuffer(&constantBuffer, &cb, cmd);
+9 -4
View File
@@ -5,6 +5,7 @@
#include "wiVector.h"
#include "wiMath.h"
#include "wiECS.h"
#include "wiScene_Decl.h"
namespace wi
{
@@ -18,10 +19,13 @@ namespace wi
wi::vector<XMFLOAT3> scales;
wi::vector<float> opacities;
wi::vector<XMFLOAT3> f_dc;
wi::vector<float> f_rest; // 45 floats per splat (15 * rgb coefficient SH3)
wi::vector<float> f_rest; // number of floats depends on SH degree
// Below this are non-serialized attributes:
wi::primitive::AABB aabb;
wi::primitive::AABB aabb_rest; // aabb without trasformation
wi::primitive::AABB aabb; // aabb with trasformation
XMFLOAT4X4 transform = wi::math::IDENTITY_MATRIX;
XMFLOAT4X4 transform_inverse = wi::math::IDENTITY_MATRIX;
wi::graphics::GPUBuffer splatBuffer;
wi::graphics::GPUBuffer shBuffer;
wi::graphics::GPUBuffer indirectBuffer;
@@ -36,8 +40,9 @@ namespace wi
void CreateRenderData();
void Update(const XMFLOAT4X4& transform, wi::graphics::CommandList cmd);
void Draw(wi::graphics::CommandList cmd);
void Update(const XMFLOAT4X4& matrix);
void UpdateGPU(const wi::scene::CameraComponent& camera, wi::graphics::CommandList cmd); // culling and sorting
void Draw(wi::graphics::CommandList cmd); // will be drawn with culling and sorting based on previous call to UpdateGPU
void Serialize(wi::Archive& archive, wi::ecs::EntitySerializer& seri);
+1 -8
View File
@@ -5644,18 +5644,11 @@ void UpdateRenderDataAsync(
const wi::GaussianSplatModel& splat = vis.scene->gaussian_splats[i];
if (!vis.camera->frustum.CheckBoxFast(splat.aabb))
continue;
XMFLOAT4X4 matrix = wi::math::IDENTITY_MATRIX;
Entity entity = vis.scene->gaussian_splats.GetEntity(i);
const TransformComponent* transform = vis.scene->transforms.GetComponent(entity);
if (transform != nullptr)
{
matrix = transform->world;
}
if (prof_splats == 0)
{
prof_splats = wi::profiler::BeginRangeGPU("Gaussian Splat Culling and Sorting", cmd);
}
vis.scene->gaussian_splats[i].Update(matrix, cmd);
vis.scene->gaussian_splats[i].UpdateGPU(*vis.camera, cmd);
}
if (prof_splats != 0)
{
+13
View File
@@ -5263,6 +5263,19 @@ namespace wi::scene
}
});
for (size_t i = 0; i < gaussian_splats.GetCount(); ++i)
{
const wi::GaussianSplatModel& splat = gaussian_splats[i];
XMFLOAT4X4 matrix = wi::math::IDENTITY_MATRIX;
Entity entity = gaussian_splats.GetEntity(i);
const TransformComponent* transform = transforms.GetComponent(entity);
if (transform != nullptr)
{
matrix = transform->world;
}
gaussian_splats[i].Update(matrix);
}
}
void Scene::RunWeatherUpdateSystem(wi::jobsystem::context& ctx)
{
+1 -1
View File
@@ -9,7 +9,7 @@ namespace wi::version
// minor features, major updates, breaking compatibility changes
const int minor = 72;
// minor bug fixes, alterations, refactors, updates
const int revision = 48;
const int revision = 49;
const std::string version_string = std::to_string(major) + "." + std::to_string(minor) + "." + std::to_string(revision);