1
0
mirror of https://github.com/doitsujin/dxvk.git synced 2024-12-03 04:24:11 +01:00

[d3d11] Refactor shader binding

This way we can get rid of an unnecessary template and make future
extensions possible.
This commit is contained in:
Philip Rebohle 2018-07-30 19:34:48 +02:00
parent b7bdd9de38
commit 7f0f7ac048
No known key found for this signature in database
GPG Key ID: C8CC613427A31C99
8 changed files with 95 additions and 55 deletions

View File

@ -1281,7 +1281,10 @@ namespace dxvk {
if (m_state.vs.shader != shader) { if (m_state.vs.shader != shader) {
m_state.vs.shader = shader; m_state.vs.shader = shader;
BindShader(shader, VK_SHADER_STAGE_VERTEX_BIT);
BindShader(
DxbcProgramType::VertexShader,
GetCommonShader(shader));
} }
} }
@ -1407,7 +1410,10 @@ namespace dxvk {
if (m_state.hs.shader != shader) { if (m_state.hs.shader != shader) {
m_state.hs.shader = shader; m_state.hs.shader = shader;
BindShader(shader, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
BindShader(
DxbcProgramType::HullShader,
GetCommonShader(shader));
} }
} }
@ -1533,7 +1539,10 @@ namespace dxvk {
if (m_state.ds.shader != shader) { if (m_state.ds.shader != shader) {
m_state.ds.shader = shader; m_state.ds.shader = shader;
BindShader(shader, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
BindShader(
DxbcProgramType::DomainShader,
GetCommonShader(shader));
} }
} }
@ -1659,7 +1668,10 @@ namespace dxvk {
if (m_state.gs.shader != shader) { if (m_state.gs.shader != shader) {
m_state.gs.shader = shader; m_state.gs.shader = shader;
BindShader(shader, VK_SHADER_STAGE_GEOMETRY_BIT);
BindShader(
DxbcProgramType::GeometryShader,
GetCommonShader(shader));
} }
} }
@ -1785,7 +1797,10 @@ namespace dxvk {
if (m_state.ps.shader != shader) { if (m_state.ps.shader != shader) {
m_state.ps.shader = shader; m_state.ps.shader = shader;
BindShader(shader, VK_SHADER_STAGE_FRAGMENT_BIT);
BindShader(
DxbcProgramType::PixelShader,
GetCommonShader(shader));
} }
} }
@ -1911,7 +1926,10 @@ namespace dxvk {
if (m_state.cs.shader != shader) { if (m_state.cs.shader != shader) {
m_state.cs.shader = shader; m_state.cs.shader = shader;
BindShader(shader, VK_SHADER_STAGE_COMPUTE_BIT);
BindShader(
DxbcProgramType::ComputeShader,
GetCommonShader(shader));
} }
} }
@ -2536,6 +2554,18 @@ namespace dxvk {
} }
void D3D11DeviceContext::BindShader(
DxbcProgramType ShaderStage,
const D3D11CommonShader* pShaderModule) {
EmitCs([
cStage = GetShaderStage(ShaderStage),
cShader = pShaderModule != nullptr ? pShaderModule->GetShader() : nullptr
] (DxvkContext* ctx) {
ctx->bindShader(cStage, cShader);
});
}
void D3D11DeviceContext::BindFramebuffer(BOOL Spill) { void D3D11DeviceContext::BindFramebuffer(BOOL Spill) {
// NOTE According to the Microsoft docs, we are supposed to // NOTE According to the Microsoft docs, we are supposed to
// unbind overlapping shader resource views. Since this comes // unbind overlapping shader resource views. Since this comes
@ -2840,12 +2870,12 @@ namespace dxvk {
void D3D11DeviceContext::RestoreState() { void D3D11DeviceContext::RestoreState() {
BindFramebuffer(m_state.om.isUavRendering); BindFramebuffer(m_state.om.isUavRendering);
BindShader(m_state.vs.shader.ptr(), VK_SHADER_STAGE_VERTEX_BIT); BindShader(DxbcProgramType::VertexShader, GetCommonShader(m_state.vs.shader.ptr()));
BindShader(m_state.hs.shader.ptr(), VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT); BindShader(DxbcProgramType::HullShader, GetCommonShader(m_state.hs.shader.ptr()));
BindShader(m_state.ds.shader.ptr(), VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT); BindShader(DxbcProgramType::DomainShader, GetCommonShader(m_state.ds.shader.ptr()));
BindShader(m_state.gs.shader.ptr(), VK_SHADER_STAGE_GEOMETRY_BIT); BindShader(DxbcProgramType::GeometryShader, GetCommonShader(m_state.gs.shader.ptr()));
BindShader(m_state.ps.shader.ptr(), VK_SHADER_STAGE_FRAGMENT_BIT); BindShader(DxbcProgramType::PixelShader, GetCommonShader(m_state.ps.shader.ptr()));
BindShader(m_state.cs.shader.ptr(), VK_SHADER_STAGE_COMPUTE_BIT); BindShader(DxbcProgramType::ComputeShader, GetCommonShader(m_state.cs.shader.ptr()));
ApplyInputLayout(); ApplyInputLayout();
ApplyPrimitiveTopology(); ApplyPrimitiveTopology();

View File

@ -667,21 +667,13 @@ namespace dxvk {
void ApplyViewportState(); void ApplyViewportState();
void BindShader(
DxbcProgramType ShaderStage,
const D3D11CommonShader* pShaderModule);
void BindFramebuffer( void BindFramebuffer(
BOOL Spill); BOOL Spill);
template<typename T>
void BindShader(
T* pShader,
VkShaderStageFlagBits Stage) {
EmitCs([
cShader = pShader != nullptr ? pShader->GetShader() : nullptr,
cStage = Stage
] (DxvkContext* ctx) {
ctx->bindShader(cStage, cShader);
});
}
void BindVertexBuffer( void BindVertexBuffer(
UINT Slot, UINT Slot,
D3D11Buffer* pBuffer, D3D11Buffer* pBuffer,
@ -783,6 +775,11 @@ namespace dxvk {
DxvkDataSlice AllocUpdateBufferSlice(size_t Size); DxvkDataSlice AllocUpdateBufferSlice(size_t Size);
template<typename T>
const D3D11CommonShader* GetCommonShader(T* pShader) const {
return pShader != nullptr ? pShader->GetCommonShader() : nullptr;
}
template<typename Cmd> template<typename Cmd>
void EmitCs(Cmd&& command) { void EmitCs(Cmd&& command) {
if (!m_csChunk->push(command)) { if (!m_csChunk->push(command)) {

View File

@ -1073,7 +1073,7 @@ namespace dxvk {
ID3D11ClassLinkage* pClassLinkage, ID3D11ClassLinkage* pClassLinkage,
ID3D11VertexShader** ppVertexShader) { ID3D11VertexShader** ppVertexShader) {
InitReturnPtr(ppVertexShader); InitReturnPtr(ppVertexShader);
D3D11ShaderModule module; D3D11CommonShader module;
DxbcModuleInfo moduleInfo; DxbcModuleInfo moduleInfo;
moduleInfo.options = m_dxbcOptions; moduleInfo.options = m_dxbcOptions;
@ -1097,7 +1097,7 @@ namespace dxvk {
ID3D11ClassLinkage* pClassLinkage, ID3D11ClassLinkage* pClassLinkage,
ID3D11GeometryShader** ppGeometryShader) { ID3D11GeometryShader** ppGeometryShader) {
InitReturnPtr(ppGeometryShader); InitReturnPtr(ppGeometryShader);
D3D11ShaderModule module; D3D11CommonShader module;
DxbcModuleInfo moduleInfo; DxbcModuleInfo moduleInfo;
moduleInfo.options = m_dxbcOptions; moduleInfo.options = m_dxbcOptions;
@ -1140,7 +1140,7 @@ namespace dxvk {
ID3D11ClassLinkage* pClassLinkage, ID3D11ClassLinkage* pClassLinkage,
ID3D11PixelShader** ppPixelShader) { ID3D11PixelShader** ppPixelShader) {
InitReturnPtr(ppPixelShader); InitReturnPtr(ppPixelShader);
D3D11ShaderModule module; D3D11CommonShader module;
DxbcModuleInfo moduleInfo; DxbcModuleInfo moduleInfo;
moduleInfo.options = m_dxbcOptions; moduleInfo.options = m_dxbcOptions;
@ -1164,7 +1164,7 @@ namespace dxvk {
ID3D11ClassLinkage* pClassLinkage, ID3D11ClassLinkage* pClassLinkage,
ID3D11HullShader** ppHullShader) { ID3D11HullShader** ppHullShader) {
InitReturnPtr(ppHullShader); InitReturnPtr(ppHullShader);
D3D11ShaderModule module; D3D11CommonShader module;
DxbcModuleInfo moduleInfo; DxbcModuleInfo moduleInfo;
moduleInfo.options = m_dxbcOptions; moduleInfo.options = m_dxbcOptions;
@ -1188,7 +1188,7 @@ namespace dxvk {
ID3D11ClassLinkage* pClassLinkage, ID3D11ClassLinkage* pClassLinkage,
ID3D11DomainShader** ppDomainShader) { ID3D11DomainShader** ppDomainShader) {
InitReturnPtr(ppDomainShader); InitReturnPtr(ppDomainShader);
D3D11ShaderModule module; D3D11CommonShader module;
DxbcModuleInfo moduleInfo; DxbcModuleInfo moduleInfo;
moduleInfo.options = m_dxbcOptions; moduleInfo.options = m_dxbcOptions;
@ -1212,7 +1212,7 @@ namespace dxvk {
ID3D11ClassLinkage* pClassLinkage, ID3D11ClassLinkage* pClassLinkage,
ID3D11ComputeShader** ppComputeShader) { ID3D11ComputeShader** ppComputeShader) {
InitReturnPtr(ppComputeShader); InitReturnPtr(ppComputeShader);
D3D11ShaderModule module; D3D11CommonShader module;
DxbcModuleInfo moduleInfo; DxbcModuleInfo moduleInfo;
moduleInfo.options = m_dxbcOptions; moduleInfo.options = m_dxbcOptions;
@ -1852,7 +1852,7 @@ namespace dxvk {
HRESULT D3D11Device::CreateShaderModule( HRESULT D3D11Device::CreateShaderModule(
D3D11ShaderModule* pShaderModule, D3D11CommonShader* pShaderModule,
const void* pShaderBytecode, const void* pShaderBytecode,
size_t BytecodeLength, size_t BytecodeLength,
ID3D11ClassLinkage* pClassLinkage, ID3D11ClassLinkage* pClassLinkage,

View File

@ -20,6 +20,7 @@ namespace dxvk {
class DxgiAdapter; class DxgiAdapter;
class D3D11Buffer; class D3D11Buffer;
class D3D11CommonShader;
class D3D11CommonTexture; class D3D11CommonTexture;
class D3D11Counter; class D3D11Counter;
class D3D11DeviceContext; class D3D11DeviceContext;
@ -27,7 +28,6 @@ namespace dxvk {
class D3D11Predicate; class D3D11Predicate;
class D3D11Presenter; class D3D11Presenter;
class D3D11Query; class D3D11Query;
class D3D11ShaderModule;
class D3D11Texture1D; class D3D11Texture1D;
class D3D11Texture2D; class D3D11Texture2D;
class D3D11Texture3D; class D3D11Texture3D;
@ -372,7 +372,7 @@ namespace dxvk {
D3D11ShaderModuleSet m_shaderModules; D3D11ShaderModuleSet m_shaderModules;
HRESULT CreateShaderModule( HRESULT CreateShaderModule(
D3D11ShaderModule* pShaderModule, D3D11CommonShader* pShaderModule,
const void* pShaderBytecode, const void* pShaderBytecode,
size_t BytecodeLength, size_t BytecodeLength,
ID3D11ClassLinkage* pClassLinkage, ID3D11ClassLinkage* pClassLinkage,

View File

@ -34,11 +34,11 @@ namespace dxvk {
} }
D3D11ShaderModule:: D3D11ShaderModule() { } D3D11CommonShader:: D3D11CommonShader() { }
D3D11ShaderModule::~D3D11ShaderModule() { } D3D11CommonShader::~D3D11CommonShader() { }
D3D11ShaderModule::D3D11ShaderModule( D3D11CommonShader::D3D11CommonShader(
const D3D11ShaderKey* pShaderKey, const D3D11ShaderKey* pShaderKey,
const DxbcModuleInfo* pDxbcModuleInfo, const DxbcModuleInfo* pDxbcModuleInfo,
const void* pShaderBytecode, const void* pShaderBytecode,
@ -91,7 +91,7 @@ namespace dxvk {
D3D11ShaderModuleSet::~D3D11ShaderModuleSet() { } D3D11ShaderModuleSet::~D3D11ShaderModuleSet() { }
D3D11ShaderModule D3D11ShaderModuleSet::GetShaderModule( D3D11CommonShader D3D11ShaderModuleSet::GetShaderModule(
const DxbcModuleInfo* pDxbcModuleInfo, const DxbcModuleInfo* pDxbcModuleInfo,
const void* pShaderBytecode, const void* pShaderBytecode,
size_t BytecodeLength, size_t BytecodeLength,
@ -108,7 +108,7 @@ namespace dxvk {
// This shader has not been compiled yet, so we have to create a // This shader has not been compiled yet, so we have to create a
// new module. This takes a while, so we won't lock the structure. // new module. This takes a while, so we won't lock the structure.
D3D11ShaderModule module(&key, pDxbcModuleInfo, pShaderBytecode, BytecodeLength); D3D11CommonShader module(&key, pDxbcModuleInfo, pShaderBytecode, BytecodeLength);
// Insert the new module into the lookup table. If another thread // Insert the new module into the lookup table. If another thread
// has compiled the same shader in the meantime, we should return // has compiled the same shader in the meantime, we should return

View File

@ -57,23 +57,25 @@ namespace dxvk {
/** /**
* \brief Shader module * \brief Common shader object
* *
* Stores the compiled SPIR-V shader and the SHA-1 * Stores the compiled SPIR-V shader and the SHA-1
* hash of the original DXBC shader, which can be * hash of the original DXBC shader, which can be
* used to identify the shader. * used to identify the shader.
*/ */
class D3D11ShaderModule { class D3D11CommonShader {
public: public:
D3D11ShaderModule(); D3D11CommonShader();
D3D11ShaderModule( D3D11CommonShader(
const D3D11ShaderKey* pShaderKey, const D3D11ShaderKey* pShaderKey,
const DxbcModuleInfo* pDxbcModuleInfo, const DxbcModuleInfo* pDxbcModuleInfo,
const void* pShaderBytecode, const void* pShaderBytecode,
size_t BytecodeLength); size_t BytecodeLength);
~D3D11ShaderModule(); ~D3D11CommonShader();
DxbcProgramType GetProgramType() const;
Rc<DxvkShader> GetShader() const { Rc<DxvkShader> GetShader() const {
return m_shader; return m_shader;
@ -103,8 +105,8 @@ namespace dxvk {
public: public:
D3D11Shader(D3D11Device* device, const D3D11ShaderModule& module) D3D11Shader(D3D11Device* device, const D3D11CommonShader& shader)
: m_device(device), m_module(module) { } : m_device(device), m_shader(shader) { }
~D3D11Shader() { } ~D3D11Shader() { }
@ -126,18 +128,14 @@ namespace dxvk {
*ppDevice = m_device.ref(); *ppDevice = m_device.ref();
} }
Rc<DxvkShader> STDMETHODCALLTYPE GetShader() const { const D3D11CommonShader* GetCommonShader() const {
return m_module.GetShader(); return &m_shader;
}
const std::string& GetName() const {
return m_module.GetName();
} }
private: private:
Com<D3D11Device> m_device; Com<D3D11Device> m_device;
D3D11ShaderModule m_module; D3D11CommonShader m_shader;
}; };
@ -164,7 +162,7 @@ namespace dxvk {
D3D11ShaderModuleSet(); D3D11ShaderModuleSet();
~D3D11ShaderModuleSet(); ~D3D11ShaderModuleSet();
D3D11ShaderModule GetShaderModule( D3D11CommonShader GetShaderModule(
const DxbcModuleInfo* pDxbcModuleInfo, const DxbcModuleInfo* pDxbcModuleInfo,
const void* pShaderBytecode, const void* pShaderBytecode,
size_t BytecodeLength, size_t BytecodeLength,
@ -176,7 +174,7 @@ namespace dxvk {
std::unordered_map< std::unordered_map<
D3D11ShaderKey, D3D11ShaderKey,
D3D11ShaderModule, D3D11CommonShader,
D3D11ShaderKeyHash> m_modules; D3D11ShaderKeyHash> m_modules;
}; };

View File

@ -116,4 +116,16 @@ namespace dxvk {
return memoryFlags; return memoryFlags;
} }
VkShaderStageFlagBits GetShaderStage(DxbcProgramType ProgramType) {
switch (ProgramType) {
case DxbcProgramType::VertexShader: return VK_SHADER_STAGE_VERTEX_BIT;
case DxbcProgramType::HullShader: return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
case DxbcProgramType::DomainShader: return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
case DxbcProgramType::GeometryShader: return VK_SHADER_STAGE_GEOMETRY_BIT;
case DxbcProgramType::PixelShader: return VK_SHADER_STAGE_FRAGMENT_BIT;
case DxbcProgramType::ComputeShader: return VK_SHADER_STAGE_COMPUTE_BIT;
}
}
} }

View File

@ -37,4 +37,7 @@ namespace dxvk {
VkMemoryPropertyFlags GetMemoryFlagsForUsage( VkMemoryPropertyFlags GetMemoryFlagsForUsage(
D3D11_USAGE Usage); D3D11_USAGE Usage);
VkShaderStageFlagBits GetShaderStage(
DxbcProgramType ProgramType);
} }