diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp index e7073ccfd..3bfcdd571 100644 --- a/src/dxbc/dxbc_compiler.cpp +++ b/src/dxbc/dxbc_compiler.cpp @@ -5421,6 +5421,92 @@ namespace dxvk { } + void DxbcCompiler::emitInitWorkgroupMemory() { + bool hasTgsm = false; + + for (uint32_t i = 0; i < m_gRegs.size(); i++) { + if (!m_gRegs[i].varId) + continue; + + if (!m_cs.builtinLocalInvocationIndex) { + m_cs.builtinLocalInvocationIndex = emitNewBuiltinVariable({ + { DxbcScalarType::Uint32, 1, 0 }, + spv::StorageClassInput }, + spv::BuiltInLocalInvocationIndex, + "vThreadIndexInGroup"); + } + + uint32_t intTypeId = getScalarTypeId(DxbcScalarType::Uint32); + uint32_t ptrTypeId = m_module.defPointerType( + intTypeId, spv::StorageClassWorkgroup); + + uint32_t numElements = m_gRegs[i].type == DxbcResourceType::Structured + ? m_gRegs[i].elementCount * m_gRegs[i].elementStride / 4 + : m_gRegs[i].elementCount / 4; + + uint32_t numThreads = m_cs.workgroupSizeX * + m_cs.workgroupSizeY * m_cs.workgroupSizeZ; + + uint32_t numElementsPerThread = numElements / numThreads; + uint32_t numElementsRemaining = numElements % numThreads; + + uint32_t threadId = m_module.opLoad( + intTypeId, m_cs.builtinLocalInvocationIndex); + + uint32_t strideId = m_module.constu32(numElementsPerThread); + uint32_t zeroId = m_module.constu32(0); + + for (uint32_t e = 0; e < numElementsPerThread; e++) { + uint32_t ofsId = m_module.opIAdd(intTypeId, + m_module.opIMul(intTypeId, strideId, threadId), + m_module.constu32(e)); + + uint32_t ptrId = m_module.opAccessChain( + ptrTypeId, m_gRegs[i].varId, 1, &ofsId); + + m_module.opStore(ptrId, zeroId); + } + + if (numElementsRemaining) { + uint32_t condition = m_module.opULessThan( + m_module.defBoolType(), threadId, + m_module.constu32(numElementsRemaining)); + + DxbcConditional cond; + cond.labelIf = m_module.allocateId(); + cond.labelEnd = m_module.allocateId(); + + m_module.opSelectionMerge(cond.labelEnd, spv::SelectionControlMaskNone); + m_module.opBranchConditional(condition, cond.labelIf, cond.labelEnd); + + m_module.opLabel(cond.labelIf); + + uint32_t ofsId = m_module.opIAdd(intTypeId, + m_module.constu32(numThreads * numElementsPerThread), + threadId); + + uint32_t ptrId = m_module.opAccessChain( + ptrTypeId, m_gRegs[i].varId, 1, &ofsId); + + m_module.opStore(ptrId, zeroId); + + m_module.opBranch(cond.labelEnd); + m_module.opLabel (cond.labelEnd); + } + + hasTgsm = true; + } + + if (hasTgsm) { + m_module.opControlBarrier( + m_module.constu32(spv::ScopeInvocation), + m_module.constu32(spv::ScopeWorkgroup), + m_module.constu32(spv::MemorySemanticsWorkgroupMemoryMask + | spv::MemorySemanticsAcquireReleaseMask)); + } + } + + DxbcRegisterValue DxbcCompiler::emitVsSystemValueLoad( DxbcSystemValue sv, DxbcRegMask mask) { @@ -6355,9 +6441,14 @@ namespace dxvk { void DxbcCompiler::emitCsFinalize() { this->emitMainFunctionBegin(); + + if (m_moduleInfo.options.zeroInitWorkgroupMemory) + this->emitInitWorkgroupMemory(); + m_module.opFunctionCall( m_module.defVoidType(), m_cs.functionId, 0, nullptr); + this->emitFunctionEnd(); } diff --git a/src/dxbc/dxbc_compiler.h b/src/dxbc/dxbc_compiler.h index 25d1c16f7..bc791b526 100644 --- a/src/dxbc/dxbc_compiler.h +++ b/src/dxbc/dxbc_compiler.h @@ -991,6 +991,8 @@ namespace dxvk { void emitOutputMapping(); void emitOutputDepthClamp(); + void emitInitWorkgroupMemory(); + ////////////////////////////////////////// // System value load methods (per shader) DxbcRegisterValue emitVsSystemValueLoad( diff --git a/src/dxbc/dxbc_options.h b/src/dxbc/dxbc_options.h index d08428d15..f20fca4f4 100644 --- a/src/dxbc/dxbc_options.h +++ b/src/dxbc/dxbc_options.h @@ -17,6 +17,9 @@ namespace dxvk { /// Use clustered subgroup operations bool useSubgroupOpsClustered = false; + + /// Clear thread-group shared memory to zero + bool zeroInitWorkgroupMemory = false; }; } \ No newline at end of file