diff --git a/src/d3d11/d3d11_context.cpp b/src/d3d11/d3d11_context.cpp index a7ed1d3f2..6e2d45406 100644 --- a/src/d3d11/d3d11_context.cpp +++ b/src/d3d11/d3d11_context.cpp @@ -1196,7 +1196,17 @@ namespace dxvk { ID3D11ComputeShader* pComputeShader, ID3D11ClassInstance* const* ppClassInstances, UINT NumClassInstances) { - Logger::err("D3D11DeviceContext::CSSetShader: Not implemented"); + auto shader = static_cast(pComputeShader); + + if (NumClassInstances != 0) + Logger::err("D3D11DeviceContext::CSSetShader: Class instances not supported"); + + if (m_state.cs.shader != shader) { + m_state.cs.shader = shader; + + m_context->bindShader(VK_SHADER_STAGE_COMPUTE_BIT, + shader != nullptr ? shader->GetShader() : nullptr); + } } diff --git a/src/d3d11/d3d11_device.cpp b/src/d3d11/d3d11_device.cpp index 9bbb7591c..eba73107e 100644 --- a/src/d3d11/d3d11_device.cpp +++ b/src/d3d11/d3d11_device.cpp @@ -677,8 +677,18 @@ namespace dxvk { SIZE_T BytecodeLength, ID3D11ClassLinkage* pClassLinkage, ID3D11ComputeShader** ppComputeShader) { - Logger::err("D3D11Device::CreateComputeShader: Not implemented"); - return E_NOTIMPL; + D3D11ShaderModule module; + + if (FAILED(this->CreateShaderModule(&module, + pShaderBytecode, BytecodeLength, pClassLinkage))) + return E_INVALIDARG; + + if (ppComputeShader != nullptr) { + *ppComputeShader = ref(new D3D11ComputeShader( + this, std::move(module))); + } + + return S_OK; } diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp index 80843127d..a88b5496e 100644 --- a/src/dxbc/dxbc_compiler.cpp +++ b/src/dxbc/dxbc_compiler.cpp @@ -37,6 +37,7 @@ namespace dxvk { case DxbcProgramType::VertexShader: this->emitVsInit(); break; case DxbcProgramType::GeometryShader: this->emitGsInit(); break; case DxbcProgramType::PixelShader: this->emitPsInit(); break; + case DxbcProgramType::ComputeShader: this->emitCsInit(); break; default: throw DxvkError("DxbcCompiler: Unsupported program type"); } } @@ -116,6 +117,7 @@ namespace dxvk { case DxbcProgramType::VertexShader: this->emitVsFinalize(); break; case DxbcProgramType::GeometryShader: this->emitGsFinalize(); break; case DxbcProgramType::PixelShader: this->emitPsFinalize(); break; + case DxbcProgramType::ComputeShader: this->emitCsFinalize(); break; default: throw DxvkError("DxbcCompiler: Unsupported program type"); } @@ -2447,7 +2449,7 @@ namespace dxvk { void DxbcCompiler::emitGsInitBuiltins(uint32_t vertexCount) { - + // TODO implement } @@ -2460,6 +2462,11 @@ namespace dxvk { } + void DxbcCompiler::emitCsInitBuiltins() { + // TODO implement + } + + void DxbcCompiler::emitVsInit() { m_module.enableCapability(spv::CapabilityShader); m_module.enableCapability(spv::CapabilityClipDistance); @@ -2572,6 +2579,27 @@ namespace dxvk { } + void DxbcCompiler::emitCsInit() { + m_module.enableCapability(spv::CapabilityShader); + + // There are no input or output + // variables for compute shaders + emitCsInitBuiltins(); + + // Main function of the compute shader + m_cs.functionId = m_module.allocateId(); + m_module.setDebugName(m_ps.functionId, "cs_main"); + + m_module.functionBegin( + m_module.defVoidType(), + m_cs.functionId, + m_module.defFunctionType( + m_module.defVoidType(), 0, nullptr), + spv::FunctionControlMaskNone); + m_module.opLabel(m_module.allocateId()); + } + + void DxbcCompiler::emitVsFinalize() { this->emitInputSetup(); m_module.opFunctionCall( @@ -2601,6 +2629,13 @@ namespace dxvk { } + void DxbcCompiler::emitCsFinalize() { + m_module.opFunctionCall( + m_module.defVoidType(), + m_cs.functionId, 0, nullptr); + } + + void DxbcCompiler::emitDclInputArray(uint32_t vertexCount) { DxbcArrayType info; info.ctype = DxbcScalarType::Float32; diff --git a/src/dxbc/dxbc_compiler.h b/src/dxbc/dxbc_compiler.h index 92e721c77..cb6322304 100644 --- a/src/dxbc/dxbc_compiler.h +++ b/src/dxbc/dxbc_compiler.h @@ -121,6 +121,14 @@ namespace dxvk { }; + /** + * \brief Compute shader-specific structure + */ + struct DxbcCompilerCsPart { + uint32_t functionId = 0; + }; + + enum class DxbcCfgBlockType : uint32_t { If, Loop, }; @@ -253,6 +261,7 @@ namespace dxvk { DxbcCompilerVsPart m_vs; DxbcCompilerGsPart m_gs; DxbcCompilerPsPart m_ps; + DxbcCompilerCsPart m_cs; ///////////////////////////////////////////////////// // Shader interface and metadata declaration methods @@ -500,18 +509,21 @@ namespace dxvk { void emitVsInitBuiltins(); void emitGsInitBuiltins(uint32_t vertexCount); void emitPsInitBuiltins(); + void emitCsInitBuiltins(); ///////////////////////////////// // Shader initialization methods void emitVsInit(); void emitGsInit(); void emitPsInit(); + void emitCsInit(); /////////////////////////////// // Shader finalization methods void emitVsFinalize(); void emitGsFinalize(); void emitPsFinalize(); + void emitCsFinalize(); ////////////// // Misc stuff