diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 2200af0f67a86..8d8e861c84157 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -649,24 +649,29 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { return getShapes().front(); } - // TODO: Support folding with more than 2 input shapes - if (getShapes().size() > 2) + if (!adaptor.getShapes().front()) return nullptr; - if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1]) - return nullptr; - auto lhsShape = llvm::to_vector<6>( - llvm::cast(adaptor.getShapes()[0]) - .getValues()); - auto rhsShape = llvm::to_vector<6>( - llvm::cast(adaptor.getShapes()[1]) + SmallVector resultShape( + llvm::cast(adaptor.getShapes().front()) .getValues()); - SmallVector resultShape; - // If the shapes are not compatible, we can't fold it. - // TODO: Fold to an "error". - if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) - return nullptr; + for (auto next : adaptor.getShapes().drop_front()) { + if (!next) + return nullptr; + auto nextShape = llvm::to_vector<6>( + llvm::cast(next).getValues()); + + SmallVector tmpShape; + // If the shapes are not compatible, we can't fold it. + // TODO: Fold to an "error". + if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape)) + return nullptr; + + resultShape.clear(); + std::copy(tmpShape.begin(), tmpShape.end(), + std::back_inserter(resultShape)); + } Builder builder(getContext()); return builder.getIndexTensorAttr(resultShape); diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index a7aa25eae2644..6e62a33037eb8 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef shape1, if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) { // One or both dimensions is unknown. Follow TensorFlow behavior: // - If either dimension is greater than 1, we assume that the program is - // correct, and the other dimension will be broadcast to match it. + // correct, and the other dimension will be broadcasted to match it. // - If either dimension is 1, the other dimension is the output. if (*i1 > 1) { *iR = *i1; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index cf439c9c1b854..9ed4837a2fe7e 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape { // ----- +// Variadic case including extent tensors. +// CHECK-LABEL: @broadcast_variadic +func.func @broadcast_variadic() -> !shape.shape { + // CHECK: shape.const_shape [7, 2, 10] : !shape.shape + %0 = shape.const_shape [2, 1] : tensor<2xindex> + %1 = shape.const_shape [7, 2, 1] : tensor<3xindex> + %2 = shape.const_shape [1, 10] : tensor<2xindex> + %3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape + return %3 : !shape.shape +} + +// ----- + // Rhs is a scalar. // CHECK-LABEL: func @f func.func @f(%arg0 : !shape.shape) -> !shape.shape {