1
0
mirror of https://github.com/doitsujin/dxvk.git synced 2025-01-18 20:52:10 +01:00

[dxso] Handle extraneous writemasks in matrix ops

This commit is contained in:
Joshua Ashton 2020-11-24 15:58:21 +00:00
parent f74071ac0a
commit c282ec7976

View File

@ -2150,60 +2150,69 @@ namespace dxvk {
void DxsoCompiler::emitMatrixAlu(const DxsoInstructionContext& ctx) {
const auto& src = ctx.src;
DxsoRegMask mask = ctx.dst.mask;
DxsoRegisterPointer dst = emitGetOperandPtr(ctx.dst);
DxsoRegisterValue result;
result.type.ctype = dst.type.ctype;
result.type.ccount = mask.popCount();
DxsoVectorType scalarType = result.type;
scalarType.ccount = 1;
const uint32_t typeId = getVectorTypeId(result.type);
const uint32_t scalarTypeId = getVectorTypeId(scalarType);
const DxsoOpcode opcode = ctx.instruction.opcode;
uint32_t dotCount;
uint32_t iterCount;
uint32_t componentCount;
switch (opcode) {
case DxsoOpcode::M3x2:
dotCount = 3;
iterCount = 2;
dotCount = 3;
componentCount = 2;
break;
case DxsoOpcode::M3x3:
dotCount = 3;
iterCount = 3;
dotCount = 3;
componentCount = 3;
break;
case DxsoOpcode::M3x4:
dotCount = 3;
iterCount = 4;
dotCount = 3;
componentCount = 4;
break;
case DxsoOpcode::M4x3:
dotCount = 4;
iterCount = 3;
dotCount = 4;
componentCount = 3;
break;
case DxsoOpcode::M4x4:
dotCount = 4;
iterCount = 4;
dotCount = 4;
componentCount = 4;
break;
default:
Logger::warn(str::format("DxsoCompiler::emitMatrixAlu: unimplemented op ", opcode));
return;
}
DxsoRegisterPointer dst = emitGetOperandPtr(ctx.dst);
// Fix the dst mask if componentCount != maskCount
// ie. M4x3 on .xyzw.
uint32_t maskCnt = 0;
uint8_t mask = 0;
for (uint32_t i = 0; i < 4 && maskCnt < componentCount; i++) {
if (ctx.dst.mask[i]) {
mask |= 1 << i;
maskCnt++;
}
}
DxsoRegMask dstMask = DxsoRegMask(mask);
DxsoRegisterValue result;
result.type.ctype = dst.type.ctype;
result.type.ccount = componentCount;
DxsoVectorType scalarType;
scalarType.ctype = result.type.ctype;
scalarType.ccount = 1;
const uint32_t typeId = getVectorTypeId(result.type);
const uint32_t scalarTypeId = getVectorTypeId(scalarType);
DxsoRegMask srcMask(true, true, true, dotCount == 4);
std::array<uint32_t, 4> indices;
DxsoRegister src0 = src[0];
DxsoRegister src1 = src[1];
DxsoRegister src0 = ctx.src[0];
DxsoRegister src1 = ctx.src[1];
for (uint32_t i = 0; i < iterCount; i++) {
for (uint32_t i = 0; i < componentCount; i++) {
indices[i] = m_module.opDot(scalarTypeId,
emitRegisterLoad(src0, srcMask).id,
emitRegisterLoad(src1, srcMask).id);
@ -2212,9 +2221,9 @@ namespace dxvk {
}
result.id = m_module.opCompositeConstruct(
typeId, iterCount, indices.data());
typeId, componentCount, indices.data());
this->emitDstStore(dst, result, mask, ctx.dst.saturate, emitPredicateLoad(ctx), ctx.dst.shift, ctx.dst.id);
this->emitDstStore(dst, result, dstMask, ctx.dst.saturate, emitPredicateLoad(ctx), ctx.dst.shift, ctx.dst.id);
}