diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp index 91d46c9f..435703a0 100644 --- a/src/dxbc/dxbc_compiler.cpp +++ b/src/dxbc/dxbc_compiler.cpp @@ -1952,7 +1952,7 @@ namespace dxvk { const bool isImm = ins.dstCount == 2; const bool isUav = ins.dst[ins.dstCount - 1].type == DxbcOperandType::UnorderedAccessView; - // Retrieve destination pointer for the atomic operation + // Retrieve destination pointer for the atomic operation> const DxbcRegisterPointer pointer = emitGetAtomicPointer( ins.dst[ins.dstCount - 1], ins.src[0]); @@ -4226,6 +4226,14 @@ namespace dxvk { // of obtaining the final pointer are used. const bool isUav = operand.type == DxbcOperandType::UnorderedAccessView; + // If the resource is an UAV, we need to specify a format + // for the image type. Atomic ops are only allowed for + // 32-bit scalar integer formats. + if (isUav) { + m_module.setImageTypeFormat(resourceInfo.typeId, + getScalarImageFormat(resourceInfo.stype)); + } + // Compute the actual address into the resource const DxbcRegisterValue addressValue = [&] { switch (resourceInfo.type) { @@ -5837,6 +5845,7 @@ namespace dxvk { case DxbcOperandType::Resource: { DxbcBufferInfo result; result.image = m_textures.at(registerId).imageInfo; + result.stype = m_textures.at(registerId).sampledType; result.type = m_textures.at(registerId).type; result.typeId = m_textures.at(registerId).imageTypeId; result.varId = m_textures.at(registerId).varId; @@ -5848,6 +5857,7 @@ namespace dxvk { case DxbcOperandType::UnorderedAccessView: { DxbcBufferInfo result; result.image = m_uavs.at(registerId).imageInfo; + result.stype = m_uavs.at(registerId).sampledType; result.type = m_uavs.at(registerId).type; result.typeId = m_uavs.at(registerId).imageTypeId; result.varId = m_uavs.at(registerId).varId; @@ -5859,6 +5869,7 @@ namespace dxvk { case DxbcOperandType::ThreadGroupSharedMemory: { DxbcBufferInfo result; result.image = { spv::DimBuffer, 0, 0, 0 }; + result.stype = DxbcScalarType::Uint32; result.type = m_gRegs.at(registerId).type; result.typeId = m_module.defPointerType( getScalarTypeId(DxbcScalarType::Uint32), @@ -5934,6 +5945,16 @@ namespace dxvk { } + spv::ImageFormat DxbcCompiler::getScalarImageFormat(DxbcScalarType type) const { + switch (type) { + case DxbcScalarType::Float32: return spv::ImageFormatR32f; + case DxbcScalarType::Sint32: return spv::ImageFormatR32i; + case DxbcScalarType::Uint32: return spv::ImageFormatR32ui; + default: throw DxvkError("DxbcCompiler: Unhandled scalar resource type"); + } + } + + uint32_t DxbcCompiler::getScalarTypeId(DxbcScalarType type) { switch (type) { case DxbcScalarType::Uint32: return m_module.defIntType(32, 0); diff --git a/src/dxbc/dxbc_compiler.h b/src/dxbc/dxbc_compiler.h index a74c92f7..9fa7e192 100644 --- a/src/dxbc/dxbc_compiler.h +++ b/src/dxbc/dxbc_compiler.h @@ -288,6 +288,7 @@ namespace dxvk { struct DxbcBufferInfo { DxbcImageInfo image; + DxbcScalarType stype; DxbcResourceType type; uint32_t typeId; uint32_t varId; @@ -977,6 +978,9 @@ namespace dxvk { VkImageViewType getViewType( DxbcResourceDim dim) const; + spv::ImageFormat getScalarImageFormat( + DxbcScalarType type) const; + /////////////////////////// // Type definition methods uint32_t getScalarTypeId( diff --git a/src/spirv/spirv_module.cpp b/src/spirv/spirv_module.cpp index b0fd813e..ed1533ac 100644 --- a/src/spirv/spirv_module.cpp +++ b/src/spirv/spirv_module.cpp @@ -599,14 +599,18 @@ namespace dxvk { uint32_t multisample, uint32_t sampled, spv::ImageFormat format) { - std::array args = { - sampledType, dimensionality, - depth, arrayed, multisample, - sampled, format - }; + uint32_t resultId = this->allocateId(); - return this->defType(spv::OpTypeImage, - args.size(), args.data()); + m_typeConstDefs.putIns (spv::OpTypeImage, 9); + m_typeConstDefs.putWord(resultId); + m_typeConstDefs.putWord(sampledType); + m_typeConstDefs.putWord(dimensionality); + m_typeConstDefs.putWord(depth); + m_typeConstDefs.putWord(arrayed); + m_typeConstDefs.putWord(multisample); + m_typeConstDefs.putWord(sampled); + m_typeConstDefs.putWord(format); + return resultId; } @@ -616,6 +620,20 @@ namespace dxvk { } + void SpirvModule::setImageTypeFormat( + uint32_t imageType, + spv::ImageFormat format) { + for (auto ins : m_typeConstDefs) { + bool match = ins.arg(1) == imageType; + + if (match) { + ins.setArg(8, format); + return; + } + } + } + + uint32_t SpirvModule::newVar( uint32_t pointerType, spv::StorageClass storageClass) { diff --git a/src/spirv/spirv_module.h b/src/spirv/spirv_module.h index 041b6012..c1528ca5 100644 --- a/src/spirv/spirv_module.h +++ b/src/spirv/spirv_module.h @@ -254,6 +254,10 @@ namespace dxvk { uint32_t defSampledImageType( uint32_t imageType); + void setImageTypeFormat( + uint32_t imageType, + spv::ImageFormat format); + uint32_t newVar( uint32_t pointerType, spv::StorageClass storageClass);