Skip to content

[MLIR][Vector][NFC] Move helper functions to vector distribution utils #114208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

kurapov-peter
Copy link
Contributor

The first portion of #112945.

@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Petr Kurapov (kurapov-peter)

Changes

The first portion of #112945.


Full diff: https://github.com/llvm/llvm-project/pull/114208.diff

4 Files Affected:

  • (added) mlir/include/mlir/Dialect/Vector/Utils/VectorDistributionUtils.h (+37)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+1-79)
  • (modified) mlir/lib/Dialect/Vector/Utils/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Utils/VectorDistributionUtils.cpp (+93)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorDistributionUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorDistributionUtils.h
new file mode 100644
index 00000000000000..460f1f2a49d89c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorDistributionUtils.h
@@ -0,0 +1,37 @@
+//===- VectorDistributionUtils.h - Distribution Utilities -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_DISTRIBUTION_UTILS_VECTORUTILS_H_
+#define MLIR_DIALECT_VECTOR_DISTRIBITION_UTILS_VECTORUTILS_H_
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+#include <utility>
+
+namespace mlir {
+namespace vector {
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
+                         const std::function<bool(Operation *)> &fn);
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes);
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices);
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_DISTRIBITION_UTILS_VECTORUTILS_H_
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..1745345a640dbc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorDistributionUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -160,68 +161,6 @@ struct DistributedLoadStoreHelper {
 
 } // namespace
 
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes) {
-  // Create a new op before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(warpOp);
-  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
-      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
-      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
-  Region &opBody = warpOp.getBodyRegion();
-  Region &newOpBody = newWarpOp.getBodyRegion();
-  Block &newOpFirstBlock = newOpBody.front();
-  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
-  rewriter.eraseBlock(&newOpFirstBlock);
-  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
-         "expected WarpOp with single block");
-
-  auto yield =
-      cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
-  rewriter.modifyOpInPlace(
-      yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
-  return newWarpOp;
-}
-
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes,
-    llvm::SmallVector<size_t> &indices) {
-  SmallVector<Type> types(warpOp.getResultTypes().begin(),
-                          warpOp.getResultTypes().end());
-  auto yield = cast<vector::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
-                                              yield.getOperands().end());
-  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
-    if (yieldValues.insert(std::get<0>(newRet))) {
-      types.push_back(std::get<1>(newRet));
-      indices.push_back(yieldValues.size() - 1);
-    } else {
-      // If the value already exit the region don't create a new output.
-      for (auto [idx, yieldOperand] :
-           llvm::enumerate(yieldValues.getArrayRef())) {
-        if (yieldOperand == std::get<0>(newRet)) {
-          indices.push_back(idx);
-          break;
-        }
-      }
-    }
-  }
-  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
-  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-      rewriter, warpOp, yieldValues.getArrayRef(), types);
-  rewriter.replaceOp(warpOp,
-                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
-  return newWarpOp;
-}
-
 /// Helper to know if an op can be hoisted out of the region.
 static bool canBeHoisted(Operation *op,
                          function_ref<bool(Value)> definedOutside) {
@@ -229,23 +168,6 @@ static bool canBeHoisted(Operation *op,
          isMemoryEffectFree(op) && op->getNumRegions() == 0;
 }
 
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
-                                const std::function<bool(Operation *)> &fn) {
-  auto yield = cast<vector::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  for (OpOperand &yieldOperand : yield->getOpOperands()) {
-    Value yieldValues = yieldOperand.get();
-    Operation *definedOp = yieldValues.getDefiningOp();
-    if (definedOp && fn(definedOp)) {
-      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
-        return &yieldOperand;
-    }
-  }
-  return {};
-}
-
 // Clones `op` into a new operation that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
index fa3971695d4bf2..41fd677c845ae2 100644
--- a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVectorUtils
   VectorUtils.cpp
+  VectorDistributionUtils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Utils
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorDistributionUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorDistributionUtils.cpp
new file mode 100644
index 00000000000000..19df2105de861f
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Utils/VectorDistributionUtils.cpp
@@ -0,0 +1,93 @@
+//===- VectorDistributionUtils.cpp - Distribution tools for VectorOps -----===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements distribution utility methods.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Utils/VectorDistributionUtils.h"
+
+#include "mlir/IR/Value.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+mlir::OpOperand *
+mlir::vector::getWarpResult(WarpExecuteOnLane0Op warpOp,
+                            const std::function<bool(Operation *)> &fn) {
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) {
+    Value yieldValues = yieldOperand.get();
+    Operation *definedOp = yieldValues.getDefiningOp();
+    if (definedOp && fn(definedOp)) {
+      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        return &yieldOperand;
+    }
+  }
+  return {};
+}
+
+WarpExecuteOnLane0Op mlir::vector::moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  // Create a new op before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(warpOp);
+  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+  Region &opBody = warpOp.getBodyRegion();
+  Region &newOpBody = newWarpOp.getBodyRegion();
+  Block &newOpFirstBlock = newOpBody.front();
+  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  rewriter.eraseBlock(&newOpFirstBlock);
+  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+         "expected WarpOp with single block");
+
+  auto yield =
+      cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+  rewriter.modifyOpInPlace(
+      yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
+  return newWarpOp;
+}
+
+WarpExecuteOnLane0Op mlir::vector::moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices) {
+  SmallVector<Type> types(warpOp.getResultTypes().begin(),
+                          warpOp.getResultTypes().end());
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+                                              yield.getOperands().end());
+  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+    if (yieldValues.insert(std::get<0>(newRet))) {
+      types.push_back(std::get<1>(newRet));
+      indices.push_back(yieldValues.size() - 1);
+    } else {
+      // If the value already exit the region don't create a new output.
+      for (auto [idx, yieldOperand] :
+           llvm::enumerate(yieldValues.getArrayRef())) {
+        if (yieldOperand == std::get<0>(newRet)) {
+          indices.push_back(idx);
+          break;
+        }
+      }
+    }
+  }
+  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+      rewriter, warpOp, yieldValues.getArrayRef(), types);
+  rewriter.replaceOp(warpOp,
+                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
+  return newWarpOp;
+}

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand what you are trying to do, but this all screaming for an interface where the distribution could be implemented on operations that aren't vector type.

@kurapov-peter
Copy link
Contributor Author

I understand what you are trying to do, but this all screaming for an interface where the distribution could be implemented on operations that aren't vector type.

This change has nothing to do with interfaces, why are you blocking it?

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if an interface would be a nice idea, it would need users to make sure we're designing it right. This patch is an NFC movement that may help us get there, but let's not put the cart in front of the horses.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this is just taking internal implementation details of an op and arbitrarily exposing it.

But, I dont really have a stake in this. I am going to unblock this.

@MaheshRavishankar MaheshRavishankar dismissed their stale review November 1, 2024 07:42

Abandon my objection

@kurapov-peter
Copy link
Contributor Author

Done in #119264

@kurapov-peter kurapov-peter deleted the pakurapo/vector-utils-nfc branch December 16, 2024 12:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants