diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index f4bac9376f2ea..0ecbdba95cb19 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -562,8 +562,10 @@ class LLVM_DbgIntrOp traits = []> }]; } -def LLVM_DbgDeclareOp : LLVM_DbgIntrOp<"dbg.declare", "addr", - [DeclareOpInterfaceMethods]> { +def LLVM_DbgDeclareOp : LLVM_DbgIntrOp<"dbg.declare", "addr", [ + DeclareOpInterfaceMethods]> { let summary = "Describes how the address relates to a source language variable."; let arguments = (ins LLVM_AnyPointer:$addr, diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index e10e2d4e104c3..9db89361c7800 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -229,6 +229,36 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> { (ins "const ::llvm::SmallPtrSetImpl &":$blockingUses, "::mlir::RewriterBase &":$rewriter) >, + InterfaceMethod<[{ + This method allows the promoted operation to visit the SSA values used + in place of the memory slot once the promotion process of the memory + slot is complete. + + If this method returns true, the `visitReplacedValues` method on this + operation will be called after the main mutation stage finishes + (i.e., after all ops have been processed with `removeBlockingUses`). + + Operations should only the replaced values if the intended + transformation applies to all the replaced values. Furthermore, replaced + values must not be deleted. + }], "bool", "requiresReplacedValues", (ins), [{}], + [{ return false; }] + >, + InterfaceMethod<[{ + Transforms the IR using the SSA values that replaced the memory slot. + + This method will only be called after all blocking uses have been + scheduled for removal and if `requiresReplacedValues` returned + true. + + The rewriter is located after the promotable operation on call. All IR + mutations must happen through the rewriter. During the transformation, + *no operation should be deleted*. + }], + "void", "visitReplacedValues", + (ins "::llvm::ArrayRef>":$mutatedDefs, + "::mlir::RewriterBase &":$rewriter), [{}], [{ return; }] + >, ]; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index f171bf7cc4bec..f893c554601e1 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -166,13 +166,6 @@ bool LLVM::StoreOp::canUsesBeRemoved( DeletionKind LLVM::StoreOp::removeBlockingUses( const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, RewriterBase &rewriter, Value reachingDefinition) { - // `canUsesBeRemoved` checked this blocking use must be the stored slot - // pointer. - for (Operation *user : slot.ptr.getUsers()) - if (auto declareOp = dyn_cast(user)) - rewriter.create(declareOp->getLoc(), getValue(), - declareOp.getVarInfo(), - declareOp.getLocationExpr()); return DeletionKind::Delete; } @@ -405,6 +398,18 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses( return DeletionKind::Keep; } +bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; } + +void LLVM::DbgDeclareOp::visitReplacedValues( + ArrayRef> definitions, + RewriterBase &rewriter) { + for (auto [op, value] : definitions) { + rewriter.setInsertionPointAfter(op); + rewriter.create(getLoc(), value, getVarInfo(), + getLocationExpr()); + } +} + //===----------------------------------------------------------------------===// // Interfaces for GEPOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index 80e3b79016329..abe565ea862f8 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -202,6 +202,7 @@ class MemorySlotPromoter { /// Contains the reaching definition at this operation. Reaching definitions /// are only computed for promotable memory operations with blocking uses. DenseMap reachingDefs; + DenseMap replacedValuesMap; DominanceInfo &dominance; MemorySlotPromotionInfo info; const Mem2RegStatistics &statistics; @@ -438,6 +439,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, assert(stored && "a memory operation storing to a slot must provide a " "new definition of the slot"); reachingDef = stored; + replacedValuesMap[memOp] = stored; } } } @@ -552,6 +554,10 @@ void MemorySlotPromoter::removeBlockingUses() { dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent()); llvm::SmallVector toErase; + // List of all replaced values in the slot. + llvm::SmallVector> replacedValuesList; + // Ops to visit with the `visitReplacedValues` method. + llvm::SmallVector toVisit; for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) { if (auto toPromoteMemOp = dyn_cast(toPromote)) { Value reachingDef = reachingDefs.lookup(toPromoteMemOp); @@ -565,7 +571,9 @@ void MemorySlotPromoter::removeBlockingUses() { slot, info.userToBlockingUses[toPromote], rewriter, reachingDef) == DeletionKind::Delete) toErase.push_back(toPromote); - + if (toPromoteMemOp.storesTo(slot)) + if (Value replacedValue = replacedValuesMap[toPromoteMemOp]) + replacedValuesList.push_back({toPromoteMemOp, replacedValue}); continue; } @@ -574,6 +582,12 @@ void MemorySlotPromoter::removeBlockingUses() { if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote], rewriter) == DeletionKind::Delete) toErase.push_back(toPromote); + if (toPromoteBasic.requiresReplacedValues()) + toVisit.push_back(toPromoteBasic); + } + for (PromotableOpInterface op : toVisit) { + rewriter.setInsertionPointAfter(op); + op.visitReplacedValues(replacedValuesList, rewriter); } for (Operation *toEraseOp : toErase) diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir index f7ddb4a7abe5a..b7cbd787f06e4 100644 --- a/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir +++ b/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir @@ -29,6 +29,27 @@ llvm.func @basic_store_load(%arg0: i64) -> i64 { llvm.return %2 : i64 } +// CHECK-LABEL: llvm.func @multiple_store_load +llvm.func @multiple_store_load(%arg0: i64) -> i64 { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NOT: = llvm.alloca + %1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: llvm.intr.dbg.declare + llvm.intr.dbg.declare #di_local_variable = %1 : !llvm.ptr + // CHECK-NOT: llvm.store + llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr + // CHECK-NOT: llvm.intr.dbg.declare + llvm.intr.dbg.declare #di_local_variable = %1 : !llvm.ptr + // CHECK: llvm.intr.dbg.value #[[$VAR]] = %[[LOADED:.*]] : i64 + // CHECK: llvm.intr.dbg.value #[[$VAR]] = %[[LOADED]] : i64 + // CHECK-NOT: llvm.intr.dbg.value + // CHECK-NOT: llvm.intr.dbg.declare + // CHECK-NOT: llvm.store + %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64 + // CHECK: llvm.return %[[LOADED]] : i64 + llvm.return %2 : i64 +} + // CHECK-LABEL: llvm.func @block_argument_value // CHECK-SAME: (%[[ARG0:.*]]: i64, {{.*}}) llvm.func @block_argument_value(%arg0: i64, %arg1: i1) -> i64 {