Skip to content

Commit ad5ed49

Browse files
matthias-springertru
authored andcommitted
[mlir][memref] Fix crash in SubViewReturnTypeCanonicalizer
`SubViewReturnTypeCanonicalizer` is used by `OpWithOffsetSizesAndStridesConstantArgumentFolder`, which folds constant SSA value (dynamic) sizes into static sizes. The previous implementation crashed when a dynamic size was folded into a static `1` dimension, which was then mistaken as a rank reduction. Differential Revision: https://reviews.llvm.org/D158721
1 parent 08d720d commit ad5ed49

File tree

2 files changed

+49
-32
lines changed

2 files changed

+49
-32
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,17 @@ namespace {
3131
namespace saturated_arith {
3232
struct Wrapper {
3333
static Wrapper stride(int64_t v) {
34-
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
35-
: Wrapper{false, v};
34+
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
3635
}
3736
static Wrapper offset(int64_t v) {
38-
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
39-
: Wrapper{false, v};
37+
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
4038
}
4139
static Wrapper size(int64_t v) {
4240
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
4341
}
44-
int64_t asOffset() {
45-
return saturated ? ShapedType::kDynamic : v;
46-
}
42+
int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
4743
int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
48-
int64_t asStride() {
49-
return saturated ? ShapedType::kDynamic : v;
50-
}
44+
int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
5145
bool operator==(Wrapper other) {
5246
return (saturated && other.saturated) ||
5347
(!saturated && !other.saturated && v == other.v);
@@ -732,8 +726,7 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
732726
for (auto it : llvm::zip(sourceStrides, resultStrides)) {
733727
auto ss = std::get<0>(it), st = std::get<1>(it);
734728
if (ss != st)
735-
if (ShapedType::isDynamic(ss) &&
736-
!ShapedType::isDynamic(st))
729+
if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
737730
return false;
738731
}
739732

@@ -766,8 +759,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
766759
// same. They are also compatible if either one is dynamic (see
767760
// description of MemRefCastOp for details).
768761
auto checkCompatible = [](int64_t a, int64_t b) {
769-
return (ShapedType::isDynamic(a) ||
770-
ShapedType::isDynamic(b) || a == b);
762+
return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
771763
};
772764
if (!checkCompatible(aOffset, bOffset))
773765
return false;
@@ -1890,8 +1882,7 @@ LogicalResult ReinterpretCastOp::verify() {
18901882
// Match offset in result memref type and in static_offsets attribute.
18911883
int64_t expectedOffset = getStaticOffsets().front();
18921884
if (!ShapedType::isDynamic(resultOffset) &&
1893-
!ShapedType::isDynamic(expectedOffset) &&
1894-
resultOffset != expectedOffset)
1885+
!ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
18951886
return emitError("expected result type with offset = ")
18961887
<< expectedOffset << " instead of " << resultOffset;
18971888

@@ -2945,18 +2936,6 @@ static MemRefType getCanonicalSubViewResultType(
29452936
nonRankReducedType.getMemorySpace());
29462937
}
29472938

2948-
/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2949-
/// to deduce the result type. Additionally, reduce the rank of the inferred
2950-
/// result type if `currentResultType` is lower rank than `sourceType`.
2951-
static MemRefType getCanonicalSubViewResultType(
2952-
MemRefType currentResultType, MemRefType sourceType,
2953-
ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2954-
ArrayRef<OpFoldResult> mixedStrides) {
2955-
return getCanonicalSubViewResultType(currentResultType, sourceType,
2956-
sourceType, mixedOffsets, mixedSizes,
2957-
mixedStrides);
2958-
}
2959-
29602939
Value mlir::memref::createCanonicalRankReducingSubViewOp(
29612940
OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
29622941
auto memrefType = llvm::cast<MemRefType>(memref.getType());
@@ -3109,9 +3088,32 @@ struct SubViewReturnTypeCanonicalizer {
31093088
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
31103089
ArrayRef<OpFoldResult> mixedSizes,
31113090
ArrayRef<OpFoldResult> mixedStrides) {
3112-
return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
3113-
mixedOffsets, mixedSizes,
3114-
mixedStrides);
3091+
// Infer a memref type without taking into account any rank reductions.
3092+
MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType(
3093+
op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides));
3094+
3095+
// Directly return the non-rank reduced type if there are no dropped dims.
3096+
llvm::SmallBitVector droppedDims = op.getDroppedDims();
3097+
if (droppedDims.empty())
3098+
return nonReducedType;
3099+
3100+
// Take the strides and offset from the non-rank reduced type.
3101+
auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
3102+
3103+
// Drop dims from shape and strides.
3104+
SmallVector<int64_t> targetShape;
3105+
SmallVector<int64_t> targetStrides;
3106+
for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3107+
if (droppedDims.test(i))
3108+
continue;
3109+
targetStrides.push_back(nonReducedStrides[i]);
3110+
targetShape.push_back(nonReducedType.getDimSize(i));
3111+
}
3112+
3113+
return MemRefType::get(targetShape, nonReducedType.getElementType(),
3114+
StridedLayoutAttr::get(nonReducedType.getContext(),
3115+
offset, targetStrides),
3116+
nonReducedType.getMemorySpace());
31153117
}
31163118
};
31173119

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ func.func @fold_multiple_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32
931931

932932
// -----
933933

934-
// CHECK-lABEL: func @ub_negative_alloc_size
934+
// CHECK-LABEL: func private @ub_negative_alloc_size
935935
func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
936936
%idx1 = index.constant 1
937937
%c-2 = arith.constant -2 : index
@@ -940,3 +940,18 @@ func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
940940
%alloc = memref.alloc(%c15, %c-2, %idx1) : memref<?x?x?xi1>
941941
return %alloc : memref<?x?x?xi1>
942942
}
943+
944+
// -----
945+
946+
// CHECK-LABEL: func @subview_rank_reduction(
947+
// CHECK-SAME: %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index
948+
func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
949+
-> memref<?x?xf32, strided<[384, 1], offset: ?>> {
950+
%c1 = arith.constant 1 : index
951+
// CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
952+
// CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?x?xf32, strided<[384, 1], offset: ?>>
953+
%0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1]
954+
: memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
955+
// CHECK: return %[[cast]]
956+
return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
957+
}

0 commit comments

Comments
 (0)