diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp index b9317e97f..84a46ab3e 100644 --- a/src/dxbc/dxbc_compiler.cpp +++ b/src/dxbc/dxbc_compiler.cpp @@ -3983,15 +3983,33 @@ namespace dxvk { m_module.constu32(spv::ScopeSubgroup), killState); - uint32_t invocationMask = m_module.opLoad( - getVectorTypeId({ DxbcScalarType::Uint32, 4 }), - m_ps.invocationMask); + uint32_t laneId = m_module.opLoad( + getScalarTypeId(DxbcScalarType::Uint32), + m_ps.builtinLaneId); - uint32_t killSubgroup = m_module.opAll( + uint32_t laneIdPart = m_module.opShiftRightLogical( + getScalarTypeId(DxbcScalarType::Uint32), + laneId, m_module.constu32(5)); + + uint32_t laneMask = m_module.opVectorExtractDynamic( + getScalarTypeId(DxbcScalarType::Uint32), + ballot, laneIdPart); + + uint32_t laneIdQuad = m_module.opBitwiseAnd( + getScalarTypeId(DxbcScalarType::Uint32), + laneId, m_module.constu32(0x1c)); + + laneMask = m_module.opShiftRightLogical( + getScalarTypeId(DxbcScalarType::Uint32), + laneMask, laneIdQuad); + + laneMask = m_module.opBitwiseAnd( + getScalarTypeId(DxbcScalarType::Uint32), + laneMask, m_module.constu32(0xf)); + + uint32_t killSubgroup = m_module.opIEqual( m_module.defBoolType(), - m_module.opIEqual( - m_module.defVectorType(m_module.defBoolType(), 4), - ballot, invocationMask)); + laneMask, m_module.constu32(0xf)); DxbcConditional cond; cond.labelIf = m_module.allocateId(); @@ -6544,18 +6562,13 @@ namespace dxvk { m_module.enableCapability(spv::CapabilityGroupNonUniform); m_module.enableCapability(spv::CapabilityGroupNonUniformBallot); - DxbcRegisterInfo invocationMask; - invocationMask.type = { DxbcScalarType::Uint32, 4, 0 }; - invocationMask.sclass = spv::StorageClassFunction; + DxbcRegisterInfo laneId; + laneId.type = { DxbcScalarType::Uint32, 1, 0 }; + laneId.sclass = spv::StorageClassInput; - m_ps.invocationMask = emitNewVariable(invocationMask); - m_module.setDebugName(m_ps.invocationMask, "fInvocationMask"); - - m_module.opStore(m_ps.invocationMask, - m_module.opGroupNonUniformBallot( - getVectorTypeId({ DxbcScalarType::Uint32, 4 }), - m_module.constu32(spv::ScopeSubgroup), - m_module.constBool(true))); + m_ps.builtinLaneId = emitNewBuiltinVariable( + laneId, spv::BuiltInSubgroupLocalInvocationId, + "fLaneId"); } } } diff --git a/src/dxbc/dxbc_compiler.h b/src/dxbc/dxbc_compiler.h index 3b717169d..3714e1d0d 100644 --- a/src/dxbc/dxbc_compiler.h +++ b/src/dxbc/dxbc_compiler.h @@ -178,7 +178,7 @@ namespace dxvk { uint32_t builtinLayer = 0; uint32_t builtinViewportId = 0; - uint32_t invocationMask = 0; + uint32_t builtinLaneId = 0; uint32_t killState = 0; };