diff --git a/src/d3d11/d3d11_context.cpp b/src/d3d11/d3d11_context.cpp index 455aaf0cb..71183b403 100644 --- a/src/d3d11/d3d11_context.cpp +++ b/src/d3d11/d3d11_context.cpp @@ -254,7 +254,7 @@ namespace dxvk { m_state.pr.predicateValue = FALSE; // Make sure to apply all state - RestoreState(); + ResetState(); } @@ -3775,6 +3775,93 @@ namespace dxvk { } + void D3D11DeviceContext::ResetState() { + EmitCs([ + cBlendState = m_defaultBlendState.prvRef(), + cDepthStencilState = m_defaultDepthStencilState.prvRef(), + cRasterizerState = m_defaultRasterizerState.prvRef() + ] (DxvkContext* ctx) { + // Reset render targets + ctx->bindRenderTargets(DxvkRenderTargets()); + + // Reset vertex input state + ctx->setInputLayout(0, nullptr, 0, nullptr); + + DxvkInputAssemblyState iaState; + InitDefaultPrimitiveTopology(&iaState); + + ctx->setInputAssemblyState(iaState); + + // Reset render states + cBlendState->BindToContext(ctx, D3D11_DEFAULT_SAMPLE_MASK); + cDepthStencilState->BindToContext(ctx); + cRasterizerState->BindToContext(ctx); + + // Reset dynamic states + ctx->setBlendConstants(DxvkBlendConstants { 1.0f, 1.0f, 1.0f, 1.0f }); + ctx->setStencilReference(D3D11_DEFAULT_STENCIL_REFERENCE); + + // Reset viewports + auto viewport = VkViewport(); + auto scissor = VkRect2D(); + + ctx->setViewports(1, &viewport, &scissor); + + // Reset predication + ctx->setPredicate(DxvkBufferSlice(), 0); + + // Unbind indirect draw buffer + ctx->bindDrawBuffers(DxvkBufferSlice(), DxvkBufferSlice()); + + // Unbind index and vertex buffers + ctx->bindIndexBuffer(DxvkBufferSlice(), VK_INDEX_TYPE_UINT32); + + for (uint32_t i = 0; i < D3D11_IA_VERTEX_INPUT_RESOURCE_SLOT_COUNT; i++) + ctx->bindVertexBuffer(i, DxvkBufferSlice(), 0); + + // Unbind transform feedback buffers + for (uint32_t i = 0; i < D3D11_SO_BUFFER_SLOT_COUNT; i++) + ctx->bindXfbBuffer(i, DxvkBufferSlice(), DxvkBufferSlice()); + + // Unbind per-shader stage resources + for (uint32_t i = 0; i < 6; i++) { + auto programType = DxbcProgramType(i); + ctx->bindShader(GetShaderStage(programType), nullptr); + + // Unbind constant buffers, including the shader's ICB + auto cbSlotId = computeConstantBufferBinding(programType, 0); + + for (uint32_t j = 0; j <= D3D11_COMMONSHADER_CONSTANT_BUFFER_API_SLOT_COUNT; j++) + ctx->bindResourceBuffer(cbSlotId + j, DxvkBufferSlice()); + + // Unbind shader resource views + auto srvSlotId = computeSrvBinding(programType, 0); + + for (uint32_t j = 0; j < D3D11_COMMONSHADER_INPUT_RESOURCE_SLOT_COUNT; j++) + ctx->bindResourceView(srvSlotId + j, nullptr, nullptr); + + // Unbind texture samplers + auto samplerSlotId = computeSamplerBinding(programType, 0); + + for (uint32_t j = 0; j < D3D11_COMMONSHADER_SAMPLER_SLOT_COUNT; j++) + ctx->bindResourceSampler(samplerSlotId + j, nullptr); + + // Unbind UAVs for supported stages + if (programType == DxbcProgramType::PixelShader + || programType == DxbcProgramType::ComputeShader) { + auto uavSlotId = computeUavBinding(programType, 0); + auto ctrSlotId = computeUavCounterBinding(programType, 0); + + for (uint32_t j = 0; j < D3D11_1_UAV_SLOT_COUNT; j++) { + ctx->bindResourceView (uavSlotId, nullptr, nullptr); + ctx->bindResourceBuffer (ctrSlotId, DxvkBufferSlice()); + } + } + } + }); + } + + void D3D11DeviceContext::RestoreState() { BindFramebuffer(); @@ -4172,4 +4259,12 @@ namespace dxvk { return m_parent->AllocCsChunk(m_csFlags); } + + void D3D11DeviceContext::InitDefaultPrimitiveTopology( + DxvkInputAssemblyState* pIaState) { + pIaState->primitiveTopology = VK_PRIMITIVE_TOPOLOGY_MAX_ENUM; + pIaState->primitiveRestart = VK_FALSE; + pIaState->patchVertexCount = 0; + } + } diff --git a/src/d3d11/d3d11_context.h b/src/d3d11/d3d11_context.h index b513c70d0..56e7e84c9 100644 --- a/src/d3d11/d3d11_context.h +++ b/src/d3d11/d3d11_context.h @@ -838,6 +838,8 @@ namespace dxvk { UINT* pFirstConstant, UINT* pNumConstants); + void ResetState(); + void RestoreState(); template @@ -902,6 +904,9 @@ namespace dxvk { DxvkCsChunkRef AllocCsChunk(); + static void InitDefaultPrimitiveTopology( + DxvkInputAssemblyState* pIaState); + template const D3D11CommonShader* GetCommonShader(T* pShader) const { return pShader != nullptr ? pShader->GetCommonShader() : nullptr;