1
0
mirror of https://github.com/doitsujin/dxvk.git synced 2024-12-02 19:24:12 +01:00

[spirv] Implement basic dead code elimination

Fixes invalid SPIR-V in Trails through Daybreak.
This commit is contained in:
Philip Rebohle 2024-08-12 15:15:36 +02:00 committed by Philip Rebohle
parent 6308266a0f
commit 813b653645
4 changed files with 122 additions and 3 deletions

View File

@ -50,6 +50,18 @@ namespace dxvk {
} }
void SpirvCodeBuffer::append(const SpirvInstruction& ins) {
const size_t size = m_code.size();
m_code.resize(size + ins.length());
for (uint32_t i = 0; i < ins.length(); i++)
m_code[size + i] = ins.arg(i);
m_ptr += ins.length();
}
void SpirvCodeBuffer::append(const SpirvCodeBuffer& other) { void SpirvCodeBuffer::append(const SpirvCodeBuffer& other) {
if (other.size() != 0) { if (other.size() != 0) {
const size_t size = m_code.size(); const size_t size = m_code.size();

View File

@ -89,6 +89,14 @@ namespace dxvk {
*/ */
uint32_t allocId(); uint32_t allocId();
/**
* \brief Appends an instruction
*
* Slightly faster than individually adding words.
* \param [in] ins Instruction
*/
void append(const SpirvInstruction& ins);
/** /**
* \brief Merges two code buffers * \brief Merges two code buffers
* *

View File

@ -15,7 +15,7 @@ namespace dxvk {
} }
SpirvCodeBuffer SpirvModule::compile() const { SpirvCodeBuffer SpirvModule::compile() {
SpirvCodeBuffer result; SpirvCodeBuffer result;
result.putHeader(m_version, m_id); result.putHeader(m_version, m_id);
result.append(m_capabilities); result.append(m_capabilities);
@ -28,7 +28,35 @@ namespace dxvk {
result.append(m_annotations); result.append(m_annotations);
result.append(m_typeConstDefs); result.append(m_typeConstDefs);
result.append(m_variables); result.append(m_variables);
result.append(m_code);
// Perform some crude dead code elimination. In some cases, our compilers
// may emit invalid code, such as an unreachable block branching to a loop's
// continue block, but those cases cannot be reasonably detected up-front.
std::unordered_set<uint32_t> reachableBlocks;
std::unordered_set<uint32_t> mergeBlocks;
classifyBlocks(reachableBlocks, mergeBlocks);
bool reachable = true;
for (auto ins : m_code) {
if (ins.opCode() == spv::OpFunctionEnd) {
reachable = true;
result.append(ins);
} else if (ins.opCode() == spv::OpLabel) {
uint32_t labelId = ins.arg(1);
if ((reachable = reachableBlocks.find(labelId) != reachableBlocks.end())) {
result.append(ins);
} else if (mergeBlocks.find(labelId) != mergeBlocks.end()) {
result.append(ins);
result.putIns(spv::OpUnreachable, 1);
}
} else if (reachable) {
result.append(ins);
}
}
return result; return result;
} }
@ -3905,4 +3933,69 @@ namespace dxvk {
} }
} }
void SpirvModule::classifyBlocks(
std::unordered_set<uint32_t>& reachableBlocks,
std::unordered_set<uint32_t>& mergeBlocks) {
std::unordered_multimap<uint32_t, uint32_t> branches;
std::queue<uint32_t> blockQueue;
uint32_t blockId = 0;
for (auto ins : m_code) {
switch (ins.opCode()) {
case spv::OpLabel: {
uint32_t id = ins.arg(1);
if (!blockId)
branches.insert({ 0u, id });
blockId = id;
} break;
case spv::OpFunction: {
blockId = 0u;
} break;
case spv::OpBranch: {
branches.insert({ blockId, ins.arg(1) });
} break;
case spv::OpBranchConditional: {
branches.insert({ blockId, ins.arg(2) });
branches.insert({ blockId, ins.arg(3) });
} break;
case spv::OpSwitch: {
branches.insert({ blockId, ins.arg(2) });
for (uint32_t i = 4; i < ins.length(); i += 2)
branches.insert({ blockId, ins.arg(i) });
} break;
case spv::OpSelectionMerge:
case spv::OpLoopMerge: {
mergeBlocks.insert(ins.arg(1));
} break;
default:;
}
}
blockQueue.push(0);
while (!blockQueue.empty()) {
uint32_t id = blockQueue.front();
auto range = branches.equal_range(id);
for (auto i = range.first; i != range.second; i++) {
if (reachableBlocks.insert(i->second).second)
blockQueue.push(i->second);
}
blockQueue.pop();
}
}
} }

View File

@ -1,5 +1,7 @@
#pragma once #pragma once
#include <queue>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "spirv_code_buffer.h" #include "spirv_code_buffer.h"
@ -59,7 +61,7 @@ namespace dxvk {
~SpirvModule(); ~SpirvModule();
SpirvCodeBuffer compile() const; SpirvCodeBuffer compile();
size_t getInsertionPtr() { size_t getInsertionPtr() {
return m_code.getInsertionPtr(); return m_code.getInsertionPtr();
@ -1326,6 +1328,10 @@ namespace dxvk {
bool isInterfaceVar( bool isInterfaceVar(
spv::StorageClass sclass) const; spv::StorageClass sclass) const;
void classifyBlocks(
std::unordered_set<uint32_t>& reachableBlocks,
std::unordered_set<uint32_t>& mergeBlocks);
}; };
} }