diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp index b37d8032..330f139d 100644 --- a/src/dxbc/dxbc_compiler.cpp +++ b/src/dxbc/dxbc_compiler.cpp @@ -545,11 +545,16 @@ namespace dxvk { DxbcComponentMask mask) { const DxbcPointer ptr = this->getOperandPtr(operand); + // The value to store is actually allowed to be scalar, + // so we might need to create a vector from it. + if (value.type.componentCount == 1) + value = m_gen->regVector(value, mask.componentCount()); + // Cast source value to destination register type. // TODO verify that this actually works as intended. DxbcValueType dstType; dstType.componentType = ptr.type.valueType.componentType; - dstType.componentCount = mask.componentCount(); + dstType.componentCount = value.type.componentCount; value = m_gen->regCast(value, dstType); m_gen->regStore(ptr, value, mask); diff --git a/src/dxbc/gen/dxbc_gen_common.cpp b/src/dxbc/gen/dxbc_gen_common.cpp index b07d041d..b60e08d9 100644 --- a/src/dxbc/gen/dxbc_gen_common.cpp +++ b/src/dxbc/gen/dxbc_gen_common.cpp @@ -533,6 +533,24 @@ namespace dxvk { } + DxbcValue DxbcCodeGen::regVector( + const DxbcValue& src, + uint32_t size) { + if (size == 1) + return src; + + std::array ids = { + src.valueId, src.valueId, src.valueId, src.valueId, + }; + + DxbcValue result; + result.type = DxbcValueType(src.type.componentType, size); + result.valueId = m_module.opCompositeConstruct( + this->defValueType(result.type), size, ids.data()); + return result; + } + + DxbcValue DxbcCodeGen::regLoad(const DxbcPointer& ptr) { DxbcValue result; result.type = ptr.type.valueType; diff --git a/src/dxbc/gen/dxbc_gen_common.h b/src/dxbc/gen/dxbc_gen_common.h index bedccaa4..26377a81 100644 --- a/src/dxbc/gen/dxbc_gen_common.h +++ b/src/dxbc/gen/dxbc_gen_common.h @@ -145,6 +145,10 @@ namespace dxvk { const DxbcValue& src, DxbcComponentMask mask); + DxbcValue regVector( + const DxbcValue& src, + uint32_t size); + DxbcValue regLoad( const DxbcPointer& ptr); diff --git a/src/spirv/spirv_module.cpp b/src/spirv/spirv_module.cpp index b3137f3f..a03f98ab 100644 --- a/src/spirv/spirv_module.cpp +++ b/src/spirv/spirv_module.cpp @@ -578,6 +578,22 @@ namespace dxvk { } + uint32_t SpirvModule::opCompositeConstruct( + uint32_t resultType, + uint32_t valueCount, + const uint32_t* valueArray) { + uint32_t resultId = this->allocateId(); + + m_code.putIns (spv::OpCompositeConstruct, 3 + valueCount); + m_code.putWord(resultType); + m_code.putWord(resultId); + + for (uint32_t i = 0; i < valueCount; i++) + m_code.putWord(valueArray[i]); + return resultId; + } + + uint32_t SpirvModule::opCompositeExtract( uint32_t resultType, uint32_t composite, diff --git a/src/spirv/spirv_module.h b/src/spirv/spirv_module.h index 76072c93..02bb0750 100644 --- a/src/spirv/spirv_module.h +++ b/src/spirv/spirv_module.h @@ -214,6 +214,11 @@ namespace dxvk { uint32_t resultType, uint32_t operand); + uint32_t opCompositeConstruct( + uint32_t resultType, + uint32_t valueCount, + const uint32_t* valueArray); + uint32_t opCompositeExtract( uint32_t resultType, uint32_t composite,