diff --git a/src/d3d9/d3d9_device.cpp b/src/d3d9/d3d9_device.cpp index 8f21204eb..ff91bdc22 100644 --- a/src/d3d9/d3d9_device.cpp +++ b/src/d3d9/d3d9_device.cpp @@ -5117,8 +5117,8 @@ namespace dxvk { auto& rs = m_state.renderStates; if constexpr (Item == D3D9RenderStateItem::AlphaRef) { - float alpha = float(rs[D3DRS_ALPHAREF] & 0xFF) / 255.0f; - UpdatePushConstant(&alpha); + uint32_t alpha = rs[D3DRS_ALPHAREF]; + UpdatePushConstant(&alpha); } else if constexpr (Item == D3D9RenderStateItem::FogColor) { Vector4 color; diff --git a/src/d3d9/d3d9_fixed_function.cpp b/src/d3d9/d3d9_fixed_function.cpp index 6f60e5ad2..bd9edc575 100644 --- a/src/d3d9/d3d9_fixed_function.cpp +++ b/src/d3d9/d3d9_fixed_function.cpp @@ -215,6 +215,71 @@ namespace dxvk { spvModule.opBranchConditional(isNotAlways, atestBeginLabel, atestSkipLabel); spvModule.opLabel(atestBeginLabel); + // The lower 8 bits of the alpha ref contain the actual reference value + // from the API, the upper bits store the accuracy bit count minus 8. + // So if we want 12 bits of accuracy (i.e. 0-4095), that value will be 4. + uint32_t uintType = spvModule.defIntType(32, 0); + + // Check if the given bit precision is supported + uint32_t precisionIntLabel = spvModule.allocateId(); + uint32_t precisionFloatLabel = spvModule.allocateId(); + uint32_t precisionEndLabel = spvModule.allocateId(); + + uint32_t useIntPrecision = spvModule.opULessThanEqual(boolType, + ctx.alphaPrecisionId, spvModule.constu32(8)); + + spvModule.opSelectionMerge(precisionEndLabel, spv::SelectionControlMaskNone); + spvModule.opBranchConditional(useIntPrecision, precisionIntLabel, precisionFloatLabel); + spvModule.opLabel(precisionIntLabel); + + // Adjust alpha ref to the given range + uint32_t alphaRefIdInt = spvModule.opBitwiseOr(uintType, + spvModule.opShiftLeftLogical(uintType, ctx.alphaRefId, ctx.alphaPrecisionId), + spvModule.opShiftRightLogical(uintType, ctx.alphaRefId, + spvModule.opISub(uintType, spvModule.constu32(8), ctx.alphaPrecisionId))); + + // Convert alpha ref to float since we'll do the comparison based on that + uint32_t floatType = spvModule.defFloatType(32); + alphaRefIdInt = spvModule.opConvertUtoF(floatType, alphaRefIdInt); + + // Adjust alpha to the given range and round + uint32_t alphaFactorId = spvModule.opISub(uintType, + spvModule.opShiftLeftLogical(uintType, spvModule.constu32(256), ctx.alphaPrecisionId), + spvModule.constu32(1)); + alphaFactorId = spvModule.opConvertUtoF(floatType, alphaFactorId); + + uint32_t alphaIdInt = spvModule.opRoundEven(floatType, + spvModule.opFMul(floatType, ctx.alphaId, alphaFactorId)); + + spvModule.opBranch(precisionEndLabel); + spvModule.opLabel(precisionFloatLabel); + + // If we're not using integer precision, normalize the alpha ref + uint32_t alphaRefIdFloat = spvModule.opFDiv(floatType, + spvModule.opConvertUtoF(floatType, ctx.alphaRefId), + spvModule.constf32(255.0f)); + + spvModule.opBranch(precisionEndLabel); + spvModule.opLabel(precisionEndLabel); + + std::array alphaRefLabels = { + SpirvPhiLabel { alphaRefIdInt, precisionIntLabel }, + SpirvPhiLabel { alphaRefIdFloat, precisionFloatLabel }, + }; + + uint32_t alphaRefId = spvModule.opPhi(floatType, + alphaRefLabels.size(), + alphaRefLabels.data()); + + std::array alphaIdLabels = { + SpirvPhiLabel { alphaIdInt, precisionIntLabel }, + SpirvPhiLabel { ctx.alphaId, precisionFloatLabel }, + }; + + uint32_t alphaId = spvModule.opPhi(floatType, + alphaIdLabels.size(), + alphaIdLabels.data()); + // switch (alpha_func) { ... } spvModule.opSelectionMerge(atestTestLabel, spv::SelectionControlMaskNone); spvModule.opSwitch(ctx.alphaFuncId, @@ -231,12 +296,12 @@ namespace dxvk { atestVariables[i].varId = [&] { switch (VkCompareOp(atestCaseLabels[i].literal)) { case VK_COMPARE_OP_NEVER: return spvModule.constBool(false); - case VK_COMPARE_OP_LESS: return spvModule.opFOrdLessThan (boolType, ctx.alphaId, ctx.alphaRefId); - case VK_COMPARE_OP_EQUAL: return spvModule.opFOrdEqual (boolType, ctx.alphaId, ctx.alphaRefId); - case VK_COMPARE_OP_LESS_OR_EQUAL: return spvModule.opFOrdLessThanEqual (boolType, ctx.alphaId, ctx.alphaRefId); - case VK_COMPARE_OP_GREATER: return spvModule.opFOrdGreaterThan (boolType, ctx.alphaId, ctx.alphaRefId); - case VK_COMPARE_OP_NOT_EQUAL: return spvModule.opFOrdNotEqual (boolType, ctx.alphaId, ctx.alphaRefId); - case VK_COMPARE_OP_GREATER_OR_EQUAL: return spvModule.opFOrdGreaterThanEqual(boolType, ctx.alphaId, ctx.alphaRefId); + case VK_COMPARE_OP_LESS: return spvModule.opFOrdLessThan (boolType, alphaId, alphaRefId); + case VK_COMPARE_OP_EQUAL: return spvModule.opFOrdEqual (boolType, alphaId, alphaRefId); + case VK_COMPARE_OP_LESS_OR_EQUAL: return spvModule.opFOrdLessThanEqual (boolType, alphaId, alphaRefId); + case VK_COMPARE_OP_GREATER: return spvModule.opFOrdGreaterThan (boolType, alphaId, alphaRefId); + case VK_COMPARE_OP_NOT_EQUAL: return spvModule.opFOrdNotEqual (boolType, alphaId, alphaRefId); + case VK_COMPARE_OP_GREATER_OR_EQUAL: return spvModule.opFOrdGreaterThanEqual(boolType, alphaId, alphaRefId); default: case VK_COMPARE_OP_ALWAYS: return spvModule.constBool(true); } @@ -271,6 +336,7 @@ namespace dxvk { uint32_t SetupRenderStateBlock(SpirvModule& spvModule, uint32_t count) { uint32_t floatType = spvModule.defFloatType(32); + uint32_t uintType = spvModule.defIntType(32, 0); uint32_t vec3Type = spvModule.defVectorType(floatType, 3); std::array rsMembers = {{ @@ -278,7 +344,8 @@ namespace dxvk { floatType, floatType, floatType, - floatType, + + uintType, floatType, floatType, @@ -2300,7 +2367,7 @@ namespace dxvk { void D3D9FFShaderCompiler::alphaTestPS() { - uint32_t floatPtr = m_module.defPointerType(m_floatType, spv::StorageClassPushConstant); + uint32_t uintPtr = m_module.defPointerType(m_uint32Type, spv::StorageClassPushConstant); auto oC0 = m_ps.out.COLOR; @@ -2309,8 +2376,9 @@ namespace dxvk { D3D9AlphaTestContext alphaTestContext; alphaTestContext.alphaFuncId = m_spec.get(m_module, m_specUbo, SpecAlphaCompareOp); + alphaTestContext.alphaPrecisionId = m_spec.get(m_module, m_specUbo, SpecAlphaPrecisionBits); alphaTestContext.alphaRefId = m_module.opLoad(m_floatType, - m_module.opAccessChain(floatPtr, m_rsBlock, 1, &alphaRefMember)); + m_module.opAccessChain(uintPtr, m_rsBlock, 1, &alphaRefMember)); alphaTestContext.alphaId = m_module.opCompositeExtract(m_floatType, m_module.opLoad(m_vec4Type, oC0), 1, &alphaComponentId); diff --git a/src/d3d9/d3d9_fixed_function.h b/src/d3d9/d3d9_fixed_function.h index 492c265fd..4fe8d82c0 100644 --- a/src/d3d9/d3d9_fixed_function.h +++ b/src/d3d9/d3d9_fixed_function.h @@ -40,6 +40,7 @@ namespace dxvk { struct D3D9AlphaTestContext { uint32_t alphaId; + uint32_t alphaPrecisionId; uint32_t alphaFuncId; uint32_t alphaRefId; }; diff --git a/src/d3d9/d3d9_spec_constants.h b/src/d3d9/d3d9_spec_constants.h index ba45cccb9..78a850301 100644 --- a/src/d3d9/d3d9_spec_constants.h +++ b/src/d3d9/d3d9_spec_constants.h @@ -22,6 +22,7 @@ namespace dxvk { SpecSamplerNull, // 1 bit for 20 samplers | Bits: 20 SpecProjectionType, // 1 bit for 6 PS 1.x samplers | Bits: 6 + SpecAlphaPrecisionBits, // Range: 0 -> 8 or 0xF | Bits: 4 SpecVertexShaderBools, // 16 bools | Bits: 16 SpecPixelShaderBools, // 16 bools | Bits: 16 @@ -56,6 +57,7 @@ namespace dxvk { { 2, 0, 20 }, // SamplerNull { 2, 20, 6 }, // ProjectionType + { 2, 26, 4 }, // AlphaPrecisionBits { 3, 0, 16 }, // VertexShaderBools { 3, 16, 16 }, // PixelShaderBools diff --git a/src/dxso/dxso_compiler.cpp b/src/dxso/dxso_compiler.cpp index 05028daa3..15436fb92 100644 --- a/src/dxso/dxso_compiler.cpp +++ b/src/dxso/dxso_compiler.cpp @@ -3694,9 +3694,9 @@ void DxsoCompiler::emitControlFlowGenericLoop( void DxsoCompiler::emitPsProcessing() { - uint32_t boolType = m_module.defBoolType(); uint32_t floatType = m_module.defFloatType(32); - uint32_t floatPtr = m_module.defPointerType(floatType, spv::StorageClassPushConstant); + uint32_t uintType = m_module.defIntType(32, 0); + uint32_t uintPtr = m_module.defPointerType(uintType, spv::StorageClassPushConstant); // Implement alpha test and fog DxsoRegister color0; @@ -3712,8 +3712,9 @@ void DxsoCompiler::emitControlFlowGenericLoop( D3D9AlphaTestContext alphaTestContext; alphaTestContext.alphaFuncId = m_spec.get(m_module, m_specUbo, SpecAlphaCompareOp); - alphaTestContext.alphaRefId = m_module.opLoad(floatType, - m_module.opAccessChain(floatPtr, m_rsBlock, 1, &alphaRefMember)); + alphaTestContext.alphaPrecisionId = m_spec.get(m_module, m_specUbo, SpecAlphaPrecisionBits); + alphaTestContext.alphaRefId = m_module.opLoad(uintType, + m_module.opAccessChain(uintPtr, m_rsBlock, 1, &alphaRefMember)); alphaTestContext.alphaId = m_module.opCompositeExtract(floatType, m_module.opLoad(m_module.defVectorType(floatType, 4), oC0.id), 1, &alphaComponentId);