Skip to content

[mlir][vector] transpose(broadcast) -> broadcast canonicalization #135096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 16, 2025

Conversation

newling
Copy link
Contributor

@newling newling commented Apr 9, 2025

Example seen in the 'real world':

%0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
%1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>

This PR adds a canonicalizer that rewrites the above as

  %1 = vector.broadcast %arg0 : vector<1xi8> to vector<8x1xi8>

It works by determining if a transpose is only shuffling contiguous broadcast dimensions.

@llvmbot
Copy link
Member

llvmbot commented Apr 9, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

Example seen in the 'real world':

%0 = vector.broadcast %arg0 : vector&lt;1xi8&gt; to vector&lt;1x8xi8&gt;
%1 = vector.transpose %0, [1, 0] : vector&lt;1x8xi8&gt; to vector&lt;8x1xi8&gt;

This PR adds a canonicalizer that rewrites the above as

  %1 = vector.broadcast %arg0 : vector&lt;1xi8&gt; to vector&lt;8x1xi8&gt;

It works by determining if a transpose is only shuffling contiguous broadcast dimensions.


Full diff: https://github.com/llvm/llvm-project/pull/135096.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+104-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+74)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..05ff93da13aea 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6155,12 +6155,115 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' means the flattened
+/// inputs and outputs of the transpose have identical (numerical) values.
+///
+/// Example:
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
+///  %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
+///                                                 to vector<8x1xi32>
+/// ```
+/// can be rewritten as the equivalent
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
+/// ```
+/// The algorithm works by partitioning dimensions into groups that can be
+/// locally permuted while preserving order, and checks that the transpose
+/// only permutes within these groups.
+class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  static bool canFoldIntoPrecedingBroadcast(vector::TransposeOp transpose) {
+
+    vector::BroadcastOp broadcast =
+        transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+    if (!broadcast)
+      return false;
+
+    auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+    bool inputIsScalar = !inputType;
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    int64_t inputRank = inputType.getRank();
+    int64_t outputRank = transpose.getType().getRank();
+    int64_t deltaRank = outputRank - inputRank;
+
+    // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
+    if (inputIsScalar)
+      return true;
+
+    // Return true if all permutation destinations for indices in [low, high)
+    // are in [low, high), so the permutation is local to the group.
+    auto isGroupBound = [&](int low, int high) {
+      ArrayRef<int64_t> permutation = transpose.getPermutation();
+      for (int j = low; j < high; ++j) {
+        if (permutation[j] < low || permutation[j] >= high) {
+          return false;
+        }
+      }
+      return true;
+    };
+
+    // Groups are either contiguous sequences  of 1s and non-1s (1-element
+    // groups). Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent
+    // to broadcasting from 1x1x4x1x1x7.
+    //                      ^^^ ^ ^^^ ^
+    //               groups: 0  1  2  3
+    // Order preserving permutations for this example are ones that only permute
+    // within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+    int low = 0;
+    for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
+      bool notOne = inputShape[inputIndex] != 1;
+      bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
+      bool groupEndFound = notOne || prevNotOne;
+      if (groupEndFound) {
+        int high = inputIndex + deltaRank;
+        if (!isGroupBound(low, high)) {
+          return false;
+        }
+        low = high;
+      }
+    }
+    if (!isGroupBound(low, outputRank)) {
+      return false;
+    }
+
+    // The preceding logic ensures that by this point, the ouutput of the
+    // transpose is definitely broadcastable from the input shape. So we don't
+    // need to call 'vector::isBroadcastableTo', but asserting here just as a
+    // sanity check:
+    bool isBroadcastable =
+        vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
+        vector::BroadcastableToResult::Success;
+    assert(isBroadcastable &&
+           "(I think) it must be broadcastable at this point.");
+
+    return true;
+  }
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+    if (!canFoldIntoPrecedingBroadcast(transpose))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        transpose, transpose.getResultVectorType(), transpose.getVector());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat>(context);
+              TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..03a338985299d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2215,6 +2215,80 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
+//       CHECK:  return %[[RES]] : vector<2x3x4xi8>
+func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+  return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
+//       CHECK:  return %[[RES]] : vector<2x3x4xi8>
+func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+  return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
+//       CHECK:  return %[[RES]] : vector<8x1xi8>
+func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
+  %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
+  return %1 : vector<8x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_mixed_example_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
+//       CHECK:  return %[[RES]] : vector<3x2x4x5x6x7xi8>
+func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+  %0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
+  %1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
+  return %1 : vector<3x2x4x5x6x7xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_102_nofold
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
+//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+  %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
+  return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_021_nofold
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
+//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+  %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
+  return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func.func @insert_1d_constant
 //   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
 //   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice—this is very neatly crafted, thanks!

I left a few minor comments inline, and one broader question about whether this qualifies as a canonicalization - curious to hear your thoughts on that.

One additional note: it would be good to include negative tests. Also, I’d suggest putting more extensive testing under mlir/test/Vector, and keeping only minimal illustrative examples in canonicalization.mlir to avoid bloating that file.

Comment on lines 6265 to 6253
TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
context);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering whether this qualifies as canonicalisation. I am not an expert, so merely raising my concerns. From https://mlir.llvm.org/docs/Canonicalization/#general-design

Canonicalize shouldn’t lose the semantic of original operation: the original information should always be recoverable from the transformed IR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. The new broadcast is the same rank and 'volume' as the old one, so i can't think of sense in which it'll be more complex. So the removal of the transpose clinches it in my mind!

As an aside: It would be nice if MLIR/dialects defined an 'energy function' defining what classifies as a canonicalization (i.e. something to guarantee every rewrite takes us closer to a fixed point = energy minimum).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the discussion!

It would be nice if MLIR/dialects defined an 'energy function' defining what classifies as a canonicalization

So nice :)

@newling newling force-pushed the broadcast_to_transpose_folder branch from 7decc51 to 7a90358 Compare April 10, 2025 17:47
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates!

I really like the idea behind vector-transpose-canonicalize.mlir — we should start doing something similar for other patterns as well! To help streamline this going forward, would you mind creating a subdirectory called canonicalize/ and moving the test file there? I'd also suggest dropping "canonicalize" from the file name itself, and keeping all your tests in that subdirectory (*).

Sorry for the extra churn — I hadn’t realized this was a viable option earlier.

(*) I know I previously suggested the opposite. At the time, I wrongly assumed we could only have one file dedicated to canonicalization tests. That was my mistake!

Comment on lines 6265 to 6253
TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
context);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the discussion!

It would be nice if MLIR/dialects defined an 'energy function' defining what classifies as a canonicalization

So nice :)

@newling newling force-pushed the broadcast_to_transpose_folder branch from 482da07 to d3fe38a Compare April 11, 2025 18:13
@newling
Copy link
Contributor Author

newling commented Apr 11, 2025

Looks nicer now with your suggestions @banach-space, thanks! Let me know if the new test directory looks sensible, happy to iterate again if needed. I guess canonicalize.mlir could be moved into it in the future, although that might cause more trouble than it's worth.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM — thank you for working on this, and for being so diligent with the documentation and testing 🙏🏻

Since this adds a canonicalization, let’s give it a day or two before merging, just to allow time for other reviewers to take a look.

Also, I noticed that you force-pushed — just a heads-up that we generally try to avoid that during the review process in LLVM:

(one of many bits of documentation that’s not easy to find!)

@@ -0,0 +1,114 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s

// This file contains some canonicalizations tests involving vector.transpose.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, it's totally valid (and something I personally encourage) to document what pattern specifically is being tested:

@newling
Copy link
Contributor Author

newling commented Apr 11, 2025

I just realized there is a folder a480d75 which is a subset of this PR (only considers case where source is scalar). I'll remove that folder

Also, I noticed that you force-pushed — just a heads-up that we generally try to avoid that during the review process in LLVM:

Ok, noted for next time. Thanks for pointing me to these guidelines (of which there don't seem to be too many, I should just read them..)

Since this adds a canonicalization, let’s give it a day or two before merging, just to allow time for other reviewers to take a look.

That's fine, no rush.

@qedawkins qedawkins merged commit 0daf20b into llvm:main Apr 16, 2025
11 checks passed
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
…vm#135096)

Example seen in the 'real world':
 
 ```
 %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
 %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
 ```
 
 This PR adds a canonicalizer that rewrites the above as
 
```
  %1 = vector.broadcast %arg0 : vector<1xi8> to vector<8x1xi8>
```

It works by determining if a transpose is only shuffling contiguous
broadcast dimensions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants