diff --git a/drivers/metal/metal_objects.mm b/drivers/metal/metal_objects.mm index e65d9154840..76a3f67d0de 100644 --- a/drivers/metal/metal_objects.mm +++ b/drivers/metal/metal_objects.mm @@ -722,8 +722,6 @@ void MDCommandBuffer::encodeRenderCommandEncoderWithDescriptor(MTLRenderPassDesc void MDCommandBuffer::render_bind_uniform_sets(VectorView p_uniform_sets, RDD::ShaderID p_shader, uint32_t p_first_set_index, uint32_t p_set_count, uint32_t p_dynamic_offsets) { DEV_ASSERT(type == MDCommandBufferStateType::Render); - render.dynamic_offsets |= p_dynamic_offsets; - if (uint32_t new_size = p_first_set_index + p_set_count; render.uniform_sets.size() < new_size) { uint32_t s = render.uniform_sets.size(); render.uniform_sets.resize(new_size); @@ -734,6 +732,20 @@ void MDCommandBuffer::render_bind_uniform_sets(VectorView p_u const MDShader *shader = (const MDShader *)p_shader.id; DynamicOffsetLayout layout = shader->dynamic_offset_layout; + // Clear bits for sets being rebound before OR'ing new values. + // This prevents corruption when the same set is bound multiple times + // with different frame indices (e.g., OPAQUE pass then ALPHA pass). + for (uint32_t i = 0; i < p_set_count && render.dynamic_offsets != 0; i++) { + uint32_t set_index = p_first_set_index + i; + uint32_t count = layout.get_count(set_index); + if (count > 0) { + uint32_t shift = layout.get_offset_index_shift(set_index); + uint32_t mask = ((1u << (count * 4u)) - 1u) << shift; + render.dynamic_offsets &= ~mask; + } + } + render.dynamic_offsets |= p_dynamic_offsets; + for (size_t i = 0; i < p_set_count; ++i) { MDUniformSet *set = (MDUniformSet *)(p_uniform_sets[i].id); @@ -1620,8 +1632,6 @@ void MDCommandBuffer::ComputeState::reset() { void MDCommandBuffer::compute_bind_uniform_sets(VectorView p_uniform_sets, RDD::ShaderID p_shader, uint32_t p_first_set_index, uint32_t p_set_count, uint32_t p_dynamic_offsets) { DEV_ASSERT(type == MDCommandBufferStateType::Compute); - compute.dynamic_offsets |= p_dynamic_offsets; - if (uint32_t new_size = p_first_set_index + p_set_count; compute.uniform_sets.size() < new_size) { uint32_t s = compute.uniform_sets.size(); compute.uniform_sets.resize(new_size); @@ -1632,6 +1642,20 @@ void MDCommandBuffer::compute_bind_uniform_sets(VectorView p_ const MDShader *shader = (const MDShader *)p_shader.id; DynamicOffsetLayout layout = shader->dynamic_offset_layout; + // Clear bits for sets being rebound before OR'ing new values. + // This prevents corruption when the same set is bound multiple times + // with different frame indices. + for (uint32_t i = 0; i < p_set_count && compute.dynamic_offsets != 0; i++) { + uint32_t set_index = p_first_set_index + i; + uint32_t count = layout.get_count(set_index); + if (count > 0) { + uint32_t shift = layout.get_offset_index_shift(set_index); + uint32_t mask = ((1u << (count * 4u)) - 1u) << shift; + compute.dynamic_offsets &= ~mask; + } + } + compute.dynamic_offsets |= p_dynamic_offsets; + for (size_t i = 0; i < p_set_count; ++i) { MDUniformSet *set = (MDUniformSet *)(p_uniform_sets[i].id);