gaussian splat gpu position data compression
This commit is contained in:
@@ -5,20 +5,24 @@
|
||||
|
||||
struct GaussianSplat
|
||||
{
|
||||
float3 position;
|
||||
uint2 position_radius; // unorm16x3 | half
|
||||
uint2 color; // half4
|
||||
uint2 cov3D_M11_M12_M13_radius; // half4
|
||||
uint2 cov3D_M11_M12_M13; // half3
|
||||
uint2 cov3D_M22_M23_M33; // half3
|
||||
};
|
||||
struct ShaderGaussianSplatModel
|
||||
{
|
||||
ShaderTransform transform;
|
||||
ShaderTransform transform_inverse;
|
||||
float4x4 modelViewMatrices[6]; // 6 for cubemap rendering
|
||||
int sphericalHarmonicsDegree;
|
||||
int splatStride;
|
||||
int descriptor_splatBuffer;
|
||||
int descriptor_shBuffer;
|
||||
float3 aabb_min; // for position remapping from unorm16x3 to float3
|
||||
float maxScale; // for culling radius modification
|
||||
float3 aabb_max; // for position remapping from unorm16x3 to float3
|
||||
float padding0;
|
||||
float4x4 modelViewMatrices[6]; // 6 for cubemap rendering
|
||||
int sphericalHarmonicsDegree; // 0, 1, 2 or 3
|
||||
int splatStride; // element count of shBuffer per splat
|
||||
int descriptor_splatBuffer; // pointer to StructuredBuffer<GaussianSplat> (one element per splat)
|
||||
int descriptor_shBuffer; // pointer to Buffer<half> (splatStride element count per splat)
|
||||
};
|
||||
|
||||
static const uint GAUSSIAN_COMPUTE_THREADSIZE = 64;
|
||||
|
||||
@@ -31,8 +31,8 @@ void main(uint DTid : SV_DispatchThreadID)
|
||||
return;
|
||||
|
||||
ShaderSphere sphere;
|
||||
sphere.center = mul(model.transform.GetMatrix(), float4(splats[splatIndex].position, 1)).xyz;
|
||||
sphere.radius = max3(mul(model.transform.GetMatrixAdjoint(), unpack_half4(splats[splatIndex].cov3D_M11_M12_M13_radius).www));
|
||||
sphere.center = mul(model.transform.GetMatrix(), float4(lerp(model.aabb_min, model.aabb_max, unpack_unorm16x4(splats[splatIndex].position_radius).xyz), 1)).xyz;
|
||||
sphere.radius = unpack_half4(splats[splatIndex].position_radius).w * model.maxScale;
|
||||
|
||||
const float3 eyeVector = sphere.center - GetCamera().position;
|
||||
const float distSq = dot(eyeVector, eyeVector);
|
||||
|
||||
@@ -44,7 +44,7 @@ static const float frustumDilation = 0.2;
|
||||
|
||||
float3x3 fetchCovariance(in StructuredBuffer<GaussianSplat> splats, in uint splatIndex)
|
||||
{
|
||||
const half3 cov3D_M11_M12_M13 = unpack_half4(splats[splatIndex].cov3D_M11_M12_M13_radius).xyz;
|
||||
const half3 cov3D_M11_M12_M13 = unpack_half3(splats[splatIndex].cov3D_M11_M12_M13);
|
||||
const half3 cov3D_M22_M23_M33 = unpack_half3(splats[splatIndex].cov3D_M22_M23_M33);
|
||||
return float3x3(cov3D_M11_M12_M13.x, cov3D_M11_M12_M13.y, cov3D_M11_M12_M13.z, cov3D_M11_M12_M13.y, cov3D_M22_M23_M33.x, cov3D_M22_M23_M33.y, cov3D_M11_M12_M13.z, cov3D_M22_M23_M33.y, cov3D_M22_M23_M33.z);
|
||||
}
|
||||
@@ -252,11 +252,10 @@ void main(in uint vertexID : SV_VertexID, in uint instanceID : SV_InstanceID, ou
|
||||
//const float4x4 modelViewMatrix = mul(camera.view, model.transform.GetMatrix());
|
||||
const float4x4 modelViewMatrix = model.modelViewMatrices[cameraIndex]; // optimization: precomputed above matrix on CPU
|
||||
|
||||
const float3 splatCenter = float4(splats[splatIndex].position, 1).xyz;
|
||||
const float3 splatCenter = float4(lerp(model.aabb_min, model.aabb_max, unpack_unorm16x4(splats[splatIndex].position_radius).xyz), 1).xyz;
|
||||
color = unpack_half4(splats[splatIndex].color);
|
||||
half3 viewDir = normalize(splatCenter - mul(model.transform_inverse.GetMatrix(), float4(camera.position, 1)).xyz);
|
||||
color.rgb += fetchViewDependentRadiance(model, splatIndex, viewDir);
|
||||
color.rgb = RemoveSRGBCurve_Fast(color.rgb); // Checked against SuperSplat Editor, result is more similar with gamma remove
|
||||
|
||||
const float4 viewCenter = mul(modelViewMatrix, float4(splatCenter, 1.0));
|
||||
const float4 clipCenter = mul(camera.projection, viewCenter);
|
||||
|
||||
@@ -28,6 +28,10 @@ namespace wi
|
||||
GraphicsDevice* device = GetDevice();
|
||||
|
||||
aabb_rest = AABB();
|
||||
for (size_t splatIdx = 0; splatIdx < positions.size(); ++splatIdx)
|
||||
{
|
||||
aabb_rest.AddPoint(positions[splatIdx]);
|
||||
}
|
||||
|
||||
GPUBufferDesc desc;
|
||||
desc.bind_flags = BindFlag::SHADER_RESOURCE;
|
||||
@@ -64,27 +68,29 @@ namespace wi
|
||||
uint16_t* sh_dest = (uint16_t*)((uint8_t*)dest + sh_aligned_offset);
|
||||
for (size_t splatIdx = 0; splatIdx < positions.size(); ++splatIdx)
|
||||
{
|
||||
aabb_rest.AddPoint(positions[splatIdx]);
|
||||
|
||||
GaussianSplat splat = {};
|
||||
splat.position = positions[splatIdx];
|
||||
|
||||
XMFLOAT3 scale = XMFLOAT3(std::exp(scales[splatIdx].x), std::exp(scales[splatIdx].y), std::exp(scales[splatIdx].z));
|
||||
static const float sqrt8 = std::sqrt(8.0f);
|
||||
const float radius = std::max(scale.x, std::max(scale.y, scale.z)) * sqrt8; // culling
|
||||
// position remap and quantize to 16 bit UNORM:
|
||||
const XMFLOAT3 pos = wi::math::InverseLerp(aabb_rest._min, aabb_rest._max, positions[splatIdx]);
|
||||
splat.position_radius.x = uint32_t(saturate(pos.x) * 65535.0f);
|
||||
splat.position_radius.x |= uint32_t(saturate(pos.y) * 65535.0f) << 16u;
|
||||
splat.position_radius.y = uint32_t(saturate(pos.z) * 65535.0f);
|
||||
|
||||
const XMFLOAT3 scale = XMFLOAT3(std::exp(scales[splatIdx].x), std::exp(scales[splatIdx].y), std::exp(scales[splatIdx].z));
|
||||
const float radius = std::max(scale.x, std::max(scale.y, scale.z)); // culling
|
||||
splat.position_radius.y |= uint32_t(XMConvertFloatToHalf(radius)) << 16u; // radius in half precision
|
||||
|
||||
// f_dc is L0 spherical harmonics (not view dependent), so it's converted to rgb color here
|
||||
// https://github.com/nvpro-samples/vk_gaussian_splatting/blob/main/src/splat_set_vk.cpp
|
||||
// I also remove SRGB curve here with pow(rgb, 2.2)
|
||||
static constexpr float SH_C0 = 0.28209479177387814f;
|
||||
float4 color = {};
|
||||
color.x = saturate(0.5f + SH_C0 * f_dc[splatIdx].x);
|
||||
color.y = saturate(0.5f + SH_C0 * f_dc[splatIdx].y);
|
||||
color.z = saturate(0.5f + SH_C0 * f_dc[splatIdx].z);
|
||||
float4 color;
|
||||
color.x = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].x), 2.2f);
|
||||
color.y = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].y), 2.2f);
|
||||
color.z = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].z), 2.2f);
|
||||
color.w = saturate(1.0f / (1.0f + std::exp(-opacities[splatIdx])));
|
||||
splat.color = pack_half4(color);
|
||||
|
||||
const float opacity = saturate(1.0f / (1.0f + std::exp(-opacities[splatIdx])));
|
||||
|
||||
// covariance from: https://github.com/nvpro-samples/vk_gaussian_splatting/blob/main/src/splat_set_vk.cpp
|
||||
// changed from glm to DirectXMath (column->row major, matrix mul order changed)
|
||||
const XMMATRIX scaleMatrix = XMMatrixScaling(scale.x, scale.y, scale.z);
|
||||
@@ -93,16 +99,15 @@ namespace wi
|
||||
const XMMATRIX transformedCovarianceMatrix = XMMatrixMultiply(XMMatrixTranspose(covarianceMatrix), covarianceMatrix);
|
||||
XMFLOAT3X3 transformedCovariance;
|
||||
XMStoreFloat3x3(&transformedCovariance, transformedCovarianceMatrix);
|
||||
float4 cov3D_M11_M12_M13_radius;
|
||||
float3 cov3D_M11_M12_M13;
|
||||
float3 cov3D_M22_M23_M33;
|
||||
cov3D_M11_M12_M13_radius.x = transformedCovariance._11;
|
||||
cov3D_M11_M12_M13_radius.y = transformedCovariance._12;
|
||||
cov3D_M11_M12_M13_radius.z = transformedCovariance._13;
|
||||
cov3D_M11_M12_M13_radius.w = radius;
|
||||
cov3D_M11_M12_M13.x = transformedCovariance._11;
|
||||
cov3D_M11_M12_M13.y = transformedCovariance._12;
|
||||
cov3D_M11_M12_M13.z = transformedCovariance._13;
|
||||
cov3D_M22_M23_M33.x = transformedCovariance._22;
|
||||
cov3D_M22_M23_M33.y = transformedCovariance._23;
|
||||
cov3D_M22_M23_M33.z = transformedCovariance._33;
|
||||
splat.cov3D_M11_M12_M13_radius = pack_half4(cov3D_M11_M12_M13_radius);
|
||||
splat.cov3D_M11_M12_M13 = pack_half3(cov3D_M11_M12_M13);
|
||||
splat.cov3D_M22_M23_M33 = pack_half3(cov3D_M22_M23_M33);
|
||||
|
||||
std::memcpy(splat_dest + splatIdx, &splat, sizeof(splat)); // memcpy into uncached
|
||||
@@ -155,10 +160,16 @@ namespace wi
|
||||
|
||||
void GaussianSplatModel::Update(const XMFLOAT4X4& matrix)
|
||||
{
|
||||
transform = matrix;
|
||||
XMFLOAT4X4 transform_inverse;
|
||||
XMStoreFloat4x4(&transform_inverse, XMMatrixInverse(nullptr, XMLoadFloat4x4(&matrix)));
|
||||
aabb = aabb_rest.transform(matrix);
|
||||
const XMMATRIX W = XMLoadFloat4x4(&matrix);
|
||||
aabb = aabb_rest.transform(W);
|
||||
|
||||
static const float sqrt8 = std::sqrt(8.0f);
|
||||
XMVECTOR scale = XMVectorSet(1, 1, 1, 1);
|
||||
scale = XMVector3TransformNormal(scale, W);
|
||||
maxScale = std::max(XMVectorGetX(scale), std::max(XMVectorGetY(scale), XMVectorGetZ(scale))) * sqrt8;
|
||||
|
||||
XMStoreFloat4x4(&transform, W);
|
||||
XMStoreFloat4x4(&transform_inverse, XMMatrixInverse(nullptr, W));
|
||||
}
|
||||
|
||||
void GaussianSplatModel::Serialize(wi::Archive& archive, wi::ecs::EntitySerializer& seri)
|
||||
@@ -351,6 +362,9 @@ namespace wi
|
||||
shmodel.descriptor_shBuffer = device->GetDescriptorIndex(&model.buffer, SubresourceType::SRV, model.subresource_shBuffer);
|
||||
shmodel.transform.Create(model.transform);
|
||||
shmodel.transform_inverse.Create(model.transform_inverse);
|
||||
shmodel.aabb_min = model.aabb_rest._min;
|
||||
shmodel.aabb_max = model.aabb_rest._max;
|
||||
shmodel.maxScale = model.maxScale;
|
||||
const XMMATRIX modelMatrix = XMLoadFloat4x4(&model.transform);
|
||||
for (uint32_t i = 0; i < camera_count; ++i)
|
||||
{
|
||||
|
||||
@@ -27,6 +27,7 @@ namespace wi
|
||||
wi::primitive::AABB aabb; // aabb with transformation
|
||||
XMFLOAT4X4 transform = wi::math::IDENTITY_MATRIX;
|
||||
XMFLOAT4X4 transform_inverse = wi::math::IDENTITY_MATRIX;
|
||||
float maxScale = 1;
|
||||
wi::graphics::GPUBuffer buffer;
|
||||
int subresource_splatBuffer = -1;
|
||||
int subresource_shBuffer = -1;
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace wi::version
|
||||
// minor features, major updates, breaking compatibility changes
|
||||
const int minor = 72;
|
||||
// minor bug fixes, alterations, refactors, updates
|
||||
const int revision = 52;
|
||||
const int revision = 53;
|
||||
|
||||
const std::string version_string = std::to_string(major) + "." + std::to_string(minor) + "." + std::to_string(revision);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user