-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][vector] transpose(broadcast) -> broadcast canonicalization #135096
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesExample seen in the 'real world':
This PR adds a canonicalizer that rewrites the above as
It works by determining if a transpose is only shuffling contiguous broadcast dimensions. Full diff: https://github.com/llvm/llvm-project/pull/135096.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..05ff93da13aea 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6155,12 +6155,115 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' means the flattened
+/// inputs and outputs of the transpose have identical (numerical) values.
+///
+/// Example:
+/// ```
+/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
+/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
+/// to vector<8x1xi32>
+/// ```
+/// can be rewritten as the equivalent
+/// ```
+/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
+/// ```
+/// The algorithm works by partitioning dimensions into groups that can be
+/// locally permuted while preserving order, and checks that the transpose
+/// only permutes within these groups.
+class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+ static bool canFoldIntoPrecedingBroadcast(vector::TransposeOp transpose) {
+
+ vector::BroadcastOp broadcast =
+ transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcast)
+ return false;
+
+ auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+ bool inputIsScalar = !inputType;
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t inputRank = inputType.getRank();
+ int64_t outputRank = transpose.getType().getRank();
+ int64_t deltaRank = outputRank - inputRank;
+
+ // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
+ if (inputIsScalar)
+ return true;
+
+ // Return true if all permutation destinations for indices in [low, high)
+ // are in [low, high), so the permutation is local to the group.
+ auto isGroupBound = [&](int low, int high) {
+ ArrayRef<int64_t> permutation = transpose.getPermutation();
+ for (int j = low; j < high; ++j) {
+ if (permutation[j] < low || permutation[j] >= high) {
+ return false;
+ }
+ }
+ return true;
+ };
+
+ // Groups are either contiguous sequences of 1s and non-1s (1-element
+ // groups). Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent
+ // to broadcasting from 1x1x4x1x1x7.
+ // ^^^ ^ ^^^ ^
+ // groups: 0 1 2 3
+ // Order preserving permutations for this example are ones that only permute
+ // within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+ int low = 0;
+ for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
+ bool notOne = inputShape[inputIndex] != 1;
+ bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
+ bool groupEndFound = notOne || prevNotOne;
+ if (groupEndFound) {
+ int high = inputIndex + deltaRank;
+ if (!isGroupBound(low, high)) {
+ return false;
+ }
+ low = high;
+ }
+ }
+ if (!isGroupBound(low, outputRank)) {
+ return false;
+ }
+
+ // The preceding logic ensures that by this point, the ouutput of the
+ // transpose is definitely broadcastable from the input shape. So we don't
+ // need to call 'vector::isBroadcastableTo', but asserting here just as a
+ // sanity check:
+ bool isBroadcastable =
+ vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
+ vector::BroadcastableToResult::Success;
+ assert(isBroadcastable &&
+ "(I think) it must be broadcastable at this point.");
+
+ return true;
+ }
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
+ if (!canFoldIntoPrecedingBroadcast(transpose))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ transpose, transpose.getResultVectorType(), transpose.getVector());
+
+ return success();
+ }
+};
+
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
- TransposeFolder, FoldTransposeSplat>(context);
+ TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..03a338985299d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2215,6 +2215,80 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
// -----
+// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
+// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
+// CHECK: return %[[RES]] : vector<2x3x4xi8>
+func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
+ %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+ return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
+// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
+// CHECK: return %[[RES]] : vector<2x3x4xi8>
+func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
+ %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+ return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
+// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
+// CHECK: return %[[RES]] : vector<8x1xi8>
+func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
+ %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
+ return %1 : vector<8x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_mixed_example_folds
+// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
+// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
+func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+ %0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
+ %1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
+ return %1 : vector<3x2x4x5x6x7xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_102_nofold
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
+// CHECK: return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+ %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+ %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
+ return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_021_nofold
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
+// CHECK: return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+ %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+ %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
+ return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
// CHECK-LABEL: func.func @insert_1d_constant
// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice—this is very neatly crafted, thanks!
I left a few minor comments inline, and one broader question about whether this qualifies as a canonicalization - curious to hear your thoughts on that.
One additional note: it would be good to include negative tests. Also, I’d suggest putting more extensive testing under mlir/test/Vector
, and keeping only minimal illustrative examples in canonicalization.mlir
to avoid bloating that file.
TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>( | ||
context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering whether this qualifies as canonicalisation. I am not an expert, so merely raising my concerns. From https://mlir.llvm.org/docs/Canonicalization/#general-design
Canonicalize shouldn’t lose the semantic of original operation: the original information should always be recoverable from the transformed IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. The new broadcast is the same rank and 'volume' as the old one, so i can't think of sense in which it'll be more complex. So the removal of the transpose clinches it in my mind!
As an aside: It would be nice if MLIR/dialects defined an 'energy function' defining what classifies as a canonicalization (i.e. something to guarantee every rewrite takes us closer to a fixed point = energy minimum).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks for the discussion!
It would be nice if MLIR/dialects defined an 'energy function' defining what classifies as a canonicalization
So nice :)
7decc51
to
7a90358
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates!
I really like the idea behind vector-transpose-canonicalize.mlir
— we should start doing something similar for other patterns as well! To help streamline this going forward, would you mind creating a subdirectory called canonicalize/
and moving the test file there? I'd also suggest dropping "canonicalize" from the file name itself, and keeping all your tests in that subdirectory (*).
Sorry for the extra churn — I hadn’t realized this was a viable option earlier.
(*) I know I previously suggested the opposite. At the time, I wrongly assumed we could only have one file dedicated to canonicalization tests. That was my mistake!
TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>( | ||
context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks for the discussion!
It would be nice if MLIR/dialects defined an 'energy function' defining what classifies as a canonicalization
So nice :)
482da07
to
d3fe38a
Compare
Looks nicer now with your suggestions @banach-space, thanks! Let me know if the new test directory looks sensible, happy to iterate again if needed. I guess canonicalize.mlir could be moved into it in the future, although that might cause more trouble than it's worth. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM — thank you for working on this, and for being so diligent with the documentation and testing 🙏🏻
Since this adds a canonicalization, let’s give it a day or two before merging, just to allow time for other reviewers to take a look.
Also, I noticed that you force-pushed — just a heads-up that we generally try to avoid that during the review process in LLVM:
(one of many bits of documentation that’s not easy to find!)
@@ -0,0 +1,114 @@ | |||
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s | |||
|
|||
// This file contains some canonicalizations tests involving vector.transpose. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, it's totally valid (and something I personally encourage) to document what pattern specifically is being tested:
I just realized there is a folder a480d75 which is a subset of this PR (only considers case where source is scalar). I'll remove that folder
Ok, noted for next time. Thanks for pointing me to these guidelines (of which there don't seem to be too many, I should just read them..)
That's fine, no rush. |
…vm#135096) Example seen in the 'real world': ``` %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8> %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8> ``` This PR adds a canonicalizer that rewrites the above as ``` %1 = vector.broadcast %arg0 : vector<1xi8> to vector<8x1xi8> ``` It works by determining if a transpose is only shuffling contiguous broadcast dimensions.
Example seen in the 'real world':
This PR adds a canonicalizer that rewrites the above as
It works by determining if a transpose is only shuffling contiguous broadcast dimensions.