Skip to content

[MLIR][Vector] Allow any shaped type to be distributed for vector.wa… #114215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

kurapov-peter
Copy link
Contributor

…rp_execute_on_lane_0's return values

The second part of #112945.

@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2024

@llvm/pr-subscribers-mlir

Author: Petr Kurapov (kurapov-peter)

Changes

…rp_execute_on_lane_0's return values

The second part of #112945.


Full diff: https://github.com/llvm/llvm-project/pull/114215.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-6)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+3-3)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1853ae04f45d90..af5a2a276042ca 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
   // If the types matches there is no distribution.
   if (expanded == distributed)
     return success();
-  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
-  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
+  auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
+  auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
   if (!expandedVecType || !distributedVecType)
-    return op->emitOpError("expected vector type for distributed operands.");
+    return op->emitOpError("expected shaped type for distributed operands.");
   if (expandedVecType.getRank() != distributedVecType.getRank() ||
       expandedVecType.getElementType() != distributedVecType.getElementType())
     return op->emitOpError(
-        "expected distributed vectors to have same rank and element type.");
+        "expected distributed types to have same rank and element type.");
 
   SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
   for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
@@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
       continue;
     if (eDim % dDim != 0)
       return op->emitOpError()
-             << "expected expanded vector dimension #" << i << " (" << eDim
-             << ") to be a multipler of the distributed vector dimension ("
+             << "expected expanded type dimension #" << i << " (" << eDim
+             << ") to be a multipler of the distributed type dimension ("
              << dDim << ")";
     scales[i] = eDim / dDim;
   }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 56039d04549aa5..f2b7685d79effb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1665,7 +1665,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
 // -----
 
 func.func @warp_2_distributed_dims(%laneid: index) {
-  // expected-error@+1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}}
+  // expected-error@+1 {{expected expanded type dimension #1 (8) to be a multipler of the distributed type dimension (3)}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) {
     %0 = arith.constant dense<2>: vector<4x8xi32>
     vector.yield %0 : vector<4x8xi32>
@@ -1676,7 +1676,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
 // -----
 
 func.func @warp_mismatch_rank(%laneid: index) {
-  // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
+  // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed types to have same rank and element type.}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
     %0 = arith.constant dense<2>: vector<128xi32>
     vector.yield %0 : vector<128xi32>
@@ -1687,7 +1687,7 @@ func.func @warp_mismatch_rank(%laneid: index) {
 // -----
 
 func.func @warp_mismatch_rank(%laneid: index) {
-  // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected vector type for distributed operands.}}
+  // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected shaped type for distributed operands.}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (i32) {
     %0 = arith.constant dense<2>: vector<128xi32>
     vector.yield %0 : vector<128xi32>

@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2024

@llvm/pr-subscribers-mlir-vector

Author: Petr Kurapov (kurapov-peter)

Changes

…rp_execute_on_lane_0's return values

The second part of #112945.


Full diff: https://github.com/llvm/llvm-project/pull/114215.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-6)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+3-3)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1853ae04f45d90..af5a2a276042ca 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
   // If the types matches there is no distribution.
   if (expanded == distributed)
     return success();
-  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
-  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
+  auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
+  auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
   if (!expandedVecType || !distributedVecType)
-    return op->emitOpError("expected vector type for distributed operands.");
+    return op->emitOpError("expected shaped type for distributed operands.");
   if (expandedVecType.getRank() != distributedVecType.getRank() ||
       expandedVecType.getElementType() != distributedVecType.getElementType())
     return op->emitOpError(
-        "expected distributed vectors to have same rank and element type.");
+        "expected distributed types to have same rank and element type.");
 
   SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
   for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
@@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
       continue;
     if (eDim % dDim != 0)
       return op->emitOpError()
-             << "expected expanded vector dimension #" << i << " (" << eDim
-             << ") to be a multipler of the distributed vector dimension ("
+             << "expected expanded type dimension #" << i << " (" << eDim
+             << ") to be a multipler of the distributed type dimension ("
              << dDim << ")";
     scales[i] = eDim / dDim;
   }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 56039d04549aa5..f2b7685d79effb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1665,7 +1665,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
 // -----
 
 func.func @warp_2_distributed_dims(%laneid: index) {
-  // expected-error@+1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}}
+  // expected-error@+1 {{expected expanded type dimension #1 (8) to be a multipler of the distributed type dimension (3)}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) {
     %0 = arith.constant dense<2>: vector<4x8xi32>
     vector.yield %0 : vector<4x8xi32>
@@ -1676,7 +1676,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
 // -----
 
 func.func @warp_mismatch_rank(%laneid: index) {
-  // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
+  // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed types to have same rank and element type.}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
     %0 = arith.constant dense<2>: vector<128xi32>
     vector.yield %0 : vector<128xi32>
@@ -1687,7 +1687,7 @@ func.func @warp_mismatch_rank(%laneid: index) {
 // -----
 
 func.func @warp_mismatch_rank(%laneid: index) {
-  // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected vector type for distributed operands.}}
+  // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected shaped type for distributed operands.}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (i32) {
     %0 = arith.constant dense<2>: vector<128xi32>
     vector.yield %0 : vector<128xi32>

@kurapov-peter kurapov-peter changed the title [MLIR][Vector] Allow any shaped typed to be distributed for vector.wa… [MLIR][Vector] Allow any shaped type to be distributed for vector.wa… Oct 30, 2024
@kurapov-peter kurapov-peter requested a review from kuhar as a code owner October 30, 2024 12:19
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just translating my objection from the earlier PR to this. It seems wrong to check a verification on vector dialect to check that the result is a shapedtype. That seems like a layering violation.

@kurapov-peter
Copy link
Contributor Author

Just translating my objection from the earlier PR to this. It seems wrong to check a verification on vector dialect to check that the result is a shapedtype. That seems like a layering violation.

I would like to unblock further Xegpu lowering development and not require the creation of a whole new dialect since it may take a lot of time to land. I agree that this patch is not how the end solution should look, and I'm happy to improve it gradually.

I also considered an interface while doing the thing, but the interface alone doesn't solve the layering problem. So, between the two approaches, I picked the one with fewer changes. Please let me know if I missed some real issues that an interface solves.

@rengolin
Copy link
Member

Just translating my objection from the earlier PR to this. It seems wrong to check a verification on vector dialect to check that the result is a shapedtype. That seems like a layering violation.

Agreed. This could be an interface that both "vector types" implement, and the distribution works on the interface. It would also make it easier to extend other vector types to work in the same way, or change internal behaviour without breaking the distribution.

@kurapov-peter
Copy link
Contributor Author

As discussed in https://discourse.llvm.org/t/rfc-extending-vector-distribution-to-support-other-types/82833, moved the op to gpu instead where it can now be not vector-specific.

@kurapov-peter kurapov-peter deleted the vector-warp-shape-type branch December 10, 2024 11:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants