1
0
mirror of https://github.com/doitsujin/dxvk.git synced 2025-01-18 20:52:10 +01:00

[dxso] Handle extraneous writemasks in matrix ops

This commit is contained in:
Joshua Ashton 2020-11-24 15:58:21 +00:00
parent f74071ac0a
commit c282ec7976

View File

@ -2150,60 +2150,69 @@ namespace dxvk {
void DxsoCompiler::emitMatrixAlu(const DxsoInstructionContext& ctx) { void DxsoCompiler::emitMatrixAlu(const DxsoInstructionContext& ctx) {
const auto& src = ctx.src;
DxsoRegMask mask = ctx.dst.mask;
DxsoRegisterPointer dst = emitGetOperandPtr(ctx.dst);
DxsoRegisterValue result;
result.type.ctype = dst.type.ctype;
result.type.ccount = mask.popCount();
DxsoVectorType scalarType = result.type;
scalarType.ccount = 1;
const uint32_t typeId = getVectorTypeId(result.type);
const uint32_t scalarTypeId = getVectorTypeId(scalarType);
const DxsoOpcode opcode = ctx.instruction.opcode; const DxsoOpcode opcode = ctx.instruction.opcode;
uint32_t dotCount; uint32_t dotCount;
uint32_t iterCount; uint32_t componentCount;
switch (opcode) { switch (opcode) {
case DxsoOpcode::M3x2: case DxsoOpcode::M3x2:
dotCount = 3; dotCount = 3;
iterCount = 2; componentCount = 2;
break; break;
case DxsoOpcode::M3x3: case DxsoOpcode::M3x3:
dotCount = 3; dotCount = 3;
iterCount = 3; componentCount = 3;
break; break;
case DxsoOpcode::M3x4: case DxsoOpcode::M3x4:
dotCount = 3; dotCount = 3;
iterCount = 4; componentCount = 4;
break; break;
case DxsoOpcode::M4x3: case DxsoOpcode::M4x3:
dotCount = 4; dotCount = 4;
iterCount = 3; componentCount = 3;
break; break;
case DxsoOpcode::M4x4: case DxsoOpcode::M4x4:
dotCount = 4; dotCount = 4;
iterCount = 4; componentCount = 4;
break; break;
default: default:
Logger::warn(str::format("DxsoCompiler::emitMatrixAlu: unimplemented op ", opcode)); Logger::warn(str::format("DxsoCompiler::emitMatrixAlu: unimplemented op ", opcode));
return; return;
} }
DxsoRegisterPointer dst = emitGetOperandPtr(ctx.dst);
// Fix the dst mask if componentCount != maskCount
// ie. M4x3 on .xyzw.
uint32_t maskCnt = 0;
uint8_t mask = 0;
for (uint32_t i = 0; i < 4 && maskCnt < componentCount; i++) {
if (ctx.dst.mask[i]) {
mask |= 1 << i;
maskCnt++;
}
}
DxsoRegMask dstMask = DxsoRegMask(mask);
DxsoRegisterValue result;
result.type.ctype = dst.type.ctype;
result.type.ccount = componentCount;
DxsoVectorType scalarType;
scalarType.ctype = result.type.ctype;
scalarType.ccount = 1;
const uint32_t typeId = getVectorTypeId(result.type);
const uint32_t scalarTypeId = getVectorTypeId(scalarType);
DxsoRegMask srcMask(true, true, true, dotCount == 4); DxsoRegMask srcMask(true, true, true, dotCount == 4);
std::array<uint32_t, 4> indices; std::array<uint32_t, 4> indices;
DxsoRegister src0 = src[0]; DxsoRegister src0 = ctx.src[0];
DxsoRegister src1 = src[1]; DxsoRegister src1 = ctx.src[1];
for (uint32_t i = 0; i < iterCount; i++) { for (uint32_t i = 0; i < componentCount; i++) {
indices[i] = m_module.opDot(scalarTypeId, indices[i] = m_module.opDot(scalarTypeId,
emitRegisterLoad(src0, srcMask).id, emitRegisterLoad(src0, srcMask).id,
emitRegisterLoad(src1, srcMask).id); emitRegisterLoad(src1, srcMask).id);
@ -2212,9 +2221,9 @@ namespace dxvk {
} }
result.id = m_module.opCompositeConstruct( result.id = m_module.opCompositeConstruct(
typeId, iterCount, indices.data()); typeId, componentCount, indices.data());
this->emitDstStore(dst, result, mask, ctx.dst.saturate, emitPredicateLoad(ctx), ctx.dst.shift, ctx.dst.id); this->emitDstStore(dst, result, dstMask, ctx.dst.saturate, emitPredicateLoad(ctx), ctx.dst.shift, ctx.dst.id);
} }