From c40116716197b58aef4982b13315d2c12f0c3b13 Mon Sep 17 00:00:00 2001 From: Philip Rebohle Date: Sun, 17 Jul 2022 13:49:24 +0200 Subject: [PATCH] [dxvk] Introduce SPIR-V pass to inject render target swizzles --- src/dxvk/dxvk_shader.cpp | 262 +++++++++++++++++++++++++++++++++++++++ src/dxvk/dxvk_shader.h | 11 +- 2 files changed, 272 insertions(+), 1 deletion(-) diff --git a/src/dxvk/dxvk_shader.cpp b/src/dxvk/dxvk_shader.cpp index c05130f5a..f53dc7ebd 100644 --- a/src/dxvk/dxvk_shader.cpp +++ b/src/dxvk/dxvk_shader.cpp @@ -379,6 +379,268 @@ namespace dxvk { } + void DxvkShader::emitOutputSwizzles( + SpirvCodeBuffer& code, + uint32_t outputMask, + const VkComponentMapping* swizzles) { + // Skip this step entirely if all relevant + // outputs use the identity swizzle + bool requiresEpilogue = false; + + for (auto index : bit::BitMask(outputMask)) + requiresEpilogue |= !util::isIdentityMapping(swizzles[index]); + + if (!requiresEpilogue) + return; + + // Gather some information. We need to scan pointer types with + // the output storage class to find the base vector type, and + // we need to scan vector types to find the component count. + uint32_t entryPointId = 0; + uint32_t functionId = 0; + + size_t epilogueOffset = 0; + size_t variableOffset = 0; + + struct VarInfo { + uint32_t varId; + uint32_t typeId; + uint32_t location; + uint32_t componentCount; + uint32_t componentTypeId; + }; + + struct VarIdInfo { + uint32_t location; + }; + + struct TypeIdInfo { + uint32_t componentCount; + uint32_t baseTypeId; + }; + + union IdInfo { + VarIdInfo var; + TypeIdInfo type; + }; + + // Stores type information depending on type category: + // OpTypePointer: type id -> base type id + // OpTypeVector: type id -> component count + // OpTypeFloat/Int: type id -> 1 + std::unordered_map idInfo; + std::vector varInfos; + + SpirvInstruction prev; + + for (auto ins : code) { + switch (ins.opCode()) { + case spv::OpEntryPoint: { + entryPointId = ins.arg(2); + } break; + + case spv::OpDecorate: { + if (ins.arg(2) == spv::DecorationLocation) { + IdInfo info; + info.var.location = ins.arg(3); + idInfo.insert({ ins.arg(1), info }); + } + } break; + + case spv::OpTypeVector: { + IdInfo info; + info.type.componentCount = ins.arg(3); + info.type.baseTypeId = ins.arg(2); + idInfo.insert({ ins.arg(1), info }); + } break; + + case spv::OpTypeInt: + case spv::OpTypeFloat: { + IdInfo info; + info.type.componentCount = 1; + info.type.baseTypeId = 0; + idInfo.insert({ ins.arg(1), info }); + } break; + + case spv::OpTypePointer: { + if (ins.arg(2) == spv::StorageClassOutput) { + IdInfo info; + info.type.componentCount = 0; + info.type.baseTypeId = ins.arg(3); + idInfo.insert({ ins.arg(1), info }); + } + } break; + + case spv::OpVariable: { + if (!variableOffset) + variableOffset = ins.offset(); + + if (ins.arg(3) == spv::StorageClassOutput) { + uint32_t ptrId = ins.arg(1); + uint32_t varId = ins.arg(2); + + auto ptrEntry = idInfo.find(ptrId); + auto varEntry = idInfo.find(varId); + + if (ptrEntry != idInfo.end() + && varEntry != idInfo.end()) { + uint32_t typeId = ptrEntry->second.type.baseTypeId; + + auto typeEntry = idInfo.find(typeId); + if (typeEntry != idInfo.end()) { + VarInfo info; + info.varId = varId; + info.typeId = typeId; + info.location = varEntry->second.var.location; + info.componentCount = typeEntry->second.type.componentCount; + info.componentTypeId = (info.componentCount == 1) + ? typeId : typeEntry->second.type.baseTypeId; + + varInfos.push_back(info); + } + } + } + } break; + + case spv::OpFunction: { + functionId = ins.arg(2); + } break; + + case spv::OpFunctionEnd: { + if (entryPointId == functionId) + epilogueOffset = prev.offset(); + } break; + + default: + prev = ins; + } + + if (epilogueOffset) + break; + } + + // Oops, this shouldn't happen + if (!epilogueOffset) + return; + + code.beginInsertion(epilogueOffset); + + struct ConstInfo { + uint32_t constId; + uint32_t typeId; + uint32_t value; + }; + + std::vector consts; + + for (const auto& var : varInfos) { + uint32_t storeId = 0; + + if (var.componentCount == 1) { + if (util::getComponentIndex(swizzles[var.location].r, 0) != 0) { + storeId = code.allocId(); + + ConstInfo constInfo; + constInfo.constId = storeId; + constInfo.typeId = var.componentTypeId; + constInfo.value = 0; + consts.push_back(constInfo); + } + } else { + uint32_t constId = 0; + + std::array indices = {{ + util::getComponentIndex(swizzles[var.location].r, 0), + util::getComponentIndex(swizzles[var.location].g, 1), + util::getComponentIndex(swizzles[var.location].b, 2), + util::getComponentIndex(swizzles[var.location].a, 3), + }}; + + bool needsSwizzle = false; + + for (uint32_t i = 0; i < var.componentCount && !constId; i++) { + needsSwizzle |= indices[i] != i; + + if (indices[i] >= var.componentCount) + constId = code.allocId(); + } + + if (needsSwizzle) { + uint32_t loadId = code.allocId(); + code.putIns(spv::OpLoad, 4); + code.putWord(var.typeId); + code.putWord(loadId); + code.putWord(var.varId); + + if (!constId) { + storeId = code.allocId(); + code.putIns(spv::OpVectorShuffle, 5 + var.componentCount); + code.putWord(var.typeId); + code.putWord(storeId); + code.putWord(loadId); + code.putWord(loadId); + + for (uint32_t i = 0; i < var.componentCount; i++) + code.putWord(indices[i]); + } else { + std::array ids = { }; + + ConstInfo constInfo; + constInfo.constId = constId; + constInfo.typeId = var.componentTypeId; + constInfo.value = 0; + consts.push_back(constInfo); + + for (uint32_t i = 0; i < var.componentCount; i++) { + if (indices[i] < var.componentCount) { + ids[i] = code.allocId(); + + code.putIns(spv::OpCompositeExtract, 5); + code.putWord(var.componentTypeId); + code.putWord(ids[i]); + code.putWord(loadId); + code.putWord(indices[i]); + } else { + ids[i] = constId; + } + } + + storeId = code.allocId(); + code.putIns(spv::OpCompositeConstruct, 3 + var.componentCount); + code.putWord(var.typeId); + code.putWord(storeId); + + for (uint32_t i = 0; i < var.componentCount; i++) + code.putWord(ids[i]); + } + } + } + + if (storeId) { + code.putIns(spv::OpStore, 3); + code.putWord(var.varId); + code.putWord(storeId); + } + } + + code.endInsertion(); + + // If necessary, insert constants + if (!consts.empty()) { + code.beginInsertion(variableOffset); + + for (const auto& c : consts) { + code.putIns(spv::OpConstant, 4); + code.putWord(c.typeId); + code.putWord(c.constId); + code.putWord(c.value); + } + + code.endInsertion(); + } + } + + DxvkShaderStageInfo::DxvkShaderStageInfo(const DxvkDevice* device) : m_device(device) { diff --git a/src/dxvk/dxvk_shader.h b/src/dxvk/dxvk_shader.h index 6ed5c632e..905c8ed91 100644 --- a/src/dxvk/dxvk_shader.h +++ b/src/dxvk/dxvk_shader.h @@ -80,6 +80,8 @@ namespace dxvk { struct DxvkShaderModuleCreateInfo { bool fsDualSrcBlend = false; uint32_t undefinedInputs = 0; + + std::array rtSwizzles = { }; }; @@ -227,7 +229,14 @@ namespace dxvk { DxvkBindingLayout m_bindings; - static void eliminateInput(SpirvCodeBuffer& code, uint32_t location); + static void eliminateInput( + SpirvCodeBuffer& code, + uint32_t location); + + static void emitOutputSwizzles( + SpirvCodeBuffer& code, + uint32_t outputMask, + const VkComponentMapping* swizzles); };