diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 5f32aca88a273..6bd924307376d 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -20,6 +20,8 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" +#include + namespace mlir { // Forward declarations. @@ -324,6 +326,24 @@ namespace matcher { bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType); } // namespace matcher + +/// Return a value yielded by `warpOp` which statifies the filter lamdba +/// condition and is not dead. +OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp, + const std::function &fn); + +/// Helper to create a new WarpExecuteOnLane0Op with different signature. +vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( + RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp, + ValueRange newYieldedValues, TypeRange newReturnTypes); + +/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. +/// `indices` return the index of each new output. +vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( + RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp, + ValueRange newYieldedValues, TypeRange newReturnTypes, + llvm::SmallVector &indices); + } // namespace mlir #endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_ diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h index 63ea26df06937..fe5198d1ac6db 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h @@ -16,6 +16,7 @@ namespace xegpu { /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`. void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns); +void populateXeGPUDistributePatterns(RewritePatternSet &patterns); } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a2abe1619454f..51d3691fd107a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed, // If the types matches there is no distribution. if (expanded == distributed) return success(); - auto expandedVecType = llvm::dyn_cast(expanded); - auto distributedVecType = llvm::dyn_cast(distributed); + auto expandedVecType = llvm::dyn_cast(expanded); + auto distributedVecType = llvm::dyn_cast(distributed); if (!expandedVecType || !distributedVecType) - return op->emitOpError("expected vector type for distributed operands."); + return op->emitOpError("expected shaped type for distributed operands."); if (expandedVecType.getRank() != distributedVecType.getRank() || expandedVecType.getElementType() != distributedVecType.getElementType()) return op->emitOpError( - "expected distributed vectors to have same rank and element type."); + "expected distributed types to have same rank and element type."); SmallVector scales(expandedVecType.getRank(), 1); for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) { @@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed, continue; if (eDim % dDim != 0) return op->emitOpError() - << "expected expanded vector dimension #" << i << " (" << eDim - << ") to be a multipler of the distributed vector dimension (" + << "expected expanded type dimension #" << i << " (" << eDim + << ") to be a multipler of the distributed type dimension (" << dDim << ")"; scales[i] = eDim / dDim; } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 2289fd1ff1364..c80c3179b5e02 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" @@ -160,68 +161,6 @@ struct DistributedLoadStoreHelper { } // namespace -/// Helper to create a new WarpExecuteOnLane0Op with different signature. -static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( - RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, - ValueRange newYieldedValues, TypeRange newReturnTypes) { - // Create a new op before the existing one, with the extra operands. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(warpOp); - auto newWarpOp = rewriter.create( - warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), - warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); - - Region &opBody = warpOp.getBodyRegion(); - Region &newOpBody = newWarpOp.getBodyRegion(); - Block &newOpFirstBlock = newOpBody.front(); - rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); - rewriter.eraseBlock(&newOpFirstBlock); - assert(newWarpOp.getWarpRegion().hasOneBlock() && - "expected WarpOp with single block"); - - auto yield = - cast(newOpBody.getBlocks().begin()->getTerminator()); - - rewriter.modifyOpInPlace( - yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); }); - return newWarpOp; -} - -/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. -/// `indices` return the index of each new output. -static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( - RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, - ValueRange newYieldedValues, TypeRange newReturnTypes, - llvm::SmallVector &indices) { - SmallVector types(warpOp.getResultTypes().begin(), - warpOp.getResultTypes().end()); - auto yield = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - llvm::SmallSetVector yieldValues(yield.getOperands().begin(), - yield.getOperands().end()); - for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) { - if (yieldValues.insert(std::get<0>(newRet))) { - types.push_back(std::get<1>(newRet)); - indices.push_back(yieldValues.size() - 1); - } else { - // If the value already exit the region don't create a new output. - for (auto [idx, yieldOperand] : - llvm::enumerate(yieldValues.getArrayRef())) { - if (yieldOperand == std::get<0>(newRet)) { - indices.push_back(idx); - break; - } - } - } - } - yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end()); - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, yieldValues.getArrayRef(), types); - rewriter.replaceOp(warpOp, - newWarpOp.getResults().take_front(warpOp.getNumResults())); - return newWarpOp; -} - /// Helper to know if an op can be hoisted out of the region. static bool canBeHoisted(Operation *op, function_ref definedOutside) { @@ -229,23 +168,6 @@ static bool canBeHoisted(Operation *op, isMemoryEffectFree(op) && op->getNumRegions() == 0; } -/// Return a value yielded by `warpOp` which statifies the filter lamdba -/// condition and is not dead. -static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, - const std::function &fn) { - auto yield = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - for (OpOperand &yieldOperand : yield->getOpOperands()) { - Value yieldValues = yieldOperand.get(); - Operation *definedOp = yieldValues.getDefiningOp(); - if (definedOp && fn(definedOp)) { - if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) - return &yieldOperand; - } - } - return {}; -} - // Clones `op` into a new operation that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt index fa3971695d4bf..9db0c172fec5c 100644 --- a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRVectorUtils VectorUtils.cpp + VectorDistributeUtils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Utils diff --git a/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp new file mode 100644 index 0000000000000..91dac58bacf66 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp @@ -0,0 +1,91 @@ +//===- VectorDistributeUtils.cpp - MLIR Utilities VectorOps distribution -===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utility methods for working with the Vector dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" + +using namespace mlir; + +mlir::OpOperand * +mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp, + const std::function &fn) { + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) { + Value yieldValues = yieldOperand.get(); + Operation *definedOp = yieldValues.getDefiningOp(); + if (definedOp && fn(definedOp)) { + if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) + return &yieldOperand; + } + } + return {}; +} + +vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns( + RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp, + ValueRange newYieldedValues, TypeRange newReturnTypes) { + // Create a new op before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(warpOp); + auto newWarpOp = rewriter.create( + warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), + warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); + + Region &opBody = warpOp.getBodyRegion(); + Region &newOpBody = newWarpOp.getBodyRegion(); + Block &newOpFirstBlock = newOpBody.front(); + rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); + rewriter.eraseBlock(&newOpFirstBlock); + assert(newWarpOp.getWarpRegion().hasOneBlock() && + "expected WarpOp with single block"); + + auto yield = + cast(newOpBody.getBlocks().begin()->getTerminator()); + + rewriter.modifyOpInPlace( + yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); }); + return newWarpOp; +} + +vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndAppendReturns( + RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp, + ValueRange newYieldedValues, TypeRange newReturnTypes, + llvm::SmallVector &indices) { + SmallVector types(warpOp.getResultTypes().begin(), + warpOp.getResultTypes().end()); + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + llvm::SmallSetVector yieldValues(yield.getOperands().begin(), + yield.getOperands().end()); + for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) { + if (yieldValues.insert(std::get<0>(newRet))) { + types.push_back(std::get<1>(newRet)); + indices.push_back(yieldValues.size() - 1); + } else { + // If the value already exit the region don't create a new output. + for (auto [idx, yieldOperand] : + llvm::enumerate(yieldValues.getArrayRef())) { + if (yieldOperand == std::get<0>(newRet)) { + indices.push_back(idx); + break; + } + } + } + } + yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end()); + vector::WarpExecuteOnLane0Op newWarpOp = + moveRegionToNewWarpOpAndReplaceReturns(rewriter, warpOp, + yieldValues.getArrayRef(), types); + rewriter.replaceOp(warpOp, + newWarpOp.getResults().take_front(warpOp.getNumResults())); + return newWarpOp; +} diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index 7fb64d3b97b87..148ff46ba41b7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRXeGPUTransforms XeGPUFoldAliasOps.cpp + XeGPUDistribute.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU @@ -12,6 +13,10 @@ add_mlir_dialect_library(MLIRXeGPUTransforms MLIRIR MLIRMemRefDialect MLIRXeGPUDialect + MLIRVectorDialect + MLIRVectorUtils + MLIRArithDialect + MLIRFuncDialect MLIRPass MLIRTransforms ) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp new file mode 100644 index 0000000000000..78a010ff1c941 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp @@ -0,0 +1,393 @@ +//===- XeGPUDistribute.cpp - XeGPU ditribute ops to work items --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "xegpu-distribute" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +using namespace mlir; + +namespace { +bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); } + +/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing +/// `vector.warp_execute_on_lane_0` and put it after the warp op. +/// The warp op will still contain the original op that will not be used by the +/// yield op (and should be cleaned up later with dce). The yield op will bypass +/// the create_nd_tdesc's arguments. +/// The rewrite will create a subview of the size used by a single work item and +/// appropriate offset. The distributed create_nd_tdesc points into the subview +/// without offset. The tensor descriptor types is distributed according to +/// sg_map attribute. +/// +/// Example: +/// +/// ``` +/// #sg_map_8 = #xegpu.sg_map +/// %r = vector.warp_execute_on_lane_0(%laneid) -> +/// (!xegpu.tensor_desc<4x8xf32>) { +/// ... +/// %td = xegpu.create_nd_tdesc %arg0[0, 0] +/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32> +/// vector.yield %td +/// } +/// ``` +/// To +/// ``` +/// %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () { +/// ... +/// %dead = xegpu.create_nd_tdesc %arg0[0, 0] +/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32> +/// vector.yield %arg0, %dead +/// } +/// %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1] +/// : memref<4x8xf32> to memref<4x1xf32> +/// %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32> +/// -> !xegpu.tensor_desc<4x1xf32> +/// +/// ``` +struct WarpOpTensorDescOp final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override; +}; + +/// Sink a store_nd feeding into vector.yield op for the enclosing +/// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed +/// through the warp op interface they would be propagated as returned values. +/// Both the stored vector type and tensor descriptor types are distributed +/// according to sg_map attribute. +/// +/// Example: +/// +/// ``` +/// #sg_map_8 = #xegpu.sg_map +/// vector.warp_execute_on_lane_0(%laneid) -> () { +/// ... +/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>, +/// !xegpu.tensor_desc<4x8xf32> +/// vector.yield +/// } +/// ``` +/// To +/// ``` +/// %r = vector.warp_execute_on_lane_0(%laneid) -> () { +/// ... +/// vector.yield +/// } +/// xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32> +/// +/// ``` +struct WarpOpStoreNd final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override; +}; + +/// Clone a load_nd feeding into vector.yield op for the enclosing +/// `vector.warp_execute_on_lane_0` and put it after the warp op. +/// The warp op will still contain the original op that will not be used by the +/// yield op (and should be cleaned up later with dce). The yield op will bypass +/// the load's arguments. +/// Both the loaded vector type and tensor descriptor types are distributed +/// according to sg_map attribute. +/// +/// Example: +/// +/// ``` +/// #sg_map_8 = #xegpu.sg_map +/// %r = vector.warp_execute_on_lane_0(%laneid) -> +/// (!xegpu.tensor_desc<4x8xf32>) { +/// ... +/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>, +/// vector<4x8xf32> vector.yield %ld +/// } +/// ``` +/// To +/// ``` +/// %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () { +/// ... +/// %dead = xegpu.load_nd %arg0, %arg1: +/// !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32> +/// vector.yield %arg0, %arg1 +/// } +/// xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32> +/// +/// ``` +struct WarpOpLoadNd final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override; +}; + +FailureOr getDistributedVectorType(VectorType originalT, + xegpu::SGMapAttr sgMap) { + llvm::SmallVector distributedShape; + auto layout = sgMap.getWiLayout(); + auto shape = originalT.getShape(); + for (const auto [l, o] : llvm::zip_equal(layout, shape)) { + if (!divisible(APInt(64, o), APInt(64, l))) + return failure(); + distributedShape.push_back(o / l); + } + auto newVectorType = + VectorType::get(distributedShape, originalT.getElementType(), + originalT.getScalableDims()); + return newVectorType; +} + +FailureOr +getDistributedTensorDescType(xegpu::TensorDescType originalT, + xegpu::SGMapAttr sgMap, + xegpu::MemorySpace memSpace) { + llvm::SmallVector distributedShape; + auto layout = sgMap.getWiLayout(); + auto shape = originalT.getShape(); + for (const auto [l, o] : llvm::zip_equal(layout, shape)) { + if (!divisible(APInt(64, o), APInt(64, l))) + return failure(); + distributedShape.push_back(o / l); + } + xegpu::TensorDescType distributedDescType; + if (originalT.isScattered()) { + + distributedDescType = xegpu::TensorDescType::get( + distributedShape, originalT.getElementType(), originalT.getChunkSize(), + originalT.getMemorySpace(), originalT.getSGMapAttr()); + } else { + distributedDescType = xegpu::TensorDescType::get( + distributedShape, originalT.getElementType(), + originalT.getBoundaryCheck(), originalT.getArrayLength(), + originalT.getMemorySpace(), originalT.getSGMapAttr()); + } + return distributedDescType; +} +} // namespace + +LogicalResult +WarpOpStoreNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const { + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + Operation *lastNode = yield->getPrevNode(); + auto storeOp = dyn_cast_or_null(lastNode); + if (!storeOp) + return failure(); + + auto origType = storeOp.getTensorDescType(); + xegpu::SGMapAttr sgMap = origType.getSGMapAttr(); + if (!sgMap) + return rewriter.notifyMatchFailure( + storeOp, "the source tensor descriptor lacks sg_map attribute"); + + if (storeOp.getTensorDescType().getShape().size() != 2) + return rewriter.notifyMatchFailure(storeOp, "unsupported shape"); + DBGS() << "Matched store_nd: " << storeOp << "\n"; + + auto distributedTypeOrFailure = + getDistributedVectorType(storeOp.getValueType(), sgMap); + if (failed(distributedTypeOrFailure)) + return rewriter.notifyMatchFailure(storeOp, + "Failed to distribute the type"); + VectorType newVectorType = distributedTypeOrFailure.value(); + + auto distributedDescTypeOrFailure = getDistributedTensorDescType( + storeOp.getTensorDescType(), sgMap, + storeOp.getTensorDescType().getMemorySpace()); + if (failed(distributedDescTypeOrFailure)) + return rewriter.notifyMatchFailure(storeOp, + "Failed to distribute the desc type"); + xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value(); + + SmallVector newRetIndices; + vector::WarpExecuteOnLane0Op newWarpOp = + moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, + ValueRange{storeOp.getTensorDesc(), storeOp.getValue()}, + TypeRange{newTDescType, newVectorType}, newRetIndices); + + rewriter.setInsertionPointAfter(newWarpOp); + auto newStoreOp = + cast(rewriter.clone(*storeOp.getOperation())); + rewriter.eraseOp(storeOp); + newStoreOp.getTensorDescMutable().assign( + newWarpOp.getResult(newRetIndices[0])); + newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1])); + + return success(); +} + +LogicalResult WarpOpLoadNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const { + OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { + return isa(op) && op->hasOneUse(); + }); + + if (!operand) + return rewriter.notifyMatchFailure(warpOp, + "warp result is not a xegpu::LoadNd op"); + + auto loadOp = operand->get().getDefiningOp(); + + if (loadOp.getPacked()) + return rewriter.notifyMatchFailure( + loadOp, "Packed load distribution not supported"); + + xegpu::TensorDescType origType = loadOp.getTensorDescType(); + xegpu::SGMapAttr sgMap = origType.getSGMapAttr(); + if (!sgMap) + return rewriter.notifyMatchFailure( + loadOp, "the source tensor descriptor lacks sg_map attribute"); + + auto origShape = origType.getShape(); + if (origShape.size() != 2) + return rewriter.notifyMatchFailure(loadOp, "unsupported shape"); + + auto distributedTypeOrFailure = + getDistributedVectorType(loadOp.getType(), sgMap); + if (failed(distributedTypeOrFailure)) + return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type"); + VectorType newVectorType = distributedTypeOrFailure.value(); + + auto distributedDescTypeOrFailure = + getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap, + loadOp.getTensorDescType().getMemorySpace()); + if (failed(distributedDescTypeOrFailure)) + return rewriter.notifyMatchFailure(loadOp, + "Failed to distribute the desc type"); + xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value(); + + unsigned operandIdx = operand->getOperandNumber(); + + SmallVector newRetIndices; + vector::WarpExecuteOnLane0Op newWarpOp = + moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType}, + newRetIndices); + + rewriter.setInsertionPointAfter(newWarpOp); + + auto newLoadOp = rewriter.create( + loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(), + loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(), + loadOp.getL2HintAttr(), loadOp.getL3HintAttr()); + + newLoadOp.getTensorDescMutable().assign( + newWarpOp.getResult(newRetIndices[0])); + Value distributedVal = newWarpOp.getResult(operandIdx); + rewriter.replaceAllUsesWith(distributedVal, newLoadOp); + + return success(); +} + +LogicalResult +WarpOpTensorDescOp::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const { + OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { + return isa(op) && op->hasOneUse(); + }); + + if (!operand) + return rewriter.notifyMatchFailure( + warpOp, "warp result is not a xegpu::CreateNdDesc op"); + auto descOp = operand->get().getDefiningOp(); + assert(descOp && "desc op must be not null"); + unsigned operandIdx = operand->getOperandNumber(); + + // TODO: is memref uniform in the region + rewriter.setInsertionPoint(warpOp); + auto srcTypedVal = dyn_cast>(descOp.getSource()); + assert(srcTypedVal && "source value must be not null"); + + auto descOffsets = descOp.getMixedOffsets(); + if (descOffsets.size() != 2) + return rewriter.notifyMatchFailure(descOp, + "offsets size is expected to be 2"); + + xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr(); + if (!sgMap) + return rewriter.notifyMatchFailure( + descOp, "the tensor descriptor lacks sg_map attribute"); + + auto layout = sgMap.getWiLayout(); + + // Calculate the offset within tensor descriptor for the current lane_id. The + // access to proper element for a work item is done through a lane-specific + // subview (tdesc offsets are used as base, lane shift is added on top). + auto laneid = warpOp.getLaneid(); + auto xDim = + rewriter.create(laneid.getLoc(), layout[0]); + auto shiftx = rewriter.create(laneid.getLoc(), laneid, xDim); + auto shifty = rewriter.create(laneid.getLoc(), laneid, xDim); + + auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(), + descOffsets[0]); + auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(), + descOffsets[1]); + auto offsetx = rewriter.create(laneid.getLoc(), shiftx, basex); + auto offsety = rewriter.create(laneid.getLoc(), shifty, basey); + + auto distributedDescTypeOrFailure = getDistributedTensorDescType( + descOp.getType(), sgMap, descOp.getType().getMemorySpace()); + if (failed(distributedDescTypeOrFailure)) + return rewriter.notifyMatchFailure(descOp, + "Failed to distribute the desc type"); + xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value(); + auto distributedShape = newTDescType.getShape(); + // use the base memref strides + SmallVector overwriteStrides = + getAsIndexOpFoldResult(rewriter.getContext(), SmallVector{1, 1}); + SmallVector overwriteSizes = + getAsIndexOpFoldResult(rewriter.getContext(), distributedShape); + + SmallVector newRetIndices; + vector::WarpExecuteOnLane0Op newWarpOp = + moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, descOp.getSource(), descOp.getSourceType(), + newRetIndices); + + rewriter.setInsertionPointAfter(newWarpOp); + auto subview = rewriter.create( + newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}), + overwriteSizes, overwriteStrides); + subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0])); + + auto zero = rewriter.create(laneid.getLoc(), 0); + auto newDescOp = rewriter.create( + newWarpOp.getLoc(), newTDescType, subview, + getAsOpFoldResult({zero, zero})); + + Value distributedVal = newWarpOp.getResult(operandIdx); + rewriter.replaceAllUsesWith(distributedVal, newDescOp); + + return success(); +} + +void xegpu::populateXeGPUDistributePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 36d04bb77e3b9..6934657417792 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1665,7 +1665,7 @@ func.func @warp_2_distributed_dims(%laneid: index) { // ----- func.func @warp_2_distributed_dims(%laneid: index) { - // expected-error@+1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}} + // expected-error@+1 {{expected expanded type dimension #1 (8) to be a multipler of the distributed type dimension (3)}} %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) { %0 = arith.constant dense<2>: vector<4x8xi32> vector.yield %0 : vector<4x8xi32> @@ -1676,7 +1676,7 @@ func.func @warp_2_distributed_dims(%laneid: index) { // ----- func.func @warp_mismatch_rank(%laneid: index) { - // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}} + // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed types to have same rank and element type.}} %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) { %0 = arith.constant dense<2>: vector<128xi32> vector.yield %0 : vector<128xi32> @@ -1687,7 +1687,7 @@ func.func @warp_mismatch_rank(%laneid: index) { // ----- func.func @warp_mismatch_rank(%laneid: index) { - // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected vector type for distributed operands.}} + // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected shaped type for distributed operands.}} %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (i32) { %0 = arith.constant dense<2>: vector<128xi32> vector.yield %0 : vector<128xi32> diff --git a/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir new file mode 100644 index 0000000000000..ec01fc8268815 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-opt -test-xegpu-distribute -split-input-file %s | FileCheck %s + +#sg_map_16 = #xegpu.sg_map +#blk_tdesc = #xegpu.block_tdesc_attr + +// CHECK-LABEL: test_store_nd_distribution +// CHECK: %[[laneid:.*]] = gpu.lane_id +// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}}, %{{.*}} : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>) +// CHECK-SAME: -> (!xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>, vector<24x2xf16>) +// CHECK: ^bb0(%[[src:.*]]: vector<24x32xf16>, %[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>) +// CHECK: vector.yield %[[dst]], %[[src]] : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>, vector<24x32xf16> +// CHECK: xegpu.store_nd %[[res]]#1, %[[res]]#0 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : +// CHECK-SAME: vector<24x2xf16>, !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map> + +func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> () { + %laneid = gpu.lane_id + vector.warp_execute_on_lane_0(%laneid)[16] + args(%src, %dst: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) { + ^bb0(%arg0: vector<24x32xf16>, %arg1: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>): + xegpu.store_nd %arg0, %arg1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> + } + return +} + +// ----- + +#sg_map_16 = #xegpu.sg_map +#blk_tdesc = #xegpu.block_tdesc_attr + +// CHECK-LABEL: test_load_nd_distribution +// CHECK: %[[laneid:.*]] = gpu.lane_id +// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>) +// CHECK-SAME: -> (vector<24x2xf16>, !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>) +// CHECK: ^bb0(%[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>) +// CHECK: %[[dead:.*]] = xegpu.load_nd +// CHECK: vector.yield %[[dead]], %[[dst]] : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map> +// CHECK: %[[load:.*]] = xegpu.load_nd %[[res]]#1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : +// CHECK-SAME: !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map> -> vector<24x2xf16> +// CHECK: return %[[load]] + +func.func @test_load_nd_distribution(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) { + %laneid = gpu.lane_id + %r = vector.warp_execute_on_lane_0(%laneid)[16] + args(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) { + ^bb0(%arg0: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>): + %0 = xegpu.load_nd %arg0 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> -> vector<24x32xf16> + vector.yield %0 : vector<24x32xf16> + } + return %r : vector<24x2xf16> +} + +// ----- + +#sg_map_16 = #xegpu.sg_map +#blk_tdesc = #xegpu.block_tdesc_attr + +// CHECK-LABEL: test_create_nd_desc_distribution +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[laneid:.*]] = gpu.lane_id +// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} : memref<24x32xf16>) +// CHECK-SAME: -> (!xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>, memref<24x32xf16>) +// CHECK: ^bb0(%[[dst:.*]]: memref<24x32xf16>) +// CHECK: %[[dead:.*]] = xegpu.create_nd_tdesc +// CHECK: vector.yield %[[dead]], %[[dst]] : +// CHECK-SAME: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>, memref<24x32xf16> +// CHECK: %[[view:.*]] = memref.subview %[[res]]#1[%[[C0]], %[[laneid]]] [24, 2] [1, 1] : memref<24x32xf16> to memref<24x2xf16, strided<[32, 1], offset: ?>> +// CHECK: %[[desc:.*]] = xegpu.create_nd_tdesc %[[view]][0, 0] : memref<24x2xf16, strided<[32, 1], offset: ?>> +// CHECK-SAME: -> !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map> +// CHECK: return %[[desc]] + +func.func @test_create_nd_desc_distribution(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>) { + %laneid = gpu.lane_id + %r = vector.warp_execute_on_lane_0(%laneid)[16] + args(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>) { + ^bb0(%arg0: memref<24x32xf16>): + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> + vector.yield %0 : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> + } + return %r : !xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16> +} diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt index 29fb4441a24fd..a8fd70e6397a5 100644 --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -22,3 +22,4 @@ add_subdirectory(TestDyn) add_subdirectory(Tosa) add_subdirectory(Transform) add_subdirectory(Vector) +add_subdirectory(XeGPU) diff --git a/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt new file mode 100644 index 0000000000000..c8fe0db5f6213 --- /dev/null +++ b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt @@ -0,0 +1,17 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRXeGPUTestPasses + TestXeGPUTransforms.cpp + + EXCLUDE_FROM_LIBMLIR + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRXeGPUTransforms + MLIRXeGPUDialect + MLIRSupport + ) + diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp new file mode 100644 index 0000000000000..eda68b8374813 --- /dev/null +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -0,0 +1,58 @@ +//===- TestXeGPUTransforms.cpp - Test XeGPU transforms and lowerings ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::xegpu; +using namespace mlir::vector; + +namespace { +struct TestXeGPUDistribution + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUDistribution) + + TestXeGPUDistribution() = default; + TestXeGPUDistribution(const TestXeGPUDistribution &pass) + : PassWrapper(pass) {} + + StringRef getArgument() const final { return "test-xegpu-distribute"; } + StringRef getDescription() const final { + return "Test patterns for operations work item distribution"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateXeGPUDistributePatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestXeGPUTransforms() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 8b79de58fa102..e4ffbbee7a1d9 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -47,6 +47,7 @@ if(MLIR_INCLUDE_TESTS) MLIRTilingInterfaceTestPasses MLIRVectorTestPasses MLIRTestVectorToSPIRV + MLIRXeGPUTestPasses MLIRLLVMTestPasses ) set(test_libs ${test_libs} diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 36b142484bb04..b53e9513b0598 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -151,6 +151,7 @@ void registerTestTransformDialectEraseSchedulePass(); void registerTestPassStateExtensionCommunication(); void registerTestVectorLowerings(); void registerTestVectorReductionToSPIRVDotProd(); +void registerTestXeGPUTransforms(); void registerTestWrittenToPass(); #if MLIR_ENABLE_PDL_IN_PATTERNMATCH void registerTestDialectConversionPasses(); @@ -286,6 +287,7 @@ void registerTestPasses() { mlir::test::registerTestTransformDialectEraseSchedulePass(); mlir::test::registerTestPassStateExtensionCommunication(); mlir::test::registerTestVectorLowerings(); + mlir::test::registerTestXeGPUTransforms(); mlir::test::registerTestVectorReductionToSPIRVDotProd(); mlir::test::registerTestWrittenToPass(); #if MLIR_ENABLE_PDL_IN_PATTERNMATCH