# Backport of commit 6a6c527ee287 (available in 17.0). To avoid merge conflicts the whole function was copied from patch LLVM 15.0. # It enables better code geneartion for FP16 convert on SPR: int64/uint64 -> half (vcvtqq2ph / vcvtuqq2ph) for x8 target. diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 8bb7e81e19bb..049c68b7cc3d 100644 --- llvm/lib/Target/X86/X86ISelLowering.cpp.orig +++ llvm/lib/Target/X86/X86ISelLowering.cpp @@ -53682,22 +53682,56 @@ static SDValue combineFP_ROUND(SDNode *N, SelectionDAG &DAG, if (!Subtarget.hasF16C() || Subtarget.useSoftFloat()) return SDValue(); - if (Subtarget.hasFP16()) - return SDValue(); - + bool IsStrict = N->isStrictFPOpcode(); EVT VT = N->getValueType(0); - SDValue Src = N->getOperand(0); + SDValue Src = N->getOperand(IsStrict ? 1 : 0); EVT SrcVT = Src.getValueType(); if (!VT.isVector() || VT.getVectorElementType() != MVT::f16 || SrcVT.getVectorElementType() != MVT::f32) return SDValue(); + SDLoc dl(N); + + SDValue Cvt, Chain; unsigned NumElts = VT.getVectorNumElements(); - if (NumElts == 1 || !isPowerOf2_32(NumElts)) + if (Subtarget.hasFP16()) { + // Combine (v8f16 fp_round(concat_vectors(v4f32 (xint_to_fp v4i64), ..))) + // into (v8f16 vector_shuffle(v8f16 (CVTXI2P v4i64), ..)) + if (NumElts == 8 && Src.getOpcode() == ISD::CONCAT_VECTORS) { + SDValue Cvt0, Cvt1; + SDValue Op0 = Src.getOperand(0); + SDValue Op1 = Src.getOperand(1); + bool IsOp0Strict = Op0->isStrictFPOpcode(); + if (Op0.getOpcode() != Op1.getOpcode() || + Op0.getOperand(IsOp0Strict ? 1 : 0).getValueType() != MVT::v4i64 || + Op1.getOperand(IsOp0Strict ? 1 : 0).getValueType() != MVT::v4i64) { + return SDValue(); + } + int Mask[8] = {0, 1, 2, 3, 8, 9, 10, 11}; + if (IsStrict) { + assert(IsOp0Strict && "Op0 must be strict node"); + unsigned Opc = Op0.getOpcode() == ISD::STRICT_SINT_TO_FP + ? X86ISD::STRICT_CVTSI2P + : X86ISD::STRICT_CVTUI2P; + Cvt0 = DAG.getNode(Opc, dl, {MVT::v8f16, MVT::Other}, + {Op0.getOperand(0), Op0.getOperand(1)}); + Cvt1 = DAG.getNode(Opc, dl, {MVT::v8f16, MVT::Other}, + {Op1.getOperand(0), Op1.getOperand(1)}); + Cvt = DAG.getVectorShuffle(MVT::v8f16, dl, Cvt0, Cvt1, Mask); + return DAG.getMergeValues({Cvt, Cvt0.getValue(1)}, dl); + } + unsigned Opc = Op0.getOpcode() == ISD::SINT_TO_FP ? X86ISD::CVTSI2P + : X86ISD::CVTUI2P; + Cvt0 = DAG.getNode(Opc, dl, MVT::v8f16, Op0.getOperand(0)); + Cvt1 = DAG.getNode(Opc, dl, MVT::v8f16, Op1.getOperand(0)); + return Cvt = DAG.getVectorShuffle(MVT::v8f16, dl, Cvt0, Cvt1, Mask); + } return SDValue(); + } - SDLoc dl(N); + if (NumElts == 1 || !isPowerOf2_32(NumElts)) + return SDValue(); // Widen to at least 4 input elements. if (NumElts < 4) @@ -53705,10 +53739,16 @@ static SDValue combineFP_ROUND(SDNode *N, SelectionDAG &DAG, DAG.getConstantFP(0.0, dl, SrcVT)); // Destination is v8i16 with at least 8 elements. - EVT CvtVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, - std::max(8U, NumElts)); - SDValue Cvt = DAG.getNode(X86ISD::CVTPS2PH, dl, CvtVT, Src, - DAG.getTargetConstant(4, dl, MVT::i32)); + EVT CvtVT = + EVT::getVectorVT(*DAG.getContext(), MVT::i16, std::max(8U, NumElts)); + SDValue Rnd = DAG.getTargetConstant(4, dl, MVT::i32); + if (IsStrict) { + Cvt = DAG.getNode(X86ISD::STRICT_CVTPS2PH, dl, {CvtVT, MVT::Other}, + {N->getOperand(0), Src, Rnd}); + Chain = Cvt.getValue(1); + } else { + Cvt = DAG.getNode(X86ISD::CVTPS2PH, dl, CvtVT, Src, Rnd); + } // Extract down to real number of elements. if (NumElts < 8) { @@ -53717,7 +53757,12 @@ static SDValue combineFP_ROUND(SDNode *N, SelectionDAG &DAG, DAG.getIntPtrConstant(0, dl)); } - return DAG.getBitcast(VT, Cvt); + Cvt = DAG.getBitcast(VT, Cvt); + + if (IsStrict) + return DAG.getMergeValues({Cvt, Chain}, dl); + + return Cvt; } static SDValue combineMOVDQ2Q(SDNode *N, SelectionDAG &DAG) {