diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp index 16ec63b2c..ad393d554 100644 --- a/src/dxbc/dxbc_compiler.cpp +++ b/src/dxbc/dxbc_compiler.cpp @@ -2273,21 +2273,8 @@ namespace dxvk { && bufferInfo.type != DxbcResourceType::Typed && isUav; - // Perform atomic operations on UAVs only if the UAV - // is bound and if there is nothing else stopping us. - DxbcConditional cond; - - if (isUav) { - uint32_t writeTest = emitUavWriteTest(bufferInfo); - - cond.labelIf = m_module.allocateId(); - cond.labelEnd = m_module.allocateId(); - - m_module.opSelectionMerge(cond.labelEnd, spv::SelectionControlMaskNone); - m_module.opBranchConditional(writeTest, cond.labelIf, cond.labelEnd); - - m_module.opLabel(cond.labelIf); - } + // Perform atomic operations on UAVs only if the invocation is alive + DxbcConditional cond = emitBeginPsKillTest(); // Retrieve destination pointer for the atomic operation> const DxbcRegisterPointer pointer = emitGetAtomicPointer( @@ -2413,11 +2400,7 @@ namespace dxvk { if (isImm) emitRegisterStore(ins.dst[0], value); - // End conditional block - if (isUav) { - m_module.opBranch(cond.labelEnd); - m_module.opLabel (cond.labelEnd); - } + emitEndPsKillTest(cond); } @@ -2425,30 +2408,22 @@ namespace dxvk { // imm_atomic_alloc and imm_atomic_consume have the following operands: // (dst0) The register that will hold the old counter value // (dst1) The UAV whose counter is going to be modified - const DxbcBufferInfo bufferInfo = getBufferInfo(ins.dst[1]); - const uint32_t registerId = ins.dst[1].idx[0].offset; if (m_uavs.at(registerId).ctrId == 0) m_uavs.at(registerId).ctrId = emitDclUavCounter(registerId); - // Only perform the operation if the UAV is bound - uint32_t writeTest = emitUavWriteTest(bufferInfo); - - DxbcConditional cond; - cond.labelIf = m_module.allocateId(); - cond.labelEnd = m_module.allocateId(); - - m_module.opSelectionMerge(cond.labelEnd, spv::SelectionControlMaskNone); - m_module.opBranchConditional(writeTest, cond.labelIf, cond.labelEnd); - - m_module.opLabel(cond.labelIf); + // Only perform the operation if the invocation is alive + DxbcConditional cond = emitBeginPsKillTest(); // Only use subgroup ops on compute to avoid having to // deal with helper invocations or hardware limitations bool useSubgroupOps = m_moduleInfo.options.useSubgroupOpsForAtomicCounters && m_programInfo.type() == DxbcProgramType::ComputeShader; + // Current block ID used in a phi later on + uint32_t baseBlockId = m_module.getBlockId(); + // In case we have subgroup ops enabled, we need to // count the number of active lanes, the lane index, // and we need to perform the atomic op conditionally @@ -2550,7 +2525,7 @@ namespace dxvk { std::array phiLabels = {{ { value.id, elect.labelIf }, - { undef, cond.labelIf }, + { undef, baseBlockId }, }}; value.id = m_module.opPhi(typeId, @@ -2562,10 +2537,7 @@ namespace dxvk { // Store the result emitRegisterStore(ins.dst[0], value); - - // End conditional block - m_module.opBranch(cond.labelEnd); - m_module.opLabel (cond.labelEnd); + emitEndPsKillTest(cond); } @@ -3677,17 +3649,8 @@ namespace dxvk { // (src1) The value to store const DxbcBufferInfo uavInfo = getBufferInfo(ins.dst[0]); - // Execute write op only if the UAV is bound - uint32_t writeTest = emitUavWriteTest(uavInfo); - - DxbcConditional cond; - cond.labelIf = m_module.allocateId(); - cond.labelEnd = m_module.allocateId(); - - m_module.opSelectionMerge (cond.labelEnd, spv::SelectionControlMaskNone); - m_module.opBranchConditional(writeTest, cond.labelIf, cond.labelEnd); - - m_module.opLabel(cond.labelIf); + // Execute write op only if the invocation is active + DxbcConditional cond = emitBeginPsKillTest(); // Load texture coordinates DxbcRegisterValue texCoord = emitLoadTexCoord(ins.src[0], uavInfo.image); @@ -3702,10 +3665,8 @@ namespace dxvk { m_module.opImageWrite( m_module.opLoad(uavInfo.typeId, uavInfo.varId), texCoord.id, texValue.id, SpirvImageOperands()); - - // End conditional block - m_module.opBranch(cond.labelEnd); - m_module.opLabel (cond.labelEnd); + + emitEndPsKillTest(cond); } @@ -5152,21 +5113,8 @@ namespace dxvk { bool isSsbo = m_moduleInfo.options.minSsboAlignment <= bufferInfo.align && !isTgsm; - // Perform UAV writes only if the UAV is bound and if there - // is nothing else preventing us from writing to it. - DxbcConditional cond; - - if (!isTgsm) { - uint32_t writeTest = emitUavWriteTest(bufferInfo); - - cond.labelIf = m_module.allocateId(); - cond.labelEnd = m_module.allocateId(); - - m_module.opSelectionMerge(cond.labelEnd, spv::SelectionControlMaskNone); - m_module.opBranchConditional(writeTest, cond.labelIf, cond.labelEnd); - - m_module.opLabel(cond.labelIf); - } + // Perform UAV writes only if the invocation is active + DxbcConditional cond = emitBeginPsKillTest(); // Perform the actual write operation uint32_t bufferId = isTgsm || isSsbo ? 0 : m_module.opLoad(bufferInfo.typeId, bufferInfo.varId); @@ -5228,12 +5176,8 @@ namespace dxvk { m_module.constu32(spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsAcquireReleaseMask)); } - - // End conditional block - if (!isTgsm) { - m_module.opBranch(cond.labelEnd); - m_module.opLabel (cond.labelEnd); - } + + emitEndPsKillTest(cond); } @@ -6483,23 +6427,35 @@ namespace dxvk { } - uint32_t DxbcCompiler::emitUavWriteTest(const DxbcBufferInfo& uav) { - uint32_t typeId = m_module.defBoolType(); - uint32_t testId = 0; + DxbcConditional DxbcCompiler::emitBeginPsKillTest() { + if (!m_ps.killState) + return DxbcConditional(); + + uint32_t boolId = m_module.defBoolType(); + uint32_t killState = m_module.opLoad(boolId, m_ps.killState); + uint32_t testId = m_module.opLogicalNot(boolId, killState); + + DxbcConditional cond; + cond.labelIf = m_module.allocateId(); + cond.labelEnd = m_module.allocateId(); - if (m_ps.killState != 0) { - uint32_t killState = m_module.opLoad(typeId, m_ps.killState); - - testId = m_module.opLogicalAnd(typeId, testId, - m_module.opLogicalNot(typeId, killState)); - } else { - testId = m_module.constBool(true); - } + m_module.opSelectionMerge(cond.labelEnd, spv::SelectionControlMaskNone); + m_module.opBranchConditional(testId, cond.labelIf, cond.labelEnd); - return testId; + m_module.opLabel(cond.labelIf); + return cond; } - - + + + void DxbcCompiler::emitEndPsKillTest(const DxbcConditional& cond) { + if (!m_ps.killState) + return; + + m_module.opBranch(cond.labelEnd); + m_module.opLabel(cond.labelEnd); + } + + void DxbcCompiler::emitInit() { // Set up common capabilities for all shaders m_module.enableCapability(spv::CapabilityShader); diff --git a/src/dxbc/dxbc_compiler.h b/src/dxbc/dxbc_compiler.h index dea5007bb..29d410c70 100644 --- a/src/dxbc/dxbc_compiler.h +++ b/src/dxbc/dxbc_compiler.h @@ -1063,8 +1063,9 @@ namespace dxvk { /////////////////////////////// // Some state checking methods - uint32_t emitUavWriteTest( - const DxbcBufferInfo& uav); + DxbcConditional emitBeginPsKillTest(); + + void emitEndPsKillTest(const DxbcConditional& cond); ////////////////////////////////////// // Common function definition methods