Skip to content

[MLIR][XeGPU] Xegpu distribution patterns for load_nd, store_nd, and create_nd_tdesc. #112945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

kurapov-peter
Copy link
Contributor

This PR introduces distribution patterns for a portion of xegpu dialect similarly to the vector dialect, as well as moving some of the common functionality to the vector utilities.

Xegpu ops rewrite patterns distribute the vector and xegpu tensor descriptor types when sinked through yield op of a vector.warp_execute_on_lane_0 according to the xegpu.sg_map attribute. The validation of distributed types in the warp_execute_on_lane_0 was hence relaxed to allow ShapedType return values to have distributed shapes.

@llvmbot
Copy link
Member

llvmbot commented Oct 18, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-core

Author: Petr Kurapov (kurapov-peter)

Changes

This PR introduces distribution patterns for a portion of xegpu dialect similarly to the vector dialect, as well as moving some of the common functionality to the vector utilities.

Xegpu ops rewrite patterns distribute the vector and xegpu tensor descriptor types when sinked through yield op of a vector.warp_execute_on_lane_0 according to the xegpu.sg_map attribute. The validation of distributed types in the warp_execute_on_lane_0 was hence relaxed to allow ShapedType return values to have distributed shapes.


Patch is 43.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112945.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+20)
  • (modified) mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+1-79)
  • (modified) mlir/lib/Dialect/Vector/Utils/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp (+96)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt (+5)
  • (added) mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp (+393)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+3-3)
  • (added) mlir/test/Dialect/XeGPU/xegpu-distribute.mlir (+81)
  • (modified) mlir/test/lib/Dialect/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Dialect/XeGPU/CMakeLists.txt (+17)
  • (added) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+58)
  • (modified) mlir/tools/mlir-opt/CMakeLists.txt (+1)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 5f32aca88a2734..6bd924307376dc 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -20,6 +20,8 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <utility>
+
 namespace mlir {
 
 // Forward declarations.
@@ -324,6 +326,24 @@ namespace matcher {
 bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
 
 } // namespace matcher
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+                         const std::function<bool(Operation *)> &fn);
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes);
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 63ea26df069372..fe5198d1ac6dba 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,6 +16,7 @@ namespace xegpu {
 
 /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+void populateXeGPUDistributePatterns(RewritePatternSet &patterns);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..51d3691fd107ae 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
   // If the types matches there is no distribution.
   if (expanded == distributed)
     return success();
-  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
-  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
+  auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
+  auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
   if (!expandedVecType || !distributedVecType)
-    return op->emitOpError("expected vector type for distributed operands.");
+    return op->emitOpError("expected shaped type for distributed operands.");
   if (expandedVecType.getRank() != distributedVecType.getRank() ||
       expandedVecType.getElementType() != distributedVecType.getElementType())
     return op->emitOpError(
-        "expected distributed vectors to have same rank and element type.");
+        "expected distributed types to have same rank and element type.");
 
   SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
   for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
@@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
       continue;
     if (eDim % dDim != 0)
       return op->emitOpError()
-             << "expected expanded vector dimension #" << i << " (" << eDim
-             << ") to be a multipler of the distributed vector dimension ("
+             << "expected expanded type dimension #" << i << " (" << eDim
+             << ") to be a multipler of the distributed type dimension ("
              << dDim << ")";
     scales[i] = eDim / dDim;
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..c80c3179b5e025 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -160,68 +161,6 @@ struct DistributedLoadStoreHelper {
 
 } // namespace
 
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes) {
-  // Create a new op before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(warpOp);
-  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
-      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
-      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
-  Region &opBody = warpOp.getBodyRegion();
-  Region &newOpBody = newWarpOp.getBodyRegion();
-  Block &newOpFirstBlock = newOpBody.front();
-  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
-  rewriter.eraseBlock(&newOpFirstBlock);
-  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
-         "expected WarpOp with single block");
-
-  auto yield =
-      cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
-  rewriter.modifyOpInPlace(
-      yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
-  return newWarpOp;
-}
-
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes,
-    llvm::SmallVector<size_t> &indices) {
-  SmallVector<Type> types(warpOp.getResultTypes().begin(),
-                          warpOp.getResultTypes().end());
-  auto yield = cast<vector::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
-                                              yield.getOperands().end());
-  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
-    if (yieldValues.insert(std::get<0>(newRet))) {
-      types.push_back(std::get<1>(newRet));
-      indices.push_back(yieldValues.size() - 1);
-    } else {
-      // If the value already exit the region don't create a new output.
-      for (auto [idx, yieldOperand] :
-           llvm::enumerate(yieldValues.getArrayRef())) {
-        if (yieldOperand == std::get<0>(newRet)) {
-          indices.push_back(idx);
-          break;
-        }
-      }
-    }
-  }
-  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
-  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-      rewriter, warpOp, yieldValues.getArrayRef(), types);
-  rewriter.replaceOp(warpOp,
-                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
-  return newWarpOp;
-}
-
 /// Helper to know if an op can be hoisted out of the region.
 static bool canBeHoisted(Operation *op,
                          function_ref<bool(Value)> definedOutside) {
@@ -229,23 +168,6 @@ static bool canBeHoisted(Operation *op,
          isMemoryEffectFree(op) && op->getNumRegions() == 0;
 }
 
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
-                                const std::function<bool(Operation *)> &fn) {
-  auto yield = cast<vector::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  for (OpOperand &yieldOperand : yield->getOpOperands()) {
-    Value yieldValues = yieldOperand.get();
-    Operation *definedOp = yieldValues.getDefiningOp();
-    if (definedOp && fn(definedOp)) {
-      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
-        return &yieldOperand;
-    }
-  }
-  return {};
-}
-
 // Clones `op` into a new operation that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
index fa3971695d4bf2..9db0c172fec5ce 100644
--- a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVectorUtils
   VectorUtils.cpp
+  VectorDistributeUtils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Utils
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
new file mode 100644
index 00000000000000..f41581c6d47f2a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
@@ -0,0 +1,96 @@
+//===- VectorDistributeUtils.cpp - MLIR Utilities VectorOps distribution  -===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the Vector dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+
+using namespace mlir;
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+mlir::OpOperand *
+mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+                    const std::function<bool(Operation *)> &fn) {
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) {
+    Value yieldValues = yieldOperand.get();
+    Operation *definedOp = yieldValues.getDefiningOp();
+    if (definedOp && fn(definedOp)) {
+      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        return &yieldOperand;
+    }
+  }
+  return {};
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  // Create a new op before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(warpOp);
+  auto newWarpOp = rewriter.create<vector::WarpExecuteOnLane0Op>(
+      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+  Region &opBody = warpOp.getBodyRegion();
+  Region &newOpBody = newWarpOp.getBodyRegion();
+  Block &newOpFirstBlock = newOpBody.front();
+  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  rewriter.eraseBlock(&newOpFirstBlock);
+  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+         "expected WarpOp with single block");
+
+  auto yield =
+      cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+  rewriter.modifyOpInPlace(
+      yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
+  return newWarpOp;
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices) {
+  SmallVector<Type> types(warpOp.getResultTypes().begin(),
+                          warpOp.getResultTypes().end());
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+                                              yield.getOperands().end());
+  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+    if (yieldValues.insert(std::get<0>(newRet))) {
+      types.push_back(std::get<1>(newRet));
+      indices.push_back(yieldValues.size() - 1);
+    } else {
+      // If the value already exit the region don't create a new output.
+      for (auto [idx, yieldOperand] :
+           llvm::enumerate(yieldValues.getArrayRef())) {
+        if (yieldOperand == std::get<0>(newRet)) {
+          indices.push_back(idx);
+          break;
+        }
+      }
+    }
+  }
+  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+  vector::WarpExecuteOnLane0Op newWarpOp =
+      moveRegionToNewWarpOpAndReplaceReturns(rewriter, warpOp,
+                                             yieldValues.getArrayRef(), types);
+  rewriter.replaceOp(warpOp,
+                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
+  return newWarpOp;
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7fb64d3b97b87d..148ff46ba41b72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
+  XeGPUDistribute.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
@@ -12,6 +13,10 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   MLIRIR
   MLIRMemRefDialect
   MLIRXeGPUDialect
+  MLIRVectorDialect
+  MLIRVectorUtils
+  MLIRArithDialect
+  MLIRFuncDialect
   MLIRPass
   MLIRTransforms
 )
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
new file mode 100644
index 00000000000000..78a010ff1c941b
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -0,0 +1,393 @@
+//===- XeGPUDistribute.cpp - XeGPU ditribute ops to work items --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "xegpu-distribute"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
+
+/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the create_nd_tdesc's arguments.
+/// The rewrite will create a subview of the size used by a single work item and
+/// appropriate offset. The distributed create_nd_tdesc points into the subview
+/// without offset. The tensor descriptor types is distributed according to
+/// sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   %r = vector.warp_execute_on_lane_0(%laneid) ->
+///                   (!xegpu.tensor_desc<4x8xf32>) {
+///     ...
+///     %td = xegpu.create_nd_tdesc %arg0[0, 0]
+///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+///     vector.yield %td
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     %dead = xegpu.create_nd_tdesc %arg0[0, 0]
+///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+///     vector.yield %arg0, %dead
+///   }
+///   %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1]
+///                               : memref<4x8xf32> to memref<4x1xf32>
+///   %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32>
+///                                 -> !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpTensorDescOp final
+    : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Sink a store_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
+/// through the warp op interface they would be propagated as returned values.
+/// Both the stored vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
+///                                 !xegpu.tensor_desc<4x8xf32>
+///     vector.yield
+///   }
+/// ```
+/// To
+/// ```
+///   %r = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     vector.yield
+///   }
+///   xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpStoreNd final
+    : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Clone a load_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the load's arguments.
+/// Both the loaded vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   %r = vector.warp_execute_on_lane_0(%laneid) ->
+///                   (!xegpu.tensor_desc<4x8xf32>) {
+///     ...
+///     %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>,
+///     vector<4x8xf32> vector.yield %ld
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     %dead = xegpu.load_nd %arg0, %arg1:
+///         !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32>
+///     vector.yield %arg0, %arg1
+///   }
+///   xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpLoadNd final
+    : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
+     ...
[truncated]

@kurapov-peter kurapov-peter linked an issue Oct 22, 2024 that may be closed by this pull request
4 tasks
/// vector.yield %arg0, %dead
/// }
/// %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1]
/// : memref<4x8xf32> to memref<4x1xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need a subview here? One of major purpose of create_nd_tdesc is to carry down the information of memref to the low level. Doesn't SIMT version intrinsic need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a solution to type consistency problem I encountered on the way. When distributing a type without the view you end up with unequal sizes for the descriptor and the result vector type. This breaks the validation of the op. A subview is a natural way of resolving it I think since it is exactly what we are doing here - creating a subview for a single lane.

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The XeGPU part looks good to me. I would suggest someone to look at the changes to the Vector dialect.

@kurapov-peter
Copy link
Contributor Author

There was no feedback for the draft, so I assume the change to vector is acceptable? Could you please confirm @ThomasRaoux, @nicolasvasilache, @matthias-springer?

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I read the code correctly, most of vector dialect changes are about moving some utils to VectorDistribute.cpp. This part looks okay to me, just a nit about function comments.

The other part that I don't follow is the change in VectorOps.cpp. Can you elaborate a little more about why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps I missed something.. I don't see any changes related to verifyDistributedType. Why do we need the change?

Comment on lines 17 to 18
/// Return a value yielded by `warpOp` which statifies the filter lamdba
/// condition and is not dead.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function comment is already added to *.h files, and we don't need to dup it here. Can you remove it (and other function comments in this file)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, done

@kurapov-peter
Copy link
Contributor Author

If I read the code correctly, most of vector dialect changes are about moving some utils to VectorDistribute.cpp. This part looks okay to me, just a nit about function comments.

Yup, most of it is trivial code move.

The other part that I don't follow is the change in VectorOps.cpp. Can you elaborate a little more about why?

This is to aid with xegpu ops sinking through the yield op of the warp_execute_on_lane_0. Ops in xegpu take xegpu.tensor_desc as an argument. This type roughly describes a physical memory tile to be processed by a subgroup. So, to move to SIMT, each logical thread should own some portion of this tile (just like with vectors). Hence, I'm distributing the xegpu.tensor_desc type in a similar way to vector. Here's a pseudo IR example:

%res = vector.warp_execute_on_lane_0(%laneid) {
  ...
  %desc = xegpu.create_nd_tdesc %somesrc[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
  vector.yield %desc: !xegpu.tensor_desc<24x32xf16> // original type for the whole subgroup, not distributed
}

xegpu.someop %res : !xegpu.tensor_desc<24x2xf16> // the type is distributed

A tensor descriptor is a result of the warp_execute_on_lane_0 has yielded type for the whole tile, but outside of the region it should be distributed. Currently, this is not allowed since the validation checks the type against vector and fails. So, I allow a shaped type instead.

@hanhanW
Copy link
Contributor

hanhanW commented Oct 29, 2024

If I read the code correctly, most of vector dialect changes are about moving some utils to VectorDistribute.cpp. This part looks okay to me, just a nit about function comments.

Yup, most of it is trivial code move.

The other part that I don't follow is the change in VectorOps.cpp. Can you elaborate a little more about why?

This is to aid with xegpu ops sinking through the yield op of the warp_execute_on_lane_0. Ops in xegpu take xegpu.tensor_desc as an argument. This type roughly describes a physical memory tile to be processed by a subgroup. So, to move to SIMT, each logical thread should own some portion of this tile (just like with vectors). Hence, I'm distributing the xegpu.tensor_desc type in a similar way to vector. Here's a pseudo IR example:

%res = vector.warp_execute_on_lane_0(%laneid) {
  ...
  %desc = xegpu.create_nd_tdesc %somesrc[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
  vector.yield %desc: !xegpu.tensor_desc<24x32xf16> // original type for the whole subgroup, not distributed
}

xegpu.someop %res : !xegpu.tensor_desc<24x2xf16> // the type is distributed

A tensor descriptor is a result of the warp_execute_on_lane_0 has yielded type for the whole tile, but outside of the region it should be distributed. Currently, this is not allowed since the validation checks the type against vector and fails. So, I allow a shaped type instead.

I see, thanks for the detail! I'm okay with the change; I think we also need to update the documentation?

Operands are vector values distributed on all lanes that may be used by

@MaheshRavishankar MaheshRavishankar self-requested a review October 29, 2024 17:49
@hanhanW hanhanW dismissed their stale review October 29, 2024 17:51

I'm not very familiar with the op, so I'm passing the review to Mahesh.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a big PR touching a lot of core parts of MLIR. Blocking till I can review it properly.

@stellaraccident
Copy link
Contributor

Drive-by suggestion: The vector dialect is shared by many pieces, and it is generally easier on everyone to land such changes separately (if I scanned this right, the vector changes would even qualify for an NFC, which helps on review scope/latency even further). This doesn't just help review latency, but it also helps integrators because in case of issues, we tend to not expect that patches labeled for a specific target (XeGPU) can affect anything else.

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you split out the vector dialect changes?

Comment on lines +329 to +346

/// Return a value yielded by `warpOp` which statifies the filter lamdba
/// condition and is not dead.
OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
const std::function<bool(Operation *)> &fn);

/// Helper to create a new WarpExecuteOnLane0Op with different signature.
vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes);

/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
/// `indices` return the index of each new output.
vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes,
llvm::SmallVector<size_t> &indices);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These utilities should be in a seperate header file VectorDistributeUtils.h, also they shouldn't be in mlir namespace? These are utilities related to a specific transformation. I'm not sure if exposing them to the entire namespace is a good idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in #114208

Comment on lines -6561 to +6568
auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
if (!expandedVecType || !distributedVecType)
return op->emitOpError("expected vector type for distributed operands.");
return op->emitOpError("expected shaped type for distributed operands.");
if (expandedVecType.getRank() != distributedVecType.getRank() ||
expandedVecType.getElementType() != distributedVecType.getElementType())
return op->emitOpError(
"expected distributed vectors to have same rank and element type.");
"expected distributed types to have same rank and element type.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you split out these vector op changes into a seperate patch? These are unrelated to XeGPU and are seperate from this patch. We also need to update documentation for these ops if we are doing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, #114215

@MaheshRavishankar
Copy link
Contributor

This PR introduces distribution patterns for a portion of xegpu dialect similarly to the vector dialect, as well as moving some of the common functionality to the vector utilities.

Xegpu ops rewrite patterns distribute the vector and xegpu tensor descriptor types when sinked through yield op of a vector.warp_execute_on_lane_0 according to the xegpu.sg_map attribute. The validation of distributed types in the warp_execute_on_lane_0 was hence relaxed to allow ShapedType return values to have distributed shapes.

I dont think relaxing a vector verification to allow ShapedType makes sense. If you are really relying on vector distribution like thing, you probably need to implement an interface so that you can implement a different op with that interface. Thats just a suggestion. There are other suggestions here as well in terms of smaller PRs etc. that might help, but I think we need more information/better design to achieve what you are looking for.

@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
// If the types matches there is no distribution.
if (expanded == distributed)
return success();
auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to generalize distribution to work on more than vector types, this whole thing needs to be moved out of vector dialect and made an interface. Checking for just ShapedType here seems like a violation.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Mahesh about introducing a DistributionTypeInterface if that helps generalizing and decoupling things. I would also suggest that we introduce a distribution dialect that can hold the operations related to distribution, such as warp_execute_on_lane0.

@kurapov-peter
Copy link
Contributor Author

Thanks all! I'll split the patch into 3 portions: an NFC for vector, the vector op change, and xegpu part.
Regarding distribution abstraction - I agree, this is the question I raised in the draft but got no feedback. A distribution interface alone does not add much, so I didn't create it right away. The op itself doesn't belong to vector and should be moved but that is a separate topic. I'll create an RFC for that.

@kurapov-peter
Copy link
Contributor Author

@kurapov-peter
Copy link
Contributor Author

I'll keep and repurpose this PR for the remaining xegpu changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

XeGPU simt lowering
8 participants