diff --git a/src/spirv/spirv_module.cpp b/src/spirv/spirv_module.cpp index 6cf71f8f7..ff9271a3d 100644 --- a/src/spirv/spirv_module.cpp +++ b/src/spirv/spirv_module.cpp @@ -361,6 +361,38 @@ namespace dxvk { } + uint32_t SpirvModule::lateConst32( + uint32_t typeId) { + uint32_t resultId = this->allocateId(); + m_lateConsts.insert(resultId); + + m_typeConstDefs.putIns (spv::OpConstant, 4); + m_typeConstDefs.putWord(typeId); + m_typeConstDefs.putWord(resultId); + m_typeConstDefs.putWord(0); + return resultId; + } + + + void SpirvModule::setLateConst( + uint32_t constId, + const uint32_t* argIds) { + for (auto ins : m_typeConstDefs) { + if (ins.opCode() != spv::OpConstant + && ins.opCode() != spv::OpConstantComposite) + continue; + + if (ins.arg(2) != constId) + continue; + + for (uint32_t i = 3; i < ins.length(); i++) + ins.setArg(i, argIds[i - 3]); + + return; + } + } + + uint32_t SpirvModule::specConstBool( bool v) { uint32_t typeId = this->defBoolType(); @@ -3371,8 +3403,13 @@ namespace dxvk { for (uint32_t i = 0; i < argCount && match; i++) match &= ins.arg(3 + i) == argIds[i]; - if (match) - return ins.arg(2); + if (!match) + continue; + + uint32_t id = ins.arg(2); + + if (m_lateConsts.find(id) == m_lateConsts.end()) + return id; } // Constant not yet declared, make a new one diff --git a/src/spirv/spirv_module.h b/src/spirv/spirv_module.h index 15d577069..0eac52f13 100644 --- a/src/spirv/spirv_module.h +++ b/src/spirv/spirv_module.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "spirv_code_buffer.h" namespace dxvk { @@ -167,6 +169,13 @@ namespace dxvk { uint32_t constUndef( uint32_t typeId); + uint32_t lateConst32( + uint32_t typeId); + + void setLateConst( + uint32_t constId, + const uint32_t* argIds); + uint32_t specConstBool( bool v); @@ -1166,6 +1175,8 @@ namespace dxvk { SpirvCodeBuffer m_typeConstDefs; SpirvCodeBuffer m_variables; SpirvCodeBuffer m_code; + + std::unordered_set m_lateConsts; uint32_t defType( spv::Op op,