diff --git a/src/dxvk/dxvk_graphics.cpp b/src/dxvk/dxvk_graphics.cpp index 8e295e609..1070859d8 100644 --- a/src/dxvk/dxvk_graphics.cpp +++ b/src/dxvk/dxvk_graphics.cpp @@ -720,6 +720,7 @@ namespace dxvk { if (shaderInfo.stage == VK_SHADER_STAGE_FRAGMENT_BIT) { info.fsDualSrcBlend = state.useDualSourceBlending(); + info.fsFlatShading = state.rs.flatShading() && shader->info().flatShadingInputs; for (uint32_t i = 0; i < MaxNumRenderTargets; i++) { if ((shaderInfo.outputMask & (1u << i)) && state.writesRenderTarget(i)) diff --git a/src/dxvk/dxvk_shader.cpp b/src/dxvk/dxvk_shader.cpp index 0f556b680..d64305659 100644 --- a/src/dxvk/dxvk_shader.cpp +++ b/src/dxvk/dxvk_shader.cpp @@ -12,6 +12,7 @@ namespace dxvk { bool DxvkShaderModuleCreateInfo::eq(const DxvkShaderModuleCreateInfo& other) const { bool eq = fsDualSrcBlend == other.fsDualSrcBlend + && fsFlatShading == other.fsFlatShading && undefinedInputs == other.undefinedInputs; for (uint32_t i = 0; i < rtSwizzles.size() && eq; i++) { @@ -28,6 +29,7 @@ namespace dxvk { size_t DxvkShaderModuleCreateInfo::hash() const { DxvkHashState hash; hash.add(uint32_t(fsDualSrcBlend)); + hash.add(uint32_t(fsFlatShading)); hash.add(undefinedInputs); for (uint32_t i = 0; i < rtSwizzles.size(); i++) { @@ -189,6 +191,10 @@ namespace dxvk { if (m_info.stage == VK_SHADER_STAGE_FRAGMENT_BIT) emitOutputSwizzles(spirvCode, m_info.outputMask, state.rtSwizzles.data()); + // Emit input decorations for flat shading as necessary + if (m_info.stage == VK_SHADER_STAGE_FRAGMENT_BIT && state.fsFlatShading) + emitFlatShadingDeclarations(spirvCode, m_info.flatShadingInputs); + return spirvCode; } @@ -696,6 +702,95 @@ namespace dxvk { } + void DxvkShader::emitFlatShadingDeclarations( + SpirvCodeBuffer& code, + uint32_t inputMask) { + if (!inputMask) + return; + + struct VarInfo { + uint32_t varId; + size_t decorationOffset; + }; + + std::unordered_set candidates; + std::unordered_map decorations; + std::vector flatVars; + + size_t decorateOffset = 0; + + for (auto ins : code) { + switch (ins.opCode()) { + case spv::OpDecorate: { + decorateOffset = ins.offset() + ins.length(); + uint32_t varId = ins.arg(1); + + switch (ins.arg(2)) { + case spv::DecorationLocation: { + uint32_t location = ins.arg(3); + + if (inputMask & (1u << location)) + candidates.insert(varId); + } break; + + case spv::DecorationFlat: + case spv::DecorationCentroid: + case spv::DecorationSample: + case spv::DecorationNoPerspective: { + decorations.insert({ varId, ins.offset() + 2 }); + } break; + + default: ; + } + } break; + + case spv::OpVariable: { + if (ins.arg(3) == spv::StorageClassInput) { + uint32_t varId = ins.arg(2); + + // Only consider variables that have a desired location + if (candidates.find(varId) != candidates.end()) { + VarInfo varInfo; + varInfo.varId = varId; + varInfo.decorationOffset = 0; + + auto decoration = decorations.find(varId); + if (decoration != decorations.end()) + varInfo.decorationOffset = decoration->second; + + flatVars.push_back(varInfo); + } + } + } break; + + default: + break; + } + } + + // Change existing decorations as necessary + for (const auto& var : flatVars) { + if (var.decorationOffset) { + uint32_t* rawCode = code.data(); + rawCode[var.decorationOffset] = spv::DecorationFlat; + } + } + + // Insert new decorations for remaining variables + code.beginInsertion(decorateOffset); + + for (const auto& var : flatVars) { + if (!var.decorationOffset) { + code.putIns(spv::OpDecorate, 3); + code.putWord(var.varId); + code.putWord(spv::DecorationFlat); + } + } + + code.endInsertion(); + } + + DxvkShaderStageInfo::DxvkShaderStageInfo(const DxvkDevice* device) : m_device(device) { diff --git a/src/dxvk/dxvk_shader.h b/src/dxvk/dxvk_shader.h index 0a628054c..6363e3a41 100644 --- a/src/dxvk/dxvk_shader.h +++ b/src/dxvk/dxvk_shader.h @@ -46,6 +46,8 @@ namespace dxvk { /// Input and output register mask uint32_t inputMask = 0; uint32_t outputMask = 0; + /// Flat shading input mask + uint32_t flatShadingInputs = 0; /// Push constant range uint32_t pushConstOffset = 0; uint32_t pushConstSize = 0; @@ -64,6 +66,7 @@ namespace dxvk { */ struct DxvkShaderModuleCreateInfo { bool fsDualSrcBlend = false; + bool fsFlatShading = false; uint32_t undefinedInputs = 0; std::array rtSwizzles = { }; @@ -237,6 +240,10 @@ namespace dxvk { uint32_t outputMask, const VkComponentMapping* swizzles); + static void emitFlatShadingDeclarations( + SpirvCodeBuffer& code, + uint32_t inputMask); + };