1
0
mirror of https://github.com/doitsujin/dxvk.git synced 2025-02-27 04:54:15 +01:00

[dxvk] Add pass to patch GS input topology if necessary

This commit is contained in:
Philip Rebohle 2025-02-26 16:58:53 +01:00
parent 3b9dd40605
commit d05abbf8fd
2 changed files with 312 additions and 1 deletions

View File

@ -209,6 +209,12 @@ namespace dxvk {
for (uint32_t u : bit::BitMask(state.undefinedInputs))
eliminateInput(spirvCode, u);
// Patch primitive topology as necessary
if (m_info.stage == VK_SHADER_STAGE_GEOMETRY_BIT
&& state.inputTopology != m_info.inputTopology
&& state.inputTopology != VK_PRIMITIVE_TOPOLOGY_MAX_ENUM)
patchInputTopology(spirvCode, state.inputTopology);
// Emit fragment shader swizzles as necessary
if (m_info.stage == VK_SHADER_STAGE_FRAGMENT_BIT)
emitOutputSwizzles(spirvCode, m_info.outputMask, state.rtSwizzles.data());
@ -475,7 +481,7 @@ namespace dxvk {
}
}
}
void DxvkShader::emitOutputSwizzles(
SpirvCodeBuffer& code,
@ -828,6 +834,306 @@ namespace dxvk {
}
void DxvkShader::patchInputTopology(SpirvCodeBuffer& code, VkPrimitiveTopology topology) {
struct TopologyInfo {
VkPrimitiveTopology topology;
spv::ExecutionMode mode;
uint32_t vertexCount;
};
static const std::array<TopologyInfo, 5> s_topologies = {{
{ VK_PRIMITIVE_TOPOLOGY_POINT_LIST, spv::ExecutionModeInputPoints, 1u },
{ VK_PRIMITIVE_TOPOLOGY_LINE_LIST, spv::ExecutionModeInputLines, 2u },
{ VK_PRIMITIVE_TOPOLOGY_LINE_LIST_WITH_ADJACENCY, spv::ExecutionModeInputLinesAdjacency, 4u },
{ VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST, spv::ExecutionModeTriangles, 3u },
{ VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST_WITH_ADJACENCY, spv::ExecutionModeInputTrianglesAdjacency, 6u },
}};
const TopologyInfo* topologyInfo = nullptr;
for (const auto& top : s_topologies) {
if (top.topology == topology) {
topologyInfo = &top;
break;
}
}
if (!topologyInfo)
return;
uint32_t typeUint32Id = 0u;
uint32_t typeSint32Id = 0u;
struct ConstantInfo {
uint32_t typeId;
uint32_t value;
};
struct ArrayTypeInfo {
uint32_t arrayLengthId;
uint32_t scalarTypeId;
uint32_t replaceTypeId;
};
struct PointerTypeInfo {
uint32_t objectTypeId;
};
std::unordered_map<uint32_t, uint32_t> nullConstantsByType;
std::unordered_map<uint32_t, ConstantInfo> constants;
std::unordered_map<uint32_t, uint32_t> uintConstantValueToId;
std::unordered_map<uint32_t, ArrayTypeInfo> arrayTypes;
std::unordered_map<uint32_t, PointerTypeInfo> pointerTypes;
std::unordered_map<uint32_t, uint32_t> variableTypes;
std::unordered_set<uint32_t> nullAccessChains;
std::unordered_map<uint32_t, uint32_t> nullVarsByType;
std::vector<std::pair<uint32_t, uint32_t>> newNullVars;
uint32_t functionOffset = 0u;
for (auto iter = code.begin(); iter != code.end(); ) {
auto ins = *iter;
switch (ins.opCode()) {
case spv::OpExecutionMode: {
bool isTopology = false;
for (const auto& top : s_topologies)
isTopology |= spv::ExecutionMode(ins.arg(2)) == top.mode;
if (isTopology)
ins.setArg(2, uint32_t(topologyInfo->mode));
} break;
case spv::OpConstant: {
if (ins.arg(1) == typeUint32Id || ins.arg(1) == typeSint32Id) {
ConstantInfo c = { };
c.typeId = ins.arg(1);
c.value = ins.arg(3);
constants.insert({ ins.arg(2), c });
uintConstantValueToId.insert({ ins.arg(3), ins.arg(2) });
}
} break;
case spv::OpConstantNull: {
nullConstantsByType.insert({ ins.arg(1), ins.arg(2) });
} break;
case spv::OpTypeInt: {
if (ins.arg(2u) == 32u) {
if (ins.arg(3u))
typeSint32Id = ins.arg(1u);
else
typeUint32Id = ins.arg(1u);
}
} break;
case spv::OpTypeArray: {
ArrayTypeInfo t = { };
t.arrayLengthId = ins.arg(3);
t.scalarTypeId = ins.arg(2);
t.replaceTypeId = 0u;
arrayTypes.insert({ ins.arg(1), t });
} break;
case spv::OpTypePointer: {
// We know that all input arrays use the vertex count as their outer
// array size, so it is safe for us to simply replace the array type
// of any pointer type declaration with an appropriately sized array.
auto storageClass = spv::StorageClass(ins.arg(2));
if (storageClass == spv::StorageClassInput) {
uint32_t len = ins.length();
uint32_t arrayTypeId = 0u;
uint32_t scalarTypeId = 0u;
PointerTypeInfo t = { };
t.objectTypeId = ins.arg(3);
auto entry = arrayTypes.find(t.objectTypeId);
if (entry != arrayTypes.end()) {
if (!entry->second.replaceTypeId) {
arrayTypeId = code.allocId();
scalarTypeId = entry->second.scalarTypeId;
entry->second.replaceTypeId = arrayTypeId;
}
t.objectTypeId = entry->second.replaceTypeId;
ins.setArg(3, t.objectTypeId);
}
pointerTypes.insert({ ins.arg(1), t });
// If we replaced the array type, emit it before the pointer type
// decoration as necessary. It is legal to delcare identical array
// types multiple times.
if (arrayTypeId) {
code.beginInsertion(ins.offset());
auto lengthId = uintConstantValueToId.find(topologyInfo->vertexCount);
if (lengthId == uintConstantValueToId.end()) {
if (!typeUint32Id) {
typeUint32Id = code.allocId();
code.putIns (spv::OpTypeInt, 4);
code.putWord (typeUint32Id);
code.putWord (32);
code.putWord (0);
}
ConstantInfo c;
c.typeId = typeUint32Id;
c.value = topologyInfo->vertexCount;
uint32_t arrayLengthId = code.allocId();
code.putIns (spv::OpConstant, 4);
code.putWord (c.typeId);
code.putWord (arrayLengthId);
code.putWord (c.value);
lengthId = uintConstantValueToId.insert({ c.value, arrayLengthId }).first;
constants.insert({ arrayLengthId, c });
}
ArrayTypeInfo t = { };
t.scalarTypeId = scalarTypeId;
t.arrayLengthId = lengthId->second;
arrayTypes.insert({ arrayTypeId, t });
code.putIns (spv::OpTypeArray, 4);
code.putWord (arrayTypeId);
code.putWord (t.scalarTypeId);
code.putWord (t.arrayLengthId);
iter = SpirvInstructionIterator(code.data(), code.endInsertion() + len, code.dwords());
continue;
}
}
} break;
case spv::OpVariable: {
auto storageClass = spv::StorageClass(ins.arg(3));
if (storageClass == spv::StorageClassInput)
variableTypes.insert({ ins.arg(2), ins.arg(1) });
} break;
case spv::OpFunction: {
if (!functionOffset)
functionOffset = ins.offset();
} break;
case spv::OpAccessChain:
case spv::OpInBoundsAccessChain: {
bool nullChain = false;
auto var = variableTypes.find(ins.arg(3));
if (var == variableTypes.end()) {
// If we're recursively loading from a null access chain, skip
auto chain = nullAccessChains.find(ins.arg(3));
nullChain = chain != nullAccessChains.end();
} else {
// If the index is out of bounds, mark the access chain as
// dead so we can replace all loads with a null constant.
auto c = constants.find(ins.arg(4u));
if (c == constants.end())
break;
nullChain = c->second.value >= topologyInfo->vertexCount;
}
if (nullChain) {
nullAccessChains.insert(ins.arg(2));
code.beginInsertion(ins.offset());
code.erase(ins.length());
iter = SpirvInstructionIterator(code.data(), code.endInsertion(), code.dwords());
continue;
}
} break;
case spv::OpLoad: {
// If we're loading from a null access chain, replace with null constant load.
// We should never load the entire array at once, so ignore that case.
if (nullAccessChains.find(ins.arg(3)) != nullAccessChains.end()) {
auto var = nullVarsByType.find(ins.arg(1));
if (var == nullVarsByType.end()) {
var = nullVarsByType.insert({ ins.arg(1), code.allocId() }).first;
newNullVars.push_back(std::make_pair(var->second, ins.arg(1)));
}
ins.setArg(3, var->second);
}
} break;
default:;
}
iter++;
}
// Insert new null variables
code.beginInsertion(functionOffset);
for (auto v : newNullVars) {
auto nullConst = nullConstantsByType.find(v.second);
if (nullConst == nullConstantsByType.end()) {
uint32_t nullConstId = code.allocId();
code.putIns (spv::OpConstantNull, 3u);
code.putWord (v.second);
code.putWord (nullConstId);
nullConst = nullConstantsByType.insert({ v.second, nullConstId }).first;
}
uint32_t pointerTypeId = code.allocId();
code.putIns (spv::OpTypePointer, 4u);
code.putWord (pointerTypeId);
code.putWord (spv::StorageClassPrivate);
code.putWord (v.second);
code.putIns (spv::OpVariable, 5u);
code.putWord (pointerTypeId);
code.putWord (v.first);
code.putWord (spv::StorageClassPrivate);
code.putWord (nullConst->second);
}
code.endInsertion();
// Add newly declared null variables to entry point
for (auto ins : code) {
if (ins.opCode() == spv::OpEntryPoint) {
uint32_t len = ins.length();
uint32_t token = ins.opCode() | ((len + newNullVars.size()) << 16);
ins.setArg(0, token);
code.beginInsertion(ins.offset() + len);
for (auto v : newNullVars)
code.putWord(v.first);
code.endInsertion();
break;
}
}
}
DxvkShaderStageInfo::DxvkShaderStageInfo(const DxvkDevice* device)
: m_device(device) {

View File

@ -74,6 +74,7 @@ namespace dxvk {
bool fsDualSrcBlend = false;
bool fsFlatShading = false;
uint32_t undefinedInputs = 0;
VkPrimitiveTopology inputTopology = VK_PRIMITIVE_TOPOLOGY_MAX_ENUM;
std::array<VkComponentMapping, MaxNumRenderTargets> rtSwizzles = { };
@ -282,6 +283,10 @@ namespace dxvk {
SpirvCodeBuffer& code,
uint32_t inputMask);
static void patchInputTopology(
SpirvCodeBuffer& code,
VkPrimitiveTopology topology);
};