diff --git a/WickedEngine/shaders/ShaderInterop_GaussianSplat.h b/WickedEngine/shaders/ShaderInterop_GaussianSplat.h index f2c3f2fd2..671fa32af 100644 --- a/WickedEngine/shaders/ShaderInterop_GaussianSplat.h +++ b/WickedEngine/shaders/ShaderInterop_GaussianSplat.h @@ -5,20 +5,24 @@ struct GaussianSplat { - float3 position; + uint2 position_radius; // unorm16x3 | half uint2 color; // half4 - uint2 cov3D_M11_M12_M13_radius; // half4 + uint2 cov3D_M11_M12_M13; // half3 uint2 cov3D_M22_M23_M33; // half3 }; struct ShaderGaussianSplatModel { ShaderTransform transform; ShaderTransform transform_inverse; - float4x4 modelViewMatrices[6]; // 6 for cubemap rendering - int sphericalHarmonicsDegree; - int splatStride; - int descriptor_splatBuffer; - int descriptor_shBuffer; + float3 aabb_min; // for position remapping from unorm16x3 to float3 + float maxScale; // for culling radius modification + float3 aabb_max; // for position remapping from unorm16x3 to float3 + float padding0; + float4x4 modelViewMatrices[6]; // 6 for cubemap rendering + int sphericalHarmonicsDegree; // 0, 1, 2 or 3 + int splatStride; // element count of shBuffer per splat + int descriptor_splatBuffer; // pointer to StructuredBuffer (one element per splat) + int descriptor_shBuffer; // pointer to Buffer (splatStride element count per splat) }; static const uint GAUSSIAN_COMPUTE_THREADSIZE = 64; diff --git a/WickedEngine/shaders/gaussian_splatCS.hlsl b/WickedEngine/shaders/gaussian_splatCS.hlsl index 7cfe683bc..df5c1e945 100644 --- a/WickedEngine/shaders/gaussian_splatCS.hlsl +++ b/WickedEngine/shaders/gaussian_splatCS.hlsl @@ -31,8 +31,8 @@ void main(uint DTid : SV_DispatchThreadID) return; ShaderSphere sphere; - sphere.center = mul(model.transform.GetMatrix(), float4(splats[splatIndex].position, 1)).xyz; - sphere.radius = max3(mul(model.transform.GetMatrixAdjoint(), unpack_half4(splats[splatIndex].cov3D_M11_M12_M13_radius).www)); + sphere.center = mul(model.transform.GetMatrix(), float4(lerp(model.aabb_min, model.aabb_max, unpack_unorm16x4(splats[splatIndex].position_radius).xyz), 1)).xyz; + sphere.radius = unpack_half4(splats[splatIndex].position_radius).w * model.maxScale; const float3 eyeVector = sphere.center - GetCamera().position; const float distSq = dot(eyeVector, eyeVector); diff --git a/WickedEngine/shaders/gaussian_splatVS.hlsl b/WickedEngine/shaders/gaussian_splatVS.hlsl index 143d63366..560d1af67 100644 --- a/WickedEngine/shaders/gaussian_splatVS.hlsl +++ b/WickedEngine/shaders/gaussian_splatVS.hlsl @@ -44,7 +44,7 @@ static const float frustumDilation = 0.2; float3x3 fetchCovariance(in StructuredBuffer splats, in uint splatIndex) { - const half3 cov3D_M11_M12_M13 = unpack_half4(splats[splatIndex].cov3D_M11_M12_M13_radius).xyz; + const half3 cov3D_M11_M12_M13 = unpack_half3(splats[splatIndex].cov3D_M11_M12_M13); const half3 cov3D_M22_M23_M33 = unpack_half3(splats[splatIndex].cov3D_M22_M23_M33); return float3x3(cov3D_M11_M12_M13.x, cov3D_M11_M12_M13.y, cov3D_M11_M12_M13.z, cov3D_M11_M12_M13.y, cov3D_M22_M23_M33.x, cov3D_M22_M23_M33.y, cov3D_M11_M12_M13.z, cov3D_M22_M23_M33.y, cov3D_M22_M23_M33.z); } @@ -252,11 +252,10 @@ void main(in uint vertexID : SV_VertexID, in uint instanceID : SV_InstanceID, ou //const float4x4 modelViewMatrix = mul(camera.view, model.transform.GetMatrix()); const float4x4 modelViewMatrix = model.modelViewMatrices[cameraIndex]; // optimization: precomputed above matrix on CPU - const float3 splatCenter = float4(splats[splatIndex].position, 1).xyz; + const float3 splatCenter = float4(lerp(model.aabb_min, model.aabb_max, unpack_unorm16x4(splats[splatIndex].position_radius).xyz), 1).xyz; color = unpack_half4(splats[splatIndex].color); half3 viewDir = normalize(splatCenter - mul(model.transform_inverse.GetMatrix(), float4(camera.position, 1)).xyz); color.rgb += fetchViewDependentRadiance(model, splatIndex, viewDir); - color.rgb = RemoveSRGBCurve_Fast(color.rgb); // Checked against SuperSplat Editor, result is more similar with gamma remove const float4 viewCenter = mul(modelViewMatrix, float4(splatCenter, 1.0)); const float4 clipCenter = mul(camera.projection, viewCenter); diff --git a/WickedEngine/wiGaussianSplatModel.cpp b/WickedEngine/wiGaussianSplatModel.cpp index 55c9835e8..af8032ad2 100644 --- a/WickedEngine/wiGaussianSplatModel.cpp +++ b/WickedEngine/wiGaussianSplatModel.cpp @@ -28,6 +28,10 @@ namespace wi GraphicsDevice* device = GetDevice(); aabb_rest = AABB(); + for (size_t splatIdx = 0; splatIdx < positions.size(); ++splatIdx) + { + aabb_rest.AddPoint(positions[splatIdx]); + } GPUBufferDesc desc; desc.bind_flags = BindFlag::SHADER_RESOURCE; @@ -64,27 +68,29 @@ namespace wi uint16_t* sh_dest = (uint16_t*)((uint8_t*)dest + sh_aligned_offset); for (size_t splatIdx = 0; splatIdx < positions.size(); ++splatIdx) { - aabb_rest.AddPoint(positions[splatIdx]); - GaussianSplat splat = {}; - splat.position = positions[splatIdx]; - XMFLOAT3 scale = XMFLOAT3(std::exp(scales[splatIdx].x), std::exp(scales[splatIdx].y), std::exp(scales[splatIdx].z)); - static const float sqrt8 = std::sqrt(8.0f); - const float radius = std::max(scale.x, std::max(scale.y, scale.z)) * sqrt8; // culling + // position remap and quantize to 16 bit UNORM: + const XMFLOAT3 pos = wi::math::InverseLerp(aabb_rest._min, aabb_rest._max, positions[splatIdx]); + splat.position_radius.x = uint32_t(saturate(pos.x) * 65535.0f); + splat.position_radius.x |= uint32_t(saturate(pos.y) * 65535.0f) << 16u; + splat.position_radius.y = uint32_t(saturate(pos.z) * 65535.0f); + + const XMFLOAT3 scale = XMFLOAT3(std::exp(scales[splatIdx].x), std::exp(scales[splatIdx].y), std::exp(scales[splatIdx].z)); + const float radius = std::max(scale.x, std::max(scale.y, scale.z)); // culling + splat.position_radius.y |= uint32_t(XMConvertFloatToHalf(radius)) << 16u; // radius in half precision // f_dc is L0 spherical harmonics (not view dependent), so it's converted to rgb color here // https://github.com/nvpro-samples/vk_gaussian_splatting/blob/main/src/splat_set_vk.cpp + // I also remove SRGB curve here with pow(rgb, 2.2) static constexpr float SH_C0 = 0.28209479177387814f; - float4 color = {}; - color.x = saturate(0.5f + SH_C0 * f_dc[splatIdx].x); - color.y = saturate(0.5f + SH_C0 * f_dc[splatIdx].y); - color.z = saturate(0.5f + SH_C0 * f_dc[splatIdx].z); + float4 color; + color.x = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].x), 2.2f); + color.y = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].y), 2.2f); + color.z = std::pow(saturate(0.5f + SH_C0 * f_dc[splatIdx].z), 2.2f); color.w = saturate(1.0f / (1.0f + std::exp(-opacities[splatIdx]))); splat.color = pack_half4(color); - const float opacity = saturate(1.0f / (1.0f + std::exp(-opacities[splatIdx]))); - // covariance from: https://github.com/nvpro-samples/vk_gaussian_splatting/blob/main/src/splat_set_vk.cpp // changed from glm to DirectXMath (column->row major, matrix mul order changed) const XMMATRIX scaleMatrix = XMMatrixScaling(scale.x, scale.y, scale.z); @@ -93,16 +99,15 @@ namespace wi const XMMATRIX transformedCovarianceMatrix = XMMatrixMultiply(XMMatrixTranspose(covarianceMatrix), covarianceMatrix); XMFLOAT3X3 transformedCovariance; XMStoreFloat3x3(&transformedCovariance, transformedCovarianceMatrix); - float4 cov3D_M11_M12_M13_radius; + float3 cov3D_M11_M12_M13; float3 cov3D_M22_M23_M33; - cov3D_M11_M12_M13_radius.x = transformedCovariance._11; - cov3D_M11_M12_M13_radius.y = transformedCovariance._12; - cov3D_M11_M12_M13_radius.z = transformedCovariance._13; - cov3D_M11_M12_M13_radius.w = radius; + cov3D_M11_M12_M13.x = transformedCovariance._11; + cov3D_M11_M12_M13.y = transformedCovariance._12; + cov3D_M11_M12_M13.z = transformedCovariance._13; cov3D_M22_M23_M33.x = transformedCovariance._22; cov3D_M22_M23_M33.y = transformedCovariance._23; cov3D_M22_M23_M33.z = transformedCovariance._33; - splat.cov3D_M11_M12_M13_radius = pack_half4(cov3D_M11_M12_M13_radius); + splat.cov3D_M11_M12_M13 = pack_half3(cov3D_M11_M12_M13); splat.cov3D_M22_M23_M33 = pack_half3(cov3D_M22_M23_M33); std::memcpy(splat_dest + splatIdx, &splat, sizeof(splat)); // memcpy into uncached @@ -155,10 +160,16 @@ namespace wi void GaussianSplatModel::Update(const XMFLOAT4X4& matrix) { - transform = matrix; - XMFLOAT4X4 transform_inverse; - XMStoreFloat4x4(&transform_inverse, XMMatrixInverse(nullptr, XMLoadFloat4x4(&matrix))); - aabb = aabb_rest.transform(matrix); + const XMMATRIX W = XMLoadFloat4x4(&matrix); + aabb = aabb_rest.transform(W); + + static const float sqrt8 = std::sqrt(8.0f); + XMVECTOR scale = XMVectorSet(1, 1, 1, 1); + scale = XMVector3TransformNormal(scale, W); + maxScale = std::max(XMVectorGetX(scale), std::max(XMVectorGetY(scale), XMVectorGetZ(scale))) * sqrt8; + + XMStoreFloat4x4(&transform, W); + XMStoreFloat4x4(&transform_inverse, XMMatrixInverse(nullptr, W)); } void GaussianSplatModel::Serialize(wi::Archive& archive, wi::ecs::EntitySerializer& seri) @@ -351,6 +362,9 @@ namespace wi shmodel.descriptor_shBuffer = device->GetDescriptorIndex(&model.buffer, SubresourceType::SRV, model.subresource_shBuffer); shmodel.transform.Create(model.transform); shmodel.transform_inverse.Create(model.transform_inverse); + shmodel.aabb_min = model.aabb_rest._min; + shmodel.aabb_max = model.aabb_rest._max; + shmodel.maxScale = model.maxScale; const XMMATRIX modelMatrix = XMLoadFloat4x4(&model.transform); for (uint32_t i = 0; i < camera_count; ++i) { diff --git a/WickedEngine/wiGaussianSplatModel.h b/WickedEngine/wiGaussianSplatModel.h index 18668aa70..1f0ed743b 100644 --- a/WickedEngine/wiGaussianSplatModel.h +++ b/WickedEngine/wiGaussianSplatModel.h @@ -27,6 +27,7 @@ namespace wi wi::primitive::AABB aabb; // aabb with transformation XMFLOAT4X4 transform = wi::math::IDENTITY_MATRIX; XMFLOAT4X4 transform_inverse = wi::math::IDENTITY_MATRIX; + float maxScale = 1; wi::graphics::GPUBuffer buffer; int subresource_splatBuffer = -1; int subresource_shBuffer = -1; diff --git a/WickedEngine/wiVersion.cpp b/WickedEngine/wiVersion.cpp index e10050258..769209c00 100644 --- a/WickedEngine/wiVersion.cpp +++ b/WickedEngine/wiVersion.cpp @@ -9,7 +9,7 @@ namespace wi::version // minor features, major updates, breaking compatibility changes const int minor = 72; // minor bug fixes, alterations, refactors, updates - const int revision = 52; + const int revision = 53; const std::string version_string = std::to_string(major) + "." + std::to_string(minor) + "." + std::to_string(revision);