diff --git a/src/dxvk/dxvk_shader.cpp b/src/dxvk/dxvk_shader.cpp index 8e2f6080..cfe822af 100644 --- a/src/dxvk/dxvk_shader.cpp +++ b/src/dxvk/dxvk_shader.cpp @@ -1,6 +1,8 @@ #include "dxvk_shader.h" #include +#include +#include namespace dxvk { @@ -182,6 +184,10 @@ namespace dxvk { if (info.fsDualSrcBlend && m_o1IdxOffset && m_o1LocOffset) std::swap(code[m_o1IdxOffset], code[m_o1LocOffset]); + // Replace undefined input variables with zero + for (uint32_t u = info.undefinedInputs; u; u &= u - 1) + eliminateInput(spirvCode, bit::tzcnt(u)); + return DxvkShaderModule(vkd, this, spirvCode); } @@ -189,5 +195,191 @@ namespace dxvk { void DxvkShader::dump(std::ostream& outputStream) const { m_code.decompress().store(outputStream); } + + + void DxvkShader::eliminateInput(SpirvCodeBuffer& code, uint32_t location) { + struct SpirvTypeInfo { + spv::Op op = spv::OpNop; + uint32_t baseTypeId = 0; + uint32_t compositeSize = 0; + spv::StorageClass storageClass = spv::StorageClassMax; + }; + + std::unordered_map types; + std::unordered_map constants; + std::unordered_set candidates; + + // Find the input variable in question + size_t inputVarOffset = 0; + uint32_t inputVarTypeId = 0; + uint32_t inputVarId = 0; + + for (auto ins : code) { + if (ins.opCode() == spv::OpDecorate) { + if (ins.arg(2) == spv::DecorationLocation + && ins.arg(3) == location) + candidates.insert(ins.arg(1)); + } + + if (ins.opCode() == spv::OpConstant) + constants.insert({ ins.arg(2), ins.arg(3) }); + + if (ins.opCode() == spv::OpTypeFloat || ins.opCode() == spv::OpTypeInt) + types.insert({ ins.arg(1), { ins.opCode(), 0, ins.arg(2), spv::StorageClassMax }}); + + if (ins.opCode() == spv::OpTypeVector) + types.insert({ ins.arg(1), { ins.opCode(), ins.arg(2), ins.arg(3), spv::StorageClassMax }}); + + if (ins.opCode() == spv::OpTypeArray) { + auto constant = constants.find(ins.arg(3)); + if (constant == constants.end()) + continue; + types.insert({ ins.arg(1), { ins.opCode(), ins.arg(2), constant->second, spv::StorageClassMax }}); + } + + if (ins.opCode() == spv::OpTypePointer) + types.insert({ ins.arg(1), { ins.opCode(), ins.arg(3), 0, spv::StorageClass(ins.arg(2)) }}); + + if (ins.opCode() == spv::OpVariable && spv::StorageClass(ins.arg(3)) == spv::StorageClassInput) { + if (candidates.find(ins.arg(2)) != candidates.end()) { + inputVarOffset = ins.offset(); + inputVarTypeId = ins.arg(1); + inputVarId = ins.arg(2); + break; + } + } + } + + if (!inputVarId) + return; + + // Declare private pointer types + auto pointerType = types.find(inputVarTypeId); + if (pointerType == types.end()) + return; + + code.beginInsertion(inputVarOffset); + std::vector> privateTypes; + + for (auto p = types.find(pointerType->second.baseTypeId); + p != types.end(); + p = types.find(p->second.baseTypeId)) { + std::pair info = *p; + info.first = 0; + info.second.baseTypeId = p->first; + info.second.storageClass = spv::StorageClassPrivate; + + for (auto t : types) { + if (t.second.op == info.second.op + && t.second.baseTypeId == info.second.baseTypeId + && t.second.storageClass == info.second.storageClass) + info.first = t.first; + } + + if (!info.first) { + info.first = code.allocId(); + + code.putIns(spv::OpTypePointer, 4); + code.putWord(info.first); + code.putWord(info.second.storageClass); + code.putWord(info.second.baseTypeId); + } + + privateTypes.push_back(info); + } + + // Define zero constants + uint32_t constantId = 0; + + for (auto i = privateTypes.rbegin(); i != privateTypes.rend(); i++) { + if (constantId) { + uint32_t compositeSize = i->second.compositeSize; + uint32_t compositeId = code.allocId(); + + code.putIns(spv::OpConstantComposite, 3 + compositeSize); + code.putWord(i->second.baseTypeId); + code.putWord(compositeId); + + for (uint32_t i = 0; i < compositeSize; i++) + code.putWord(constantId); + + constantId = compositeId; + } else { + constantId = code.allocId(); + + code.putIns(spv::OpConstant, 4); + code.putWord(i->second.baseTypeId); + code.putWord(constantId); + code.putWord(0); + } + } + + // Erase and re-declare variable + code.erase(4); + + code.putIns(spv::OpVariable, 5); + code.putWord(privateTypes[0].first); + code.putWord(inputVarId); + code.putWord(spv::StorageClassPrivate); + code.putWord(constantId); + + code.endInsertion(); + + // Remove variable from interface list + for (auto ins : code) { + if (ins.opCode() == spv::OpEntryPoint) { + uint32_t argIdx = 2 + code.strLen(ins.chr(2)); + + while (argIdx < ins.length()) { + if (ins.arg(argIdx) == inputVarId) { + ins.setArg(0, spv::OpEntryPoint | ((ins.length() - 1) << spv::WordCountShift)); + + code.beginInsertion(ins.offset() + argIdx); + code.erase(1); + code.endInsertion(); + break; + } + + argIdx += 1; + } + } + } + + // Remove location declarations + for (auto ins : code) { + if (ins.opCode() == spv::OpDecorate + && ins.arg(2) == spv::DecorationLocation + && ins.arg(1) == inputVarId) { + code.beginInsertion(ins.offset()); + code.erase(4); + code.endInsertion(); + break; + } + } + + // Fix up pointer types used in access chain instructions + std::unordered_map accessChainIds; + + for (auto ins : code) { + if (ins.opCode() == spv::OpAccessChain + || ins.opCode() == spv::OpInBoundsAccessChain) { + uint32_t depth = ins.length() - 4; + + if (ins.arg(3) == inputVarId) { + // Access chains accessing the variable directly + ins.setArg(1, privateTypes.at(depth).first); + accessChainIds.insert({ ins.arg(2), depth }); + } else { + // Access chains derived from the variable + auto entry = accessChainIds.find(ins.arg(2)); + if (entry != accessChainIds.end()) { + depth += entry->second; + ins.setArg(1, privateTypes.at(depth).first); + accessChainIds.insert({ ins.arg(2), depth }); + } + } + } + } + } } \ No newline at end of file diff --git a/src/dxvk/dxvk_shader.h b/src/dxvk/dxvk_shader.h index 716c0cf0..065da8a4 100644 --- a/src/dxvk/dxvk_shader.h +++ b/src/dxvk/dxvk_shader.h @@ -121,7 +121,8 @@ namespace dxvk { * \brief Shader module create info */ struct DxvkShaderModuleCreateInfo { - bool fsDualSrcBlend; + bool fsDualSrcBlend = false; + uint32_t undefinedInputs = 0; }; @@ -293,6 +294,8 @@ namespace dxvk { size_t m_o1IdxOffset = 0; size_t m_o1LocOffset = 0; + static void eliminateInput(SpirvCodeBuffer& code, uint32_t location); + };