Skip to content

Commit 9be8219

Browse files
committed
[mlir][Linalg] Add an interface to decompose complex ops
This patch adds an interface, named AggregatedOpInterface, that decomposes complex operations into simpler ones. For now, make the interface specific to Linalg because although the concept is general, the way to materialize it needs some maturing. Use that interface with the softmax operator. Differential Revision: https://reviews.llvm.org/D154363
1 parent d53d842 commit 9be8219

File tree

6 files changed

+310
-0
lines changed

6 files changed

+310
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,4 +897,34 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
897897
let verifyWithRegions = 1;
898898
}
899899

900+
def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
901+
let description = [{
902+
Interface for decomposing aggregated operations into a sequence of simpler
903+
ops.
904+
}];
905+
let cppNamespace = "::mlir";
906+
let methods = [
907+
InterfaceMethod<
908+
/*desc=*/[{
909+
Method to decompose the operation into simpler operations.
910+
911+
On success, this method returns one `Value` per result in the
912+
original operation.
913+
The order of the returned values must match the order of the
914+
original values.
915+
In other words, the returned vector can be used directly with
916+
`RewriterBase::replaceOp(this, returnedValues)`.
917+
}],
918+
/*retType=*/"FailureOr<SmallVector<Value>>",
919+
/*methodName=*/"decomposeOperation",
920+
/*args=*/(ins
921+
"OpBuilder &":$b),
922+
/*methodBody=*/"",
923+
/*defaultImplementation=*/[{
924+
return {};
925+
}]
926+
>
927+
];
928+
}
929+
900930
#endif // LINALG_IR_LINALGINTERFACES

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define LINALG_OPS
1515

1616
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
17+
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
1718
include "mlir/Interfaces/ControlFlowInterfaces.td"
1819
include "mlir/Interfaces/DestinationStyleOpInterface.td"
1920
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -93,6 +94,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
9394
[DestinationStyleOpInterface,
9495
PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
9596
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
97+
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
9698
DeclareOpInterfaceMethods<TilingInterface,
9799
["getIterationDomain",
98100
"getLoopIteratorTypes",

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,33 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
11991199
}];
12001200
}
12011201

1202+
//===----------------------------------------------------------------------===//
1203+
// DecomposeInterfaceOp
1204+
//===----------------------------------------------------------------------===//
1205+
1206+
def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
1207+
[FunctionalStyleTransformOpTrait,
1208+
MemoryEffectsOpInterface,
1209+
TransformOpInterface,
1210+
TransformEachOpTrait,
1211+
ReportTrackingListenerFailuresOpTrait]> {
1212+
let description = [{
1213+
TODO
1214+
}];
1215+
1216+
let arguments = (ins TransformHandleTypeInterface:$target);
1217+
let results = (outs TransformHandleTypeInterface:$transformed);
1218+
let assemblyFormat =
1219+
"$target attr-dict `:` functional-type(operands, results)";
1220+
1221+
let extraClassDeclaration = [{
1222+
::mlir::DiagnosedSilenceableFailure applyToOne(
1223+
::mlir::transform::TransformRewriter &rewriter,
1224+
::mlir::Operation *target,
1225+
::mlir::transform::ApplyToEachResultList &results,
1226+
::mlir::transform::TransformState &state);
1227+
}];
1228+
}
12021229
//===----------------------------------------------------------------------===//
12031230
// RewriteInDestinationPassingStyleOp.
12041231
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2323,6 +2323,176 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
23232323
.reifyResultShapes(b, reifiedReturnShapes);
23242324
}
23252325

2326+
// Helper functions for softmax decomposition.
2327+
// @{
2328+
2329+
// Helper function to produce the iterator types (reduction or parallel) and
2330+
// affine maps for the iterators used in the decomposition of softmax.
2331+
// This method creates:
2332+
// If allParallel == true:
2333+
// - iterator type: {parallel, ..., parallel}
2334+
// - affine maps:
2335+
// -- identity with inputRank dimensions.
2336+
// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2337+
// where N == inputRank.
2338+
//
2339+
// If allParallel == false:
2340+
// - iterator type at dim(i) == parallel for i != \p dim and
2341+
// dim(dim) == reduction.
2342+
// - affine map:
2343+
// -- identity with inputRank dimensions.
2344+
// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2345+
// where N == inputRank.
2346+
static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2347+
computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
2348+
int64_t dim, bool allParallel = false) {
2349+
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2350+
utils::IteratorType::parallel);
2351+
if (!allParallel)
2352+
iteratorTypes[dim] = utils::IteratorType::reduction;
2353+
MLIRContext *ctxt = builder.getContext();
2354+
auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2355+
SmallVector<AffineExpr, 2> affineExprs;
2356+
for (int i = 0; i < inputRank; i++) {
2357+
if (i != dim)
2358+
affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2359+
}
2360+
auto reductionMap =
2361+
AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2362+
SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2363+
return std::make_tuple(iteratorTypes, indexingMaps);
2364+
}
2365+
2366+
// Helper function to produce a linalg.generic that computes a reduction on
2367+
// dimension \p dim with the operation type \p T.
2368+
template <typename T>
2369+
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2370+
int64_t dim) {
2371+
auto inputType = cast<ShapedType>(input.getType());
2372+
ArrayRef<int64_t> inputShape = inputType.getShape();
2373+
int64_t inputRank = inputShape.size();
2374+
auto [iteratorTypes, indexingMaps] =
2375+
computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2376+
assert(indexingMaps.size() == 2 &&
2377+
"We should have two maps: 1 for the input, 1 for the output");
2378+
assert(indexingMaps[0].isIdentity() && "input map should be identity");
2379+
2380+
auto genericOp = builder.create<linalg::GenericOp>(
2381+
loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2382+
[&](OpBuilder &b, Location loc, ValueRange args) {
2383+
Value result = b.create<T>(loc, args[0], args[1]);
2384+
b.create<linalg::YieldOp>(loc, result);
2385+
});
2386+
return genericOp.getResult(0);
2387+
}
2388+
2389+
/// Produce a linalg generic that computes the second step of the softmax
2390+
/// decomposition: res = exp(input - max), where \p max is the max of \p input
2391+
/// on dimension \p dim.
2392+
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2393+
Value max, Value output, int64_t dim) {
2394+
auto inputType = cast<ShapedType>(input.getType());
2395+
ArrayRef<int64_t> inputShape = inputType.getShape();
2396+
int64_t inputRank = inputShape.size();
2397+
auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2398+
builder, inputRank, dim, /*allParallel=*/true);
2399+
assert(indexingMaps.size() == 2 && "We should have one map for each input");
2400+
assert(indexingMaps[0].isIdentity() && "input map should be identity");
2401+
// Add the affine map for the output argument.
2402+
indexingMaps.push_back(indexingMaps[0]);
2403+
auto genericOp = builder.create<linalg::GenericOp>(
2404+
loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2405+
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2406+
Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2407+
Value result = b.create<math::ExpOp>(loc, diff);
2408+
b.create<linalg::YieldOp>(loc, result);
2409+
});
2410+
return genericOp.getResult(0);
2411+
}
2412+
2413+
/// Produce a linalg generic that computes the final step of the softmax
2414+
/// decomposition.
2415+
/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2416+
/// yield n / d
2417+
/// }
2418+
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2419+
Value denominator, Value output, int64_t dim) {
2420+
auto inputType = cast<ShapedType>(numerator.getType());
2421+
ArrayRef<int64_t> inputShape = inputType.getShape();
2422+
int64_t inputRank = inputShape.size();
2423+
auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2424+
builder, inputRank, dim, /*allParallel=*/true);
2425+
assert(indexingMaps.size() == 2 &&
2426+
"We should have one map for each input (2)");
2427+
assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2428+
// Add the affine map for the output tensor.
2429+
indexingMaps.push_back(indexingMaps[0]);
2430+
auto genericOp = builder.create<linalg::GenericOp>(
2431+
loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2432+
indexingMaps, iteratorTypes,
2433+
[&](OpBuilder &b, Location loc, ValueRange args) {
2434+
Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2435+
b.create<linalg::YieldOp>(loc, result);
2436+
});
2437+
return genericOp.getResult(0);
2438+
}
2439+
// @} End helper functions for softmax decomposition.
2440+
2441+
/// Given an N-dimensional tensor x, this method converts
2442+
/// softmax(x) to the following sequence of operations:
2443+
///
2444+
/// 1. Compute the max of x along dimension d. This results
2445+
/// in a N-1 dimensional tensor m.
2446+
/// m = max(x, dim = d)
2447+
///
2448+
/// 2. Subtract a broadcasted m from x and exponentiate. This results in
2449+
/// a N dimensional tensor z.
2450+
/// z = exp(x - m)
2451+
///
2452+
/// 3. Compute the sum of z along dimension d. This results in
2453+
/// a N-1 dimensional tensor l.
2454+
/// l = sum(z, dim = d)
2455+
///
2456+
/// 4. Divide z and l. This gives the N-dimensional softmax.
2457+
/// softmax = z / l
2458+
///
2459+
FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2460+
OpBuilder::InsertionGuard guard(b);
2461+
b.setInsertionPoint(*this);
2462+
Location loc = getLoc();
2463+
Value input = getInput();
2464+
ShapedType inputType = getInputOperandType();
2465+
Type elementType = inputType.getElementType();
2466+
int64_t reductionDim = getDimension();
2467+
SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2468+
Value outputNd = b.create<tensor::EmptyOp>(loc, dims, elementType);
2469+
dims.erase(dims.begin() + reductionDim);
2470+
// Step 1: Compute max along dim.
2471+
Value output = b.create<tensor::EmptyOp>(loc, dims, elementType);
2472+
Value neutralForMaxF =
2473+
arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc);
2474+
Value neutralForMaxFInit =
2475+
b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, output).result();
2476+
Value max =
2477+
reduce<arith::MaxFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2478+
2479+
// Step 2: Subtract max from input and exponentiate.
2480+
Value numerator =
2481+
buildSubAndExpOp(b, loc, input, max, outputNd, reductionDim);
2482+
2483+
// Step 3: Compute sum along dim.
2484+
Value zero =
2485+
arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc);
2486+
Value zeroInit = b.create<linalg::FillOp>(loc, Value{zero}, output).result();
2487+
Value denominator =
2488+
reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2489+
2490+
// Step 4: Compute softmax.
2491+
Value result =
2492+
buildDivOp(b, loc, numerator, denominator, outputNd, reductionDim);
2493+
return SmallVector<Value>{result};
2494+
}
2495+
23262496
//===----------------------------------------------------------------------===//
23272497
// LinalgDialect
23282498
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,38 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
335335
return emitDefaultSilenceableFailure(target);
336336
}
337337

338+
//===----------------------------------------------------------------------===//
339+
// DecomposeInterfaceOp
340+
//===----------------------------------------------------------------------===//
341+
342+
// Decompose the target operation if it implements the AggregatedOpInterface.
343+
// Push the decomposed operations (the ones that replaces the values produced by
344+
// \p target) in the `results`.
345+
DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
346+
transform::TransformRewriter &rewriter, Operation *target,
347+
transform::ApplyToEachResultList &results,
348+
transform::TransformState &state) {
349+
auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
350+
if (!decomposableOp) {
351+
failed(rewriter.notifyMatchFailure(target,
352+
"payload is not a decomposable op"));
353+
return emitDefaultSilenceableFailure(target);
354+
}
355+
356+
FailureOr<SmallVector<Value>> maybeNewResults =
357+
decomposableOp.decomposeOperation(rewriter);
358+
if (failed(maybeNewResults))
359+
return emitDefaultSilenceableFailure(target);
360+
361+
rewriter.replaceOp(decomposableOp, *maybeNewResults);
362+
for (Value val : *maybeNewResults) {
363+
Operation *definition = val.getDefiningOp();
364+
if (definition)
365+
results.push_back(definition);
366+
}
367+
return DiagnosedSilenceableFailure::success();
368+
}
369+
338370
//===----------------------------------------------------------------------===//
339371
// EliminateLinalgOpAnchoredEmptyTensorsOp
340372
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/transform-op-decompose.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
22

3+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4+
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
5+
36
// CHECK-LABEL: @conv_2d_nhwc_hwcf
47
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
58
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
@@ -199,8 +202,54 @@ func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32
199202
return %0 : tensor<?x?x1x?xf32>
200203
}
201204

205+
func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
206+
%1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
207+
return %1 : tensor<2x16x32xf32>
208+
}
209+
210+
// CHECK-LABEL: func.func @softmax(
211+
//CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
212+
// CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
213+
// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
214+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000 : f32
215+
// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
216+
// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
217+
// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
218+
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
219+
// CHECK: %[[D8:.+]] = arith.maxf %[[IN]], %[[OUT]] : f32
220+
// CHECK: linalg.yield %[[D8]] : f32
221+
// CHECK: } -> tensor<2x16xf32>
222+
// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
223+
// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
224+
// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) {
225+
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
226+
// CHECK: %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
227+
// CHECK: %[[D9:.+]] = math.exp %[[D8]] : f32
228+
// CHECK: linalg.yield %[[D9]] : f32
229+
// CHECK: } -> tensor<2x16x32xf32>
230+
// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
231+
// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
232+
// CHECK: %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
233+
// CHECK-SAME: "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
234+
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
235+
// CHECK: %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32
236+
// CHECK: linalg.yield %[[D8]] : f32
237+
// CHECK: } -> tensor<2x16xf32>
238+
// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
239+
// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
240+
// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) {
241+
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
242+
// CHECK: %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
243+
// CHECK: linalg.yield %[[D8]] : f32
244+
// CHECK: } -> tensor<2x16x32xf32>
245+
// CHECK: return %[[D7]] : tensor<2x16x32xf32>
246+
// CHECK: }
247+
202248
transform.sequence failures(propagate) {
203249
^bb1(%arg1: !transform.any_op):
204250
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
205251
%1 = transform.structured.decompose %0 : (!transform.any_op) -> !transform.any_op
252+
253+
%2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
254+
%3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op
206255
}

0 commit comments

Comments
 (0)