diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index 02d05780a7ac1..6d7ac2be951dd 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -24,10 +24,10 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">; // Multiply two integer attributes and create a new one with the result. def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">; -// TODO: Canonicalizations currently doesn't take into account integer overflow -// flags and always reset them to default (wraparound) which is safe but can -// inhibit later optimizations. Individual patterns must be reviewed for -// better handling of overflow flags. +// Merge overflow flags from 2 ops, selecting the most conservative combination. +def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">; + +// Default overflow flag (all wraparounds allowed). defvar DefOverflow = ConstantEnumCase; class cast : NativeCodeCall<"::mlir::cast<" # type # ">($0)">; @@ -45,7 +45,7 @@ def AddIAddConstant : (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), (ConstantLikeMatcher APIntAttr:$c1), $ovf2), (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), - DefOverflow)>; + (MergeOverflow $ovf1, $ovf2))>; // addi(subi(x, c0), c1) -> addi(x, c1 - c0) def AddISubConstantRHS : @@ -53,7 +53,7 @@ def AddISubConstantRHS : (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), (ConstantLikeMatcher APIntAttr:$c1), $ovf2), (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), - DefOverflow)>; + (MergeOverflow $ovf1, $ovf2))>; // addi(subi(c0, x), c1) -> subi(c0 + c1, x) def AddISubConstantLHS : @@ -61,7 +61,7 @@ def AddISubConstantLHS : (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), (ConstantLikeMatcher APIntAttr:$c1), $ovf2), (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x, - DefOverflow)>; + (MergeOverflow $ovf1, $ovf2))>; def IsScalarOrSplatNegativeOne : Constraint; // addi(muli(x, -1), y) -> subi(y, x) @@ -81,7 +81,7 @@ def AddIMulNegativeOneLhs : Pat<(Arith_AddIOp (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $y, $ovf2), - (Arith_SubIOp $y, $x, DefOverflow), + (Arith_SubIOp $y, $x, DefOverflow), // TODO: overflow flags [(IsScalarOrSplatNegativeOne $c0)]>; // muli(muli(x, c0), c1) -> muli(x, c0 * c1) @@ -90,7 +90,7 @@ def MulIMulIConstant : (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), (ConstantLikeMatcher APIntAttr:$c1), $ovf2), (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)), - DefOverflow)>; + (MergeOverflow $ovf1, $ovf2))>; //===----------------------------------------------------------------------===// // AddUIExtendedOp @@ -113,7 +113,7 @@ def SubIRHSAddConstant : (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), (ConstantLikeMatcher APIntAttr:$c1), $ovf2), (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), - DefOverflow)>; + DefOverflow)>; // TODO: overflow flags // subi(c1, addi(x, c0)) -> subi(c1 - c0, x) def SubILHSAddConstant : @@ -121,7 +121,7 @@ def SubILHSAddConstant : (ConstantLikeMatcher APIntAttr:$c1), (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2), (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x, - DefOverflow)>; + (MergeOverflow $ovf1, $ovf2))>; // subi(subi(x, c0), c1) -> subi(x, c0 + c1) def SubIRHSSubConstantRHS : @@ -129,7 +129,7 @@ def SubIRHSSubConstantRHS : (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), (ConstantLikeMatcher APIntAttr:$c1), $ovf2), (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), - DefOverflow)>; + (MergeOverflow $ovf1, $ovf2))>; // subi(subi(c0, x), c1) -> subi(c0 - c1, x) def SubIRHSSubConstantLHS : @@ -137,7 +137,7 @@ def SubIRHSSubConstantLHS : (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), (ConstantLikeMatcher APIntAttr:$c1), $ovf2), (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x, - DefOverflow)>; + (MergeOverflow $ovf1, $ovf2))>; // subi(c1, subi(x, c0)) -> subi(c0 + c1, x) def SubILHSSubConstantRHS : @@ -145,7 +145,7 @@ def SubILHSSubConstantRHS : (ConstantLikeMatcher APIntAttr:$c1), (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2), (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x, - DefOverflow)>; + (MergeOverflow $ovf1, $ovf2))>; // subi(c1, subi(c0, x)) -> addi(x, c1 - c0) def SubILHSSubConstantLHS : @@ -153,12 +153,13 @@ def SubILHSSubConstantLHS : (ConstantLikeMatcher APIntAttr:$c1), (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2), (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), - DefOverflow)>; + (MergeOverflow $ovf1, $ovf2))>; // subi(subi(a, b), a) -> subi(0, b) def SubISubILHSRHSLHS : Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2), - (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>; + (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, + (MergeOverflow $ovf1, $ovf2))>; //===----------------------------------------------------------------------===// // MulSIExtendedOp diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index a1568d0ebba3a..a0b50251c6b67 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -64,6 +64,14 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies()); } +// Merge overflow flags from 2 ops, selecting the most conservative combination. +static IntegerOverflowFlagsAttr +mergeOverflowFlags(IntegerOverflowFlagsAttr val1, + IntegerOverflowFlagsAttr val2) { + return IntegerOverflowFlagsAttr::get(val1.getContext(), + val1.getValue() & val2.getValue()); +} + /// Invert an integer comparison predicate. arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { switch (pred) { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index f7ce2123a93c6..e4f95bb0545a2 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -833,6 +833,30 @@ func.func @tripleAddAdd(%arg0: index) -> index { return %add2 : index } +// CHECK-LABEL: @tripleAddAddOvf1 +// CHECK: %[[cres:.+]] = arith.constant 59 : index +// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow : index +// CHECK: return %[[add]] +func.func @tripleAddAddOvf1(%arg0: index) -> index { + %c17 = arith.constant 17 : index + %c42 = arith.constant 42 : index + %add1 = arith.addi %c17, %arg0 overflow : index + %add2 = arith.addi %c42, %add1 overflow : index + return %add2 : index +} + +// CHECK-LABEL: @tripleAddAddOvf2 +// CHECK: %[[cres:.+]] = arith.constant 59 : index +// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index +// CHECK: return %[[add]] +func.func @tripleAddAddOvf2(%arg0: index) -> index { + %c17 = arith.constant 17 : index + %c42 = arith.constant 42 : index + %add1 = arith.addi %c17, %arg0 overflow : index + %add2 = arith.addi %c42, %add1 overflow : index + return %add2 : index +} + // CHECK-LABEL: @tripleAddSub0 // CHECK: %[[cres:.+]] = arith.constant 59 : index // CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index @@ -845,6 +869,18 @@ func.func @tripleAddSub0(%arg0: index) -> index { return %add2 : index } +// CHECK-LABEL: @tripleAddSub0Ovf +// CHECK: %[[cres:.+]] = arith.constant 59 : index +// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow : index +// CHECK: return %[[add]] +func.func @tripleAddSub0Ovf(%arg0: index) -> index { + %c17 = arith.constant 17 : index + %c42 = arith.constant 42 : index + %add1 = arith.subi %c17, %arg0 overflow : index + %add2 = arith.addi %c42, %add1 overflow : index + return %add2 : index +} + // CHECK-LABEL: @tripleAddSub1 // CHECK: %[[cres:.+]] = arith.constant 25 : index // CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index @@ -857,6 +893,18 @@ func.func @tripleAddSub1(%arg0: index) -> index { return %add2 : index } +// CHECK-LABEL: @tripleAddSub1Ovf +// CHECK: %[[cres:.+]] = arith.constant 25 : index +// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow : index +// CHECK: return %[[add]] +func.func @tripleAddSub1Ovf(%arg0: index) -> index { + %c17 = arith.constant 17 : index + %c42 = arith.constant 42 : index + %add1 = arith.subi %arg0, %c17 overflow : index + %add2 = arith.addi %c42, %add1 overflow : index + return %add2 : index +} + // CHECK-LABEL: @tripleSubAdd0 // CHECK: %[[cres:.+]] = arith.constant 25 : index // CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index @@ -869,6 +917,18 @@ func.func @tripleSubAdd0(%arg0: index) -> index { return %add2 : index } +// CHECK-LABEL: @tripleSubAdd0Ovf +// CHECK: %[[cres:.+]] = arith.constant 25 : index +// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow : index +// CHECK: return %[[add]] +func.func @tripleSubAdd0Ovf(%arg0: index) -> index { + %c17 = arith.constant 17 : index + %c42 = arith.constant 42 : index + %add1 = arith.addi %c17, %arg0 overflow : index + %add2 = arith.subi %c42, %add1 overflow : index + return %add2 : index +} + // CHECK-LABEL: @tripleSubAdd1 // CHECK: %[[cres:.+]] = arith.constant -25 : index // CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index @@ -891,6 +951,16 @@ func.func @subSub0(%arg0: index, %arg1: index) -> index { return %sub2 : index } +// CHECK-LABEL: @subSub0Ovf +// CHECK: %[[c0:.+]] = arith.constant 0 : index +// CHECK: %[[add:.+]] = arith.subi %[[c0]], %arg1 overflow : index +// CHECK: return %[[add]] +func.func @subSub0Ovf(%arg0: index, %arg1: index) -> index { + %sub1 = arith.subi %arg0, %arg1 overflow : index + %sub2 = arith.subi %sub1, %arg0 overflow : index + return %sub2 : index +} + // CHECK-LABEL: @tripleSubSub0 // CHECK: %[[cres:.+]] = arith.constant 25 : index // CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index @@ -903,6 +973,19 @@ func.func @tripleSubSub0(%arg0: index) -> index { return %add2 : index } +// CHECK-LABEL: @tripleSubSub0Ovf +// CHECK: %[[cres:.+]] = arith.constant 25 : index +// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow : index +// CHECK: return %[[add]] +func.func @tripleSubSub0Ovf(%arg0: index) -> index { + %c17 = arith.constant 17 : index + %c42 = arith.constant 42 : index + %add1 = arith.subi %c17, %arg0 overflow : index + %add2 = arith.subi %c42, %add1 overflow : index + return %add2 : index +} + + // CHECK-LABEL: @tripleSubSub1 // CHECK: %[[cres:.+]] = arith.constant -25 : index // CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index @@ -915,6 +998,18 @@ func.func @tripleSubSub1(%arg0: index) -> index { return %add2 : index } +// CHECK-LABEL: @tripleSubSub1Ovf +// CHECK: %[[cres:.+]] = arith.constant -25 : index +// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow : index +// CHECK: return %[[add]] +func.func @tripleSubSub1Ovf(%arg0: index) -> index { + %c17 = arith.constant 17 : index + %c42 = arith.constant 42 : index + %add1 = arith.subi %c17, %arg0 overflow : index + %add2 = arith.subi %add1, %c42 overflow : index + return %add2 : index +} + // CHECK-LABEL: @tripleSubSub2 // CHECK: %[[cres:.+]] = arith.constant 59 : index // CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index @@ -927,6 +1022,18 @@ func.func @tripleSubSub2(%arg0: index) -> index { return %add2 : index } +// CHECK-LABEL: @tripleSubSub2Ovf +// CHECK: %[[cres:.+]] = arith.constant 59 : index +// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow : index +// CHECK: return %[[add]] +func.func @tripleSubSub2Ovf(%arg0: index) -> index { + %c17 = arith.constant 17 : index + %c42 = arith.constant 42 : index + %add1 = arith.subi %arg0, %c17 overflow : index + %add2 = arith.subi %c42, %add1 overflow : index + return %add2 : index +} + // CHECK-LABEL: @tripleSubSub3 // CHECK: %[[cres:.+]] = arith.constant 59 : index // CHECK: %[[add:.+]] = arith.subi %arg0, %[[cres]] : index @@ -939,6 +1046,18 @@ func.func @tripleSubSub3(%arg0: index) -> index { return %add2 : index } +// CHECK-LABEL: @tripleSubSub3Ovf +// CHECK: %[[cres:.+]] = arith.constant 59 : index +// CHECK: %[[add:.+]] = arith.subi %arg0, %[[cres]] overflow : index +// CHECK: return %[[add]] +func.func @tripleSubSub3Ovf(%arg0: index) -> index { + %c17 = arith.constant 17 : index + %c42 = arith.constant 42 : index + %add1 = arith.subi %arg0, %c17 overflow : index + %add2 = arith.subi %add1, %c42 overflow : index + return %add2 : index +} + // CHECK-LABEL: @subAdd1 // CHECK-NEXT: return %arg0 func.func @subAdd1(%arg0: index, %arg1 : index) -> index {