diff --git a/src/dxvk/dxvk_compute.cpp b/src/dxvk/dxvk_compute.cpp index d85070f2..ee908d20 100644 --- a/src/dxvk/dxvk_compute.cpp +++ b/src/dxvk/dxvk_compute.cpp @@ -94,13 +94,13 @@ namespace dxvk { VkSpecializationInfo specInfo = specData.getSpecInfo(); - DxvkShaderModuleCreateInfo moduleInfo; - moduleInfo.fsDualSrcBlend = false; - - auto csm = m_shaders.cs->createShaderModule(vk, m_bindings, moduleInfo); + DxvkShaderStageInfo stageInfo(m_device); + stageInfo.addStage(VK_SHADER_STAGE_COMPUTE_BIT, + m_shaders.cs->getCode(m_bindings, DxvkShaderModuleCreateInfo()), + &specInfo); VkComputePipelineCreateInfo info = { VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO }; - info.stage = csm.stageInfo(&specInfo); + info.stage = *stageInfo.getStageInfos(); info.layout = m_bindings->getPipelineLayout(); info.basePipelineIndex = -1; diff --git a/src/dxvk/dxvk_graphics.cpp b/src/dxvk/dxvk_graphics.cpp index 2f5ea8c3..a51ca90c 100644 --- a/src/dxvk/dxvk_graphics.cpp +++ b/src/dxvk/dxvk_graphics.cpp @@ -1,3 +1,5 @@ +#include + #include "../util/util_time.h" #include "dxvk_device.h" @@ -621,19 +623,19 @@ namespace dxvk { specData.set(getSpecId(i), state.sc.specConstants[i], 0u); VkSpecializationInfo specInfo = specData.getSpecInfo(); - - auto vsm = createShaderModule(m_shaders.vs, state); - auto tcsm = createShaderModule(m_shaders.tcs, state); - auto tesm = createShaderModule(m_shaders.tes, state); - auto gsm = createShaderModule(m_shaders.gs, state); - auto fsm = createShaderModule(m_shaders.fs, state); - std::vector stages; - if (vsm) stages.push_back(vsm.stageInfo(&specInfo)); - if (tcsm) stages.push_back(tcsm.stageInfo(&specInfo)); - if (tesm) stages.push_back(tesm.stageInfo(&specInfo)); - if (gsm) stages.push_back(gsm.stageInfo(&specInfo)); - if (fsm) stages.push_back(fsm.stageInfo(&specInfo)); + // Build stage infos for all provided shaders + DxvkShaderStageInfo stageInfo(m_device); + stageInfo.addStage(VK_SHADER_STAGE_VERTEX_BIT, getShaderCode(m_shaders.vs, state), &specInfo); + + if (m_shaders.tcs != nullptr) + stageInfo.addStage(VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, getShaderCode(m_shaders.tcs, state), &specInfo); + if (m_shaders.tes != nullptr) + stageInfo.addStage(VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, getShaderCode(m_shaders.tes, state), &specInfo); + if (m_shaders.gs != nullptr) + stageInfo.addStage(VK_SHADER_STAGE_GEOMETRY_BIT, getShaderCode(m_shaders.gs, state), &specInfo); + if (m_shaders.fs != nullptr) + stageInfo.addStage(VK_SHADER_STAGE_FRAGMENT_BIT, getShaderCode(m_shaders.fs, state), &specInfo); DxvkGraphicsPipelineVertexInputState viState(m_device, state); DxvkGraphicsPipelinePreRasterizationState prState(m_device, state, m_shaders.gs.ptr()); @@ -645,8 +647,8 @@ namespace dxvk { dyInfo.pDynamicStates = dynamicStates.data(); VkGraphicsPipelineCreateInfo info = { VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO, &foState.rtInfo }; - info.stageCount = stages.size(); - info.pStages = stages.data(); + info.stageCount = stageInfo.getStageCount(); + info.pStages = stageInfo.getStageInfos(); info.pVertexInputState = &viState.viInfo; info.pInputAssemblyState = &viState.iaInfo; info.pTessellationState = &prState.tsInfo; @@ -692,14 +694,11 @@ namespace dxvk { } - DxvkShaderModule DxvkGraphicsPipeline::createShaderModule( + SpirvCodeBuffer DxvkGraphicsPipeline::getShaderCode( const Rc& shader, const DxvkGraphicsPipelineStateInfo& state) const { auto vk = m_device->vkd(); - if (shader == nullptr) - return DxvkShaderModule(); - const DxvkShaderCreateInfo& shaderInfo = shader->info(); DxvkShaderModuleCreateInfo info; @@ -729,7 +728,8 @@ namespace dxvk { } info.undefinedInputs = (providedInputs & consumedInputs) ^ consumedInputs; - return shader->createShaderModule(vk, m_bindings, info); + + return shader->getCode(m_bindings, info); } diff --git a/src/dxvk/dxvk_graphics.h b/src/dxvk/dxvk_graphics.h index 50a056f5..c4c616a3 100644 --- a/src/dxvk/dxvk_graphics.h +++ b/src/dxvk/dxvk_graphics.h @@ -401,7 +401,7 @@ namespace dxvk { void destroyPipeline( VkPipeline pipeline) const; - DxvkShaderModule createShaderModule( + SpirvCodeBuffer getShaderCode( const Rc& shader, const DxvkGraphicsPipelineStateInfo& state) const; diff --git a/src/dxvk/dxvk_shader.cpp b/src/dxvk/dxvk_shader.cpp index 88336505..856aa897 100644 --- a/src/dxvk/dxvk_shader.cpp +++ b/src/dxvk/dxvk_shader.cpp @@ -7,60 +7,6 @@ namespace dxvk { - DxvkShaderModule::DxvkShaderModule() - : m_vkd(nullptr), m_stage() { - - } - - - DxvkShaderModule::DxvkShaderModule(DxvkShaderModule&& other) - : m_vkd(std::move(other.m_vkd)) { - this->m_stage = other.m_stage; - other.m_stage = VkPipelineShaderStageCreateInfo(); - } - - - DxvkShaderModule::DxvkShaderModule( - const Rc& vkd, - const Rc& shader, - const SpirvCodeBuffer& code) - : m_vkd(vkd), m_stage() { - m_stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; - m_stage.pNext = nullptr; - m_stage.flags = 0; - m_stage.stage = shader->info().stage; - m_stage.module = VK_NULL_HANDLE; - m_stage.pName = "main"; - m_stage.pSpecializationInfo = nullptr; - - VkShaderModuleCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.codeSize = code.size(); - info.pCode = code.data(); - - if (m_vkd->vkCreateShaderModule(m_vkd->device(), &info, nullptr, &m_stage.module) != VK_SUCCESS) - throw DxvkError("DxvkComputePipeline::DxvkComputePipeline: Failed to create shader module"); - } - - - DxvkShaderModule::~DxvkShaderModule() { - if (m_vkd != nullptr) { - m_vkd->vkDestroyShaderModule( - m_vkd->device(), m_stage.module, nullptr); - } - } - - - DxvkShaderModule& DxvkShaderModule::operator = (DxvkShaderModule&& other) { - this->m_vkd = std::move(other.m_vkd); - this->m_stage = other.m_stage; - other.m_stage = VkPipelineShaderStageCreateInfo(); - return *this; - } - - DxvkShader::DxvkShader( const DxvkShaderCreateInfo& info, SpirvCodeBuffer&& spirv) @@ -194,14 +140,6 @@ namespace dxvk { } - DxvkShaderModule DxvkShader::createShaderModule( - const Rc& vkd, - const DxvkBindingLayoutObjects* layout, - const DxvkShaderModuleCreateInfo& info) { - return DxvkShaderModule(vkd, this, getCode(layout, info)); - } - - bool DxvkShader::canUsePipelineLibrary() const { // Pipeline libraries are unsupported for geometry and // tessellation stages since we'd need to compile them @@ -428,6 +366,62 @@ namespace dxvk { } + DxvkShaderStageInfo::DxvkShaderStageInfo(const DxvkDevice* device) + : m_device(device) { + + } + + void DxvkShaderStageInfo::addStage( + VkShaderStageFlagBits stage, + SpirvCodeBuffer&& code, + const VkSpecializationInfo* specInfo) { + // Take ownership of the SPIR-V code buffer + auto& codeBuffer = m_codeBuffers[m_stageCount]; + codeBuffer = std::move(code); + + // For graphics pipelines, as long as graphics pipeline libraries are + // enabled, we do not need to create a shader module object and can + // instead chain the create info to the shader stage info struct. + // For compute pipelines, this doesn't work and we still need a module. + auto& moduleInfo = m_moduleInfos[m_stageCount]; + moduleInfo = { VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO }; + moduleInfo.codeSize = codeBuffer.size(); + moduleInfo.pCode = codeBuffer.data(); + + VkShaderModule shaderModule = VK_NULL_HANDLE; + if (!m_device->features().extGraphicsPipelineLibrary.graphicsPipelineLibrary || stage == VK_SHADER_STAGE_COMPUTE_BIT) { + auto vk = m_device->vkd(); + + if (vk->vkCreateShaderModule(vk->device(), &moduleInfo, nullptr, &shaderModule)) + throw DxvkError("DxvkShaderStageInfo: Failed to create shader module"); + } + + // Set up shader stage info with the data provided + auto& stageInfo = m_stageInfos[m_stageCount]; + stageInfo = { VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO }; + + if (!stageInfo.module) + stageInfo.pNext = &moduleInfo; + + stageInfo.stage = stage; + stageInfo.module = shaderModule; + stageInfo.pName = "main"; + stageInfo.pSpecializationInfo = specInfo; + + m_stageCount++; + } + + + DxvkShaderStageInfo::~DxvkShaderStageInfo() { + auto vk = m_device->vkd(); + + for (uint32_t i = 0; i < m_stageCount; i++) { + if (m_stageInfos[i].module) + vk->vkDestroyShaderModule(vk->device(), m_stageInfos[i].module, nullptr); + } + } + + DxvkShaderPipelineLibrary::DxvkShaderPipelineLibrary( const DxvkDevice* device, const DxvkShader* shader, diff --git a/src/dxvk/dxvk_shader.h b/src/dxvk/dxvk_shader.h index 0a0d29c5..1f559215 100644 --- a/src/dxvk/dxvk_shader.h +++ b/src/dxvk/dxvk_shader.h @@ -135,21 +135,6 @@ namespace dxvk { const DxvkBindingLayoutObjects* layout, const DxvkShaderModuleCreateInfo& state) const; - /** - * \brief Creates a shader module - * - * Remaps resource binding and descriptor set - * numbers to match the given binding layout. - * \param [in] vkd Vulkan device functions - * \param [in] layout Binding layout - * \param [in] info Module create info - * \returns The shader module - */ - DxvkShaderModule createShaderModule( - const Rc& vkd, - const DxvkBindingLayoutObjects* layout, - const DxvkShaderModuleCreateInfo& info); - /** * \brief Tests whether this shader supports pipeline libraries * @@ -252,49 +237,54 @@ namespace dxvk { * context will create pipeline objects on the * fly when executing draw calls. */ - class DxvkShaderModule { + class DxvkShaderStageInfo { public: - DxvkShaderModule(); + DxvkShaderStageInfo(const DxvkDevice* device); - DxvkShaderModule(DxvkShaderModule&& other); - - DxvkShaderModule( - const Rc& vkd, - const Rc& shader, - const SpirvCodeBuffer& code); - - ~DxvkShaderModule(); + DxvkShaderStageInfo (DxvkShaderStageInfo&& other) = delete; + DxvkShaderStageInfo& operator = (DxvkShaderStageInfo&& other) = delete; + + ~DxvkShaderStageInfo(); - DxvkShaderModule& operator = (DxvkShaderModule&& other); - /** - * \brief Shader stage creation info - * - * \param [in] specInfo Specialization info - * \returns Shader stage create info + * \brief Counts shader stages + * \returns Shader stage count */ - VkPipelineShaderStageCreateInfo stageInfo( - const VkSpecializationInfo* specInfo) const { - VkPipelineShaderStageCreateInfo stage = m_stage; - stage.pSpecializationInfo = specInfo; - return stage; + uint32_t getStageCount() const { + return m_stageCount; } - + /** - * \brief Checks whether module is valid - * \returns \c true if module is valid + * \brief Queries shader stage infos + * \returns Pointer to shader stage infos */ - operator bool () const { - return m_stage.module != VK_NULL_HANDLE; + const VkPipelineShaderStageCreateInfo* getStageInfos() const { + return m_stageInfos.data(); } - + + /** + * \brief Adds a shader stage with specialization info + * + * \param [in] stage Shader stage + * \param [in] code SPIR-V code + * \param [in] specinfo Specialization info + */ + void addStage( + VkShaderStageFlagBits stage, + SpirvCodeBuffer&& code, + const VkSpecializationInfo* specInfo); + private: - - Rc m_vkd; - VkPipelineShaderStageCreateInfo m_stage; - + + const DxvkDevice* m_device; + + std::array m_codeBuffers; + std::array m_moduleInfos = { }; + std::array m_stageInfos = { }; + uint32_t m_stageCount = 0; + };