Skip to content

Commit 448adfe

Browse files
committed
[mlir] Only conditionally lower CF branching ops to LLVM
Previously cf.br cf.cond_br and cf.switch always lowered to their LLVM equivalents. These ops are all ops that take in some values of given types and jump to other blocks with argument lists of the same types. If the types are not the same, a verification failure will later occur. This led to confusions, as everything works when func->llvm and cf->llvm lowering both occur because func->llvm updates the blocks and argument lists while cf->llvm updates the branching ops. Without func->llvm though, there will potentially be a type mismatch. This change now only lowers the CF ops if they will later pass verification. This is possible because the parent op and its blocks will be updated before the contained branching ops, so they can test their new operand types against the types of the blocks they jump to. Another plan was to have func->llvm only update the entry block signature and to allow cf->llvm to update all other blocks, but this had 2 problems: 1. This would create a FuncOp lowering in cf->llvm lowering which is awkward 2. This new pattern would only be applied if the containing FuncOp is marked invalid. This is infeasible with the shared LLVM type conversion/target infrastructure. See previous discussions at https://discourse.llvm.org/t/lowering-cf-to-llvm/63863 and #55301 Differential Revision: https://reviews.llvm.org/D130971
1 parent d0541b4 commit 448adfe

File tree

3 files changed

+145
-20
lines changed

3 files changed

+145
-20
lines changed

mlir/docs/TargetLLVMIR.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ are expected to closely match the corresponding LLVM IR instructions and
1616
intrinsics. This minimizes the dependency on LLVM IR libraries in MLIR as well
1717
as reduces the churn in case of changes.
1818

19+
Note that many different dialects can be lowered to LLVM but are provided as
20+
different sets of patterns and have different passes available to mlir-opt.
21+
However, this is primarily useful for testing and prototyping, and using the
22+
collection of patterns together is highly recommended. One place this is
23+
important and visible is the ControlFlow dialect's branching operations which
24+
will fail to apply if their types mismatch with the blocks they jump to in the
25+
parent op.
26+
1927
SPIR-V to LLVM dialect conversion has a
2028
[dedicated document](SPIRVToLLVMDialectConversion.md).
2129

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Lines changed: 95 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/BuiltinOps.h"
2323
#include "mlir/IR/PatternMatch.h"
2424
#include "mlir/Transforms/DialectConversion.h"
25+
#include "llvm/ADT/StringRef.h"
2526
#include <functional>
2627

2728
using namespace mlir;
@@ -71,34 +72,108 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
7172
}
7273
};
7374

74-
// Base class for LLVM IR lowering terminator operations with successors.
75-
template <typename SourceOp, typename TargetOp>
76-
struct OneToOneLLVMTerminatorLowering
77-
: public ConvertOpToLLVMPattern<SourceOp> {
78-
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
79-
using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
75+
/// The cf->LLVM lowerings for branching ops require that the blocks they jump
76+
/// to first have updated types which should be handled by a pattern operating
77+
/// on the parent op.
78+
static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
79+
ValueRange operands,
80+
ValueRange blockArgs, Location loc,
81+
llvm::StringRef messagePrefix) {
82+
for (const auto &idxAndTypes :
83+
llvm::enumerate(llvm::zip(blockArgs, operands))) {
84+
int64_t i = idxAndTypes.index();
85+
Value argValue =
86+
rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
87+
Type operandType = std::get<1>(idxAndTypes.value()).getType();
88+
// In the case of an invalid jump, the block argument will have been
89+
// remapped to an UnrealizedConversionCast. In the case of a valid jump,
90+
// there might still be a no-op conversion cast with both types being equal.
91+
// Consider both of these details to see if the jump would be invalid.
92+
if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
93+
argValue.getDefiningOp())) {
94+
if (op.getOperandTypes().front() != operandType) {
95+
return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
96+
diag << messagePrefix;
97+
diag << "mismatched types from operand # " << i << " ";
98+
diag << operandType;
99+
diag << " not compatible with destination block argument type ";
100+
diag << argValue.getType();
101+
diag << " which should be converted with the parent op.";
102+
});
103+
}
104+
}
105+
}
106+
return success();
107+
}
108+
109+
/// Ensure that all block types were updated and then create an LLVM::BrOp
110+
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
111+
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
80112

81113
LogicalResult
82-
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
114+
matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
83115
ConversionPatternRewriter &rewriter) const override {
84-
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
85-
op->getSuccessors(), op->getAttrs());
116+
if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
117+
op.getSuccessor()->getArguments(),
118+
op.getLoc(),
119+
/*messagePrefix=*/"")))
120+
return failure();
121+
122+
rewriter.replaceOpWithNewOp<LLVM::BrOp>(
123+
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
86124
return success();
87125
}
88126
};
89127

90-
// FIXME: this should be tablegen'ed as well.
91-
struct BranchOpLowering
92-
: public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
93-
using Base::Base;
94-
};
95-
struct CondBranchOpLowering
96-
: public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
97-
using Base::Base;
128+
/// Ensure that all block types were updated and then create an LLVM::CondBrOp
129+
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
130+
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
131+
132+
LogicalResult
133+
matchAndRewrite(cf::CondBranchOp op,
134+
typename cf::CondBranchOp::Adaptor adaptor,
135+
ConversionPatternRewriter &rewriter) const override {
136+
if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
137+
op.getFalseDest()->getArguments(),
138+
op.getLoc(), "in false case branch ")))
139+
return failure();
140+
if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
141+
op.getTrueDest()->getArguments(),
142+
op.getLoc(), "in true case branch ")))
143+
return failure();
144+
145+
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
146+
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
147+
return success();
148+
}
98149
};
99-
struct SwitchOpLowering
100-
: public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
101-
using Base::Base;
150+
151+
/// Ensure that all block types were updated and then create an LLVM::SwitchOp
152+
struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
153+
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
154+
155+
LogicalResult
156+
matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
157+
ConversionPatternRewriter &rewriter) const override {
158+
if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
159+
op.getDefaultDestination()->getArguments(),
160+
op.getLoc(), "in switch default case ")))
161+
return failure();
162+
163+
for (const auto &i : llvm::enumerate(
164+
llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
165+
if (failed(verifyMatchingValues(
166+
rewriter, std::get<0>(i.value()),
167+
std::get<1>(i.value())->getArguments(), op.getLoc(),
168+
"in switch case " + std::to_string(i.index()) + " "))) {
169+
return failure();
170+
}
171+
}
172+
173+
rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
174+
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
175+
return success();
176+
}
102177
};
103178

104179
} // namespace
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: mlir-opt %s -convert-cf-to-llvm | FileCheck %s
2+
3+
func.func @name(%flag: i32, %pred: i1){
4+
// Test cf.br lowering failure with type mismatch
5+
// CHECK: cf.br
6+
%c0 = arith.constant 0 : index
7+
cf.br ^bb1(%c0 : index)
8+
9+
// Test cf.cond_br lowering failure with type mismatch in false_dest
10+
// CHECK: cf.cond_br
11+
^bb1(%0: index): // 2 preds: ^bb0, ^bb2
12+
%c1 = arith.constant 1 : i1
13+
%c2 = arith.constant 1 : index
14+
cf.cond_br %pred, ^bb2(%c1: i1), ^bb3(%c2: index)
15+
16+
// Test cf.cond_br lowering failure with type mismatch in true_dest
17+
// CHECK: cf.cond_br
18+
^bb2(%1: i1):
19+
%c3 = arith.constant 1 : i1
20+
%c4 = arith.constant 1 : index
21+
cf.cond_br %pred, ^bb3(%c4: index), ^bb2(%c3: i1)
22+
23+
// Test cf.switch lowering failure with type mismatch in default case
24+
// CHECK: cf.switch
25+
^bb3(%2: index): // pred: ^bb1
26+
%c5 = arith.constant 1 : i1
27+
%c6 = arith.constant 1 : index
28+
cf.switch %flag : i32, [
29+
default: ^bb1(%c6 : index),
30+
42: ^bb4(%c5 : i1)
31+
]
32+
33+
// Test cf.switch lowering failure with type mismatch in non-default case
34+
// CHECK: cf.switch
35+
^bb4(%3: i1): // pred: ^bb1
36+
%c7 = arith.constant 1 : i1
37+
%c8 = arith.constant 1 : index
38+
cf.switch %flag : i32, [
39+
default: ^bb2(%c7 : i1),
40+
41: ^bb1(%c8 : index)
41+
]
42+
}

0 commit comments

Comments
 (0)