diff --git a/src/dxvk/dxvk_shader.cpp b/src/dxvk/dxvk_shader.cpp index 558c8d2d3..58d64c04b 100644 --- a/src/dxvk/dxvk_shader.cpp +++ b/src/dxvk/dxvk_shader.cpp @@ -209,6 +209,12 @@ namespace dxvk { for (uint32_t u : bit::BitMask(state.undefinedInputs)) eliminateInput(spirvCode, u); + // Patch primitive topology as necessary + if (m_info.stage == VK_SHADER_STAGE_GEOMETRY_BIT + && state.inputTopology != m_info.inputTopology + && state.inputTopology != VK_PRIMITIVE_TOPOLOGY_MAX_ENUM) + patchInputTopology(spirvCode, state.inputTopology); + // Emit fragment shader swizzles as necessary if (m_info.stage == VK_SHADER_STAGE_FRAGMENT_BIT) emitOutputSwizzles(spirvCode, m_info.outputMask, state.rtSwizzles.data()); @@ -475,7 +481,7 @@ namespace dxvk { } } } - + void DxvkShader::emitOutputSwizzles( SpirvCodeBuffer& code, @@ -828,6 +834,306 @@ namespace dxvk { } + void DxvkShader::patchInputTopology(SpirvCodeBuffer& code, VkPrimitiveTopology topology) { + struct TopologyInfo { + VkPrimitiveTopology topology; + spv::ExecutionMode mode; + uint32_t vertexCount; + }; + + static const std::array s_topologies = {{ + { VK_PRIMITIVE_TOPOLOGY_POINT_LIST, spv::ExecutionModeInputPoints, 1u }, + { VK_PRIMITIVE_TOPOLOGY_LINE_LIST, spv::ExecutionModeInputLines, 2u }, + { VK_PRIMITIVE_TOPOLOGY_LINE_LIST_WITH_ADJACENCY, spv::ExecutionModeInputLinesAdjacency, 4u }, + { VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST, spv::ExecutionModeTriangles, 3u }, + { VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST_WITH_ADJACENCY, spv::ExecutionModeInputTrianglesAdjacency, 6u }, + }}; + + const TopologyInfo* topologyInfo = nullptr; + + for (const auto& top : s_topologies) { + if (top.topology == topology) { + topologyInfo = ⊤ + break; + } + } + + if (!topologyInfo) + return; + + uint32_t typeUint32Id = 0u; + uint32_t typeSint32Id = 0u; + + struct ConstantInfo { + uint32_t typeId; + uint32_t value; + }; + + struct ArrayTypeInfo { + uint32_t arrayLengthId; + uint32_t scalarTypeId; + uint32_t replaceTypeId; + }; + + struct PointerTypeInfo { + uint32_t objectTypeId; + }; + + std::unordered_map nullConstantsByType; + std::unordered_map constants; + std::unordered_map uintConstantValueToId; + std::unordered_map arrayTypes; + std::unordered_map pointerTypes; + std::unordered_map variableTypes; + std::unordered_set nullAccessChains; + std::unordered_map nullVarsByType; + std::vector> newNullVars; + + uint32_t functionOffset = 0u; + + for (auto iter = code.begin(); iter != code.end(); ) { + auto ins = *iter; + + switch (ins.opCode()) { + case spv::OpExecutionMode: { + bool isTopology = false; + + for (const auto& top : s_topologies) + isTopology |= spv::ExecutionMode(ins.arg(2)) == top.mode; + + if (isTopology) + ins.setArg(2, uint32_t(topologyInfo->mode)); + } break; + + case spv::OpConstant: { + if (ins.arg(1) == typeUint32Id || ins.arg(1) == typeSint32Id) { + ConstantInfo c = { }; + c.typeId = ins.arg(1); + c.value = ins.arg(3); + + constants.insert({ ins.arg(2), c }); + uintConstantValueToId.insert({ ins.arg(3), ins.arg(2) }); + } + } break; + + case spv::OpConstantNull: { + nullConstantsByType.insert({ ins.arg(1), ins.arg(2) }); + } break; + + case spv::OpTypeInt: { + if (ins.arg(2u) == 32u) { + if (ins.arg(3u)) + typeSint32Id = ins.arg(1u); + else + typeUint32Id = ins.arg(1u); + } + } break; + + case spv::OpTypeArray: { + ArrayTypeInfo t = { }; + t.arrayLengthId = ins.arg(3); + t.scalarTypeId = ins.arg(2); + t.replaceTypeId = 0u; + + arrayTypes.insert({ ins.arg(1), t }); + } break; + + case spv::OpTypePointer: { + // We know that all input arrays use the vertex count as their outer + // array size, so it is safe for us to simply replace the array type + // of any pointer type declaration with an appropriately sized array. + auto storageClass = spv::StorageClass(ins.arg(2)); + + if (storageClass == spv::StorageClassInput) { + uint32_t len = ins.length(); + + uint32_t arrayTypeId = 0u; + uint32_t scalarTypeId = 0u; + + PointerTypeInfo t = { }; + t.objectTypeId = ins.arg(3); + + auto entry = arrayTypes.find(t.objectTypeId); + + if (entry != arrayTypes.end()) { + if (!entry->second.replaceTypeId) { + arrayTypeId = code.allocId(); + scalarTypeId = entry->second.scalarTypeId; + + entry->second.replaceTypeId = arrayTypeId; + } + + t.objectTypeId = entry->second.replaceTypeId; + ins.setArg(3, t.objectTypeId); + } + + pointerTypes.insert({ ins.arg(1), t }); + + // If we replaced the array type, emit it before the pointer type + // decoration as necessary. It is legal to delcare identical array + // types multiple times. + if (arrayTypeId) { + code.beginInsertion(ins.offset()); + + auto lengthId = uintConstantValueToId.find(topologyInfo->vertexCount); + + if (lengthId == uintConstantValueToId.end()) { + if (!typeUint32Id) { + typeUint32Id = code.allocId(); + + code.putIns (spv::OpTypeInt, 4); + code.putWord (typeUint32Id); + code.putWord (32); + code.putWord (0); + } + + ConstantInfo c; + c.typeId = typeUint32Id; + c.value = topologyInfo->vertexCount; + + uint32_t arrayLengthId = code.allocId(); + + code.putIns (spv::OpConstant, 4); + code.putWord (c.typeId); + code.putWord (arrayLengthId); + code.putWord (c.value); + + lengthId = uintConstantValueToId.insert({ c.value, arrayLengthId }).first; + constants.insert({ arrayLengthId, c }); + } + + ArrayTypeInfo t = { }; + t.scalarTypeId = scalarTypeId; + t.arrayLengthId = lengthId->second; + + arrayTypes.insert({ arrayTypeId, t }); + + code.putIns (spv::OpTypeArray, 4); + code.putWord (arrayTypeId); + code.putWord (t.scalarTypeId); + code.putWord (t.arrayLengthId); + + iter = SpirvInstructionIterator(code.data(), code.endInsertion() + len, code.dwords()); + continue; + } + } + } break; + + case spv::OpVariable: { + auto storageClass = spv::StorageClass(ins.arg(3)); + + if (storageClass == spv::StorageClassInput) + variableTypes.insert({ ins.arg(2), ins.arg(1) }); + } break; + + case spv::OpFunction: { + if (!functionOffset) + functionOffset = ins.offset(); + } break; + + case spv::OpAccessChain: + case spv::OpInBoundsAccessChain: { + bool nullChain = false; + auto var = variableTypes.find(ins.arg(3)); + + if (var == variableTypes.end()) { + // If we're recursively loading from a null access chain, skip + auto chain = nullAccessChains.find(ins.arg(3)); + nullChain = chain != nullAccessChains.end(); + } else { + // If the index is out of bounds, mark the access chain as + // dead so we can replace all loads with a null constant. + auto c = constants.find(ins.arg(4u)); + + if (c == constants.end()) + break; + + nullChain = c->second.value >= topologyInfo->vertexCount; + } + + if (nullChain) { + nullAccessChains.insert(ins.arg(2)); + + code.beginInsertion(ins.offset()); + code.erase(ins.length()); + + iter = SpirvInstructionIterator(code.data(), code.endInsertion(), code.dwords()); + continue; + } + } break; + + case spv::OpLoad: { + // If we're loading from a null access chain, replace with null constant load. + // We should never load the entire array at once, so ignore that case. + if (nullAccessChains.find(ins.arg(3)) != nullAccessChains.end()) { + auto var = nullVarsByType.find(ins.arg(1)); + + if (var == nullVarsByType.end()) { + var = nullVarsByType.insert({ ins.arg(1), code.allocId() }).first; + newNullVars.push_back(std::make_pair(var->second, ins.arg(1))); + } + + ins.setArg(3, var->second); + } + } break; + + default:; + } + + iter++; + } + + // Insert new null variables + code.beginInsertion(functionOffset); + + for (auto v : newNullVars) { + auto nullConst = nullConstantsByType.find(v.second); + + if (nullConst == nullConstantsByType.end()) { + uint32_t nullConstId = code.allocId(); + + code.putIns (spv::OpConstantNull, 3u); + code.putWord (v.second); + code.putWord (nullConstId); + + nullConst = nullConstantsByType.insert({ v.second, nullConstId }).first; + } + + uint32_t pointerTypeId = code.allocId(); + + code.putIns (spv::OpTypePointer, 4u); + code.putWord (pointerTypeId); + code.putWord (spv::StorageClassPrivate); + code.putWord (v.second); + + code.putIns (spv::OpVariable, 5u); + code.putWord (pointerTypeId); + code.putWord (v.first); + code.putWord (spv::StorageClassPrivate); + code.putWord (nullConst->second); + } + + code.endInsertion(); + + // Add newly declared null variables to entry point + for (auto ins : code) { + if (ins.opCode() == spv::OpEntryPoint) { + uint32_t len = ins.length(); + uint32_t token = ins.opCode() | ((len + newNullVars.size()) << 16); + ins.setArg(0, token); + + code.beginInsertion(ins.offset() + len); + + for (auto v : newNullVars) + code.putWord(v.first); + + code.endInsertion(); + break; + } + } + } + + DxvkShaderStageInfo::DxvkShaderStageInfo(const DxvkDevice* device) : m_device(device) { diff --git a/src/dxvk/dxvk_shader.h b/src/dxvk/dxvk_shader.h index efc64ab9a..b7f6d73ed 100644 --- a/src/dxvk/dxvk_shader.h +++ b/src/dxvk/dxvk_shader.h @@ -74,6 +74,7 @@ namespace dxvk { bool fsDualSrcBlend = false; bool fsFlatShading = false; uint32_t undefinedInputs = 0; + VkPrimitiveTopology inputTopology = VK_PRIMITIVE_TOPOLOGY_MAX_ENUM; std::array rtSwizzles = { }; @@ -282,6 +283,10 @@ namespace dxvk { SpirvCodeBuffer& code, uint32_t inputMask); + static void patchInputTopology( + SpirvCodeBuffer& code, + VkPrimitiveTopology topology); + };