@@ -2323,6 +2323,176 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
2323
2323
.reifyResultShapes (b, reifiedReturnShapes);
2324
2324
}
2325
2325
2326
+ // Helper functions for softmax decomposition.
2327
+ // @{
2328
+
2329
+ // Helper function to produce the iterator types (reduction or parallel) and
2330
+ // affine maps for the iterators used in the decomposition of softmax.
2331
+ // This method creates:
2332
+ // If allParallel == true:
2333
+ // - iterator type: {parallel, ..., parallel}
2334
+ // - affine maps:
2335
+ // -- identity with inputRank dimensions.
2336
+ // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2337
+ // where N == inputRank.
2338
+ //
2339
+ // If allParallel == false:
2340
+ // - iterator type at dim(i) == parallel for i != \p dim and
2341
+ // dim(dim) == reduction.
2342
+ // - affine map:
2343
+ // -- identity with inputRank dimensions.
2344
+ // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2345
+ // where N == inputRank.
2346
+ static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2347
+ computeIteratorTypesAndIndexingMaps (OpBuilder &builder, int64_t inputRank,
2348
+ int64_t dim, bool allParallel = false ) {
2349
+ SmallVector<utils::IteratorType> iteratorTypes (inputRank,
2350
+ utils::IteratorType::parallel);
2351
+ if (!allParallel)
2352
+ iteratorTypes[dim] = utils::IteratorType::reduction;
2353
+ MLIRContext *ctxt = builder.getContext ();
2354
+ auto identityMap = AffineMap::getMultiDimIdentityMap (inputRank, ctxt);
2355
+ SmallVector<AffineExpr, 2 > affineExprs;
2356
+ for (int i = 0 ; i < inputRank; i++) {
2357
+ if (i != dim)
2358
+ affineExprs.push_back (mlir::getAffineDimExpr (i, ctxt));
2359
+ }
2360
+ auto reductionMap =
2361
+ AffineMap::get (inputRank, /* symbols=*/ 0 , affineExprs, ctxt);
2362
+ SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2363
+ return std::make_tuple (iteratorTypes, indexingMaps);
2364
+ }
2365
+
2366
+ // Helper function to produce a linalg.generic that computes a reduction on
2367
+ // dimension \p dim with the operation type \p T.
2368
+ template <typename T>
2369
+ static Value reduce (OpBuilder &builder, Location loc, Value input, Value output,
2370
+ int64_t dim) {
2371
+ auto inputType = cast<ShapedType>(input.getType ());
2372
+ ArrayRef<int64_t > inputShape = inputType.getShape ();
2373
+ int64_t inputRank = inputShape.size ();
2374
+ auto [iteratorTypes, indexingMaps] =
2375
+ computeIteratorTypesAndIndexingMaps (builder, inputRank, dim);
2376
+ assert (indexingMaps.size () == 2 &&
2377
+ " We should have two maps: 1 for the input, 1 for the output" );
2378
+ assert (indexingMaps[0 ].isIdentity () && " input map should be identity" );
2379
+
2380
+ auto genericOp = builder.create <linalg::GenericOp>(
2381
+ loc, output.getType (), input, output, indexingMaps, iteratorTypes,
2382
+ [&](OpBuilder &b, Location loc, ValueRange args) {
2383
+ Value result = b.create <T>(loc, args[0 ], args[1 ]);
2384
+ b.create <linalg::YieldOp>(loc, result);
2385
+ });
2386
+ return genericOp.getResult (0 );
2387
+ }
2388
+
2389
+ // / Produce a linalg generic that computes the second step of the softmax
2390
+ // / decomposition: res = exp(input - max), where \p max is the max of \p input
2391
+ // / on dimension \p dim.
2392
+ static Value buildSubAndExpOp (OpBuilder &builder, Location loc, Value input,
2393
+ Value max, Value output, int64_t dim) {
2394
+ auto inputType = cast<ShapedType>(input.getType ());
2395
+ ArrayRef<int64_t > inputShape = inputType.getShape ();
2396
+ int64_t inputRank = inputShape.size ();
2397
+ auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps (
2398
+ builder, inputRank, dim, /* allParallel=*/ true );
2399
+ assert (indexingMaps.size () == 2 && " We should have one map for each input" );
2400
+ assert (indexingMaps[0 ].isIdentity () && " input map should be identity" );
2401
+ // Add the affine map for the output argument.
2402
+ indexingMaps.push_back (indexingMaps[0 ]);
2403
+ auto genericOp = builder.create <linalg::GenericOp>(
2404
+ loc, input.getType (), ValueRange{input, max}, output, indexingMaps,
2405
+ iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2406
+ Value diff = b.create <arith::SubFOp>(loc, args[0 ], args[1 ]);
2407
+ Value result = b.create <math::ExpOp>(loc, diff);
2408
+ b.create <linalg::YieldOp>(loc, result);
2409
+ });
2410
+ return genericOp.getResult (0 );
2411
+ }
2412
+
2413
+ // / Produce a linalg generic that computes the final step of the softmax
2414
+ // / decomposition.
2415
+ // / \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2416
+ // / yield n / d
2417
+ // / }
2418
+ static Value buildDivOp (OpBuilder &builder, Location loc, Value numerator,
2419
+ Value denominator, Value output, int64_t dim) {
2420
+ auto inputType = cast<ShapedType>(numerator.getType ());
2421
+ ArrayRef<int64_t > inputShape = inputType.getShape ();
2422
+ int64_t inputRank = inputShape.size ();
2423
+ auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps (
2424
+ builder, inputRank, dim, /* allParallel=*/ true );
2425
+ assert (indexingMaps.size () == 2 &&
2426
+ " We should have one map for each input (2)" );
2427
+ assert (indexingMaps[0 ].isIdentity () && " Numerator map should be identity" );
2428
+ // Add the affine map for the output tensor.
2429
+ indexingMaps.push_back (indexingMaps[0 ]);
2430
+ auto genericOp = builder.create <linalg::GenericOp>(
2431
+ loc, numerator.getType (), ValueRange{numerator, denominator}, output,
2432
+ indexingMaps, iteratorTypes,
2433
+ [&](OpBuilder &b, Location loc, ValueRange args) {
2434
+ Value result = b.create <arith::DivFOp>(loc, args[0 ], args[1 ]);
2435
+ b.create <linalg::YieldOp>(loc, result);
2436
+ });
2437
+ return genericOp.getResult (0 );
2438
+ }
2439
+ // @} End helper functions for softmax decomposition.
2440
+
2441
+ // / Given an N-dimensional tensor x, this method converts
2442
+ // / softmax(x) to the following sequence of operations:
2443
+ // /
2444
+ // / 1. Compute the max of x along dimension d. This results
2445
+ // / in a N-1 dimensional tensor m.
2446
+ // / m = max(x, dim = d)
2447
+ // /
2448
+ // / 2. Subtract a broadcasted m from x and exponentiate. This results in
2449
+ // / a N dimensional tensor z.
2450
+ // / z = exp(x - m)
2451
+ // /
2452
+ // / 3. Compute the sum of z along dimension d. This results in
2453
+ // / a N-1 dimensional tensor l.
2454
+ // / l = sum(z, dim = d)
2455
+ // /
2456
+ // / 4. Divide z and l. This gives the N-dimensional softmax.
2457
+ // / softmax = z / l
2458
+ // /
2459
+ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation (OpBuilder &b) {
2460
+ OpBuilder::InsertionGuard guard (b);
2461
+ b.setInsertionPoint (*this );
2462
+ Location loc = getLoc ();
2463
+ Value input = getInput ();
2464
+ ShapedType inputType = getInputOperandType ();
2465
+ Type elementType = inputType.getElementType ();
2466
+ int64_t reductionDim = getDimension ();
2467
+ SmallVector<OpFoldResult> dims = tensor::getMixedSizes (b, loc, input);
2468
+ Value outputNd = b.create <tensor::EmptyOp>(loc, dims, elementType);
2469
+ dims.erase (dims.begin () + reductionDim);
2470
+ // Step 1: Compute max along dim.
2471
+ Value output = b.create <tensor::EmptyOp>(loc, dims, elementType);
2472
+ Value neutralForMaxF =
2473
+ arith::getIdentityValue (arith::AtomicRMWKind::maxf, elementType, b, loc);
2474
+ Value neutralForMaxFInit =
2475
+ b.create <linalg::FillOp>(loc, Value{neutralForMaxF}, output).result ();
2476
+ Value max =
2477
+ reduce<arith::MaxFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2478
+
2479
+ // Step 2: Subtract max from input and exponentiate.
2480
+ Value numerator =
2481
+ buildSubAndExpOp (b, loc, input, max, outputNd, reductionDim);
2482
+
2483
+ // Step 3: Compute sum along dim.
2484
+ Value zero =
2485
+ arith::getIdentityValue (arith::AtomicRMWKind::addf, elementType, b, loc);
2486
+ Value zeroInit = b.create <linalg::FillOp>(loc, Value{zero}, output).result ();
2487
+ Value denominator =
2488
+ reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2489
+
2490
+ // Step 4: Compute softmax.
2491
+ Value result =
2492
+ buildDivOp (b, loc, numerator, denominator, outputNd, reductionDim);
2493
+ return SmallVector<Value>{result};
2494
+ }
2495
+
2326
2496
// ===----------------------------------------------------------------------===//
2327
2497
// LinalgDialect
2328
2498
// ===----------------------------------------------------------------------===//
0 commit comments