diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h index d759299cbf762..c0b286494996b 100644 --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -101,7 +101,10 @@ Type getType(OpFoldResult ofr); /// Helper struct to build simple arithmetic quantities with minimal type /// inference support. struct ArithBuilder { - ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {} + ArithBuilder( + OpBuilder &b, Location loc, + arith::IntegerOverflowFlags ovf = arith::IntegerOverflowFlags::none) + : b(b), loc(loc), ovf(ovf) {} Value _and(Value lhs, Value rhs); Value add(Value lhs, Value rhs); @@ -114,6 +117,15 @@ struct ArithBuilder { private: OpBuilder &b; Location loc; + arith::IntegerOverflowFlags ovf; +}; + +/// ArithBuilder specialized specifically for tensor/memref indexing +/// calculations. Those calculations generally should never signed overflow and +/// always use signed integers, so we can set oveflow flags accordingly. +struct ArithIndexingBuilder : public ArithBuilder { + ArithIndexingBuilder(OpBuilder &b, Location loc) + : ArithBuilder(b, loc, arith::IntegerOverflowFlags::nsw) {} }; namespace arith { diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index f46aa0428f12f..14cbbac99d9ae 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -458,7 +458,9 @@ def ApplySinkVectorPatternsOp : Op]> { let description = [{ Patterns that remove redundant Vector Ops by re-ordering them with - e.g. elementwise Ops: + e.g. elementwise Ops. + + Example: ``` %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> @@ -469,8 +471,32 @@ def ApplySinkVectorPatternsOp : Op %r = vector.transpose %0, [1, 0] : vector<2x4xf32> ``` - At the moment, these patterns are limited to vector.broadcast and - vector.transpose. + At the moment, these patterns are limited to vector.broadcast, + vector.transpose and vector.extract. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplySinkVectorMemPatternsOp : Op]> { + let description = [{ + Patterns that replace redundant Vector Ops (followed by + `vector.load`/`vector.store`) with either vector.load/vector.store or + `memref.load`/`memref.store`. Currently limited to 1-element vectors. + + Example: + ``` + vector.load %arg0[%arg1] : memref, vector<4xf32> + vector.extract %0[1] : f32 from vector<4xf32> + ``` + Gets converted to: + ``` + %c1 = arith.constant 1 : index + %0 = arith.addi %arg1, %c1 overflow : index + %1 = memref.load %arg0[%0] : memref + ``` }]; let assemblyFormat = "attr-dict"; diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index ce97847172197..7a079dcc6affc 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -161,6 +161,20 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Patterns that remove redundant Vector Ops by merging them with load/store +/// ops +/// ``` +/// vector.load %arg0[%arg1] : memref, vector<4xf32> +/// vector.extract %0[1] : f32 from vector<4xf32> +/// ``` +/// Gets converted to: +/// ``` +/// %c1 = arith.constant 1 : index +/// %0 = arith.addi %arg1, %c1 overflow : index +/// %1 = memref.load %arg0[%0] : memref +void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Patterns that fold chained vector reductions. These patterns assume that /// elementwise operations (e.g., `arith.addf` with vector operands) are /// cheaper than vector reduction. diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index 8dde9866b22b3..6b1074e454bd5 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -315,17 +315,17 @@ Value ArithBuilder::_and(Value lhs, Value rhs) { Value ArithBuilder::add(Value lhs, Value rhs) { if (isa(lhs.getType())) return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs); + return b.create(loc, lhs, rhs, ovf); } Value ArithBuilder::sub(Value lhs, Value rhs) { if (isa(lhs.getType())) return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs); + return b.create(loc, lhs, rhs, ovf); } Value ArithBuilder::mul(Value lhs, Value rhs) { if (isa(lhs.getType())) return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs); + return b.create(loc, lhs, rhs, ovf); } Value ArithBuilder::sgt(Value lhs, Value rhs) { if (isa(lhs.getType())) diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 12dcf768dd928..a888d745be443 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns( vector::populateSinkVectorOpsPatterns(patterns); } +void transform::ApplySinkVectorMemPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateSinkVectorMemOpsPatterns(patterns); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 89839d0440d3c..b94c5fce64f83 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -902,6 +902,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern { }; /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: +/// +/// Example: /// ``` /// %a = vector.broadcast %arg1 : index to vector<1x4xindex> /// %b = vector.broadcast %arg2 : index to vector<1x4xindex> @@ -987,6 +989,8 @@ struct ReorderElementwiseOpsOnBroadcast final /// This may result in cleaner code when extracting a single value /// from multi-element vector and also to help canonicalize 1-element vectors to /// scalars. +/// +/// Example: /// ``` /// %0 = arith.addf %arg0, %arg1 : vector<4xf32> /// %1 = vector.extract %0[1] : f32 from vector<4xf32> @@ -1043,6 +1047,150 @@ class ExtractOpFromElementwise final } }; +/// Check if the element type is suitable for vector.load/store sinking. +/// Element type must be index or byte-aligned integer or floating-point type. +static bool isSupportedMemSinkElementType(Type type) { + if (isa(type)) + return true; + + return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0; +} + +/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load. +/// Only index and byte-aligned integer and floating-point element types are +/// supported for now. +/// +/// Example: +/// ``` +/// vector.load %arg0[%arg1] : memref, vector<4xf32> +/// vector.extract %0[1] : f32 from vector<4xf32> +/// ``` +/// Gets converted to: +/// ``` +/// %c1 = arith.constant 1 : index +/// %0 = arith.addi %arg1, %c1 overflow : index +/// %1 = memref.load %arg0[%0] : memref +/// ``` +class ExtractOpFromLoad final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractOp op, + PatternRewriter &rewriter) const override { + auto loadOp = op.getVector().getDefiningOp(); + if (!loadOp) + return rewriter.notifyMatchFailure(op, "expected a load op"); + + // Checking for single use so we won't duplicate load ops. + if (!loadOp->hasOneUse()) + return rewriter.notifyMatchFailure(op, "expected single op use"); + + VectorType loadVecType = loadOp.getVectorType(); + if (loadVecType.isScalable()) + return rewriter.notifyMatchFailure(op, + "scalable vectors are not supported"); + + MemRefType memType = loadOp.getMemRefType(); + + // Non-byte-aligned types are tricky and may require special handling, + // ignore them for now. + if (!isSupportedMemSinkElementType(memType.getElementType())) + return rewriter.notifyMatchFailure(op, "unsupported element type"); + + int64_t rankOffset = memType.getRank() - loadVecType.getRank(); + if (rankOffset < 0) + return rewriter.notifyMatchFailure(op, "unsupported ranks combination"); + + auto extractVecType = dyn_cast(op.getResult().getType()); + int64_t finalRank = 0; + if (extractVecType) + finalRank = extractVecType.getRank(); + + SmallVector indices = loadOp.getIndices(); + SmallVector extractPos = op.getMixedPosition(); + + // There may be memory stores between the load and the extract op, so we + // need to make sure that the new load op is inserted at the same place as + // the original load op. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + ArithIndexingBuilder idxBuilderf(rewriter, loc); + for (auto i : llvm::seq(rankOffset, indices.size() - finalRank)) { + OpFoldResult pos = extractPos[i - rankOffset]; + if (isConstantIntValue(pos, 0)) + continue; + + Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos); + indices[i] = idxBuilderf.add(indices[i], offset); + } + + Value base = loadOp.getBase(); + if (extractVecType) { + rewriter.replaceOpWithNewOp(op, extractVecType, base, + indices); + } else { + rewriter.replaceOpWithNewOp(op, base, indices); + } + // We checked for single use so we can safely erase the load op. + rewriter.eraseOp(loadOp); + return success(); + } +}; + +/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store. +/// +/// Example: +/// ``` +/// %0 = vector.splat %arg2 : vector<1xf32> +/// vector.store %0, %arg0[%arg1] : memref, vector<1xf32> +/// ``` +/// Gets converted to: +/// ``` +/// memref.store %arg2, %arg0[%arg1] : memref +/// ``` +class StoreOpFromSplatOrBroadcast final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::StoreOp op, + PatternRewriter &rewriter) const override { + VectorType vecType = op.getVectorType(); + if (vecType.isScalable()) + return rewriter.notifyMatchFailure(op, + "scalable vectors are not supported"); + + if (isa(op.getMemRefType().getElementType())) + return rewriter.notifyMatchFailure( + op, "memrefs of vectors are not supported"); + + if (vecType.getNumElements() != 1) + return rewriter.notifyMatchFailure( + op, "only 1-element vectors are supported"); + + Operation *splat = op.getValueToStore().getDefiningOp(); + if (!isa_and_present(splat)) + return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast"); + + // Checking for single use so we can remove splat. + if (!splat->hasOneUse()) + return rewriter.notifyMatchFailure(op, "expected single op use"); + + Value source = splat->getOperand(0); + Value base = op.getBase(); + ValueRange indices = op.getIndices(); + + if (isa(source.getType())) { + rewriter.replaceOpWithNewOp(op, source, base, indices); + } else { + rewriter.replaceOpWithNewOp(op, source, base, indices); + } + rewriter.eraseOp(splat); + return success(); + } +}; + // Helper that returns a vector comparison that constructs a mask: // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] // @@ -2109,6 +2257,13 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, patterns.getContext(), benefit); } +void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { + // TODO: Consider converting these patterns to canonicalizations. + patterns.add( + patterns.getContext(), benefit); +} + void mlir::vector::populateChainedVectorReductionFoldingPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir index ef17b69b2444c..4d04276742164 100644 --- a/mlir/test/Dialect/Vector/vector-sink-transform.mlir +++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir @@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} { %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %func { transform.apply_patterns.vector.sink_ops + transform.apply_patterns.vector.sink_mem_ops } : !transform.any_op transform.yield } diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir index 8c8f1797aaab6..900ad99bb4a4c 100644 --- a/mlir/test/Dialect/Vector/vector-sink.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -513,3 +513,203 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %1 = vector.extract %0[1] : f32 from vector<4xf32> return %1 : f32 } + +//----------------------------------------------------------------------------- +// [Pattern: ExtractOpFromLoad] +//----------------------------------------------------------------------------- + +// CHECK-LABEL: @extract_load_scalar +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @extract_load_scalar(%arg0: memref, %arg1: index) -> f32 { +// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref +// CHECK: return %[[RES]] : f32 + %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> + %1 = vector.extract %0[0] : f32 from vector<4xf32> + return %1 : f32 +} + +// CHECK-LABEL: @extract_load_index +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @extract_load_index(%arg0: memref, %arg1: index) -> index { +// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref +// CHECK: return %[[RES]] : index + %0 = vector.load %arg0[%arg1] : memref, vector<4xindex> + %1 = vector.extract %0[0] : index from vector<4xindex> + return %1 : index +} + +// CHECK-LABEL: @extract_load_scalar_non_zero_off +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @extract_load_scalar_non_zero_off(%arg0: memref, %arg1: index) -> f32 { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow : index +// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref +// CHECK: return %[[RES]] : f32 + %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> + %1 = vector.extract %0[1] : f32 from vector<4xf32> + return %1 : f32 +} + +// CHECK-LABEL: @extract_load_scalar_dyn_off +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +func.func @extract_load_scalar_dyn_off(%arg0: memref, %arg1: index, %arg2: index) -> f32 { +// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow : index +// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref +// CHECK: return %[[RES]] : f32 + %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> + %1 = vector.extract %0[%arg2] : f32 from vector<4xf32> + return %1 : f32 +} + +// CHECK-LABEL: @extract_load_vec_non_zero_off +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +func.func @extract_load_vec_non_zero_off(%arg0: memref, %arg1: index, %arg2: index) -> vector<4xf32> { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow : index +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref, vector<4xf32> +// CHECK: return %[[RES]] : vector<4xf32> + %0 = vector.load %arg0[%arg1, %arg2] : memref, vector<2x4xf32> + %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32> + return %1 : vector<4xf32> +} + +// CHECK-LABEL: @extract_load_scalar_non_zero_off_2d_src_memref +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +func.func @extract_load_scalar_non_zero_off_2d_src_memref(%arg0: memref, %arg1: index, %arg2: index) -> f32 { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow : index +// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref +// CHECK: return %[[RES]] : f32 + %0 = vector.load %arg0[%arg1, %arg2] : memref, vector<4xf32> + %1 = vector.extract %0[1] : f32 from vector<4xf32> + return %1 : f32 +} + +// CHECK-LABEL: @extract_load_vec_high_rank +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) +func.func @extract_load_vec_high_rank(%arg0: memref, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow : index +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref, vector<4xf32> +// CHECK: return %[[RES]] : vector<4xf32> + %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref, vector<2x4xf32> + %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32> + return %1 : vector<4xf32> +} + +// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_vec +// CHECK-SAME: (%[[ARG0:.*]]: memref>, %[[ARG1:.*]]: index) +func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref>, %arg1: index) -> f32 { +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref>, vector<4xf32> +// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32> +// CHECK: return %[[EXT]] : f32 + %0 = vector.load %arg0[%arg1] : memref>, vector<4xf32> + %1 = vector.extract %0[0] : f32 from vector<4xf32> + return %1 : f32 +} + +// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_i1 +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @negative_extract_load_scalar_from_memref_of_i1(%arg0: memref, %arg1: index) -> i1 { +// Subbyte types are tricky, ignore them for now. +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref, vector<8xi1> +// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : i1 from vector<8xi1> +// CHECK: return %[[EXT]] : i1 + %0 = vector.load %arg0[%arg1] : memref, vector<8xi1> + %1 = vector.extract %0[0] : i1 from vector<8xi1> + return %1 : i1 +} + +// CHECK-LABEL: @negative_extract_load_no_single_use +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @negative_extract_load_no_single_use(%arg0: memref, %arg1: index) -> (f32, vector<4xf32>) { +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref, vector<4xf32> +// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32> +// CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32> + %0 = vector.load %arg0[%arg1] : memref, vector<4xf32> + %1 = vector.extract %0[0] : f32 from vector<4xf32> + return %1, %0 : f32, vector<4xf32> +} + +// CHECK-LABEL: @negative_extract_load_scalable +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index) +func.func @negative_extract_load_scalable(%arg0: memref, %arg1: index) -> f32 { +// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref, vector<[1]xf32> +// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32> +// CHECK: return %[[EXT]] : f32 + %0 = vector.load %arg0[%arg1] : memref, vector<[1]xf32> + %1 = vector.extract %0[0] : f32 from vector<[1]xf32> + return %1 : f32 +} + +//----------------------------------------------------------------------------- +// [Pattern: StoreOpFromSplatOrBroadcast] +//----------------------------------------------------------------------------- + +// CHECK-LABEL: @store_splat +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32) +func.func @store_splat(%arg0: memref, %arg1: index, %arg2: f32) { +// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref + %0 = vector.splat %arg2 : vector<1xf32> + vector.store %0, %arg0[%arg1] : memref, vector<1xf32> + return +} + +// CHECK-LABEL: @store_broadcast +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32) +func.func @store_broadcast(%arg0: memref, %arg1: index, %arg2: f32) { +// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref + %0 = vector.broadcast %arg2 : f32 to vector<1xf32> + vector.store %0, %arg0[%arg1] : memref, vector<1xf32> + return +} + +// CHECK-LABEL: @store_broadcast_1d_to_2d +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>) +func.func @store_broadcast_1d_to_2d(%arg0: memref, %arg1: index, %arg2: index, %arg3: vector<1xf32>) { +// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref, vector<1xf32> + %0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32> + vector.store %0, %arg0[%arg1, %arg2] : memref, vector<1x1xf32> + return +} + +// CHECK-LABEL: @negative_store_scalable +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32) +func.func @negative_store_scalable(%arg0: memref, %arg1: index, %arg2: f32) { +// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32> +// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref, vector<[1]xf32> + %0 = vector.splat %arg2 : vector<[1]xf32> + vector.store %0, %arg0[%arg1] : memref, vector<[1]xf32> + return +} + +// CHECK-LABEL: @negative_store_memref_of_vec +// CHECK-SAME: (%[[ARG0:.*]]: memref>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32) +func.func @negative_store_memref_of_vec(%arg0: memref>, %arg1: index, %arg2: f32) { +// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32> +// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref>, vector<1xf32> + %0 = vector.splat %arg2 : vector<1xf32> + vector.store %0, %arg0[%arg1] : memref>, vector<1xf32> + return +} + +// CHECK-LABEL: @negative_store_more_than_one_element +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32) +func.func @negative_store_more_than_one_element(%arg0: memref, %arg1: index, %arg2: f32) { +// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32> +// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref, vector<4xf32> + %0 = vector.splat %arg2 : vector<4xf32> + vector.store %0, %arg0[%arg1] : memref, vector<4xf32> + return +} + +// CHECK-LABEL: @negative_store_no_single_use +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32) +func.func @negative_store_no_single_use(%arg0: memref, %arg1: index, %arg2: f32) -> vector<1xf32> { +// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32> +// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref, vector<1xf32> +// CHECK: return %[[RES:.*]] : vector<1xf32> + %0 = vector.splat %arg2 : vector<1xf32> + vector.store %0, %arg0[%arg1] : memref, vector<1xf32> + return %0 : vector<1xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index a54ae816570a8..03f907e46c2c6 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -395,6 +395,7 @@ struct TestVectorSinkPatterns void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateSinkVectorOpsPatterns(patterns); + populateSinkVectorMemOpsPatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } };