diff --git a/src/d3d11/d3d11_device.cpp b/src/d3d11/d3d11_device.cpp index 74eec122a..00d8dffd4 100644 --- a/src/d3d11/d3d11_device.cpp +++ b/src/d3d11/d3d11_device.cpp @@ -624,10 +624,14 @@ namespace dxvk { moduleInfo.options = m_dxbcOptions; moduleInfo.tess = nullptr; moduleInfo.xfb = nullptr; + + Sha1Hash hash = Sha1Hash::compute( + pShaderBytecode, BytecodeLength); if (FAILED(this->CreateShaderModule(&module, + DxvkShaderKey(VK_SHADER_STAGE_VERTEX_BIT, hash), pShaderBytecode, BytecodeLength, pClassLinkage, - &moduleInfo, DxbcProgramType::VertexShader))) + &moduleInfo))) return E_INVALIDARG; if (ppVertexShader == nullptr) @@ -651,9 +655,13 @@ namespace dxvk { moduleInfo.tess = nullptr; moduleInfo.xfb = nullptr; + Sha1Hash hash = Sha1Hash::compute( + pShaderBytecode, BytecodeLength); + if (FAILED(this->CreateShaderModule(&module, + DxvkShaderKey(VK_SHADER_STAGE_GEOMETRY_BIT, hash), pShaderBytecode, BytecodeLength, pClassLinkage, - &moduleInfo, DxbcProgramType::GeometryShader))) + &moduleInfo))) return E_INVALIDARG; if (ppGeometryShader == nullptr) @@ -686,12 +694,8 @@ namespace dxvk { // Zero-init some counterss so that we can increment // them while walking over the stream output entries - DxbcXfbInfo xfb; - xfb.entryCount = 0; + DxbcXfbInfo xfb = { }; - for (uint32_t i = 0; i < D3D11_SO_BUFFER_SLOT_COUNT; i++) - xfb.strides[i] = 0; - for (uint32_t i = 0; i < NumEntries; i++) { const D3D11_SO_DECLARATION_ENTRY* so = &pSODeclaration[i]; @@ -727,7 +731,16 @@ namespace dxvk { if (RasterizedStream != D3D11_SO_NO_RASTERIZED_STREAM) Logger::err("D3D11: CreateGeometryShaderWithStreamOutput: Rasterized stream not supported"); + + // Compute hash from both the xfb info and the source + // code, because both influence the generated code + std::array chunks = {{ + { pShaderBytecode, BytecodeLength }, + { &xfb, sizeof(xfb) }, + }}; + Sha1Hash hash = Sha1Hash::compute(chunks.size(), chunks.data()); + // Create the actual shader module DxbcModuleInfo moduleInfo; moduleInfo.options = m_dxbcOptions; @@ -735,8 +748,9 @@ namespace dxvk { moduleInfo.xfb = &xfb; if (FAILED(this->CreateShaderModule(&module, + DxvkShaderKey(VK_SHADER_STAGE_GEOMETRY_BIT, hash), pShaderBytecode, BytecodeLength, pClassLinkage, - &moduleInfo, DxbcProgramType::GeometryShader))) + &moduleInfo))) return E_INVALIDARG; if (ppGeometryShader == nullptr) @@ -760,9 +774,13 @@ namespace dxvk { moduleInfo.tess = nullptr; moduleInfo.xfb = nullptr; + Sha1Hash hash = Sha1Hash::compute( + pShaderBytecode, BytecodeLength); + if (FAILED(this->CreateShaderModule(&module, + DxvkShaderKey(VK_SHADER_STAGE_FRAGMENT_BIT, hash), pShaderBytecode, BytecodeLength, pClassLinkage, - &moduleInfo, DxbcProgramType::PixelShader))) + &moduleInfo))) return E_INVALIDARG; if (ppPixelShader == nullptr) @@ -792,9 +810,12 @@ namespace dxvk { if (tessInfo.maxTessFactor >= 8.0f) moduleInfo.tess = &tessInfo; + Sha1Hash hash = Sha1Hash::compute( + pShaderBytecode, BytecodeLength); + if (FAILED(this->CreateShaderModule(&module, - pShaderBytecode, BytecodeLength, pClassLinkage, - &moduleInfo, DxbcProgramType::HullShader))) + DxvkShaderKey(VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, hash), + pShaderBytecode, BytecodeLength, pClassLinkage, &moduleInfo))) return E_INVALIDARG; if (ppHullShader == nullptr) @@ -818,9 +839,12 @@ namespace dxvk { moduleInfo.tess = nullptr; moduleInfo.xfb = nullptr; + Sha1Hash hash = Sha1Hash::compute( + pShaderBytecode, BytecodeLength); + if (FAILED(this->CreateShaderModule(&module, - pShaderBytecode, BytecodeLength, pClassLinkage, - &moduleInfo, DxbcProgramType::DomainShader))) + DxvkShaderKey(VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, hash), + pShaderBytecode, BytecodeLength, pClassLinkage, &moduleInfo))) return E_INVALIDARG; if (ppDomainShader == nullptr) @@ -844,9 +868,13 @@ namespace dxvk { moduleInfo.tess = nullptr; moduleInfo.xfb = nullptr; + Sha1Hash hash = Sha1Hash::compute( + pShaderBytecode, BytecodeLength); + if (FAILED(this->CreateShaderModule(&module, + DxvkShaderKey(VK_SHADER_STAGE_COMPUTE_BIT, hash), pShaderBytecode, BytecodeLength, pClassLinkage, - &moduleInfo, DxbcProgramType::ComputeShader))) + &moduleInfo))) return E_INVALIDARG; if (ppComputeShader == nullptr) @@ -1512,17 +1540,17 @@ namespace dxvk { HRESULT D3D11Device::CreateShaderModule( D3D11CommonShader* pShaderModule, + DxvkShaderKey ShaderKey, const void* pShaderBytecode, size_t BytecodeLength, ID3D11ClassLinkage* pClassLinkage, - const DxbcModuleInfo* pModuleInfo, - DxbcProgramType ProgramType) { + const DxbcModuleInfo* pModuleInfo) { if (pClassLinkage != nullptr) Logger::warn("D3D11Device::CreateShaderModule: Class linkage not supported"); try { *pShaderModule = m_shaderModules.GetShaderModule(this, - pModuleInfo, pShaderBytecode, BytecodeLength, ProgramType); + &ShaderKey, pModuleInfo, pShaderBytecode, BytecodeLength); return S_OK; } catch (const DxvkError& e) { Logger::err(e.message()); diff --git a/src/d3d11/d3d11_device.h b/src/d3d11/d3d11_device.h index 0580b890a..c8fccffc8 100644 --- a/src/d3d11/d3d11_device.h +++ b/src/d3d11/d3d11_device.h @@ -394,11 +394,11 @@ namespace dxvk { HRESULT CreateShaderModule( D3D11CommonShader* pShaderModule, + DxvkShaderKey ShaderKey, const void* pShaderBytecode, size_t BytecodeLength, ID3D11ClassLinkage* pClassLinkage, - const DxbcModuleInfo* pModuleInfo, - DxbcProgramType ProgramType); + const DxbcModuleInfo* pModuleInfo); HRESULT GetFormatSupportFlags( DXGI_FORMAT Format, diff --git a/src/d3d11/d3d11_shader.cpp b/src/d3d11/d3d11_shader.cpp index ba2ef8f1e..970ce3925 100644 --- a/src/d3d11/d3d11_shader.cpp +++ b/src/d3d11/d3d11_shader.cpp @@ -81,23 +81,21 @@ namespace dxvk { D3D11CommonShader D3D11ShaderModuleSet::GetShaderModule( D3D11Device* pDevice, + const DxvkShaderKey* pShaderKey, const DxbcModuleInfo* pDxbcModuleInfo, const void* pShaderBytecode, - size_t BytecodeLength, - DxbcProgramType ProgramType) { - // Compute the shader's unique key so that we can perform a lookup - DxvkShaderKey key(GetShaderStage(ProgramType), pShaderBytecode, BytecodeLength); - + size_t BytecodeLength) { + // Use the shader's unique key for the lookup { std::unique_lock lock(m_mutex); - auto entry = m_modules.find(key); + auto entry = m_modules.find(*pShaderKey); if (entry != m_modules.end()) return entry->second; } // This shader has not been compiled yet, so we have to create a // new module. This takes a while, so we won't lock the structure. - D3D11CommonShader module(pDevice, &key, + D3D11CommonShader module(pDevice, pShaderKey, pDxbcModuleInfo, pShaderBytecode, BytecodeLength); // Insert the new module into the lookup table. If another thread @@ -105,7 +103,7 @@ namespace dxvk { // that object instead and discard the newly created module. { std::unique_lock lock(m_mutex); - auto status = m_modules.insert({ key, module }); + auto status = m_modules.insert({ *pShaderKey, module }); if (!status.second) return status.first->second; } diff --git a/src/d3d11/d3d11_shader.h b/src/d3d11/d3d11_shader.h index 75f8c5feb..6c3c5ebab 100644 --- a/src/d3d11/d3d11_shader.h +++ b/src/d3d11/d3d11_shader.h @@ -142,10 +142,10 @@ namespace dxvk { D3D11CommonShader GetShaderModule( D3D11Device* pDevice, + const DxvkShaderKey* pShaderKey, const DxbcModuleInfo* pDxbcModuleInfo, const void* pShaderBytecode, - size_t BytecodeLength, - DxbcProgramType ProgramType); + size_t BytecodeLength); private: diff --git a/src/dxvk/dxvk_shader_key.h b/src/dxvk/dxvk_shader_key.h index 657438f4c..c4e787e3f 100644 --- a/src/dxvk/dxvk_shader_key.h +++ b/src/dxvk/dxvk_shader_key.h @@ -35,6 +35,17 @@ namespace dxvk { const void* sourceCode, size_t sourceSize); + /** + * \brief Creates shader key + * + * \param [in] stage Shader stage + * \param [in] hash Shader hash + */ + DxvkShaderKey( + VkShaderStageFlagBits stage, + Sha1Hash hash) + : m_type(stage), m_sha1(hash) { } + /** * \brief Generates string from shader key * \returns String representation of the key