Skip to content

Commit bb907b2

Browse files
committed
[ValueTracking] don't recursively compute known bits using multiple llvm.assumes
This is an alternative to D99759 to avoid the compile-time explosion seen in: https://llvm.org/PR49785 Another potential solution would make the exclusion logic stronger to avoid blowing up, but note that we reduced the complexity of the exclusion mechanism in D16204 because it was too costly. So I'm questioning the need for recursion/exclusion entirely - what is the optimization value vs. cost of recursively computing known bits based on assumptions? This was built into the implementation from the start with 60db058, and we have kept adding code/cost to deal with that capability. By clearing the query's AssumptionCache inside computeKnownBitsFromAssume(), this patch retains all existing assume functionality except refining known bits based on even more assumptions. We have 1 regression test that shows a difference in optimization power. Differential Revision: https://reviews.llvm.org/D100573
1 parent b06c55a commit bb907b2

File tree

2 files changed

+42
-59
lines changed

2 files changed

+42
-59
lines changed

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 33 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -107,40 +107,13 @@ struct Query {
107107
// provide it currently.
108108
OptimizationRemarkEmitter *ORE;
109109

110-
/// Set of assumptions that should be excluded from further queries.
111-
/// This is because of the potential for mutual recursion to cause
112-
/// computeKnownBits to repeatedly visit the same assume intrinsic. The
113-
/// classic case of this is assume(x = y), which will attempt to determine
114-
/// bits in x from bits in y, which will attempt to determine bits in y from
115-
/// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call
116-
/// isKnownNonZero, which calls computeKnownBits and isKnownToBeAPowerOfTwo
117-
/// (all of which can call computeKnownBits), and so on.
118-
std::array<const Value *, MaxAnalysisRecursionDepth> Excluded;
119-
120110
/// If true, it is safe to use metadata during simplification.
121111
InstrInfoQuery IIQ;
122112

123-
unsigned NumExcluded = 0;
124-
125113
Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI,
126114
const DominatorTree *DT, bool UseInstrInfo,
127115
OptimizationRemarkEmitter *ORE = nullptr)
128116
: DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), IIQ(UseInstrInfo) {}
129-
130-
Query(const Query &Q, const Value *NewExcl)
131-
: DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), IIQ(Q.IIQ),
132-
NumExcluded(Q.NumExcluded) {
133-
Excluded = Q.Excluded;
134-
Excluded[NumExcluded++] = NewExcl;
135-
assert(NumExcluded <= Excluded.size());
136-
}
137-
138-
bool isExcluded(const Value *Value) const {
139-
if (NumExcluded == 0)
140-
return false;
141-
auto End = Excluded.begin() + NumExcluded;
142-
return std::find(Excluded.begin(), End, Value) != End;
143-
}
144117
};
145118

146119
} // end anonymous namespace
@@ -632,8 +605,6 @@ static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) {
632605
CallInst *I = cast<CallInst>(AssumeVH);
633606
assert(I->getFunction() == Q.CxtI->getFunction() &&
634607
"Got assumption for the wrong function!");
635-
if (Q.isExcluded(I))
636-
continue;
637608

638609
// Warning: This loop can end up being somewhat performance sensitive.
639610
// We're running this loop for once for each value queried resulting in a
@@ -681,8 +652,6 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
681652
CallInst *I = cast<CallInst>(AssumeVH);
682653
assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() &&
683654
"Got assumption for the wrong function!");
684-
if (Q.isExcluded(I))
685-
continue;
686655

687656
// Warning: This loop can end up being somewhat performance sensitive.
688657
// We're running this loop for once for each value queried resulting in a
@@ -713,6 +682,15 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
713682
if (!Cmp)
714683
continue;
715684

685+
// We are attempting to compute known bits for the operands of an assume.
686+
// Do not try to use other assumptions for those recursive calls because
687+
// that can lead to mutual recursion and a compile-time explosion.
688+
// An example of the mutual recursion: computeKnownBits can call
689+
// isKnownNonZero which calls computeKnownBitsFromAssume (this function)
690+
// and so on.
691+
Query QueryNoAC = Q;
692+
QueryNoAC.AC = nullptr;
693+
716694
// Note that ptrtoint may change the bitwidth.
717695
Value *A, *B;
718696
auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
@@ -727,17 +705,17 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
727705
if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) &&
728706
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
729707
KnownBits RHSKnown =
730-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
708+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
731709
Known.Zero |= RHSKnown.Zero;
732710
Known.One |= RHSKnown.One;
733711
// assume(v & b = a)
734712
} else if (match(Cmp,
735713
m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) &&
736714
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
737715
KnownBits RHSKnown =
738-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
716+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
739717
KnownBits MaskKnown =
740-
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
718+
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
741719

742720
// For those bits in the mask that are known to be one, we can propagate
743721
// known bits from the RHS to V.
@@ -748,9 +726,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
748726
m_Value(A))) &&
749727
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
750728
KnownBits RHSKnown =
751-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
729+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
752730
KnownBits MaskKnown =
753-
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
731+
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
754732

755733
// For those bits in the mask that are known to be one, we can propagate
756734
// inverted known bits from the RHS to V.
@@ -761,9 +739,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
761739
m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) &&
762740
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
763741
KnownBits RHSKnown =
764-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
742+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
765743
KnownBits BKnown =
766-
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
744+
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
767745

768746
// For those bits in B that are known to be zero, we can propagate known
769747
// bits from the RHS to V.
@@ -774,9 +752,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
774752
m_Value(A))) &&
775753
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
776754
KnownBits RHSKnown =
777-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
755+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
778756
KnownBits BKnown =
779-
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
757+
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
780758

781759
// For those bits in B that are known to be zero, we can propagate
782760
// inverted known bits from the RHS to V.
@@ -787,9 +765,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
787765
m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) &&
788766
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
789767
KnownBits RHSKnown =
790-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
768+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
791769
KnownBits BKnown =
792-
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
770+
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
793771

794772
// For those bits in B that are known to be zero, we can propagate known
795773
// bits from the RHS to V. For those bits in B that are known to be one,
@@ -803,9 +781,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
803781
m_Value(A))) &&
804782
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
805783
KnownBits RHSKnown =
806-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
784+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
807785
KnownBits BKnown =
808-
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
786+
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
809787

810788
// For those bits in B that are known to be zero, we can propagate
811789
// inverted known bits from the RHS to V. For those bits in B that are
@@ -819,7 +797,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
819797
m_Value(A))) &&
820798
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
821799
KnownBits RHSKnown =
822-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
800+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
823801

824802
// For those bits in RHS that are known, we can propagate them to known
825803
// bits in V shifted to the right by C.
@@ -832,7 +810,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
832810
m_Value(A))) &&
833811
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
834812
KnownBits RHSKnown =
835-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
813+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
836814
// For those bits in RHS that are known, we can propagate them inverted
837815
// to known bits in V shifted to the right by C.
838816
RHSKnown.One.lshrInPlace(C);
@@ -844,7 +822,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
844822
m_Value(A))) &&
845823
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
846824
KnownBits RHSKnown =
847-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
825+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
848826
// For those bits in RHS that are known, we can propagate them to known
849827
// bits in V shifted to the right by C.
850828
Known.Zero |= RHSKnown.Zero << C;
@@ -854,7 +832,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
854832
m_Value(A))) &&
855833
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
856834
KnownBits RHSKnown =
857-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
835+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
858836
// For those bits in RHS that are known, we can propagate them inverted
859837
// to known bits in V shifted to the right by C.
860838
Known.Zero |= RHSKnown.One << C;
@@ -866,7 +844,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
866844
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
867845
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
868846
KnownBits RHSKnown =
869-
computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
847+
computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
870848

871849
if (RHSKnown.isNonNegative()) {
872850
// We know that the sign bit is zero.
@@ -879,7 +857,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
879857
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
880858
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
881859
KnownBits RHSKnown =
882-
computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
860+
computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
883861

884862
if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) {
885863
// We know that the sign bit is zero.
@@ -892,7 +870,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
892870
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
893871
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
894872
KnownBits RHSKnown =
895-
computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
873+
computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
896874

897875
if (RHSKnown.isNegative()) {
898876
// We know that the sign bit is one.
@@ -905,7 +883,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
905883
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
906884
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
907885
KnownBits RHSKnown =
908-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
886+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
909887

910888
if (RHSKnown.isZero() || RHSKnown.isNegative()) {
911889
// We know that the sign bit is one.
@@ -918,7 +896,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
918896
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
919897
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
920898
KnownBits RHSKnown =
921-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
899+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
922900

923901
// Whatever high bits in c are zero are known to be zero.
924902
Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros());
@@ -929,7 +907,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
929907
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
930908
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
931909
KnownBits RHSKnown =
932-
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
910+
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
933911

934912
// If the RHS is known zero, then this assumption must be wrong (nothing
935913
// is unsigned less than zero). Signal a conflict and get out of here.
@@ -941,7 +919,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
941919

942920
// Whatever high bits in c are zero are known to be zero (if c is a power
943921
// of 2, then one more).
944-
if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I)))
922+
if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, QueryNoAC))
945923
Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1);
946924
else
947925
Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros());

llvm/test/Transforms/InstCombine/assume.ll

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,20 @@ entry:
175175
ret i32 %and1
176176
}
177177

178-
define i32 @bar4(i32 %a, i32 %b) {
179-
; CHECK-LABEL: @bar4(
178+
; If we allow recursive known bits queries based on
179+
; assumptions, we could do better here:
180+
; a == b and a & 7 == 1, so b & 7 == 1, so b & 3 == 1, so return 1.
181+
182+
define i32 @known_bits_recursion_via_assumes(i32 %a, i32 %b) {
183+
; CHECK-LABEL: @known_bits_recursion_via_assumes(
180184
; CHECK-NEXT: entry:
185+
; CHECK-NEXT: [[AND1:%.*]] = and i32 [[B:%.*]], 3
181186
; CHECK-NEXT: [[AND:%.*]] = and i32 [[A:%.*]], 7
182187
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], 1
183188
; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]])
184-
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[A]], [[B:%.*]]
189+
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[A]], [[B]]
185190
; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP2]])
186-
; CHECK-NEXT: ret i32 1
191+
; CHECK-NEXT: ret i32 [[AND1]]
187192
;
188193
entry:
189194
%and1 = and i32 %b, 3

0 commit comments

Comments
 (0)