Skip to content

[GlobalIsel] combine ext of trunc with flags #87115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 8, 2024

Conversation

tschuett
Copy link

@tschuett tschuett commented Mar 29, 2024

@llvmbot
Copy link
Member

llvmbot commented Mar 29, 2024

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-globalisel

Author: Thorsten Schütt (tschuett)

Changes

#85592

https://discourse.llvm.org/t/rfc-add-nowrap-flags-to-trunc/77453


Full diff: https://github.com/llvm/llvm-project/pull/87115.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h (+10)
  • (modified) llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h (+52)
  • (modified) llvm/include/llvm/Target/GlobalISel/Combine.td (+18-1)
  • (modified) llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp (+66)
  • (added) llvm/test/CodeGen/AArch64/GlobalISel/combine-with-flags.mir (+156)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index 28d9cf6260d620..c29998cd42a770 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -805,6 +805,12 @@ class CombinerHelper {
   /// Match constant LHS ops that should be commuted.
   bool matchCommuteConstantToRHS(MachineInstr &MI);
 
+  /// Combine sext of trunc.
+  bool matchSextOfTrunc(const MachineOperand &MO, BuildFnTy &MatchInfo);
+
+  /// Combine zext of trunc.
+  bool matchZextOfTrunc(const MachineOperand &MO, BuildFnTy &MatchInfo);
+
   /// Match constant LHS FP ops that should be commuted.
   bool matchCommuteFPConstantToRHS(MachineInstr &MI);
 
@@ -823,6 +829,10 @@ class CombinerHelper {
   /// Combine addos.
   bool matchAddOverflow(MachineInstr &MI, BuildFnTy &MatchInfo);
 
+  /// Use a function which takes in a MachineIRBuilder to perform a combine.
+  /// By default, it erases the instruction \p MI from the function.
+  void applyBuildFnMO(const MachineOperand &MO, BuildFnTy &MatchInfo);
+
 private:
   /// Checks for legality of an indexed variant of \p LdSt.
   bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
index 261cfcf504d5fe..d5523e5ee0c7ab 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
@@ -739,6 +739,58 @@ class GOr : public GLogicalBinOp {
   };
 };
 
+/// Represents a cast operation.
+/// It models the llvm::CastInst concept.
+/// The exception is bitcast.
+class GCastOp : public GenericMachineInstr {
+public:
+  Register getSrcReg() const { return getOperand(1).getReg(); }
+
+  static bool classof(const MachineInstr *MI) {
+    switch (MI->getOpcode()) {
+    case TargetOpcode::G_ADDRSPACE_CAST:
+    case TargetOpcode::G_FPEXT:
+    case TargetOpcode::G_FPTOSI:
+    case TargetOpcode::G_FPTOUI:
+    case TargetOpcode::G_FPTRUNC:
+    case TargetOpcode::G_INTTOPTR:
+    case TargetOpcode::G_PTRTOINT:
+    case TargetOpcode::G_SEXT:
+    case TargetOpcode::G_SITOFP:
+    case TargetOpcode::G_TRUNC:
+    case TargetOpcode::G_UITOFP:
+    case TargetOpcode::G_ZEXT:
+      return true;
+    default:
+      return false;
+    }
+  };
+};
+
+/// Represents a sext.
+class GSext : public GCastOp {
+public:
+  static bool classof(const MachineInstr *MI) {
+    return MI->getOpcode() == TargetOpcode::G_SEXT;
+  };
+};
+
+/// Represents a zext.
+class GZext : public GCastOp {
+public:
+  static bool classof(const MachineInstr *MI) {
+    return MI->getOpcode() == TargetOpcode::G_ZEXT;
+  };
+};
+
+/// Represents a trunc.
+class GTrunc : public GCastOp {
+public:
+  static bool classof(const MachineInstr *MI) {
+    return MI->getOpcode() == TargetOpcode::G_TRUNC;
+  };
+};
+
 } // namespace llvm
 
 #endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 72d3c0ea69bcd2..e6443d8fa8ab39 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -180,6 +180,8 @@ def FmContract  : MIFlagEnum<"FmContract">;
 def FmAfn       : MIFlagEnum<"FmAfn">;
 def FmReassoc   : MIFlagEnum<"FmReassoc">;
 def IsExact     : MIFlagEnum<"IsExact">;
+def NoSWrap     : MIFlagEnum<"NoSWrap">;
+def NoUWrap     : MIFlagEnum<"NoUWrap">;
 
 def MIFlags;
 // def not; -> Already defined as a SDNode
@@ -1305,6 +1307,20 @@ def match_addos : GICombineRule<
         [{ return Helper.matchAddOverflow(*${root}, ${matchinfo}); }]),
   (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
 
+def sext_trunc : GICombineRule<
+   (defs root:$root, build_fn_matchinfo:$matchinfo),
+   (match (G_TRUNC $src, $x, (MIFlags NoSWrap)),
+          (G_SEXT $root, $src),
+   [{ return Helper.matchSextOfTrunc(${root}, ${matchinfo}); }]),
+   (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
+def zext_trunc : GICombineRule<
+   (defs root:$root, build_fn_matchinfo:$matchinfo),
+   (match (G_TRUNC $src, $x, (MIFlags NoUWrap)),
+          (G_ZEXT $root, $src),
+   [{ return Helper.matchZextOfTrunc(${root}, ${matchinfo}); }]),
+   (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
 // Combines concat operations
 def concat_matchinfo : GIDefMatchData<"SmallVector<Register>">;
 def combine_concat_vector : GICombineRule<
@@ -1388,7 +1404,8 @@ def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
     and_or_disjoint_mask, fma_combines, fold_binop_into_select,
     sub_add_reg, select_to_minmax, redundant_binop_in_equality,
     fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
-    combine_concat_vector, double_icmp_zero_and_or_combine, match_addos]>;
+    combine_concat_vector, double_icmp_zero_and_or_combine, match_addos,
+    sext_trunc, zext_trunc]>;
 
 // A combine group used to for prelegalizer combiners at -O0. The combines in
 // this group have been selected based on experiments to balance code size and
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 98e7c73a801f59..baf6c98a386322 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -7138,3 +7138,69 @@ bool CombinerHelper::matchAddOverflow(MachineInstr &MI, BuildFnTy &MatchInfo) {
 
   return false;
 }
+
+void CombinerHelper::applyBuildFnMO(const MachineOperand &MO,
+                                    BuildFnTy &MatchInfo) {
+  MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI);
+  Builder.setInstrAndDebugLoc(*Root);
+  MatchInfo(Builder);
+  Root->eraseFromParent();
+}
+
+bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO,
+                                      BuildFnTy &MatchInfo) {
+  GSext *Sext = getOpcodeDef<GSext>(MO.getReg(), MRI);
+  if (!Sext)
+    return false;
+
+  GTrunc *Trunc = getOpcodeDef<GTrunc>(Sext->getSrcReg(), MRI);
+  if (!Trunc)
+    return false;
+
+  // The trunc must have the nsw flag.
+  if (!Trunc->getFlag(MachineInstr::MIFlag::NoSWrap))
+    return false;
+
+  Register Dst = Sext->getReg(0);
+  Register Src = Trunc->getSrcReg();
+
+  LLT DstTy = MRI.getType(Dst);
+  LLT SrcTy = MRI.getType(Src);
+
+  // The types have to match for a no-op.
+  if (DstTy != SrcTy)
+    return false;
+
+  MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
+
+  return true;
+}
+
+bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO,
+                                      BuildFnTy &MatchInfo) {
+  GZext *Zext = getOpcodeDef<GZext>(MO.getReg(), MRI);
+  if (!Zext)
+    return false;
+
+  GTrunc *Trunc = getOpcodeDef<GTrunc>(Zext->getSrcReg(), MRI);
+  if (!Trunc)
+    return false;
+
+  // The trunc must have the nuw flag.
+  if (!Trunc->getFlag(MachineInstr::MIFlag::NoUWrap))
+    return false;
+
+  Register Dst = Zext->getReg(0);
+  Register Src = Trunc->getSrcReg();
+
+  LLT DstTy = MRI.getType(Dst);
+  LLT SrcTy = MRI.getType(Src);
+
+  // The types have to match for a no-op.
+  if (DstTy != SrcTy)
+    return false;
+
+  MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
+
+  return true;
+}
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-with-flags.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-with-flags.mir
new file mode 100644
index 00000000000000..8bcfb1fec1f23b
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-with-flags.mir
@@ -0,0 +1,156 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+# RUN: llc -run-pass=aarch64-prelegalizer-combiner -verify-machineinstrs -mtriple aarch64-unknown-unknown %s -o - | FileCheck %s
+
+---
+name:            zext_trunc_nuw
+body:             |
+  bb.0:
+    liveins: $w0, $w1
+    ; CHECK-LABEL: name: zext_trunc_nuw
+    ; CHECK: liveins: $w0, $w1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: $x1 = COPY [[COPY]](s64)
+    %0:_(s64) = COPY $x0
+    %2:_(s32) = nuw G_TRUNC %0
+    %3:_(s64) = G_ZEXT  %2
+    $x1 = COPY %3
+...
+---
+name:            zext_trunc_nsw
+body:             |
+  bb.0:
+    liveins: $w0, $w1
+    ; CHECK-LABEL: name: zext_trunc_nsw
+    ; CHECK: liveins: $w0, $w1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = nsw G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[TRUNC]](s32)
+    ; CHECK-NEXT: $x1 = COPY [[ZEXT]](s64)
+    %0:_(s64) = COPY $x0
+    %2:_(s32) = nsw G_TRUNC %0
+    %3:_(s64) = G_ZEXT  %2
+    $x1 = COPY %3
+...
+---
+name:            zext_trunc
+body:             |
+  bb.0:
+    liveins: $w0, $w1
+    ; CHECK-LABEL: name: zext_trunc
+    ; CHECK: liveins: $w0, $w1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[TRUNC]](s32)
+    ; CHECK-NEXT: $x1 = COPY [[ZEXT]](s64)
+    %0:_(s64) = COPY $x0
+    %2:_(s32) = G_TRUNC %0
+    %3:_(s64) = G_ZEXT  %2
+    $x1 = COPY %3
+...
+---
+name:            zext_trunc_nuw_vector
+body:             |
+  bb.0:
+    liveins: $w0, $w1
+    ; CHECK-LABEL: name: zext_trunc_nuw_vector
+    ; CHECK: liveins: $w0, $w1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(s32) = COPY $w1
+    ; CHECK-NEXT: %bv0:_(<4 x s32>) = G_BUILD_VECTOR [[COPY]](s32), [[COPY1]](s32), [[COPY]](s32), [[COPY1]](s32)
+    ; CHECK-NEXT: $q0 = COPY %bv0(<4 x s32>)
+    ; CHECK-NEXT: RET_ReallyLR implicit $w0
+    %0:_(s32) = COPY $w0
+    %1:_(s32) = COPY $w1
+    %2:_(s32) = COPY $w2
+    %3:_(s32) = COPY $w3
+    %bv0:_(<4 x s32>) = G_BUILD_VECTOR %0:_(s32), %1:_(s32), %0:_(s32), %1:_(s32)
+    %trunc:_(<4 x s16>) = nuw G_TRUNC %bv0
+    %zext:_(<4 x s32>) = G_ZEXT  %trunc
+    $q0 = COPY %zext(<4 x s32>)
+    RET_ReallyLR implicit $w0
+...
+---
+name:            sext_trunc_nsw
+body:             |
+  bb.0:
+    liveins: $w0, $w1
+    ; CHECK-LABEL: name: sext_trunc_nsw
+    ; CHECK: liveins: $w0, $w1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: $x1 = COPY [[COPY]](s64)
+    %0:_(s64) = COPY $x0
+    %2:_(s32) = nsw G_TRUNC %0
+    %3:_(s64) = G_SEXT  %2
+    $x1 = COPY %3
+...
+---
+name:            sext_trunc_nuw
+body:             |
+  bb.0:
+    liveins: $w0, $w1
+    ; CHECK-LABEL: name: sext_trunc_nuw
+    ; CHECK: liveins: $w0, $w1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = nuw G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: [[SEXT:%[0-9]+]]:_(s64) = G_SEXT [[TRUNC]](s32)
+    ; CHECK-NEXT: $x1 = COPY [[SEXT]](s64)
+    %0:_(s64) = COPY $x0
+    %2:_(s32) = nuw G_TRUNC %0
+    %3:_(s64) = G_SEXT  %2
+    $x1 = COPY %3
+...
+---
+name:            sext_trunc
+body:             |
+  bb.0:
+    liveins: $w0, $w1
+    ; CHECK-LABEL: name: sext_trunc
+    ; CHECK: liveins: $w0, $w1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s32) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: [[SEXT:%[0-9]+]]:_(s64) = G_SEXT [[TRUNC]](s32)
+    ; CHECK-NEXT: $x1 = COPY [[SEXT]](s64)
+    %0:_(s64) = COPY $x0
+    %2:_(s32) = G_TRUNC %0
+    %3:_(s64) = G_SEXT  %2
+    $x1 = COPY %3
+...
+---
+name:            sext_trunc_nsw_types_wrong
+body:             |
+  bb.0:
+    liveins: $w0, $w1
+    ; CHECK-LABEL: name: sext_trunc_nsw_types_wrong
+    ; CHECK: liveins: $w0, $w1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s16) = nsw G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: [[SEXT:%[0-9]+]]:_(s32) = G_SEXT [[TRUNC]](s16)
+    ; CHECK-NEXT: $w1 = COPY [[SEXT]](s32)
+    %0:_(s64) = COPY $x0
+    %2:_(s16) = nsw G_TRUNC %0
+    %3:_(s32) = G_SEXT  %2
+    $w1 = COPY %3
+...
+---
+name:            sext_trunc_nsw_nuw
+body:             |
+  bb.0:
+    liveins: $w0, $w1
+    ; CHECK-LABEL: name: sext_trunc_nsw_nuw
+    ; CHECK: liveins: $w0, $w1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: $x1 = COPY [[COPY]](s64)
+    %0:_(s64) = COPY $x0
+    %2:_(s32) = nsw nuw G_TRUNC %0
+    %3:_(s64) = G_SEXT  %2
+    $x1 = COPY %3
+...

@tschuett
Copy link
Author

I can remove the flags tests. We need a mechanism to determine the type of x in ext(trunc(x)). The Sext, Zext, and Trunk wrappers make my live easier to walk over the expression.

@arsenm arsenm requested a review from Pierre-vh April 2, 2024 13:24
@tschuett
Copy link
Author

tschuett commented Apr 3, 2024

def zext_trunc_fold_matchinfo : GIDefMatchData<"Register">;

should become

def zext_trunc_kb : GICombineRule<
   (defs root:$root, build_fn_matchinfo:$matchinfo),
   (match (G_TRUNC $src, $x, (MIFlags not NoUWrap)),
          (G_ZEXT $root, $src),

APInt DemandedElts =
Ty.isVector() ? APInt::getAllOnes(Ty.getNumElements()) : APInt(1, 1);
Ty.isFixedVector() ? APInt::getAllOnes(Ty.getNumElements()) : APInt(1, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a separate patch? What's the impact?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a G_SPLAT_VECTOR test and a second combine zext_trunc_fold_matchinfo that relies on known bits to prove that zext(trunc(x)) is a noop. It crashed.

Copy link
Author

@tschuett tschuett Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes are copy-pasted from the DAG to support scalable vectors in known bits.

@tschuett
Copy link
Author

Ping.

@tschuett tschuett force-pushed the ext-trunc-flags branch from e27d9ee to 9b6cc0c Compare May 8, 2024 03:41
@tschuett tschuett merged commit 737e0bc into llvm:main May 8, 2024
@tschuett tschuett deleted the ext-trunc-flags branch May 8, 2024 12:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants