diff --git a/llvm/include/llvm/Analysis/InstSimplifyFolder.h b/llvm/include/llvm/Analysis/InstSimplifyFolder.h index 430c3edc2f0dc..d4ae4dcc918cf 100644 --- a/llvm/include/llvm/Analysis/InstSimplifyFolder.h +++ b/llvm/include/llvm/Analysis/InstSimplifyFolder.h @@ -22,6 +22,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetFolder.h" +#include "llvm/IR/CmpPredicate.h" #include "llvm/IR/IRBuilderFolder.h" #include "llvm/IR/Instruction.h" diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h index cf7d3e044188a..fa291eeef198b 100644 --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -44,6 +44,7 @@ class DataLayout; class DominatorTree; class Function; class Instruction; +class CmpPredicate; class LoadInst; struct LoopStandardAnalysisResults; class Pass; @@ -152,11 +153,11 @@ Value *simplifyOrInst(Value *LHS, Value *RHS, const SimplifyQuery &Q); Value *simplifyXorInst(Value *LHS, Value *RHS, const SimplifyQuery &Q); /// Given operands for an ICmpInst, fold the result or return null. -Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, +Value *simplifyICmpInst(CmpPredicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q); /// Given operands for an FCmpInst, fold the result or return null. -Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +Value *simplifyFCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); /// Given operands for a SelectInst, fold the result or return null. @@ -200,7 +201,7 @@ Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef Mask, //=== Helper functions for higher up the class hierarchy. /// Given operands for a CmpInst, fold the result or return null. -Value *simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +Value *simplifyCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q); /// Given operand for a UnaryOperator, fold the result or return null. diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h index c408e0a39cd18..8aa024a72afc8 100644 --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -1255,8 +1255,7 @@ std::optional isImpliedCondition(const Value *LHS, const Value *RHS, const DataLayout &DL, bool LHSIsTrue = true, unsigned Depth = 0); -std::optional isImpliedCondition(const Value *LHS, - CmpInst::Predicate RHSPred, +std::optional isImpliedCondition(const Value *LHS, CmpPredicate RHSPred, const Value *RHSOp0, const Value *RHSOp1, const DataLayout &DL, bool LHSIsTrue = true, @@ -1267,8 +1266,8 @@ std::optional isImpliedCondition(const Value *LHS, std::optional isImpliedByDomCondition(const Value *Cond, const Instruction *ContextI, const DataLayout &DL); -std::optional isImpliedByDomCondition(CmpInst::Predicate Pred, - const Value *LHS, const Value *RHS, +std::optional isImpliedByDomCondition(CmpPredicate Pred, const Value *LHS, + const Value *RHS, const Instruction *ContextI, const DataLayout &DL); diff --git a/llvm/include/llvm/IR/CmpPredicate.h b/llvm/include/llvm/IR/CmpPredicate.h new file mode 100644 index 0000000000000..4b1be7beb2b66 --- /dev/null +++ b/llvm/include/llvm/IR/CmpPredicate.h @@ -0,0 +1,62 @@ +//===- CmpPredicate.h - CmpInst Predicate with samesign information -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A CmpInst::Predicate with any samesign information (applicable to ICmpInst). +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_CMPPREDICATE_H +#define LLVM_IR_CMPPREDICATE_H + +#include "llvm/IR/InstrTypes.h" + +namespace llvm { +/// An abstraction over a floating-point predicate, and a pack of an integer +/// predicate with samesign information. Some functions in ICmpInst construct +/// and return this type in place of a Predicate. +class CmpPredicate { + CmpInst::Predicate Pred; + bool HasSameSign; + +public: + /// Constructed implictly with a either Predicate and samesign information, or + /// just a Predicate, dropping samesign information. + CmpPredicate(CmpInst::Predicate Pred, bool HasSameSign = false) + : Pred(Pred), HasSameSign(HasSameSign) { + assert(!HasSameSign || CmpInst::isIntPredicate(Pred)); + } + + /// Implictly converts to the underlying Predicate, dropping samesign + /// information. + operator CmpInst::Predicate() const { return Pred; } + + /// Query samesign information, for optimizations. + bool hasSameSign() const { return HasSameSign; } + + /// Compares two CmpPredicates taking samesign into account and returns the + /// canonicalized CmpPredicate if they match. An alternative to operator==. + /// + /// For example, + /// samesign ult + samesign ult -> samesign ult + /// samesign ult + ult -> ult + /// samesign ult + slt -> slt + /// ult + ult -> ult + /// ult + slt -> std::nullopt + static std::optional getMatching(CmpPredicate A, + CmpPredicate B); + + /// An operator== on the underlying Predicate. + bool operator==(CmpInst::Predicate P) const { return Pred == P; } + + /// There is no operator== defined on CmpPredicate. Use getMatching instead to + /// get the canonicalized matching CmpPredicate. + bool operator==(CmpPredicate) const = delete; +}; +} // namespace llvm + +#endif diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h index 605964af5d676..a42bf6bca1b9f 100644 --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -24,6 +24,7 @@ #include "llvm/ADT/iterator.h" #include "llvm/ADT/iterator_range.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/CmpPredicate.h" #include "llvm/IR/Constant.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GEPNoWrapFlags.h" @@ -1203,6 +1204,33 @@ class ICmpInst: public CmpInst { #endif } + /// @returns the predicate along with samesign information. + CmpPredicate getCmpPredicate() const { + return {getPredicate(), hasSameSign()}; + } + + /// @returns the inverse predicate along with samesign information: static + /// variant. + static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred) { + return {getInversePredicate(Pred), Pred.hasSameSign()}; + } + + /// @returns the inverse predicate along with samesign information. + CmpPredicate getInverseCmpPredicate() const { + return getInverseCmpPredicate(getCmpPredicate()); + } + + /// @returns the swapped predicate along with samesign information: static + /// variant. + static CmpPredicate getSwappedCmpPredicate(CmpPredicate Pred) { + return {getSwappedPredicate(Pred), Pred.hasSameSign()}; + } + + /// @returns the swapped predicate. + Predicate getSwappedCmpPredicate() const { + return getSwappedPredicate(getCmpPredicate()); + } + /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc. /// @returns the predicate that would be the result if the operand were /// regarded as signed. @@ -1212,7 +1240,7 @@ class ICmpInst: public CmpInst { } /// Return the signed version of the predicate: static variant. - static Predicate getSignedPredicate(Predicate pred); + static Predicate getSignedPredicate(Predicate Pred); /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc. /// @returns the predicate that would be the result if the operand were @@ -1223,14 +1251,15 @@ class ICmpInst: public CmpInst { } /// Return the unsigned version of the predicate: static variant. - static Predicate getUnsignedPredicate(Predicate pred); + static Predicate getUnsignedPredicate(Predicate Pred); - /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert + /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ /// @returns the unsigned version of the signed predicate pred or /// the signed version of the signed predicate pred. - static Predicate getFlippedSignednessPredicate(Predicate pred); + /// Static variant. + static Predicate getFlippedSignednessPredicate(Predicate Pred); - /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert + /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ /// @returns the unsigned version of the signed predicate pred or /// the signed version of the signed predicate pred. Predicate getFlippedSignednessPredicate() const { diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h index 3075b7ebae59e..71592058e3456 100644 --- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h +++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h @@ -157,7 +157,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner { /// conditional branch or select to create a compare with a canonical /// (inverted) predicate which is then more likely to be matched with other /// values. - static bool isCanonicalPredicate(CmpInst::Predicate Pred) { + static bool isCanonicalPredicate(CmpPredicate Pred) { switch (Pred) { case CmpInst::ICMP_NE: case CmpInst::ICMP_ULE: @@ -185,10 +185,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner { } std::optional> static getFlippedStrictnessPredicateAndConstant(CmpInst:: - Predicate - Pred, + CmpPredicate, + Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpPredicate + Pred, Constant *C); static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) { diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 1a5bbbc7dfceb..05e8f5761c13c 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -63,9 +63,9 @@ static Value *simplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); static Value *simplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &, const SimplifyQuery &, unsigned); -static Value *simplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &, - unsigned); -static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, +static Value *simplifyCmpInst(CmpPredicate, Value *, Value *, + const SimplifyQuery &, unsigned); +static Value *simplifyICmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse); static Value *simplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned); static Value *simplifyXorInst(Value *, Value *, const SimplifyQuery &, @@ -132,8 +132,7 @@ static Constant *getFalse(Type *Ty) { return ConstantInt::getFalse(Ty); } static Constant *getTrue(Type *Ty) { return ConstantInt::getTrue(Ty); } /// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"? -static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS, - Value *RHS) { +static bool isSameCompare(Value *V, CmpPredicate Pred, Value *LHS, Value *RHS) { CmpInst *Cmp = dyn_cast(V); if (!Cmp) return false; @@ -150,10 +149,9 @@ static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS, /// %cmp = icmp sle i32 %sel, %rhs /// Compose new comparison by substituting %sel with either %tv or %fv /// and see if it simplifies. -static Value *simplifyCmpSelCase(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, Value *Cond, - const SimplifyQuery &Q, unsigned MaxRecurse, - Constant *TrueOrFalse) { +static Value *simplifyCmpSelCase(CmpPredicate Pred, Value *LHS, Value *RHS, + Value *Cond, const SimplifyQuery &Q, + unsigned MaxRecurse, Constant *TrueOrFalse) { Value *SimplifiedCmp = simplifyCmpInst(Pred, LHS, RHS, Q, MaxRecurse); if (SimplifiedCmp == Cond) { // %cmp simplified to the select condition (%cond). @@ -167,18 +165,16 @@ static Value *simplifyCmpSelCase(CmpInst::Predicate Pred, Value *LHS, } /// Simplify comparison with true branch of select -static Value *simplifyCmpSelTrueCase(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, Value *Cond, - const SimplifyQuery &Q, +static Value *simplifyCmpSelTrueCase(CmpPredicate Pred, Value *LHS, Value *RHS, + Value *Cond, const SimplifyQuery &Q, unsigned MaxRecurse) { return simplifyCmpSelCase(Pred, LHS, RHS, Cond, Q, MaxRecurse, getTrue(Cond->getType())); } /// Simplify comparison with false branch of select -static Value *simplifyCmpSelFalseCase(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, Value *Cond, - const SimplifyQuery &Q, +static Value *simplifyCmpSelFalseCase(CmpPredicate Pred, Value *LHS, Value *RHS, + Value *Cond, const SimplifyQuery &Q, unsigned MaxRecurse) { return simplifyCmpSelCase(Pred, LHS, RHS, Cond, Q, MaxRecurse, getFalse(Cond->getType())); @@ -471,9 +467,8 @@ static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, /// We can simplify %cmp1 to true, because both branches of select are /// less than 3. We compose new comparison by substituting %tmp with both /// branches of select and see if it can be simplified. -static Value *threadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const SimplifyQuery &Q, - unsigned MaxRecurse) { +static Value *threadCmpOverSelect(CmpPredicate Pred, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -564,7 +559,7 @@ static Value *threadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, /// comparison by seeing whether comparing with all of the incoming phi values /// yields the same result every time. If so returns the common result, /// otherwise returns null. -static Value *threadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, +static Value *threadCmpOverPHI(CmpPredicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) @@ -1001,7 +996,7 @@ Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, /// Given a predicate and two operands, return true if the comparison is true. /// This is a helper for div/rem simplification where we return some other value /// when we can prove a relationship between the operands. -static bool isICmpTrue(ICmpInst::Predicate Pred, Value *LHS, Value *RHS, +static bool isICmpTrue(CmpPredicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { Value *V = simplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse); Constant *C = dyn_cast_or_null(V); @@ -2601,7 +2596,7 @@ static Type *getCompareTy(Value *Op) { /// Rummage around inside V looking for something equivalent to the comparison /// "LHS Pred RHS". Return such a value if found, otherwise return null. /// Helper function for analyzing max/min idioms. -static Value *extractEquivalentCondition(Value *V, CmpInst::Predicate Pred, +static Value *extractEquivalentCondition(Value *V, CmpPredicate Pred, Value *LHS, Value *RHS) { SelectInst *SI = dyn_cast(V); if (!SI) @@ -2710,8 +2705,8 @@ static bool haveNonOverlappingStorage(const Value *V1, const Value *V2) { // If the C and C++ standards are ever made sufficiently restrictive in this // area, it may be possible to update LLVM's semantics accordingly and reinstate // this optimization. -static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const SimplifyQuery &Q) { +static Constant *computePointerICmp(CmpPredicate Pred, Value *LHS, Value *RHS, + const SimplifyQuery &Q) { assert(LHS->getType() == RHS->getType() && "Must have same types"); const DataLayout &DL = Q.DL; const TargetLibraryInfo *TLI = Q.TLI; @@ -2859,8 +2854,8 @@ static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS, } /// Fold an icmp when its operands have i1 scalar type. -static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const SimplifyQuery &Q) { +static Value *simplifyICmpOfBools(CmpPredicate Pred, Value *LHS, Value *RHS, + const SimplifyQuery &Q) { Type *ITy = getCompareTy(LHS); // The return type. Type *OpTy = LHS->getType(); // The operand type. if (!OpTy->isIntOrIntVectorTy(1)) @@ -2962,8 +2957,8 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, } /// Try hard to fold icmp with zero RHS because this is a common case. -static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const SimplifyQuery &Q) { +static Value *simplifyICmpWithZero(CmpPredicate Pred, Value *LHS, Value *RHS, + const SimplifyQuery &Q) { if (!match(RHS, m_Zero())) return nullptr; @@ -3022,7 +3017,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, return nullptr; } -static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, +static Value *simplifyICmpWithConstant(CmpPredicate Pred, Value *LHS, Value *RHS, const InstrInfoQuery &IIQ) { Type *ITy = getCompareTy(RHS); // The return type. @@ -3115,8 +3110,8 @@ static void getUnsignedMonotonicValues(SmallPtrSetImpl &Res, Value *V, } } -static Value *simplifyICmpUsingMonotonicValues(ICmpInst::Predicate Pred, - Value *LHS, Value *RHS) { +static Value *simplifyICmpUsingMonotonicValues(CmpPredicate Pred, Value *LHS, + Value *RHS) { if (Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_ULT) return nullptr; @@ -3133,9 +3128,8 @@ static Value *simplifyICmpUsingMonotonicValues(ICmpInst::Predicate Pred, return nullptr; } -static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, - BinaryOperator *LBO, Value *RHS, - const SimplifyQuery &Q, +static Value *simplifyICmpWithBinOpOnLHS(CmpPredicate Pred, BinaryOperator *LBO, + Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { Type *ITy = getCompareTy(RHS); // The return type. @@ -3254,8 +3248,8 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, // *) C1 < C2 && C1 >= 0, or // *) C2 < C1 && C1 <= 0. // -static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const InstrInfoQuery &IIQ) { +static bool trySimplifyICmpWithAdds(CmpPredicate Pred, Value *LHS, Value *RHS, + const InstrInfoQuery &IIQ) { // TODO: only support icmp slt for now. if (Pred != CmpInst::ICMP_SLT || !IIQ.UseInstrInfo) return false; @@ -3279,8 +3273,8 @@ static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS, /// TODO: A large part of this logic is duplicated in InstCombine's /// foldICmpBinOp(). We should be able to share that and avoid the code /// duplication. -static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const SimplifyQuery &Q, +static Value *simplifyICmpWithBinOp(CmpPredicate Pred, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { BinaryOperator *LBO = dyn_cast(LHS); BinaryOperator *RBO = dyn_cast(RHS); @@ -3513,8 +3507,8 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, /// simplify integer comparisons where at least one operand of the compare /// matches an integer min/max idiom. -static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const SimplifyQuery &Q, +static Value *simplifyICmpWithMinMax(CmpPredicate Pred, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { Type *ITy = getCompareTy(LHS); // The return type. Value *A, *B; @@ -3698,7 +3692,7 @@ static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, return nullptr; } -static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate, +static Value *simplifyICmpWithDominatingAssume(CmpPredicate Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q) { // Gracefully handle instructions that have not been inserted yet. @@ -3721,8 +3715,8 @@ static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate, return nullptr; } -static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred, - Value *LHS, Value *RHS) { +static Value *simplifyICmpWithIntrinsicOnLHS(CmpPredicate Pred, Value *LHS, + Value *RHS) { auto *II = dyn_cast(LHS); if (!II) return nullptr; @@ -3770,9 +3764,8 @@ static std::optional getRange(Value *V, /// Given operands for an ICmpInst, see if we can fold the result. /// If not, this returns null. -static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, +static Value *simplifyICmpInst(CmpPredicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { - CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!"); if (Constant *CLHS = dyn_cast(LHS)) { @@ -4085,17 +4078,16 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return nullptr; } -Value *llvm::simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, +Value *llvm::simplifyICmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q) { return ::simplifyICmpInst(Predicate, LHS, RHS, Q, RecursionLimit); } /// Given operands for an FCmpInst, see if we can fold the result. /// If not, this returns null. -static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +static Value *simplifyFCmpInst(CmpPredicate Pred, Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse) { - CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; assert(CmpInst::isFPPredicate(Pred) && "Not an FP compare!"); if (Constant *CLHS = dyn_cast(LHS)) { @@ -4320,7 +4312,7 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, return nullptr; } -Value *llvm::simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +Value *llvm::simplifyFCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q) { return ::simplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); } @@ -4557,7 +4549,7 @@ static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X, } static Value *simplifyCmpSelOfMaxMin(Value *CmpLHS, Value *CmpRHS, - ICmpInst::Predicate Pred, Value *TVal, + CmpPredicate Pred, Value *TVal, Value *FVal) { // Canonicalize common cmp+sel operand as CmpLHS. if (CmpRHS == TVal || CmpRHS == FVal) { @@ -4631,8 +4623,8 @@ static Value *simplifyCmpSelOfMaxMin(Value *CmpLHS, Value *CmpRHS, /// An alternative way to test if a bit is set or not uses sgt/slt instead of /// eq/ne. static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS, - ICmpInst::Predicate Pred, - Value *TrueVal, Value *FalseVal) { + CmpPredicate Pred, Value *TrueVal, + Value *FalseVal) { if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) return simplifySelectBitTest(TrueVal, FalseVal, Res->X, &Res->Mask, Res->Pred == ICmpInst::ICMP_EQ); @@ -6142,14 +6134,14 @@ Value *llvm::simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, } /// Given operands for a CmpInst, see if we can fold the result. -static Value *simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +static Value *simplifyCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (CmpInst::isIntPredicate((CmpInst::Predicate)Predicate)) + if (CmpInst::isIntPredicate(Predicate)) return simplifyICmpInst(Predicate, LHS, RHS, Q, MaxRecurse); return simplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse); } -Value *llvm::simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +Value *llvm::simplifyCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q) { return ::simplifyCmpInst(Predicate, LHS, RHS, Q, RecursionLimit); } @@ -7187,7 +7179,7 @@ static Value *simplifyInstructionWithOperands(Instruction *I, case Instruction::Xor: return simplifyXorInst(NewOps[0], NewOps[1], Q, MaxRecurse); case Instruction::ICmp: - return simplifyICmpInst(cast(I)->getPredicate(), NewOps[0], + return simplifyICmpInst(cast(I)->getCmpPredicate(), NewOps[0], NewOps[1], Q, MaxRecurse); case Instruction::FCmp: return simplifyFCmpInst(cast(I)->getPredicate(), NewOps[0], diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index d81546d0c9fed..f2c6949e535d2 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -9379,7 +9379,7 @@ static std::optional isImpliedCondICmps(const ICmpInst *LHS, (LPred == ICmpInst::ICMP_ULT || LPred == ICmpInst::ICMP_UGE) && (RPred == ICmpInst::ICMP_ULT || RPred == ICmpInst::ICMP_UGE) && match(L0, m_c_Add(m_Specific(L1), m_Specific(R1)))) - return LPred == RPred; + return CmpPredicate::getMatching(LPred, RPred).has_value(); if (LPred == RPred) return isImpliedCondOperands(LPred, L0, L1, R0, R1); @@ -9392,7 +9392,7 @@ static std::optional isImpliedCondICmps(const ICmpInst *LHS, /// expect the RHS to be an icmp and the LHS to be an 'and', 'or', or a 'select' /// instruction. static std::optional -isImpliedCondAndOr(const Instruction *LHS, CmpInst::Predicate RHSPred, +isImpliedCondAndOr(const Instruction *LHS, CmpPredicate RHSPred, const Value *RHSOp0, const Value *RHSOp1, const DataLayout &DL, bool LHSIsTrue, unsigned Depth) { // The LHS must be an 'or', 'and', or a 'select' instruction. @@ -9422,7 +9422,7 @@ isImpliedCondAndOr(const Instruction *LHS, CmpInst::Predicate RHSPred, } std::optional -llvm::isImpliedCondition(const Value *LHS, CmpInst::Predicate RHSPred, +llvm::isImpliedCondition(const Value *LHS, CmpPredicate RHSPred, const Value *RHSOp0, const Value *RHSOp1, const DataLayout &DL, bool LHSIsTrue, unsigned Depth) { // Bail out when we hit the limit. @@ -9476,7 +9476,7 @@ std::optional llvm::isImpliedCondition(const Value *LHS, const Value *RHS, if (const ICmpInst *RHSCmp = dyn_cast(RHS)) { if (auto Implied = isImpliedCondition( - LHS, RHSCmp->getPredicate(), RHSCmp->getOperand(0), + LHS, RHSCmp->getCmpPredicate(), RHSCmp->getOperand(0), RHSCmp->getOperand(1), DL, LHSIsTrue, Depth)) return InvertRHS ? !*Implied : *Implied; return std::nullopt; @@ -9553,7 +9553,7 @@ std::optional llvm::isImpliedByDomCondition(const Value *Cond, return std::nullopt; } -std::optional llvm::isImpliedByDomCondition(CmpInst::Predicate Pred, +std::optional llvm::isImpliedByDomCondition(CmpPredicate Pred, const Value *LHS, const Value *RHS, const Instruction *ContextI, diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 065ce3a017283..4f07a4c4dd017 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -3842,9 +3842,8 @@ std::optional ICmpInst::compare(const KnownBits &LHS, } CmpInst::Predicate ICmpInst::getFlippedSignednessPredicate(Predicate pred) { - assert(CmpInst::isRelational(pred) && - "Call only with non-equality predicates!"); - + if (CmpInst::isEquality(pred)) + return pred; if (isSigned(pred)) return getUnsignedPredicate(pred); if (isUnsigned(pred)) @@ -3916,6 +3915,23 @@ bool CmpInst::isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2) { return isImpliedTrueByMatchingCmp(Pred1, getInversePredicate(Pred2)); } +//===----------------------------------------------------------------------===// +// CmpPredicate Implementation +//===----------------------------------------------------------------------===// + +std::optional CmpPredicate::getMatching(CmpPredicate A, + CmpPredicate B) { + if (A.Pred == B.Pred) + return A.HasSameSign == B.HasSameSign ? A : CmpPredicate(A.Pred); + if (A.HasSameSign && + A.Pred == ICmpInst::getFlippedSignednessPredicate(B.Pred)) + return B.Pred; + if (B.HasSameSign && + B.Pred == ICmpInst::getFlippedSignednessPredicate(A.Pred)) + return A.Pred; + return {}; +} + //===----------------------------------------------------------------------===// // SwitchInst Implementation //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 5871973776683..783c34e21b484 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -631,7 +631,7 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base, GEPNoWrapFlags NW, /// We can look through PHIs, GEPs and casts in order to determine a common base /// between GEPLHS and RHS. static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, - ICmpInst::Predicate Cond, + CmpPredicate Cond, const DataLayout &DL, InstCombiner &IC) { // FIXME: Support vector of pointers. @@ -675,8 +675,7 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, /// Fold comparisons between a GEP instruction and something else. At this point /// we know that the GEP is on the LHS of the comparison. Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, - ICmpInst::Predicate Cond, - Instruction &I) { + CmpPredicate Cond, Instruction &I) { // Don't transform signed compares of GEPs into index compares. Even if the // GEP is inbounds, the final add of the base pointer can have signed overflow // and would change the result of the icmp. @@ -912,7 +911,7 @@ bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) { /// Fold "icmp pred (X+C), X". Instruction *InstCombinerImpl::foldICmpAddOpConst(Value *X, const APInt &C, - ICmpInst::Predicate Pred) { + CmpPredicate Pred) { // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" // operators. @@ -3960,8 +3959,8 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, } static Instruction * -foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred, - SaturatingInst *II, const APInt &C, +foldICmpUSubSatOrUAddSatWithConstant(CmpPredicate Pred, SaturatingInst *II, + const APInt &C, InstCombiner::BuilderTy &Builder) { // This transform may end up producing more than one instruction for the // intrinsic, so limit it to one user of the intrinsic. @@ -4045,7 +4044,7 @@ foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred, } static Instruction * -foldICmpOfCmpIntrinsicWithConstant(ICmpInst::Predicate Pred, IntrinsicInst *I, +foldICmpOfCmpIntrinsicWithConstant(CmpPredicate Pred, IntrinsicInst *I, const APInt &C, InstCombiner::BuilderTy &Builder) { std::optional NewPredicate = std::nullopt; @@ -4244,9 +4243,8 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) { return nullptr; } -Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred, - SelectInst *SI, Value *RHS, - const ICmpInst &I) { +Instruction *InstCombinerImpl::foldSelectICmp(CmpPredicate Pred, SelectInst *SI, + Value *RHS, const ICmpInst &I) { // Try to fold the comparison into the select arms, which will cause the // select to be converted into a logical and/or. auto SimplifyOp = [&](Value *Op, bool SelectCondIsTrue) -> Value * { @@ -4415,7 +4413,7 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q, /// The Mask can be a constant, too. /// For some predicates, the operands are commutative. /// For others, x can only be on a specific side. -static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0, +static Value *foldICmpWithLowBitMaskedVal(CmpPredicate Pred, Value *Op0, Value *Op1, const SimplifyQuery &Q, InstCombiner &IC) { @@ -5526,8 +5524,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, /// Fold icmp Pred min|max(X, Y), Z. Instruction *InstCombinerImpl::foldICmpWithMinMax(Instruction &I, MinMaxIntrinsic *MinMax, - Value *Z, - ICmpInst::Predicate Pred) { + Value *Z, CmpPredicate Pred) { Value *X = MinMax->getLHS(); Value *Y = MinMax->getRHS(); if (ICmpInst::isSigned(Pred) && !MinMax->isSigned()) @@ -6880,8 +6877,8 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) { return nullptr; } -std::optional> -InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, +std::optional> +InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) { assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && "Only for relational integer predicates."); @@ -7287,7 +7284,7 @@ static Instruction *foldReductionIdiom(ICmpInst &I, } // This helper will be called with icmp operands in both orders. -Instruction *InstCombinerImpl::foldICmpCommutative(ICmpInst::Predicate Pred, +Instruction *InstCombinerImpl::foldICmpCommutative(CmpPredicate Pred, Value *Op0, Value *Op1, ICmpInst &CxtI) { // Try to optimize 'icmp GEP, P' or 'icmp P, GEP'. @@ -7415,7 +7412,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { Changed = true; } - if (Value *V = simplifyICmpInst(I.getPredicate(), Op0, Op1, Q)) + if (Value *V = simplifyICmpInst(I.getCmpPredicate(), Op0, Op1, Q)) return replaceInstUsesWith(I, V); // Comparing -val or val with non-zero is the same as just comparing val @@ -7522,10 +7519,10 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) return Res; - if (Instruction *Res = foldICmpCommutative(I.getPredicate(), Op0, Op1, I)) + if (Instruction *Res = foldICmpCommutative(I.getCmpPredicate(), Op0, Op1, I)) return Res; if (Instruction *Res = - foldICmpCommutative(I.getSwappedPredicate(), Op1, Op0, I)) + foldICmpCommutative(I.getSwappedCmpPredicate(), Op1, Op0, I)) return Res; if (I.isCommutative()) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 0508ed48fc19c..28474fec8238e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -652,10 +652,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final /// folded operation. void PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN); - Instruction *foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, - ICmpInst::Predicate Cond, Instruction &I); - Instruction *foldSelectICmp(ICmpInst::Predicate Pred, SelectInst *SI, - Value *RHS, const ICmpInst &I); + Instruction *foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, CmpPredicate Cond, + Instruction &I); + Instruction *foldSelectICmp(CmpPredicate Pred, SelectInst *SI, Value *RHS, + const ICmpInst &I); bool foldAllocaCmp(AllocaInst *Alloca); Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI, GetElementPtrInst *GEP, @@ -663,8 +663,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final ConstantInt *AndCst = nullptr); Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Constant *RHSC); - Instruction *foldICmpAddOpConst(Value *X, const APInt &C, - ICmpInst::Predicate Pred); + Instruction *foldICmpAddOpConst(Value *X, const APInt &C, CmpPredicate Pred); Instruction *foldICmpWithCastOp(ICmpInst &ICmp); Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp); @@ -678,7 +677,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final const APInt &C); Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); Instruction *foldICmpWithMinMax(Instruction &I, MinMaxIntrinsic *MinMax, - Value *Z, ICmpInst::Predicate Pred); + Value *Z, CmpPredicate Pred); Instruction *foldICmpEquality(ICmpInst &Cmp); Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); Instruction *foldSignBitTest(ICmpInst &I); @@ -736,8 +735,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final const APInt &C); Instruction *foldICmpBitCast(ICmpInst &Cmp); Instruction *foldICmpWithTrunc(ICmpInst &Cmp); - Instruction *foldICmpCommutative(ICmpInst::Predicate Pred, Value *Op0, - Value *Op1, ICmpInst &CxtI); + Instruction *foldICmpCommutative(CmpPredicate Pred, Value *Op0, Value *Op1, + ICmpInst &CxtI); // Helpers of visitSelectInst(). Instruction *foldSelectOfBools(SelectInst &SI); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 32f2a30afad48..3325a1868ebde 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1753,9 +1753,9 @@ static Value *simplifyInstructionWithPHI(Instruction &I, PHINode *PN, if (TerminatorBI && TerminatorBI->isConditional() && TerminatorBI->getSuccessor(0) != TerminatorBI->getSuccessor(1) && ICmp) { bool LHSIsTrue = TerminatorBI->getSuccessor(0) == PN->getParent(); - std::optional ImpliedCond = - isImpliedCondition(TerminatorBI->getCondition(), ICmp->getPredicate(), - Ops[0], Ops[1], DL, LHSIsTrue); + std::optional ImpliedCond = isImpliedCondition( + TerminatorBI->getCondition(), ICmp->getCmpPredicate(), Ops[0], Ops[1], + DL, LHSIsTrue); if (ImpliedCond) return ConstantInt::getBool(I.getType(), ImpliedCond.value()); } diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp index 0af812564c026..b4dbc4ed435aa 100644 --- a/llvm/unittests/IR/InstructionsTest.cpp +++ b/llvm/unittests/IR/InstructionsTest.cpp @@ -1923,5 +1923,27 @@ TEST(InstructionsTest, AtomicSyncscope) { EXPECT_TRUE(LLVMIsAtomicSingleThread(CmpXchg)); } +TEST(InstructionsTest, CmpPredicate) { + CmpPredicate P0(CmpInst::ICMP_ULE, false), P1(CmpInst::ICMP_ULE, true), + P2(CmpInst::ICMP_SLE, false), P3(CmpInst::ICMP_SLT, false); + CmpPredicate Q0 = P0, Q1 = P1, Q2 = P2; + CmpInst::Predicate R0 = P0, R1 = P1, R2 = P2; + + EXPECT_EQ(*CmpPredicate::getMatching(P0, P1), CmpInst::ICMP_ULE); + EXPECT_EQ(CmpPredicate::getMatching(P0, P1)->hasSameSign(), false); + EXPECT_EQ(*CmpPredicate::getMatching(P1, P1), CmpInst::ICMP_ULE); + EXPECT_EQ(CmpPredicate::getMatching(P1, P1)->hasSameSign(), true); + EXPECT_EQ(CmpPredicate::getMatching(P0, P2), std::nullopt); + EXPECT_EQ(*CmpPredicate::getMatching(P1, P2), CmpInst::ICMP_SLE); + EXPECT_EQ(CmpPredicate::getMatching(P1, P2)->hasSameSign(), false); + EXPECT_EQ(CmpPredicate::getMatching(P1, P3), std::nullopt); + EXPECT_FALSE(Q0.hasSameSign()); + EXPECT_TRUE(Q1.hasSameSign()); + EXPECT_FALSE(Q2.hasSameSign()); + EXPECT_EQ(P0, R0); + EXPECT_EQ(P1, R1); + EXPECT_EQ(P2, R2); +} + } // end anonymous namespace } // end namespace llvm