From a1964514001eb242b1785863aa9f799d92080048 Mon Sep 17 00:00:00 2001 From: Philip Rebohle Date: Sun, 1 Jul 2018 15:24:21 +0200 Subject: [PATCH] [dxbc] Respect number of input/output components Fixes issues with geometry shaders exceeding output component limits. --- src/dxbc/dxbc_chunk_isgn.cpp | 13 +++ src/dxbc/dxbc_chunk_isgn.h | 3 + src/dxbc/dxbc_compiler.cpp | 152 +++++++++++++++++++++++------------ src/dxbc/dxbc_compiler.h | 12 ++- src/dxbc/dxbc_decoder.h | 11 +++ 5 files changed, 134 insertions(+), 57 deletions(-) diff --git a/src/dxbc/dxbc_chunk_isgn.cpp b/src/dxbc/dxbc_chunk_isgn.cpp index 7ca329ccb..1462dffc0 100644 --- a/src/dxbc/dxbc_chunk_isgn.cpp +++ b/src/dxbc/dxbc_chunk_isgn.cpp @@ -56,6 +56,19 @@ namespace dxvk { } + DxbcRegMask DxbcIsgn::regMask( + uint32_t registerId) const { + DxbcRegMask mask; + + for (auto e = this->begin(); e != this->end(); e++) { + if (e->registerId == registerId) + mask |= e->componentMask; + } + + return mask; + } + + bool DxbcIsgn::compareSemanticNames( const std::string& a, const std::string& b) const { if (a.size() != b.size()) diff --git a/src/dxbc/dxbc_chunk_isgn.h b/src/dxbc/dxbc_chunk_isgn.h index e083605cd..44c4f6669 100644 --- a/src/dxbc/dxbc_chunk_isgn.h +++ b/src/dxbc/dxbc_chunk_isgn.h @@ -49,6 +49,9 @@ namespace dxvk { uint32_t semanticIndex, uint32_t streamIndex) const; + DxbcRegMask regMask( + uint32_t registerId) const; + private: std::vector m_entries; diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp index a6b1fd375..0545c3d19 100644 --- a/src/dxbc/dxbc_compiler.cpp +++ b/src/dxbc/dxbc_compiler.cpp @@ -4957,11 +4957,17 @@ namespace dxvk { for (uint32_t i = 0; i < m_vRegs.size(); i++) { if (m_vRegs.at(i).id != 0) { const uint32_t registerId = m_module.consti32(i); - const uint32_t srcTypeId = getVectorTypeId(m_vRegs.at(i).type); - const uint32_t srcValue = m_module.opLoad(srcTypeId, m_vRegs.at(i).id); + + DxbcRegisterPointer srcPtr = m_vRegs.at(i); + DxbcRegisterValue srcValue = emitRegisterBitcast( + emitValueLoad(srcPtr), DxbcScalarType::Float32); - m_module.opStore(m_module.opAccessChain(ptrTypeId, m_vArray, 1, ®isterId), - vecTypeId != srcTypeId ? m_module.opBitcast(vecTypeId, srcValue) : srcValue); + DxbcRegisterPointer dstPtr; + dstPtr.type = { DxbcScalarType::Float32, 4 }; + dstPtr.id = m_module.opAccessChain( + ptrTypeId, m_vArray, 1, ®isterId); + + emitValueStore(dstPtr, srcValue, DxbcRegMask::firstN(srcValue.type.ccount)); } } @@ -4994,7 +5000,6 @@ namespace dxvk { // that the outer index of the array is the vertex index. const uint32_t vecTypeId = m_module.defVectorType(m_module.defFloatType(32), 4); const uint32_t dstPtrTypeId = m_module.defPointerType(vecTypeId, spv::StorageClassPrivate); - const uint32_t srcPtrTypeId = m_module.defPointerType(vecTypeId, spv::StorageClassInput); for (uint32_t i = 0; i < m_vRegs.size(); i++) { if (m_vRegs.at(i).id != 0) { @@ -5004,13 +5009,21 @@ namespace dxvk { std::array indices = {{ m_module.consti32(v), registerId }}; - const uint32_t srcTypeId = getVectorTypeId(m_vRegs.at(i).type); - const uint32_t srcValue = m_module.opLoad(srcTypeId, - m_module.opAccessChain(srcPtrTypeId, m_vRegs.at(i).id, 1, indices.data())); + DxbcRegisterPointer srcPtr; + srcPtr.type = m_vRegs.at(i).type; + srcPtr.id = m_module.opAccessChain( + m_module.defPointerType(getVectorTypeId(srcPtr.type), spv::StorageClassInput), + m_vRegs.at(i).id, 1, indices.data()); - m_module.opStore( - m_module.opAccessChain(dstPtrTypeId, m_vArray, indices.size(), indices.data()), - vecTypeId != srcTypeId ? m_module.opBitcast(vecTypeId, srcValue) : srcValue); + DxbcRegisterValue srcValue = emitRegisterBitcast( + emitValueLoad(srcPtr), DxbcScalarType::Float32); + + DxbcRegisterPointer dstPtr; + dstPtr.type = { DxbcScalarType::Float32, 4 }; + dstPtr.id = m_module.opAccessChain( + dstPtrTypeId, m_vArray, 2, indices.data()); + + emitValueStore(dstPtr, srcValue, DxbcRegMask::firstN(srcValue.type.ccount)); } } } @@ -5050,6 +5063,7 @@ namespace dxvk { if (m_version.type() == DxbcProgramType::HullShader) { uint32_t registerIndex = m_module.constu32(svMapping.regId); + outputReg.type = { DxbcScalarType::Float32, 4 }; outputReg.id = m_module.opAccessChain( m_module.defPointerType( getVectorTypeId(outputReg.type), @@ -5425,8 +5439,8 @@ namespace dxvk { { m_hs.builtinTessLevelOuter, 1 }, // FinalLineDensityTessFactor }}; - const TessFactor tessFactor = s_tessFactors.at(static_cast(sv) - - static_cast(DxbcSystemValue::FinalQuadUeq0EdgeTessFactor)); + const TessFactor tessFactor = s_tessFactors.at(uint32_t(sv) + - uint32_t(DxbcSystemValue::FinalQuadUeq0EdgeTessFactor)); const uint32_t tessFactorArrayIndex = m_module.constu32(tessFactor.index); @@ -6117,24 +6131,25 @@ namespace dxvk { DxbcInterpolationMode::Undefined); // Vector type index - uint32_t vecTypeId = getVectorTypeId({ DxbcScalarType::Float32, 4 }); - - uint32_t dstPtrTypeId = m_module.defPointerType(vecTypeId, spv::StorageClassOutput); - uint32_t srcPtrTypeId = m_module.defPointerType(vecTypeId, spv::StorageClassInput); - const std::array dstIndices = {{ invocationId, m_module.constu32(i->registerId) }}; - uint32_t dstPtr = m_module.opAccessChain( - dstPtrTypeId, m_hs.outputPerVertex, - dstIndices.size(), dstIndices.data()); + DxbcRegisterPointer srcPtr; + srcPtr.type = m_vRegs.at(i->registerId).type; + srcPtr.id = m_module.opAccessChain( + m_module.defPointerType(getVectorTypeId(srcPtr.type), spv::StorageClassInput), + m_vRegs.at(i->registerId).id, 1, &invocationId); - uint32_t srcPtr = m_module.opAccessChain( - srcPtrTypeId, m_vRegs.at(i->registerId).id, - 1, &invocationId); - - m_module.opStore(dstPtr, - m_module.opLoad(vecTypeId, srcPtr)); + DxbcRegisterValue srcValue = emitRegisterBitcast( + emitValueLoad(srcPtr), DxbcScalarType::Float32); + + DxbcRegisterPointer dstPtr; + dstPtr.type = { DxbcScalarType::Float32, 4 }; + dstPtr.id = m_module.opAccessChain( + m_module.defPointerType(getVectorTypeId(dstPtr.type), spv::StorageClassOutput), + m_hs.outputPerVertex, dstIndices.size(), dstIndices.data()); + + emitValueStore(dstPtr, srcValue, DxbcRegMask::firstN(srcValue.type.ccount)); } // End function @@ -6458,39 +6473,70 @@ namespace dxvk { DxbcVectorType DxbcCompiler::getInputRegType(uint32_t regIdx) const { - DxbcVectorType result; - result.ctype = DxbcScalarType::Float32; - result.ccount = 4; - - // Vertex shader inputs must match the type of the input layout - if (m_version.type() == DxbcProgramType::VertexShader) { - const DxbcSgnEntry* entry = m_isgn->findByRegister(regIdx); - - if (entry != nullptr) - result.ctype = entry->componentType; + switch (m_version.type()) { + case DxbcProgramType::VertexShader: { + const DxbcSgnEntry* entry = m_isgn->findByRegister(regIdx); + + DxbcVectorType result; + result.ctype = DxbcScalarType::Float32; + result.ccount = 4; + + if (entry != nullptr) { + result.ctype = entry->componentType; + result.ccount = entry->componentMask.popCount(); + } + + return result; + } + + case DxbcProgramType::DomainShader: { + DxbcVectorType result; + result.ctype = DxbcScalarType::Float32; + result.ccount = 4; + return result; + } + + default: { + DxbcVectorType result; + result.ctype = DxbcScalarType::Float32; + result.ccount = m_isgn->regMask(regIdx).minComponents(); + return result; + } } - - return result; } DxbcVectorType DxbcCompiler::getOutputRegType(uint32_t regIdx) const { - DxbcVectorType result; - result.ctype = DxbcScalarType::Float32; - result.ccount = 4; - - // Pixel shader outputs are required to match the type of - // the render target, so we'll scan the output signature. - if (m_version.type() == DxbcProgramType::PixelShader) { - const DxbcSgnEntry* entry = m_osgn->findByRegister(regIdx); - - if (entry != nullptr) { - result.ctype = entry->componentType; - result.ccount = entry->componentMask.popCount(); + switch (m_version.type()) { + case DxbcProgramType::PixelShader: { + const DxbcSgnEntry* entry = m_osgn->findByRegister(regIdx); + + DxbcVectorType result; + result.ctype = DxbcScalarType::Float32; + result.ccount = 4; + + if (entry != nullptr) { + result.ctype = entry->componentType; + result.ccount = entry->componentMask.popCount(); + } + + return result; + } + + case DxbcProgramType::HullShader: { + DxbcVectorType result; + result.ctype = DxbcScalarType::Float32; + result.ccount = 4; + return result; + } + + default: { + DxbcVectorType result; + result.ctype = DxbcScalarType::Float32; + result.ccount = m_osgn->regMask(regIdx).minComponents(); + return result; } } - - return result; } diff --git a/src/dxbc/dxbc_compiler.h b/src/dxbc/dxbc_compiler.h index 889c1a14d..699aedf14 100644 --- a/src/dxbc/dxbc_compiler.h +++ b/src/dxbc/dxbc_compiler.h @@ -404,15 +404,19 @@ namespace dxvk { /////////////////////////////////////////////////////////// // v# registers as defined by the shader. The type of each // of these inputs is either float4 or an array of float4. - std::array m_vRegs; - std::vector m_vMappings; + std::array< + DxbcRegisterPointer, + DxbcMaxInterfaceRegs> m_vRegs; + std::vector m_vMappings; ////////////////////////////////////////////////////////// // o# registers as defined by the shader. In the fragment // shader stage, these registers are typed by the signature, // in all other stages, they are float4 registers or arrays. - std::array m_oRegs; - std::vector m_oMappings; + std::array< + DxbcRegisterPointer, + DxbcMaxInterfaceRegs> m_oRegs; + std::vector m_oMappings; ////////////////////////////////////////////////////// // Shader resource variables. These provide access to diff --git a/src/dxbc/dxbc_decoder.h b/src/dxbc/dxbc_decoder.h index df21fcbc9..cb2c60ba5 100644 --- a/src/dxbc/dxbc_decoder.h +++ b/src/dxbc/dxbc_decoder.h @@ -163,9 +163,20 @@ namespace dxvk { return n[m_mask & 0xF]; } + uint32_t minComponents() const { + const uint8_t n[16] = { 0, 1, 2, 2, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4 }; + return n[m_mask & 0xF]; + } + bool operator == (const DxbcRegMask& other) const { return m_mask == other.m_mask; } bool operator != (const DxbcRegMask& other) const { return m_mask != other.m_mask; } + DxbcRegMask& operator |= (const DxbcRegMask& other) { + m_mask |= other.m_mask; + return *this; + } + static DxbcRegMask firstN(uint32_t n) { return DxbcRegMask(n >= 1, n >= 2, n >= 3, n >= 4); }