From f76a7c285cfd3e3159fb35dd60576d1f85c43b22 Mon Sep 17 00:00:00 2001 From: Philip Rebohle Date: Mon, 9 Jan 2023 14:39:55 +0100 Subject: [PATCH] [dxvk] Rework DxvkShaderPipelineLibrary to work with multiple shaders --- src/dxvk/dxvk_graphics.cpp | 21 +++- src/dxvk/dxvk_shader.cpp | 210 +++++++++++++++++++++++++++---------- src/dxvk/dxvk_shader.h | 57 ++++++++-- 3 files changed, 221 insertions(+), 67 deletions(-) diff --git a/src/dxvk/dxvk_graphics.cpp b/src/dxvk/dxvk_graphics.cpp index aada10e0f..424e5c36d 100644 --- a/src/dxvk/dxvk_graphics.cpp +++ b/src/dxvk/dxvk_graphics.cpp @@ -1094,10 +1094,15 @@ namespace dxvk { return false; if (m_shaders.fs != nullptr) { - // If the fragment shader has inputs not produced by the - // vertex shader, we need to patch the fragment shader - uint32_t vsIoMask = m_shaders.vs->info().outputMask; + // If the fragment shader has inputs not produced by the last + // pre-rasterization stage, we need to patch the fragment shader uint32_t fsIoMask = m_shaders.fs->info().inputMask; + uint32_t vsIoMask = m_shaders.vs->info().outputMask; + + if (m_shaders.gs != nullptr) + vsIoMask = m_shaders.gs->info().outputMask; + else if (m_shaders.tes != nullptr) + vsIoMask = m_shaders.tes->info().outputMask; if ((vsIoMask & fsIoMask) != fsIoMask) return false; @@ -1226,10 +1231,16 @@ namespace dxvk { DxvkShaderStageInfo stageInfo(m_device); if (flags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT) { - stageInfo.addStage(VK_SHADER_STAGE_VERTEX_BIT, m_vsLibrary->getModuleIdentifier(), &key.scState.scInfo); + stageInfo.addStage(VK_SHADER_STAGE_VERTEX_BIT, m_vsLibrary->getModuleIdentifier(VK_SHADER_STAGE_VERTEX_BIT), &key.scState.scInfo); + if (m_shaders.tcs != nullptr) + stageInfo.addStage(VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, m_vsLibrary->getModuleIdentifier(VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT), &key.scState.scInfo); + if (m_shaders.tes != nullptr) + stageInfo.addStage(VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, m_vsLibrary->getModuleIdentifier(VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT), &key.scState.scInfo); + if (m_shaders.gs != nullptr) + stageInfo.addStage(VK_SHADER_STAGE_GEOMETRY_BIT, m_vsLibrary->getModuleIdentifier(VK_SHADER_STAGE_GEOMETRY_BIT), &key.scState.scInfo); if (m_shaders.fs != nullptr) - stageInfo.addStage(VK_SHADER_STAGE_FRAGMENT_BIT, m_fsLibrary->getModuleIdentifier(), &key.scState.scInfo); + stageInfo.addStage(VK_SHADER_STAGE_FRAGMENT_BIT, m_fsLibrary->getModuleIdentifier(VK_SHADER_STAGE_FRAGMENT_BIT), &key.scState.scInfo); } else { stageInfo.addStage(VK_SHADER_STAGE_VERTEX_BIT, getShaderCode(m_shaders.vs, key.shState.vsInfo), &key.scState.scInfo); diff --git a/src/dxvk/dxvk_shader.cpp b/src/dxvk/dxvk_shader.cpp index d9d166553..740276dab 100644 --- a/src/dxvk/dxvk_shader.cpp +++ b/src/dxvk/dxvk_shader.cpp @@ -893,9 +893,21 @@ namespace dxvk { const DxvkBindingLayoutObjects* layout) : m_device (device), m_stats (&manager->m_stats), - m_shader (shader), m_layout (layout) { - + if (shader) { + switch (shader->info().stage) { + case VK_SHADER_STAGE_VERTEX_BIT: + m_shaders.vs = shader; + break; + case VK_SHADER_STAGE_FRAGMENT_BIT: + m_shaders.fs = shader; + break; + case VK_SHADER_STAGE_COMPUTE_BIT: + m_shaders.cs = shader; + break; + default: ; + } + } } @@ -904,17 +916,19 @@ namespace dxvk { } - VkShaderModuleIdentifierEXT DxvkShaderPipelineLibrary::getModuleIdentifier() { + VkShaderModuleIdentifierEXT DxvkShaderPipelineLibrary::getModuleIdentifier( + VkShaderStageFlagBits stage) { std::lock_guard lock(m_identifierMutex); + auto identifier = getShaderIdentifier(stage); - if (!m_identifier.identifierSize) { + if (!identifier->identifierSize) { // Unfortunate, but we'll have to decode the // shader code here to retrieve the identifier - SpirvCodeBuffer spirvCode = this->getShaderCode(); - this->generateModuleIdentifierLocked(spirvCode); + SpirvCodeBuffer spirvCode = this->getShaderCode(stage); + this->generateModuleIdentifierLocked(identifier, spirvCode); } - return m_identifier; + return *identifier; } @@ -925,9 +939,7 @@ namespace dxvk { if (m_device->mustTrackPipelineLifetime()) m_useCount += 1; - VkShaderStageFlagBits stage = getShaderStage(); - - VkPipeline& pipeline = (stage == VK_SHADER_STAGE_VERTEX_BIT && !args.depthClipEnable) + VkPipeline& pipeline = (m_shaders.vs && !args.depthClipEnable) ? m_pipelineNoDepthClip : m_pipeline; @@ -988,22 +1000,20 @@ namespace dxvk { VkPipeline DxvkShaderPipelineLibrary::compileShaderPipelineLocked( const DxvkShaderPipelineLibraryCompileArgs& args) { - VkShaderStageFlagBits stage = getShaderStage(); - VkPipeline pipeline = VK_NULL_HANDLE; - - if (m_shader) - m_shader->notifyLibraryCompile(); + this->notifyLibraryCompile(); // If this is not the first time we're compiling the pipeline, // try to get a cache hit using the shader module identifier // so that we don't have to decompress our SPIR-V shader again. + VkPipeline pipeline = VK_NULL_HANDLE; + if (m_compiledOnce && canUsePipelineCacheControl()) { - pipeline = this->compileShaderPipeline(args, stage, + pipeline = this->compileShaderPipeline(args, VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT); } if (!pipeline) - pipeline = this->compileShaderPipeline(args, stage, 0); + pipeline = this->compileShaderPipeline(args, 0); // Well that didn't work if (!pipeline) @@ -1012,7 +1022,7 @@ namespace dxvk { // Increment stat counter the first time this // shader pipeline gets compiled successfully if (!m_compiledOnce) { - if (stage == VK_SHADER_STAGE_COMPUTE_BIT) + if (m_shaders.cs) m_stats->numComputePipelines += 1; else m_stats->numGraphicsLibraries += 1; @@ -1026,46 +1036,49 @@ namespace dxvk { VkPipeline DxvkShaderPipelineLibrary::compileShaderPipeline( const DxvkShaderPipelineLibraryCompileArgs& args, - VkShaderStageFlagBits stage, VkPipelineCreateFlags flags) { DxvkShaderStageInfo stageInfo(m_device); + VkShaderStageFlags stageMask = getShaderStages(); { std::lock_guard lock(m_identifierMutex); + VkShaderStageFlags stages = stageMask; - if (flags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT) { - // Fail if we have no idenfitier for whatever reason, caller - // should fall back to the slow path if this happens - if (!m_identifier.identifierSize) - return VK_NULL_HANDLE; + while (stages) { + auto stage = VkShaderStageFlagBits(stages & -stages); + auto identifier = getShaderIdentifier(stage); - stageInfo.addStage(stage, m_identifier, nullptr); - } else { - // Decompress code and generate identifier as needed - SpirvCodeBuffer spirvCode = this->getShaderCode(); + if (flags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT) { + // Fail if we have no idenfitier for whatever reason, caller + // should fall back to the slow path if this happens + if (!identifier->identifierSize) + return VK_NULL_HANDLE; - if (!m_identifier.identifierSize) - this->generateModuleIdentifierLocked(spirvCode); + stageInfo.addStage(stage, *identifier, nullptr); + } else { + // Decompress code and generate identifier as needed + SpirvCodeBuffer spirvCode = this->getShaderCode(stage); - stageInfo.addStage(stage, std::move(spirvCode), nullptr); + if (!identifier->identifierSize) + this->generateModuleIdentifierLocked(identifier, spirvCode); + + stageInfo.addStage(stage, std::move(spirvCode), nullptr); + } + + stages &= stages - 1; } } - switch (stage) { - case VK_SHADER_STAGE_VERTEX_BIT: - return compileVertexShaderPipeline(args, stageInfo, flags); - break; + if (stageMask & VK_SHADER_STAGE_VERTEX_BIT) + return compileVertexShaderPipeline(args, stageInfo, flags); - case VK_SHADER_STAGE_FRAGMENT_BIT: - return compileFragmentShaderPipeline(stageInfo, flags); - break; + if (stageMask & VK_SHADER_STAGE_FRAGMENT_BIT) + return compileFragmentShaderPipeline(stageInfo, flags); - case VK_SHADER_STAGE_COMPUTE_BIT: - return compileComputeShaderPipeline(stageInfo, flags); + if (stageMask & VK_SHADER_STAGE_COMPUTE_BIT) + return compileComputeShaderPipeline(stageInfo, flags); - default: - // Should be unreachable - return VK_NULL_HANDLE; - } + // Should be unreachable + return VK_NULL_HANDLE; } @@ -1167,7 +1180,7 @@ namespace dxvk { dynamicStates[dynamicStateCount++] = VK_DYNAMIC_STATE_DEPTH_BOUNDS; } - bool hasSampleRateShading = m_shader && m_shader->flags().test(DxvkShaderFlag::HasSampleRateShading); + bool hasSampleRateShading = m_shaders.fs && m_shaders.fs->flags().test(DxvkShaderFlag::HasSampleRateShading); bool hasDynamicMultisampleState = hasSampleRateShading && m_device->features().extExtendedDynamicState3.extendedDynamicState3RasterizationSamples && m_device->features().extExtendedDynamicState3.extendedDynamicState3SampleMask; @@ -1252,20 +1265,23 @@ namespace dxvk { } - SpirvCodeBuffer DxvkShaderPipelineLibrary::getShaderCode() const { + SpirvCodeBuffer DxvkShaderPipelineLibrary::getShaderCode(VkShaderStageFlagBits stage) const { // As a special case, it is possible that we have to deal with // a null shader, but the pipeline library extension requires // us to always specify a fragment shader for fragment stages, // so we need to return a dummy shader in that case. - if (!m_shader) + DxvkShader* shader = getShader(stage); + + if (!shader) return SpirvCodeBuffer(dxvk_dummy_frag); - return m_shader->getCode(m_layout, DxvkShaderModuleCreateInfo()); + return shader->getCode(m_layout, DxvkShaderModuleCreateInfo()); } void DxvkShaderPipelineLibrary::generateModuleIdentifierLocked( - const SpirvCodeBuffer& spirvCode) { + VkShaderModuleIdentifierEXT* identifier, + const SpirvCodeBuffer& spirvCode) { auto vk = m_device->vkd(); if (!canUsePipelineCacheControl()) @@ -1276,17 +1292,101 @@ namespace dxvk { info.pCode = spirvCode.data(); vk->vkGetShaderModuleCreateInfoIdentifierEXT( - vk->device(), &info, &m_identifier); + vk->device(), &info, identifier); } - VkShaderStageFlagBits DxvkShaderPipelineLibrary::getShaderStage() const { - VkShaderStageFlagBits stage = VK_SHADER_STAGE_FRAGMENT_BIT; + VkShaderStageFlags DxvkShaderPipelineLibrary::getShaderStages() const { + if (m_shaders.vs) { + VkShaderStageFlags result = VK_SHADER_STAGE_VERTEX_BIT; - if (m_shader != nullptr) - stage = m_shader->info().stage; + if (m_shaders.tcs) + result |= VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT; - return stage; + if (m_shaders.tes) + result |= VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT; + + if (m_shaders.gs) + result |= VK_SHADER_STAGE_GEOMETRY_BIT; + + return result; + } + + if (m_shaders.cs) + return VK_SHADER_STAGE_COMPUTE_BIT; + + // Must be a fragment shader even if fs is null + return VK_SHADER_STAGE_FRAGMENT_BIT; + } + + + DxvkShader* DxvkShaderPipelineLibrary::getShader( + VkShaderStageFlagBits stage) const { + switch (stage) { + case VK_SHADER_STAGE_VERTEX_BIT: + return m_shaders.vs; + + case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT: + return m_shaders.tcs; + + case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT: + return m_shaders.tes; + + case VK_SHADER_STAGE_GEOMETRY_BIT: + return m_shaders.gs; + + case VK_SHADER_STAGE_FRAGMENT_BIT: + return m_shaders.fs; + + case VK_SHADER_STAGE_COMPUTE_BIT: + return m_shaders.cs; + + default: + return nullptr; + } + } + + + VkShaderModuleIdentifierEXT* DxvkShaderPipelineLibrary::getShaderIdentifier( + VkShaderStageFlagBits stage) { + switch (stage) { + case VK_SHADER_STAGE_VERTEX_BIT: + return &m_identifiers.vs; + + case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT: + return &m_identifiers.tcs; + + case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT: + return &m_identifiers.tes; + + case VK_SHADER_STAGE_GEOMETRY_BIT: + return &m_identifiers.gs; + + case VK_SHADER_STAGE_FRAGMENT_BIT: + return &m_identifiers.fs; + + case VK_SHADER_STAGE_COMPUTE_BIT: + return &m_identifiers.cs; + + default: + return nullptr; + } + } + + + void DxvkShaderPipelineLibrary::notifyLibraryCompile() const { + if (m_shaders.vs) { + // Only notify the shader itself if we're actually + // building the shader's standalone pipeline library + if (!m_shaders.tcs && !m_shaders.tes && !m_shaders.gs) + m_shaders.vs->notifyLibraryCompile(); + } + + if (m_shaders.fs) + m_shaders.fs->notifyLibraryCompile(); + + if (m_shaders.cs) + m_shaders.cs->notifyLibraryCompile(); } diff --git a/src/dxvk/dxvk_shader.h b/src/dxvk/dxvk_shader.h index 6e5a21057..9ab46fb4f 100644 --- a/src/dxvk/dxvk_shader.h +++ b/src/dxvk/dxvk_shader.h @@ -396,6 +396,38 @@ namespace dxvk { }; + /** + * \brief Shader set + * + * Stores a set of shader pointers + * for use in a pipeline library. + */ + struct DxvkShaderSet { + DxvkShader* vs = nullptr; + DxvkShader* tcs = nullptr; + DxvkShader* tes = nullptr; + DxvkShader* gs = nullptr; + DxvkShader* fs = nullptr; + DxvkShader* cs = nullptr; + }; + + + /** + * \brief Shader identifer set + * + * Stores a set of shader module identifiers + * for use in a pipeline library. + */ + struct DxvkShaderIdentifierSet { + VkShaderModuleIdentifierEXT vs = { VK_STRUCTURE_TYPE_SHADER_MODULE_IDENTIFIER_EXT }; + VkShaderModuleIdentifierEXT tcs = { VK_STRUCTURE_TYPE_SHADER_MODULE_IDENTIFIER_EXT }; + VkShaderModuleIdentifierEXT tes = { VK_STRUCTURE_TYPE_SHADER_MODULE_IDENTIFIER_EXT }; + VkShaderModuleIdentifierEXT gs = { VK_STRUCTURE_TYPE_SHADER_MODULE_IDENTIFIER_EXT }; + VkShaderModuleIdentifierEXT fs = { VK_STRUCTURE_TYPE_SHADER_MODULE_IDENTIFIER_EXT }; + VkShaderModuleIdentifierEXT cs = { VK_STRUCTURE_TYPE_SHADER_MODULE_IDENTIFIER_EXT }; + }; + + /** * \brief Shader pipeline library * @@ -423,9 +455,11 @@ namespace dxvk { * Can be used to compile an optimized pipeline using the same * shader code, but without having to wait for the pipeline * library for this shader shader to compile first. + * \param [in] stage Shader stage to query * \returns Shader module identifier */ - VkShaderModuleIdentifierEXT getModuleIdentifier(); + VkShaderModuleIdentifierEXT getModuleIdentifier( + VkShaderStageFlagBits stage); /** * \brief Acquires pipeline handle for the given set of arguments @@ -461,7 +495,7 @@ namespace dxvk { const DxvkDevice* m_device; DxvkPipelineStats* m_stats; - DxvkShader* m_shader; + DxvkShaderSet m_shaders; const DxvkBindingLayoutObjects* m_layout; dxvk::mutex m_mutex; @@ -471,7 +505,7 @@ namespace dxvk { bool m_compiledOnce = false; dxvk::mutex m_identifierMutex; - VkShaderModuleIdentifierEXT m_identifier = { VK_STRUCTURE_TYPE_SHADER_MODULE_IDENTIFIER_EXT }; + DxvkShaderIdentifierSet m_identifiers; void destroyShaderPipelinesLocked(); @@ -480,7 +514,6 @@ namespace dxvk { VkPipeline compileShaderPipeline( const DxvkShaderPipelineLibraryCompileArgs& args, - VkShaderStageFlagBits stage, VkPipelineCreateFlags flags); VkPipeline compileVertexShaderPipeline( @@ -496,12 +529,22 @@ namespace dxvk { const DxvkShaderStageInfo& stageInfo, VkPipelineCreateFlags flags); - SpirvCodeBuffer getShaderCode() const; + SpirvCodeBuffer getShaderCode( + VkShaderStageFlagBits stage) const; void generateModuleIdentifierLocked( - const SpirvCodeBuffer& spirvCode); + VkShaderModuleIdentifierEXT* identifier, + const SpirvCodeBuffer& spirvCode); - VkShaderStageFlagBits getShaderStage() const; + VkShaderStageFlags getShaderStages() const; + + DxvkShader* getShader( + VkShaderStageFlagBits stage) const; + + VkShaderModuleIdentifierEXT* getShaderIdentifier( + VkShaderStageFlagBits stage); + + void notifyLibraryCompile() const; bool canUsePipelineCacheControl() const;