gaussian splat gpu position data compression

This commit is contained in:
Turánszki János
2026-03-05 12:41:17 +01:00
parent deba493574
commit 2faa2fcb0e
6 changed files with 53 additions and 35 deletions
@@ -5,20 +5,24 @@
struct GaussianSplat
{
float3 position;
uint2 position_radius; // unorm16x3 | half
uint2 color; // half4
uint2 cov3D_M11_M12_M13_radius; // half4
uint2 cov3D_M11_M12_M13; // half3
uint2 cov3D_M22_M23_M33; // half3
};
struct ShaderGaussianSplatModel
{
ShaderTransform transform;
ShaderTransform transform_inverse;
float4x4 modelViewMatrices[6]; // 6 for cubemap rendering
int sphericalHarmonicsDegree;
int splatStride;
int descriptor_splatBuffer;
int descriptor_shBuffer;
float3 aabb_min; // for position remapping from unorm16x3 to float3
float maxScale; // for culling radius modification
float3 aabb_max; // for position remapping from unorm16x3 to float3
float padding0;
float4x4 modelViewMatrices[6]; // 6 for cubemap rendering
int sphericalHarmonicsDegree; // 0, 1, 2 or 3
int splatStride; // element count of shBuffer per splat
int descriptor_splatBuffer; // pointer to StructuredBuffer<GaussianSplat> (one element per splat)
int descriptor_shBuffer; // pointer to Buffer<half> (splatStride element count per splat)
};
static const uint GAUSSIAN_COMPUTE_THREADSIZE = 64;
+2 -2
View File
@@ -31,8 +31,8 @@ void main(uint DTid : SV_DispatchThreadID)
return;
ShaderSphere sphere;
sphere.center = mul(model.transform.GetMatrix(), float4(splats[splatIndex].position, 1)).xyz;
sphere.radius = max3(mul(model.transform.GetMatrixAdjoint(), unpack_half4(splats[splatIndex].cov3D_M11_M12_M13_radius).www));
sphere.center = mul(model.transform.GetMatrix(), float4(lerp(model.aabb_min, model.aabb_max, unpack_unorm16x4(splats[splatIndex].position_radius).xyz), 1)).xyz;
sphere.radius = unpack_half4(splats[splatIndex].position_radius).w * model.maxScale;
const float3 eyeVector = sphere.center - GetCamera().position;
const float distSq = dot(eyeVector, eyeVector);
+2 -3
View File
@@ -44,7 +44,7 @@ static const float frustumDilation = 0.2;
float3x3 fetchCovariance(in StructuredBuffer<GaussianSplat> splats, in uint splatIndex)
{
const half3 cov3D_M11_M12_M13 = unpack_half4(splats[splatIndex].cov3D_M11_M12_M13_radius).xyz;
const half3 cov3D_M11_M12_M13 = unpack_half3(splats[splatIndex].cov3D_M11_M12_M13);
const half3 cov3D_M22_M23_M33 = unpack_half3(splats[splatIndex].cov3D_M22_M23_M33);
return float3x3(cov3D_M11_M12_M13.x, cov3D_M11_M12_M13.y, cov3D_M11_M12_M13.z, cov3D_M11_M12_M13.y, cov3D_M22_M23_M33.x, cov3D_M22_M23_M33.y, cov3D_M11_M12_M13.z, cov3D_M22_M23_M33.y, cov3D_M22_M23_M33.z);
}
@@ -252,11 +252,10 @@ void main(in uint vertexID : SV_VertexID, in uint instanceID : SV_InstanceID, ou
//const float4x4 modelViewMatrix = mul(camera.view, model.transform.GetMatrix());
const float4x4 modelViewMatrix = model.modelViewMatrices[cameraIndex]; // optimization: precomputed above matrix on CPU
const float3 splatCenter = float4(splats[splatIndex].position, 1).xyz;
const float3 splatCenter = float4(lerp(model.aabb_min, model.aabb_max, unpack_unorm16x4(splats[splatIndex].position_radius).xyz), 1).xyz;
color = unpack_half4(splats[splatIndex].color);
half3 viewDir = normalize(splatCenter - mul(model.transform_inverse.GetMatrix(), float4(camera.position, 1)).xyz);
color.rgb += fetchViewDependentRadiance(model, splatIndex, viewDir);
color.rgb = RemoveSRGBCurve_Fast(color.rgb); // Checked against SuperSplat Editor, result is more similar with gamma remove
const float4 viewCenter = mul(modelViewMatrix, float4(splatCenter, 1.0));
const float4 clipCenter = mul(camera.projection, viewCenter);
+36 -22
View File
@@ -28,6 +28,10 @@ namespace wi
GraphicsDevice* device = GetDevice();
aabb_rest = AABB();
for (size_t splatIdx = 0; splatIdx < positions.size(); ++splatIdx)
{
aabb_rest.AddPoint(positions[splatIdx]);
}
GPUBufferDesc desc;
desc.bind_flags = BindFlag::SHADER_RESOURCE;
@@ -64,27 +68,29 @@ namespace wi
uint16_t* sh_dest = (uint16_t*)((uint8_t*)dest + sh_aligned_offset);
for (size_t splatIdx = 0; splatIdx < positions.size(); ++splatIdx)
{
aabb_rest.AddPoint(positions[splatIdx]);
GaussianSplat splat = {};
splat.position = positions[splatIdx];
XMFLOAT3 scale = XMFLOAT3(std::exp(scales[splatIdx].x), std::exp(scales[splatIdx].y), std::exp(scales[splatIdx].z));
static const float sqrt8 = std::sqrt(8.0f);
const float radius = std::max(scale.x, std::max(scale.y, scale.z)) * sqrt8; // culling
// position remap and quantize to 16 bit UNORM:
const XMFLOAT3 pos = wi::math::InverseLerp(aabb_rest._min, aabb_rest._max, positions[splatIdx]);
splat.position_radius.x = uint32_t(saturate(pos.x) * 65535.0f);
splat.position_radius.x |= uint32_t(saturate(pos.y) * 65535.0f) << 16u;
splat.position_radius.y = uint32_t(saturate(pos.z) * 65535.0f);
const XMFLOAT3 scale = XMFLOAT3(std::exp(scales[splatIdx].x), std::exp(scales[splatIdx].y), std::exp(scales[splatIdx].z));
const float radius = std::max(scale.x, std::max(scale.y, scale.z)); // culling
splat.position_radius.y |= uint32_t(XMConvertFloatToHalf(radius)) << 16u; // radius in half precision
// f_dc is L0 spherical harmonics (not view dependent), so it's converted to rgb color here
// https://github.com/nvpro-samples/vk_gaussian_splatting/blob/main/src/splat_set_vk.cpp
// I also remove SRGB curve here with pow(rgb, 2.2)
static constexpr float SH_C0 = 0.28209479177387814f;
float4 color = {};
color.x = saturate(0.5f + SH_C0 * f_dc[splatIdx].x);
color.y = saturate(0.5f + SH_C0 * f_dc[splatIdx].y);
color.z = saturate(0.5f + SH_C0 * f_dc[splatIdx].z);
float4 color;
color.x = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].x), 2.2f);
color.y = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].y), 2.2f);
color.z = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].z), 2.2f);
color.w = saturate(1.0f / (1.0f + std::exp(-opacities[splatIdx])));
splat.color = pack_half4(color);
const float opacity = saturate(1.0f / (1.0f + std::exp(-opacities[splatIdx])));
// covariance from: https://github.com/nvpro-samples/vk_gaussian_splatting/blob/main/src/splat_set_vk.cpp
// changed from glm to DirectXMath (column->row major, matrix mul order changed)
const XMMATRIX scaleMatrix = XMMatrixScaling(scale.x, scale.y, scale.z);
@@ -93,16 +99,15 @@ namespace wi
const XMMATRIX transformedCovarianceMatrix = XMMatrixMultiply(XMMatrixTranspose(covarianceMatrix), covarianceMatrix);
XMFLOAT3X3 transformedCovariance;
XMStoreFloat3x3(&transformedCovariance, transformedCovarianceMatrix);
float4 cov3D_M11_M12_M13_radius;
float3 cov3D_M11_M12_M13;
float3 cov3D_M22_M23_M33;
cov3D_M11_M12_M13_radius.x = transformedCovariance._11;
cov3D_M11_M12_M13_radius.y = transformedCovariance._12;
cov3D_M11_M12_M13_radius.z = transformedCovariance._13;
cov3D_M11_M12_M13_radius.w = radius;
cov3D_M11_M12_M13.x = transformedCovariance._11;
cov3D_M11_M12_M13.y = transformedCovariance._12;
cov3D_M11_M12_M13.z = transformedCovariance._13;
cov3D_M22_M23_M33.x = transformedCovariance._22;
cov3D_M22_M23_M33.y = transformedCovariance._23;
cov3D_M22_M23_M33.z = transformedCovariance._33;
splat.cov3D_M11_M12_M13_radius = pack_half4(cov3D_M11_M12_M13_radius);
splat.cov3D_M11_M12_M13 = pack_half3(cov3D_M11_M12_M13);
splat.cov3D_M22_M23_M33 = pack_half3(cov3D_M22_M23_M33);
std::memcpy(splat_dest + splatIdx, &splat, sizeof(splat)); // memcpy into uncached
@@ -155,10 +160,16 @@ namespace wi
void GaussianSplatModel::Update(const XMFLOAT4X4& matrix)
{
transform = matrix;
XMFLOAT4X4 transform_inverse;
XMStoreFloat4x4(&transform_inverse, XMMatrixInverse(nullptr, XMLoadFloat4x4(&matrix)));
aabb = aabb_rest.transform(matrix);
const XMMATRIX W = XMLoadFloat4x4(&matrix);
aabb = aabb_rest.transform(W);
static const float sqrt8 = std::sqrt(8.0f);
XMVECTOR scale = XMVectorSet(1, 1, 1, 1);
scale = XMVector3TransformNormal(scale, W);
maxScale = std::max(XMVectorGetX(scale), std::max(XMVectorGetY(scale), XMVectorGetZ(scale))) * sqrt8;
XMStoreFloat4x4(&transform, W);
XMStoreFloat4x4(&transform_inverse, XMMatrixInverse(nullptr, W));
}
void GaussianSplatModel::Serialize(wi::Archive& archive, wi::ecs::EntitySerializer& seri)
@@ -351,6 +362,9 @@ namespace wi
shmodel.descriptor_shBuffer = device->GetDescriptorIndex(&model.buffer, SubresourceType::SRV, model.subresource_shBuffer);
shmodel.transform.Create(model.transform);
shmodel.transform_inverse.Create(model.transform_inverse);
shmodel.aabb_min = model.aabb_rest._min;
shmodel.aabb_max = model.aabb_rest._max;
shmodel.maxScale = model.maxScale;
const XMMATRIX modelMatrix = XMLoadFloat4x4(&model.transform);
for (uint32_t i = 0; i < camera_count; ++i)
{
+1
View File
@@ -27,6 +27,7 @@ namespace wi
wi::primitive::AABB aabb; // aabb with transformation
XMFLOAT4X4 transform = wi::math::IDENTITY_MATRIX;
XMFLOAT4X4 transform_inverse = wi::math::IDENTITY_MATRIX;
float maxScale = 1;
wi::graphics::GPUBuffer buffer;
int subresource_splatBuffer = -1;
int subresource_shBuffer = -1;
+1 -1
View File
@@ -9,7 +9,7 @@ namespace wi::version
// minor features, major updates, breaking compatibility changes
const int minor = 72;
// minor bug fixes, alterations, refactors, updates
const int revision = 52;
const int revision = 53;
const std::string version_string = std::to_string(major) + "." + std::to_string(minor) + "." + std::to_string(revision);