-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][XeGPU] Xegpu distribution patterns for load_nd, store_nd, and create_nd_tdesc. #112945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…rp_execute_on_lane_0's return values
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-core Author: Petr Kurapov (kurapov-peter) ChangesThis PR introduces distribution patterns for a portion of xegpu dialect similarly to the vector dialect, as well as moving some of the common functionality to the vector utilities. Xegpu ops rewrite patterns distribute the vector and xegpu tensor descriptor types when sinked through yield op of a Patch is 43.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112945.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 5f32aca88a2734..6bd924307376dc 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -20,6 +20,8 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <utility>
+
namespace mlir {
// Forward declarations.
@@ -324,6 +326,24 @@ namespace matcher {
bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
} // namespace matcher
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+ const std::function<bool(Operation *)> &fn);
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes);
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes,
+ llvm::SmallVector<size_t> &indices);
+
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 63ea26df069372..fe5198d1ac6dba 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,6 +16,7 @@ namespace xegpu {
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+void populateXeGPUDistributePatterns(RewritePatternSet &patterns);
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..51d3691fd107ae 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
// If the types matches there is no distribution.
if (expanded == distributed)
return success();
- auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
- auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
+ auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
+ auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
if (!expandedVecType || !distributedVecType)
- return op->emitOpError("expected vector type for distributed operands.");
+ return op->emitOpError("expected shaped type for distributed operands.");
if (expandedVecType.getRank() != distributedVecType.getRank() ||
expandedVecType.getElementType() != distributedVecType.getElementType())
return op->emitOpError(
- "expected distributed vectors to have same rank and element type.");
+ "expected distributed types to have same rank and element type.");
SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
@@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
continue;
if (eDim % dDim != 0)
return op->emitOpError()
- << "expected expanded vector dimension #" << i << " (" << eDim
- << ") to be a multipler of the distributed vector dimension ("
+ << "expected expanded type dimension #" << i << " (" << eDim
+ << ") to be a multipler of the distributed type dimension ("
<< dDim << ")";
scales[i] = eDim / dDim;
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..c80c3179b5e025 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -160,68 +161,6 @@ struct DistributedLoadStoreHelper {
} // namespace
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
- RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
- ValueRange newYieldedValues, TypeRange newReturnTypes) {
- // Create a new op before the existing one, with the extra operands.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(warpOp);
- auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
- warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
- warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
- Region &opBody = warpOp.getBodyRegion();
- Region &newOpBody = newWarpOp.getBodyRegion();
- Block &newOpFirstBlock = newOpBody.front();
- rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
- rewriter.eraseBlock(&newOpFirstBlock);
- assert(newWarpOp.getWarpRegion().hasOneBlock() &&
- "expected WarpOp with single block");
-
- auto yield =
- cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
- rewriter.modifyOpInPlace(
- yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
- return newWarpOp;
-}
-
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
- RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
- ValueRange newYieldedValues, TypeRange newReturnTypes,
- llvm::SmallVector<size_t> &indices) {
- SmallVector<Type> types(warpOp.getResultTypes().begin(),
- warpOp.getResultTypes().end());
- auto yield = cast<vector::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
- yield.getOperands().end());
- for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
- if (yieldValues.insert(std::get<0>(newRet))) {
- types.push_back(std::get<1>(newRet));
- indices.push_back(yieldValues.size() - 1);
- } else {
- // If the value already exit the region don't create a new output.
- for (auto [idx, yieldOperand] :
- llvm::enumerate(yieldValues.getArrayRef())) {
- if (yieldOperand == std::get<0>(newRet)) {
- indices.push_back(idx);
- break;
- }
- }
- }
- }
- yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldValues.getArrayRef(), types);
- rewriter.replaceOp(warpOp,
- newWarpOp.getResults().take_front(warpOp.getNumResults()));
- return newWarpOp;
-}
-
/// Helper to know if an op can be hoisted out of the region.
static bool canBeHoisted(Operation *op,
function_ref<bool(Value)> definedOutside) {
@@ -229,23 +168,6 @@ static bool canBeHoisted(Operation *op,
isMemoryEffectFree(op) && op->getNumRegions() == 0;
}
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
- const std::function<bool(Operation *)> &fn) {
- auto yield = cast<vector::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- for (OpOperand &yieldOperand : yield->getOpOperands()) {
- Value yieldValues = yieldOperand.get();
- Operation *definedOp = yieldValues.getDefiningOp();
- if (definedOp && fn(definedOp)) {
- if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
- return &yieldOperand;
- }
- }
- return {};
-}
-
// Clones `op` into a new operation that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
index fa3971695d4bf2..9db0c172fec5ce 100644
--- a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVectorUtils
VectorUtils.cpp
+ VectorDistributeUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Utils
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
new file mode 100644
index 00000000000000..f41581c6d47f2a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
@@ -0,0 +1,96 @@
+//===- VectorDistributeUtils.cpp - MLIR Utilities VectorOps distribution -===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the Vector dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+
+using namespace mlir;
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+mlir::OpOperand *
+mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+ const std::function<bool(Operation *)> &fn) {
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) {
+ Value yieldValues = yieldOperand.get();
+ Operation *definedOp = yieldValues.getDefiningOp();
+ if (definedOp && fn(definedOp)) {
+ if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+ return &yieldOperand;
+ }
+ }
+ return {};
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes) {
+ // Create a new op before the existing one, with the extra operands.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(warpOp);
+ auto newWarpOp = rewriter.create<vector::WarpExecuteOnLane0Op>(
+ warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+ warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+ Region &opBody = warpOp.getBodyRegion();
+ Region &newOpBody = newWarpOp.getBodyRegion();
+ Block &newOpFirstBlock = newOpBody.front();
+ rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+ rewriter.eraseBlock(&newOpFirstBlock);
+ assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+ "expected WarpOp with single block");
+
+ auto yield =
+ cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+ rewriter.modifyOpInPlace(
+ yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
+ return newWarpOp;
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndAppendReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes,
+ llvm::SmallVector<size_t> &indices) {
+ SmallVector<Type> types(warpOp.getResultTypes().begin(),
+ warpOp.getResultTypes().end());
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+ yield.getOperands().end());
+ for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+ if (yieldValues.insert(std::get<0>(newRet))) {
+ types.push_back(std::get<1>(newRet));
+ indices.push_back(yieldValues.size() - 1);
+ } else {
+ // If the value already exit the region don't create a new output.
+ for (auto [idx, yieldOperand] :
+ llvm::enumerate(yieldValues.getArrayRef())) {
+ if (yieldOperand == std::get<0>(newRet)) {
+ indices.push_back(idx);
+ break;
+ }
+ }
+ }
+ }
+ yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+ vector::WarpExecuteOnLane0Op newWarpOp =
+ moveRegionToNewWarpOpAndReplaceReturns(rewriter, warpOp,
+ yieldValues.getArrayRef(), types);
+ rewriter.replaceOp(warpOp,
+ newWarpOp.getResults().take_front(warpOp.getNumResults()));
+ return newWarpOp;
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7fb64d3b97b87d..148ff46ba41b72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUFoldAliasOps.cpp
+ XeGPUDistribute.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
@@ -12,6 +13,10 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
MLIRIR
MLIRMemRefDialect
MLIRXeGPUDialect
+ MLIRVectorDialect
+ MLIRVectorUtils
+ MLIRArithDialect
+ MLIRFuncDialect
MLIRPass
MLIRTransforms
)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
new file mode 100644
index 00000000000000..78a010ff1c941b
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -0,0 +1,393 @@
+//===- XeGPUDistribute.cpp - XeGPU ditribute ops to work items --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "xegpu-distribute"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
+
+/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the create_nd_tdesc's arguments.
+/// The rewrite will create a subview of the size used by a single work item and
+/// appropriate offset. The distributed create_nd_tdesc points into the subview
+/// without offset. The tensor descriptor types is distributed according to
+/// sg_map attribute.
+///
+/// Example:
+///
+/// ```
+/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+/// %r = vector.warp_execute_on_lane_0(%laneid) ->
+/// (!xegpu.tensor_desc<4x8xf32>) {
+/// ...
+/// %td = xegpu.create_nd_tdesc %arg0[0, 0]
+/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+/// vector.yield %td
+/// }
+/// ```
+/// To
+/// ```
+/// %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
+/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+/// vector.yield %arg0, %dead
+/// }
+/// %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1]
+/// : memref<4x8xf32> to memref<4x1xf32>
+/// %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32>
+/// -> !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpTensorDescOp final
+ : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+ using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Sink a store_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
+/// through the warp op interface they would be propagated as returned values.
+/// Both the stored vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+/// vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
+/// !xegpu.tensor_desc<4x8xf32>
+/// vector.yield
+/// }
+/// ```
+/// To
+/// ```
+/// %r = vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// vector.yield
+/// }
+/// xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpStoreNd final
+ : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+ using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Clone a load_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the load's arguments.
+/// Both the loaded vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+/// %r = vector.warp_execute_on_lane_0(%laneid) ->
+/// (!xegpu.tensor_desc<4x8xf32>) {
+/// ...
+/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>,
+/// vector<4x8xf32> vector.yield %ld
+/// }
+/// ```
+/// To
+/// ```
+/// %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// %dead = xegpu.load_nd %arg0, %arg1:
+/// !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32>
+/// vector.yield %arg0, %arg1
+/// }
+/// xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpLoadNd final
+ : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+ using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override;
+};
+
+FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
+ ...
[truncated]
|
/// vector.yield %arg0, %dead | ||
/// } | ||
/// %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1] | ||
/// : memref<4x8xf32> to memref<4x1xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why need a subview here? One of major purpose of create_nd_tdesc is to carry down the information of memref to the low level. Doesn't SIMT version intrinsic need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a solution to type consistency problem I encountered on the way. When distributing a type without the view you end up with unequal sizes for the descriptor and the result vector type. This breaks the validation of the op. A subview is a natural way of resolving it I think since it is exactly what we are doing here - creating a subview for a single lane.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The XeGPU part looks good to me. I would suggest someone to look at the changes to the Vector dialect.
There was no feedback for the draft, so I assume the change to vector is acceptable? Could you please confirm @ThomasRaoux, @nicolasvasilache, @matthias-springer? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I read the code correctly, most of vector dialect changes are about moving some utils to VectorDistribute.cpp
. This part looks okay to me, just a nit about function comments.
The other part that I don't follow is the change in VectorOps.cpp
. Can you elaborate a little more about why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps I missed something.. I don't see any changes related to verifyDistributedType
. Why do we need the change?
/// Return a value yielded by `warpOp` which statifies the filter lamdba | ||
/// condition and is not dead. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function comment is already added to *.h
files, and we don't need to dup it here. Can you remove it (and other function comments in this file)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, done
Yup, most of it is trivial code move.
This is to aid with xegpu ops sinking through the %res = vector.warp_execute_on_lane_0(%laneid) {
...
%desc = xegpu.create_nd_tdesc %somesrc[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
vector.yield %desc: !xegpu.tensor_desc<24x32xf16> // original type for the whole subgroup, not distributed
}
xegpu.someop %res : !xegpu.tensor_desc<24x2xf16> // the type is distributed A tensor descriptor is a result of the |
I see, thanks for the detail! I'm okay with the change; I think we also need to update the documentation?
|
I'm not very familiar with the op, so I'm passing the review to Mahesh.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a big PR touching a lot of core parts of MLIR. Blocking till I can review it properly.
Drive-by suggestion: The vector dialect is shared by many pieces, and it is generally easier on everyone to land such changes separately (if I scanned this right, the vector changes would even qualify for an NFC, which helps on review scope/latency even further). This doesn't just help review latency, but it also helps integrators because in case of issues, we tend to not expect that patches labeled for a specific target (XeGPU) can affect anything else. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you split out the vector dialect changes?
|
||
/// Return a value yielded by `warpOp` which statifies the filter lamdba | ||
/// condition and is not dead. | ||
OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp, | ||
const std::function<bool(Operation *)> &fn); | ||
|
||
/// Helper to create a new WarpExecuteOnLane0Op with different signature. | ||
vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( | ||
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp, | ||
ValueRange newYieldedValues, TypeRange newReturnTypes); | ||
|
||
/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. | ||
/// `indices` return the index of each new output. | ||
vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( | ||
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp, | ||
ValueRange newYieldedValues, TypeRange newReturnTypes, | ||
llvm::SmallVector<size_t> &indices); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These utilities should be in a seperate header file VectorDistributeUtils.h, also they shouldn't be in mlir namespace? These are utilities related to a specific transformation. I'm not sure if exposing them to the entire namespace is a good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in #114208
auto expandedVecType = llvm::dyn_cast<VectorType>(expanded); | ||
auto distributedVecType = llvm::dyn_cast<VectorType>(distributed); | ||
auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded); | ||
auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed); | ||
if (!expandedVecType || !distributedVecType) | ||
return op->emitOpError("expected vector type for distributed operands."); | ||
return op->emitOpError("expected shaped type for distributed operands."); | ||
if (expandedVecType.getRank() != distributedVecType.getRank() || | ||
expandedVecType.getElementType() != distributedVecType.getElementType()) | ||
return op->emitOpError( | ||
"expected distributed vectors to have same rank and element type."); | ||
"expected distributed types to have same rank and element type."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you split out these vector op changes into a seperate patch? These are unrelated to XeGPU and are seperate from this patch. We also need to update documentation for these ops if we are doing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, #114215
I dont think relaxing a |
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed, | |||
// If the types matches there is no distribution. | |||
if (expanded == distributed) | |||
return success(); | |||
auto expandedVecType = llvm::dyn_cast<VectorType>(expanded); | |||
auto distributedVecType = llvm::dyn_cast<VectorType>(distributed); | |||
auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to generalize distribution to work on more than vector types, this whole thing needs to be moved out of vector dialect and made an interface. Checking for just ShapedType
here seems like a violation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with Mahesh about introducing a DistributionTypeInterface
if that helps generalizing and decoupling things. I would also suggest that we introduce a distribution
dialect that can hold the operations related to distribution, such as warp_execute_on_lane0
.
Thanks all! I'll split the patch into 3 portions: an NFC for vector, the vector op change, and xegpu part. |
I'll keep and repurpose this PR for the remaining xegpu changes. |
This PR introduces distribution patterns for a portion of xegpu dialect similarly to the vector dialect, as well as moving some of the common functionality to the vector utilities.
Xegpu ops rewrite patterns distribute the vector and xegpu tensor descriptor types when sinked through yield op of a
vector.warp_execute_on_lane_0
according to thexegpu.sg_map
attribute. The validation of distributed types in thewarp_execute_on_lane_0
was hence relaxed to allowShapedType
return values to have distributed shapes.