diff --git a/src/d3d11/d3d11_context.cpp b/src/d3d11/d3d11_context.cpp index 1fdf030e..f299df2c 100644 --- a/src/d3d11/d3d11_context.cpp +++ b/src/d3d11/d3d11_context.cpp @@ -1281,7 +1281,10 @@ namespace dxvk { if (m_state.vs.shader != shader) { m_state.vs.shader = shader; - BindShader(shader, VK_SHADER_STAGE_VERTEX_BIT); + + BindShader( + DxbcProgramType::VertexShader, + GetCommonShader(shader)); } } @@ -1407,7 +1410,10 @@ namespace dxvk { if (m_state.hs.shader != shader) { m_state.hs.shader = shader; - BindShader(shader, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT); + + BindShader( + DxbcProgramType::HullShader, + GetCommonShader(shader)); } } @@ -1533,7 +1539,10 @@ namespace dxvk { if (m_state.ds.shader != shader) { m_state.ds.shader = shader; - BindShader(shader, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT); + + BindShader( + DxbcProgramType::DomainShader, + GetCommonShader(shader)); } } @@ -1659,7 +1668,10 @@ namespace dxvk { if (m_state.gs.shader != shader) { m_state.gs.shader = shader; - BindShader(shader, VK_SHADER_STAGE_GEOMETRY_BIT); + + BindShader( + DxbcProgramType::GeometryShader, + GetCommonShader(shader)); } } @@ -1785,7 +1797,10 @@ namespace dxvk { if (m_state.ps.shader != shader) { m_state.ps.shader = shader; - BindShader(shader, VK_SHADER_STAGE_FRAGMENT_BIT); + + BindShader( + DxbcProgramType::PixelShader, + GetCommonShader(shader)); } } @@ -1911,7 +1926,10 @@ namespace dxvk { if (m_state.cs.shader != shader) { m_state.cs.shader = shader; - BindShader(shader, VK_SHADER_STAGE_COMPUTE_BIT); + + BindShader( + DxbcProgramType::ComputeShader, + GetCommonShader(shader)); } } @@ -2536,6 +2554,18 @@ namespace dxvk { } + void D3D11DeviceContext::BindShader( + DxbcProgramType ShaderStage, + const D3D11CommonShader* pShaderModule) { + EmitCs([ + cStage = GetShaderStage(ShaderStage), + cShader = pShaderModule != nullptr ? pShaderModule->GetShader() : nullptr + ] (DxvkContext* ctx) { + ctx->bindShader(cStage, cShader); + }); + } + + void D3D11DeviceContext::BindFramebuffer(BOOL Spill) { // NOTE According to the Microsoft docs, we are supposed to // unbind overlapping shader resource views. Since this comes @@ -2840,12 +2870,12 @@ namespace dxvk { void D3D11DeviceContext::RestoreState() { BindFramebuffer(m_state.om.isUavRendering); - BindShader(m_state.vs.shader.ptr(), VK_SHADER_STAGE_VERTEX_BIT); - BindShader(m_state.hs.shader.ptr(), VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT); - BindShader(m_state.ds.shader.ptr(), VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT); - BindShader(m_state.gs.shader.ptr(), VK_SHADER_STAGE_GEOMETRY_BIT); - BindShader(m_state.ps.shader.ptr(), VK_SHADER_STAGE_FRAGMENT_BIT); - BindShader(m_state.cs.shader.ptr(), VK_SHADER_STAGE_COMPUTE_BIT); + BindShader(DxbcProgramType::VertexShader, GetCommonShader(m_state.vs.shader.ptr())); + BindShader(DxbcProgramType::HullShader, GetCommonShader(m_state.hs.shader.ptr())); + BindShader(DxbcProgramType::DomainShader, GetCommonShader(m_state.ds.shader.ptr())); + BindShader(DxbcProgramType::GeometryShader, GetCommonShader(m_state.gs.shader.ptr())); + BindShader(DxbcProgramType::PixelShader, GetCommonShader(m_state.ps.shader.ptr())); + BindShader(DxbcProgramType::ComputeShader, GetCommonShader(m_state.cs.shader.ptr())); ApplyInputLayout(); ApplyPrimitiveTopology(); diff --git a/src/d3d11/d3d11_context.h b/src/d3d11/d3d11_context.h index a3cfa68d..d3ad9e35 100644 --- a/src/d3d11/d3d11_context.h +++ b/src/d3d11/d3d11_context.h @@ -667,21 +667,13 @@ namespace dxvk { void ApplyViewportState(); + void BindShader( + DxbcProgramType ShaderStage, + const D3D11CommonShader* pShaderModule); + void BindFramebuffer( BOOL Spill); - template - void BindShader( - T* pShader, - VkShaderStageFlagBits Stage) { - EmitCs([ - cShader = pShader != nullptr ? pShader->GetShader() : nullptr, - cStage = Stage - ] (DxvkContext* ctx) { - ctx->bindShader(cStage, cShader); - }); - } - void BindVertexBuffer( UINT Slot, D3D11Buffer* pBuffer, @@ -783,6 +775,11 @@ namespace dxvk { DxvkDataSlice AllocUpdateBufferSlice(size_t Size); + template + const D3D11CommonShader* GetCommonShader(T* pShader) const { + return pShader != nullptr ? pShader->GetCommonShader() : nullptr; + } + template void EmitCs(Cmd&& command) { if (!m_csChunk->push(command)) { diff --git a/src/d3d11/d3d11_device.cpp b/src/d3d11/d3d11_device.cpp index b7db2a39..3696f177 100644 --- a/src/d3d11/d3d11_device.cpp +++ b/src/d3d11/d3d11_device.cpp @@ -1073,7 +1073,7 @@ namespace dxvk { ID3D11ClassLinkage* pClassLinkage, ID3D11VertexShader** ppVertexShader) { InitReturnPtr(ppVertexShader); - D3D11ShaderModule module; + D3D11CommonShader module; DxbcModuleInfo moduleInfo; moduleInfo.options = m_dxbcOptions; @@ -1097,7 +1097,7 @@ namespace dxvk { ID3D11ClassLinkage* pClassLinkage, ID3D11GeometryShader** ppGeometryShader) { InitReturnPtr(ppGeometryShader); - D3D11ShaderModule module; + D3D11CommonShader module; DxbcModuleInfo moduleInfo; moduleInfo.options = m_dxbcOptions; @@ -1140,7 +1140,7 @@ namespace dxvk { ID3D11ClassLinkage* pClassLinkage, ID3D11PixelShader** ppPixelShader) { InitReturnPtr(ppPixelShader); - D3D11ShaderModule module; + D3D11CommonShader module; DxbcModuleInfo moduleInfo; moduleInfo.options = m_dxbcOptions; @@ -1164,7 +1164,7 @@ namespace dxvk { ID3D11ClassLinkage* pClassLinkage, ID3D11HullShader** ppHullShader) { InitReturnPtr(ppHullShader); - D3D11ShaderModule module; + D3D11CommonShader module; DxbcModuleInfo moduleInfo; moduleInfo.options = m_dxbcOptions; @@ -1188,7 +1188,7 @@ namespace dxvk { ID3D11ClassLinkage* pClassLinkage, ID3D11DomainShader** ppDomainShader) { InitReturnPtr(ppDomainShader); - D3D11ShaderModule module; + D3D11CommonShader module; DxbcModuleInfo moduleInfo; moduleInfo.options = m_dxbcOptions; @@ -1212,7 +1212,7 @@ namespace dxvk { ID3D11ClassLinkage* pClassLinkage, ID3D11ComputeShader** ppComputeShader) { InitReturnPtr(ppComputeShader); - D3D11ShaderModule module; + D3D11CommonShader module; DxbcModuleInfo moduleInfo; moduleInfo.options = m_dxbcOptions; @@ -1852,7 +1852,7 @@ namespace dxvk { HRESULT D3D11Device::CreateShaderModule( - D3D11ShaderModule* pShaderModule, + D3D11CommonShader* pShaderModule, const void* pShaderBytecode, size_t BytecodeLength, ID3D11ClassLinkage* pClassLinkage, diff --git a/src/d3d11/d3d11_device.h b/src/d3d11/d3d11_device.h index 82d43101..28d801a5 100644 --- a/src/d3d11/d3d11_device.h +++ b/src/d3d11/d3d11_device.h @@ -20,6 +20,7 @@ namespace dxvk { class DxgiAdapter; class D3D11Buffer; + class D3D11CommonShader; class D3D11CommonTexture; class D3D11Counter; class D3D11DeviceContext; @@ -27,7 +28,6 @@ namespace dxvk { class D3D11Predicate; class D3D11Presenter; class D3D11Query; - class D3D11ShaderModule; class D3D11Texture1D; class D3D11Texture2D; class D3D11Texture3D; @@ -372,7 +372,7 @@ namespace dxvk { D3D11ShaderModuleSet m_shaderModules; HRESULT CreateShaderModule( - D3D11ShaderModule* pShaderModule, + D3D11CommonShader* pShaderModule, const void* pShaderBytecode, size_t BytecodeLength, ID3D11ClassLinkage* pClassLinkage, diff --git a/src/d3d11/d3d11_shader.cpp b/src/d3d11/d3d11_shader.cpp index 5a09db08..9e809396 100644 --- a/src/d3d11/d3d11_shader.cpp +++ b/src/d3d11/d3d11_shader.cpp @@ -34,11 +34,11 @@ namespace dxvk { } - D3D11ShaderModule:: D3D11ShaderModule() { } - D3D11ShaderModule::~D3D11ShaderModule() { } + D3D11CommonShader:: D3D11CommonShader() { } + D3D11CommonShader::~D3D11CommonShader() { } - D3D11ShaderModule::D3D11ShaderModule( + D3D11CommonShader::D3D11CommonShader( const D3D11ShaderKey* pShaderKey, const DxbcModuleInfo* pDxbcModuleInfo, const void* pShaderBytecode, @@ -85,13 +85,13 @@ namespace dxvk { m_shader->read(readStream); } } - + D3D11ShaderModuleSet:: D3D11ShaderModuleSet() { } D3D11ShaderModuleSet::~D3D11ShaderModuleSet() { } - D3D11ShaderModule D3D11ShaderModuleSet::GetShaderModule( + D3D11CommonShader D3D11ShaderModuleSet::GetShaderModule( const DxbcModuleInfo* pDxbcModuleInfo, const void* pShaderBytecode, size_t BytecodeLength, @@ -108,7 +108,7 @@ namespace dxvk { // This shader has not been compiled yet, so we have to create a // new module. This takes a while, so we won't lock the structure. - D3D11ShaderModule module(&key, pDxbcModuleInfo, pShaderBytecode, BytecodeLength); + D3D11CommonShader module(&key, pDxbcModuleInfo, pShaderBytecode, BytecodeLength); // Insert the new module into the lookup table. If another thread // has compiled the same shader in the meantime, we should return diff --git a/src/d3d11/d3d11_shader.h b/src/d3d11/d3d11_shader.h index 92126ac2..2c942348 100644 --- a/src/d3d11/d3d11_shader.h +++ b/src/d3d11/d3d11_shader.h @@ -57,23 +57,25 @@ namespace dxvk { /** - * \brief Shader module + * \brief Common shader object * * Stores the compiled SPIR-V shader and the SHA-1 * hash of the original DXBC shader, which can be * used to identify the shader. */ - class D3D11ShaderModule { + class D3D11CommonShader { public: - D3D11ShaderModule(); - D3D11ShaderModule( + D3D11CommonShader(); + D3D11CommonShader( const D3D11ShaderKey* pShaderKey, const DxbcModuleInfo* pDxbcModuleInfo, const void* pShaderBytecode, size_t BytecodeLength); - ~D3D11ShaderModule(); + ~D3D11CommonShader(); + + DxbcProgramType GetProgramType() const; Rc GetShader() const { return m_shader; @@ -103,8 +105,8 @@ namespace dxvk { public: - D3D11Shader(D3D11Device* device, const D3D11ShaderModule& module) - : m_device(device), m_module(module) { } + D3D11Shader(D3D11Device* device, const D3D11CommonShader& shader) + : m_device(device), m_shader(shader) { } ~D3D11Shader() { } @@ -126,18 +128,14 @@ namespace dxvk { *ppDevice = m_device.ref(); } - Rc STDMETHODCALLTYPE GetShader() const { - return m_module.GetShader(); - } - - const std::string& GetName() const { - return m_module.GetName(); + const D3D11CommonShader* GetCommonShader() const { + return &m_shader; } private: Com m_device; - D3D11ShaderModule m_module; + D3D11CommonShader m_shader; }; @@ -164,7 +162,7 @@ namespace dxvk { D3D11ShaderModuleSet(); ~D3D11ShaderModuleSet(); - D3D11ShaderModule GetShaderModule( + D3D11CommonShader GetShaderModule( const DxbcModuleInfo* pDxbcModuleInfo, const void* pShaderBytecode, size_t BytecodeLength, @@ -176,7 +174,7 @@ namespace dxvk { std::unordered_map< D3D11ShaderKey, - D3D11ShaderModule, + D3D11CommonShader, D3D11ShaderKeyHash> m_modules; }; diff --git a/src/d3d11/d3d11_util.cpp b/src/d3d11/d3d11_util.cpp index a537b6f1..b6643547 100644 --- a/src/d3d11/d3d11_util.cpp +++ b/src/d3d11/d3d11_util.cpp @@ -115,5 +115,17 @@ namespace dxvk { return memoryFlags; } + + + VkShaderStageFlagBits GetShaderStage(DxbcProgramType ProgramType) { + switch (ProgramType) { + case DxbcProgramType::VertexShader: return VK_SHADER_STAGE_VERTEX_BIT; + case DxbcProgramType::HullShader: return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT; + case DxbcProgramType::DomainShader: return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT; + case DxbcProgramType::GeometryShader: return VK_SHADER_STAGE_GEOMETRY_BIT; + case DxbcProgramType::PixelShader: return VK_SHADER_STAGE_FRAGMENT_BIT; + case DxbcProgramType::ComputeShader: return VK_SHADER_STAGE_COMPUTE_BIT; + } + } } \ No newline at end of file diff --git a/src/d3d11/d3d11_util.h b/src/d3d11/d3d11_util.h index 9673f513..b562d368 100644 --- a/src/d3d11/d3d11_util.h +++ b/src/d3d11/d3d11_util.h @@ -37,4 +37,7 @@ namespace dxvk { VkMemoryPropertyFlags GetMemoryFlagsForUsage( D3D11_USAGE Usage); + VkShaderStageFlagBits GetShaderStage( + DxbcProgramType ProgramType); + } \ No newline at end of file