diff --git a/src/d3d8/d3d8_device.cpp b/src/d3d8/d3d8_device.cpp index a2b95894..b1de9a8e 100644 --- a/src/d3d8/d3d8_device.cpp +++ b/src/d3d8/d3d8_device.cpp @@ -63,11 +63,6 @@ namespace dxvk { D3D8Device::~D3D8Device() { if (m_batcher) delete m_batcher; - - // Delete any remaining state blocks. - for (D3D8StateBlock* block : m_stateBlocks) { - delete block; - } } HRESULT STDMETHODCALLTYPE D3D8Device::GetInfo(DWORD DevInfoID, void* pDevInfoStruct, DWORD DevInfoStructSize) { @@ -981,21 +976,37 @@ namespace dxvk { Com pStateBlock9; HRESULT res = GetD3D9()->CreateStateBlock(d3d9::D3DSTATEBLOCKTYPE(Type), &pStateBlock9); - D3D8StateBlock* pStateBlock = new D3D8StateBlock(this, Type, pStateBlock9.ref()); - m_stateBlocks.insert(pStateBlock); - - *pToken = DWORD(reinterpret_cast(pStateBlock)); + m_stateBlocks.emplace(std::piecewise_construct, + std::forward_as_tuple(m_token), + std::forward_as_tuple(this, Type, pStateBlock9.ref())); + *pToken = m_token; + m_token++; return res; } HRESULT STDMETHODCALLTYPE D3D8Device::CaptureStateBlock(DWORD Token) { - return reinterpret_cast(Token)->Capture(); + auto stateBlockIter = m_stateBlocks.find(Token); + + if (unlikely(stateBlockIter == m_stateBlocks.end())) { + Logger::err("Invalid token passed to CaptureStateBlock"); + return D3DERR_INVALIDCALL; + } + + return stateBlockIter->second.Capture(); } HRESULT STDMETHODCALLTYPE D3D8Device::ApplyStateBlock(DWORD Token) { StateChange(); - return reinterpret_cast(Token)->Apply(); + + auto stateBlockIter = m_stateBlocks.find(Token); + + if (unlikely(stateBlockIter == m_stateBlocks.end())) { + Logger::err("Invalid token passed to ApplyStateBlock"); + return D3DERR_INVALIDCALL; + } + + return stateBlockIter->second.Apply(); } HRESULT STDMETHODCALLTYPE D3D8Device::DeleteStateBlock(DWORD Token) { @@ -1003,9 +1014,14 @@ namespace dxvk { if (unlikely(ShouldRecord())) return D3DERR_INVALIDCALL; - D3D8StateBlock* block = reinterpret_cast(Token); - m_stateBlocks.erase(block); - delete block; + auto stateBlockIter = m_stateBlocks.find(Token); + + if (unlikely(stateBlockIter == m_stateBlocks.end())) { + Logger::err("Invalid token passed to DeleteStateBlock"); + return D3DERR_INVALIDCALL; + } + + m_stateBlocks.erase(stateBlockIter); return D3D_OK; } @@ -1014,8 +1030,12 @@ namespace dxvk { if (unlikely(m_recorder != nullptr)) return D3DERR_INVALIDCALL; - m_recorder = new D3D8StateBlock(this); - m_stateBlocks.insert(m_recorder); + auto stateBlockIterPair = m_stateBlocks.emplace(std::piecewise_construct, + std::forward_as_tuple(m_token), + std::forward_as_tuple(this)); + m_recorder = &stateBlockIterPair.first->second; + m_recorderToken = m_token; + m_token++; return GetD3D9()->BeginStateBlock(); } @@ -1029,9 +1049,10 @@ namespace dxvk { m_recorder->SetD3D9(std::move(pStateBlock)); - *pToken = DWORD(reinterpret_cast(m_recorder)); + *pToken = m_recorderToken; m_recorder = nullptr; + m_recorderToken = -1; return res; } diff --git a/src/d3d8/d3d8_device.h b/src/d3d8/d3d8_device.h index 14172363..2e11f5f7 100644 --- a/src/d3d8/d3d8_device.h +++ b/src/d3d8/d3d8_device.h @@ -14,7 +14,7 @@ #include #include #include -#include +#include namespace dxvk { @@ -416,9 +416,11 @@ namespace dxvk { D3DPRESENT_PARAMETERS m_presentParams; - D3D8StateBlock* m_recorder = nullptr; - std::unordered_set m_stateBlocks; - D3D8Batcher* m_batcher = nullptr; + D3D8StateBlock* m_recorder = nullptr; + DWORD m_recorderToken = -1; + DWORD m_token = 0; + std::unordered_map m_stateBlocks; + D3D8Batcher* m_batcher = nullptr; struct D3D8VBO { Com buffer = nullptr;