diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index bee5c1fd6ed58..74222cb56d412 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -151,13 +151,39 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind, return false; } -AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, - VectorType vectorType) { - int64_t elementVectorRank = 0; +/// Returns the number of dimensions of the `shapedType` that participate in the +/// vector transfer, effectively the rank of the vector dimensions within the +/// `shapedType`. This is calculated by taking the rank of the `vectorType` +/// being transferred and subtracting the rank of the `shapedType`'s element +/// type if it's also a vector. +/// +/// This is used to determine the number of minor dimensions for identity maps +/// in vector transfers. +/// +/// For example, given a transfer operation involving `shapedType` and +/// `vectorType`: +/// +/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32> +/// - shapedType.getElementType() = f32 (rank 0) +/// - vectorType.getRank() = 2 +/// - Result = 2 - 0 = 2 +/// +/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32> +/// - shapedType.getElementType() = vector<20xf32> (rank 1) +/// - vectorType.getRank() = 1 +/// - Result = 1 - 1 = 0 +static unsigned getRealVectorRank(ShapedType shapedType, + VectorType vectorType) { + unsigned elementVectorRank = 0; VectorType elementVectorType = llvm::dyn_cast(shapedType.getElementType()); if (elementVectorType) elementVectorRank += elementVectorType.getRank(); + return vectorType.getRank() - elementVectorRank; +} + +AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, + VectorType vectorType) { // 0-d transfers are to/from tensor/memref and vector<1xt>. // TODO: replace once we have 0-d vectors. if (shapedType.getRank() == 0 && @@ -166,7 +192,7 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, /*numDims=*/0, /*numSymbols=*/0, getAffineConstantExpr(0, shapedType.getContext())); return AffineMap::getMinorIdentityMap( - shapedType.getRank(), vectorType.getRank() - elementVectorRank, + shapedType.getRank(), getRealVectorRank(shapedType, vectorType), shapedType.getContext()); } @@ -4261,6 +4287,10 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { Attribute permMapAttr = result.attributes.get(permMapAttrName); AffineMap permMap; if (!permMapAttr) { + if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType)) + return parser.emitError(typesLoc, + "expected a custom permutation_map when " + "rank(source) != rank(destination)"); permMap = getTransferMinorIdentityMap(shapedType, vectorType); result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); } else { @@ -4676,6 +4706,10 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser, auto permMapAttr = result.attributes.get(permMapAttrName); AffineMap permMap; if (!permMapAttr) { + if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType)) + return parser.emitError(typesLoc, + "expected a custom permutation_map when " + "rank(source) != rank(destination)"); permMap = getTransferMinorIdentityMap(shapedType, vectorType); result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); } else { diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index dbf829e014b8d..63f8667ce6b9e 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -525,6 +525,15 @@ func.func @test_vector.transfer_read(%arg0: memref>) { // ----- +func.func @test_vector.transfer_read(%arg1: memref) -> vector<3x4xindex> { + %c3 = arith.constant 3 : index + // expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}} + %0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref, vector<3x4xindex> + return %0 : vector<3x4xindex> +} + +// ----- + func.func @test_vector.transfer_write(%arg0: memref) { %c3 = arith.constant 3 : index %cst = arith.constant 3.0 : f32 @@ -646,6 +655,14 @@ func.func @test_vector.transfer_write(%arg0: memref, %arg1: vector<7xf32> // ----- +func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xindex>, %output_memref: memref) { + %c3 = arith.constant 3 : index + // expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}} + vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xindex>, memref +} + +// ----- + func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected offsets of same size as destination vector rank}} %1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>