Skip to content

Commit 6e92d3f

Browse files
[mlir][Test] Add a test pass to act as a sink towards LLVM conversion
This allows writing simple e2e tests where we can check for the proper materialization of specific LLVM IR (e.g. `llvm.intr.fmuladd`). Differential Revision: https://reviews.llvm.org/D138776
1 parent 6e4cea5 commit 6e92d3f

File tree

8 files changed

+185
-0
lines changed

8 files changed

+185
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt %s --test-transform-dialect-interpreter -test-transform-dialect-erase-schedule --test-lower-to-llvm --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: llvm.func @matmul_tensors
4+
func.func @matmul_tensors(
5+
%arg0: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, %arg2: tensor<2x6xf32>)
6+
-> tensor<2x6xf32> {
7+
// CHECK-NOT: linalg
8+
// CHECK: llvm.intr.fmuladd{{.*}}
9+
%0 = linalg.matmul ins(%arg0, %arg1: tensor<2x4xf32>, tensor<4x6xf32>)
10+
outs(%arg2: tensor<2x6xf32>)
11+
-> tensor<2x6xf32>
12+
return %0 : tensor<2x6xf32>
13+
}
14+
15+
transform.sequence failures(propagate) {
16+
^bb1(%module_op: !pdl.operation):
17+
%0 = transform.structured.match ops{["linalg.matmul"]} in %module_op
18+
%1, %loops:3 = transform.structured.tile %0 [2, 2, 2]
19+
%2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation
20+
transform.structured.vectorize %2
21+
transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op
22+
{bufferize_function_boundaries = true}
23+
%func = transform.structured.match ops{["func.func"]} in %module_op
24+
transform.vector.lower_vectors %func { multireduction_lowering = "innerreduce"}
25+
}

mlir/test/lib/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_subdirectory(DLTI)
44
add_subdirectory(Func)
55
add_subdirectory(GPU)
66
add_subdirectory(Linalg)
7+
add_subdirectory(LLVM)
78
add_subdirectory(Math)
89
add_subdirectory(MemRef)
910
add_subdirectory(NVGPU)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Exclude tests from libMLIR.so
2+
add_mlir_library(MLIRLLVMTestPasses
3+
TestLowerToLLVM.cpp
4+
5+
EXCLUDE_FROM_LIBMLIR
6+
7+
LINK_LIBS PUBLIC
8+
MLIRAffineToStandard
9+
MLIRFuncDialect
10+
MLIRFuncToLLVM
11+
MLIRIndexToLLVM
12+
MLIRIR
13+
MLIRLinalgToLLVM
14+
MLIRLLVMDialect
15+
MLIRLinalgTransforms
16+
MLIRMathToLLVM
17+
MLIRMemRefToLLVM
18+
MLIRPass
19+
MLIRReconcileUnrealizedCasts
20+
MLIRSCFToControlFlow
21+
MLIRTransforms
22+
MLIRVectorToLLVM
23+
MLIRVectorToSCF
24+
)
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
//===- TestLowerToLLVM.cpp - Test lowering to LLVM as a sink pass ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass for testing the lowering to LLVM as a generally
10+
// usable sink pass.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
15+
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16+
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
17+
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
18+
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
19+
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
20+
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
21+
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
22+
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
23+
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
24+
#include "mlir/Dialect/Func/IR/FuncOps.h"
25+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26+
#include "mlir/Dialect/Linalg/Passes.h"
27+
#include "mlir/IR/DialectRegistry.h"
28+
#include "mlir/Pass/Pass.h"
29+
#include "mlir/Pass/PassManager.h"
30+
#include "mlir/Transforms/Passes.h"
31+
32+
using namespace mlir;
33+
34+
namespace {
35+
struct TestLowerToLLVM
36+
: public PassWrapper<TestLowerToLLVM, OperationPass<ModuleOp>> {
37+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToLLVM)
38+
39+
TestLowerToLLVM() = default;
40+
TestLowerToLLVM(const TestLowerToLLVM &pass) : PassWrapper(pass) {}
41+
42+
StringRef getArgument() const final { return "test-lower-to-llvm"; }
43+
StringRef getDescription() const final {
44+
return "Test lowering to LLVM as a generally usable sink pass";
45+
}
46+
void getDependentDialects(DialectRegistry &registry) const override {
47+
registry.insert<LLVM::LLVMDialect>();
48+
}
49+
50+
Option<bool> reassociateFPReductions{
51+
*this, "reassociate-fp-reductions",
52+
llvm::cl::desc("Allow reassociation og FP reductions"),
53+
llvm::cl::init(false)};
54+
55+
void runOnOperation() final;
56+
};
57+
} // namespace
58+
59+
void TestLowerToLLVM::runOnOperation() {
60+
MLIRContext *context = &this->getContext();
61+
RewritePatternSet patterns(context);
62+
63+
// TODO: it is feasible to scope lowering at arbitrary level and introduce
64+
// unrealized casts, but there needs to be the final module-wise cleanup in
65+
// the end. Keep module-level for now.
66+
PassManager pm(&getContext());
67+
68+
// Blanket-convert any remaining high-level vector ops to loops if any remain.
69+
pm.addNestedPass<func::FuncOp>(createConvertVectorToSCFPass());
70+
// Blanket-convert any remaining linalg ops to loops if any remain.
71+
pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
72+
// Blanket-convert any remaining affine ops if any remain.
73+
pm.addPass(createLowerAffinePass());
74+
// Convert SCF to CF (always needed).
75+
pm.addPass(createConvertSCFToCFPass());
76+
// Sprinkle some cleanups.
77+
pm.addPass(createCanonicalizerPass());
78+
pm.addPass(createCSEPass());
79+
// Blanket-convert any remaining linalg ops to LLVM if any remain.
80+
pm.addPass(createConvertLinalgToLLVMPass());
81+
// Convert vector to LLVM (always needed).
82+
pm.addPass(createConvertVectorToLLVMPass(
83+
// TODO: add more options on a per-need basis.
84+
LowerVectorToLLVMOptions().enableReassociateFPReductions(
85+
reassociateFPReductions)));
86+
// Convert Math to LLVM (always needed).
87+
pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
88+
// Convert MemRef to LLVM (always needed).
89+
pm.addPass(createMemRefToLLVMConversionPass());
90+
// Convert Func to LLVM (always needed).
91+
pm.addPass(createConvertFuncToLLVMPass());
92+
// Convert Index to LLVM (always needed).
93+
pm.addPass(createConvertIndexToLLVMPass());
94+
// Convert remaining unrealized_casts (always needed).
95+
pm.addPass(createReconcileUnrealizedCastsPass());
96+
if (failed(pm.run(getOperation()))) {
97+
getOperation()->dump();
98+
return signalPassFailure();
99+
}
100+
}
101+
102+
namespace mlir {
103+
namespace test {
104+
void registerTestLowerToLLVM() { PassRegistration<TestLowerToLLVM>(); }
105+
} // namespace test
106+
} // namespace mlir

mlir/tools/mlir-opt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ if(MLIR_INCLUDE_TESTS)
3838
MLIRTestTransforms
3939
MLIRTilingInterfaceTestPasses
4040
MLIRVectorTestPasses
41+
MLIRLLVMTestPasses
4142
)
4243
endif()
4344

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ void registerTestLivenessPass();
9898
void registerTestLoopFusion();
9999
void registerTestLoopMappingPass();
100100
void registerTestLoopUnrollingPass();
101+
void registerTestLowerToLLVM();
101102
void registerTestMatchReductionPass();
102103
void registerTestMathAlgebraicSimplificationPass();
103104
void registerTestMathPolynomialApproximationPass();
@@ -201,6 +202,7 @@ void registerTestPasses() {
201202
mlir::test::registerTestLoopFusion();
202203
mlir::test::registerTestLoopMappingPass();
203204
mlir::test::registerTestLoopUnrollingPass();
205+
mlir::test::registerTestLowerToLLVM();
204206
mlir::test::registerTestMatchReductionPass();
205207
mlir::test::registerTestMathAlgebraicSimplificationPass();
206208
mlir::test::registerTestMathPolynomialApproximationPass();

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6873,6 +6873,7 @@ cc_binary(
68736873
"//mlir/test:TestGPU",
68746874
"//mlir/test:TestIR",
68756875
"//mlir/test:TestLinalg",
6876+
"//mlir/test:TestLLVM",
68766877
"//mlir/test:TestMath",
68776878
"//mlir/test:TestMemRef",
68786879
"//mlir/test:TestNVGPU",

utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,31 @@ cc_library(
583583
],
584584
)
585585

586+
cc_library(
587+
name = "TestLLVM",
588+
srcs = glob(["lib/Dialect/LLVM/*.cpp"]),
589+
defines = ["MLIR_CUDA_CONVERSIONS_ENABLED"],
590+
includes = ["lib/Dialect/Test"],
591+
deps = [
592+
"//mlir:AffineToStandard",
593+
"//mlir:FuncDialect",
594+
"//mlir:FuncToLLVM",
595+
"//mlir:IndexToLLVM",
596+
"//mlir:IR",
597+
"//mlir:LinalgToLLVM",
598+
"//mlir:LinalgTransforms",
599+
"//mlir:LLVMDialect",
600+
"//mlir:MathToLLVM",
601+
"//mlir:MemRefToLLVM",
602+
"//mlir:Pass",
603+
"//mlir:ReconcileUnrealizedCasts",
604+
"//mlir:SCFToControlFlow",
605+
"//mlir:Transforms",
606+
"//mlir:VectorToLLVM",
607+
"//mlir:VectorToSCF",
608+
],
609+
)
610+
586611
cc_library(
587612
name = "TestMath",
588613
srcs = glob(["lib/Dialect/Math/*.cpp"]),

0 commit comments

Comments
 (0)