diff --git a/src/dxso/dxso_compiler.cpp b/src/dxso/dxso_compiler.cpp index 54f0cc443..7a28cb3c4 100644 --- a/src/dxso/dxso_compiler.cpp +++ b/src/dxso/dxso_compiler.cpp @@ -2150,60 +2150,69 @@ namespace dxvk { void DxsoCompiler::emitMatrixAlu(const DxsoInstructionContext& ctx) { - const auto& src = ctx.src; - - DxsoRegMask mask = ctx.dst.mask; - - DxsoRegisterPointer dst = emitGetOperandPtr(ctx.dst); - - DxsoRegisterValue result; - result.type.ctype = dst.type.ctype; - result.type.ccount = mask.popCount(); - - DxsoVectorType scalarType = result.type; - scalarType.ccount = 1; - - const uint32_t typeId = getVectorTypeId(result.type); - const uint32_t scalarTypeId = getVectorTypeId(scalarType); - const DxsoOpcode opcode = ctx.instruction.opcode; uint32_t dotCount; - uint32_t iterCount; + uint32_t componentCount; switch (opcode) { case DxsoOpcode::M3x2: - dotCount = 3; - iterCount = 2; + dotCount = 3; + componentCount = 2; break; case DxsoOpcode::M3x3: - dotCount = 3; - iterCount = 3; + dotCount = 3; + componentCount = 3; break; case DxsoOpcode::M3x4: - dotCount = 3; - iterCount = 4; + dotCount = 3; + componentCount = 4; break; case DxsoOpcode::M4x3: - dotCount = 4; - iterCount = 3; + dotCount = 4; + componentCount = 3; break; case DxsoOpcode::M4x4: - dotCount = 4; - iterCount = 4; + dotCount = 4; + componentCount = 4; break; default: Logger::warn(str::format("DxsoCompiler::emitMatrixAlu: unimplemented op ", opcode)); return; } + DxsoRegisterPointer dst = emitGetOperandPtr(ctx.dst); + + // Fix the dst mask if componentCount != maskCount + // ie. M4x3 on .xyzw. + uint32_t maskCnt = 0; + uint8_t mask = 0; + for (uint32_t i = 0; i < 4 && maskCnt < componentCount; i++) { + if (ctx.dst.mask[i]) { + mask |= 1 << i; + maskCnt++; + } + } + DxsoRegMask dstMask = DxsoRegMask(mask); + + DxsoRegisterValue result; + result.type.ctype = dst.type.ctype; + result.type.ccount = componentCount; + + DxsoVectorType scalarType; + scalarType.ctype = result.type.ctype; + scalarType.ccount = 1; + + const uint32_t typeId = getVectorTypeId(result.type); + const uint32_t scalarTypeId = getVectorTypeId(scalarType); + DxsoRegMask srcMask(true, true, true, dotCount == 4); std::array indices; - DxsoRegister src0 = src[0]; - DxsoRegister src1 = src[1]; + DxsoRegister src0 = ctx.src[0]; + DxsoRegister src1 = ctx.src[1]; - for (uint32_t i = 0; i < iterCount; i++) { + for (uint32_t i = 0; i < componentCount; i++) { indices[i] = m_module.opDot(scalarTypeId, emitRegisterLoad(src0, srcMask).id, emitRegisterLoad(src1, srcMask).id); @@ -2212,9 +2221,9 @@ namespace dxvk { } result.id = m_module.opCompositeConstruct( - typeId, iterCount, indices.data()); + typeId, componentCount, indices.data()); - this->emitDstStore(dst, result, mask, ctx.dst.saturate, emitPredicateLoad(ctx), ctx.dst.shift, ctx.dst.id); + this->emitDstStore(dst, result, dstMask, ctx.dst.saturate, emitPredicateLoad(ctx), ctx.dst.shift, ctx.dst.id); }