mirror of
https://github.com/doitsujin/dxvk.git
synced 2025-01-18 20:52:10 +01:00
[dxso] Handle extraneous writemasks in matrix ops
This commit is contained in:
parent
f74071ac0a
commit
c282ec7976
@ -2150,60 +2150,69 @@ namespace dxvk {
|
||||
|
||||
|
||||
void DxsoCompiler::emitMatrixAlu(const DxsoInstructionContext& ctx) {
|
||||
const auto& src = ctx.src;
|
||||
|
||||
DxsoRegMask mask = ctx.dst.mask;
|
||||
|
||||
DxsoRegisterPointer dst = emitGetOperandPtr(ctx.dst);
|
||||
|
||||
DxsoRegisterValue result;
|
||||
result.type.ctype = dst.type.ctype;
|
||||
result.type.ccount = mask.popCount();
|
||||
|
||||
DxsoVectorType scalarType = result.type;
|
||||
scalarType.ccount = 1;
|
||||
|
||||
const uint32_t typeId = getVectorTypeId(result.type);
|
||||
const uint32_t scalarTypeId = getVectorTypeId(scalarType);
|
||||
|
||||
const DxsoOpcode opcode = ctx.instruction.opcode;
|
||||
|
||||
uint32_t dotCount;
|
||||
uint32_t iterCount;
|
||||
uint32_t componentCount;
|
||||
|
||||
switch (opcode) {
|
||||
case DxsoOpcode::M3x2:
|
||||
dotCount = 3;
|
||||
iterCount = 2;
|
||||
componentCount = 2;
|
||||
break;
|
||||
case DxsoOpcode::M3x3:
|
||||
dotCount = 3;
|
||||
iterCount = 3;
|
||||
componentCount = 3;
|
||||
break;
|
||||
case DxsoOpcode::M3x4:
|
||||
dotCount = 3;
|
||||
iterCount = 4;
|
||||
componentCount = 4;
|
||||
break;
|
||||
case DxsoOpcode::M4x3:
|
||||
dotCount = 4;
|
||||
iterCount = 3;
|
||||
componentCount = 3;
|
||||
break;
|
||||
case DxsoOpcode::M4x4:
|
||||
dotCount = 4;
|
||||
iterCount = 4;
|
||||
componentCount = 4;
|
||||
break;
|
||||
default:
|
||||
Logger::warn(str::format("DxsoCompiler::emitMatrixAlu: unimplemented op ", opcode));
|
||||
return;
|
||||
}
|
||||
|
||||
DxsoRegisterPointer dst = emitGetOperandPtr(ctx.dst);
|
||||
|
||||
// Fix the dst mask if componentCount != maskCount
|
||||
// ie. M4x3 on .xyzw.
|
||||
uint32_t maskCnt = 0;
|
||||
uint8_t mask = 0;
|
||||
for (uint32_t i = 0; i < 4 && maskCnt < componentCount; i++) {
|
||||
if (ctx.dst.mask[i]) {
|
||||
mask |= 1 << i;
|
||||
maskCnt++;
|
||||
}
|
||||
}
|
||||
DxsoRegMask dstMask = DxsoRegMask(mask);
|
||||
|
||||
DxsoRegisterValue result;
|
||||
result.type.ctype = dst.type.ctype;
|
||||
result.type.ccount = componentCount;
|
||||
|
||||
DxsoVectorType scalarType;
|
||||
scalarType.ctype = result.type.ctype;
|
||||
scalarType.ccount = 1;
|
||||
|
||||
const uint32_t typeId = getVectorTypeId(result.type);
|
||||
const uint32_t scalarTypeId = getVectorTypeId(scalarType);
|
||||
|
||||
DxsoRegMask srcMask(true, true, true, dotCount == 4);
|
||||
std::array<uint32_t, 4> indices;
|
||||
|
||||
DxsoRegister src0 = src[0];
|
||||
DxsoRegister src1 = src[1];
|
||||
DxsoRegister src0 = ctx.src[0];
|
||||
DxsoRegister src1 = ctx.src[1];
|
||||
|
||||
for (uint32_t i = 0; i < iterCount; i++) {
|
||||
for (uint32_t i = 0; i < componentCount; i++) {
|
||||
indices[i] = m_module.opDot(scalarTypeId,
|
||||
emitRegisterLoad(src0, srcMask).id,
|
||||
emitRegisterLoad(src1, srcMask).id);
|
||||
@ -2212,9 +2221,9 @@ namespace dxvk {
|
||||
}
|
||||
|
||||
result.id = m_module.opCompositeConstruct(
|
||||
typeId, iterCount, indices.data());
|
||||
typeId, componentCount, indices.data());
|
||||
|
||||
this->emitDstStore(dst, result, mask, ctx.dst.saturate, emitPredicateLoad(ctx), ctx.dst.shift, ctx.dst.id);
|
||||
this->emitDstStore(dst, result, dstMask, ctx.dst.saturate, emitPredicateLoad(ctx), ctx.dst.shift, ctx.dst.id);
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user