@@ -31,23 +31,17 @@ namespace {
31
31
namespace saturated_arith {
32
32
struct Wrapper {
33
33
static Wrapper stride (int64_t v) {
34
- return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 }
35
- : Wrapper{false , v};
34
+ return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 } : Wrapper{false , v};
36
35
}
37
36
static Wrapper offset (int64_t v) {
38
- return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 }
39
- : Wrapper{false , v};
37
+ return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 } : Wrapper{false , v};
40
38
}
41
39
static Wrapper size (int64_t v) {
42
40
return (ShapedType::isDynamic (v)) ? Wrapper{true , 0 } : Wrapper{false , v};
43
41
}
44
- int64_t asOffset () {
45
- return saturated ? ShapedType::kDynamic : v;
46
- }
42
+ int64_t asOffset () { return saturated ? ShapedType::kDynamic : v; }
47
43
int64_t asSize () { return saturated ? ShapedType::kDynamic : v; }
48
- int64_t asStride () {
49
- return saturated ? ShapedType::kDynamic : v;
50
- }
44
+ int64_t asStride () { return saturated ? ShapedType::kDynamic : v; }
51
45
bool operator ==(Wrapper other) {
52
46
return (saturated && other.saturated ) ||
53
47
(!saturated && !other.saturated && v == other.v );
@@ -732,8 +726,7 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
732
726
for (auto it : llvm::zip (sourceStrides, resultStrides)) {
733
727
auto ss = std::get<0 >(it), st = std::get<1 >(it);
734
728
if (ss != st)
735
- if (ShapedType::isDynamic (ss) &&
736
- !ShapedType::isDynamic (st))
729
+ if (ShapedType::isDynamic (ss) && !ShapedType::isDynamic (st))
737
730
return false ;
738
731
}
739
732
@@ -766,8 +759,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
766
759
// same. They are also compatible if either one is dynamic (see
767
760
// description of MemRefCastOp for details).
768
761
auto checkCompatible = [](int64_t a, int64_t b) {
769
- return (ShapedType::isDynamic (a) ||
770
- ShapedType::isDynamic (b) || a == b);
762
+ return (ShapedType::isDynamic (a) || ShapedType::isDynamic (b) || a == b);
771
763
};
772
764
if (!checkCompatible (aOffset, bOffset))
773
765
return false ;
@@ -1890,8 +1882,7 @@ LogicalResult ReinterpretCastOp::verify() {
1890
1882
// Match offset in result memref type and in static_offsets attribute.
1891
1883
int64_t expectedOffset = getStaticOffsets ().front ();
1892
1884
if (!ShapedType::isDynamic (resultOffset) &&
1893
- !ShapedType::isDynamic (expectedOffset) &&
1894
- resultOffset != expectedOffset)
1885
+ !ShapedType::isDynamic (expectedOffset) && resultOffset != expectedOffset)
1895
1886
return emitError (" expected result type with offset = " )
1896
1887
<< expectedOffset << " instead of " << resultOffset;
1897
1888
@@ -2945,18 +2936,6 @@ static MemRefType getCanonicalSubViewResultType(
2945
2936
nonRankReducedType.getMemorySpace ());
2946
2937
}
2947
2938
2948
- // / Compute the canonical result type of a SubViewOp. Call `inferResultType`
2949
- // / to deduce the result type. Additionally, reduce the rank of the inferred
2950
- // / result type if `currentResultType` is lower rank than `sourceType`.
2951
- static MemRefType getCanonicalSubViewResultType (
2952
- MemRefType currentResultType, MemRefType sourceType,
2953
- ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2954
- ArrayRef<OpFoldResult> mixedStrides) {
2955
- return getCanonicalSubViewResultType (currentResultType, sourceType,
2956
- sourceType, mixedOffsets, mixedSizes,
2957
- mixedStrides);
2958
- }
2959
-
2960
2939
Value mlir::memref::createCanonicalRankReducingSubViewOp (
2961
2940
OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t > targetShape) {
2962
2941
auto memrefType = llvm::cast<MemRefType>(memref.getType ());
@@ -3109,9 +3088,32 @@ struct SubViewReturnTypeCanonicalizer {
3109
3088
MemRefType operator ()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3110
3089
ArrayRef<OpFoldResult> mixedSizes,
3111
3090
ArrayRef<OpFoldResult> mixedStrides) {
3112
- return getCanonicalSubViewResultType (op.getType (), op.getSourceType (),
3113
- mixedOffsets, mixedSizes,
3114
- mixedStrides);
3091
+ // Infer a memref type without taking into account any rank reductions.
3092
+ MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType (
3093
+ op.getSourceType (), mixedOffsets, mixedSizes, mixedStrides));
3094
+
3095
+ // Directly return the non-rank reduced type if there are no dropped dims.
3096
+ llvm::SmallBitVector droppedDims = op.getDroppedDims ();
3097
+ if (droppedDims.empty ())
3098
+ return nonReducedType;
3099
+
3100
+ // Take the strides and offset from the non-rank reduced type.
3101
+ auto [nonReducedStrides, offset] = getStridesAndOffset (nonReducedType);
3102
+
3103
+ // Drop dims from shape and strides.
3104
+ SmallVector<int64_t > targetShape;
3105
+ SmallVector<int64_t > targetStrides;
3106
+ for (int64_t i = 0 ; i < static_cast <int64_t >(mixedSizes.size ()); ++i) {
3107
+ if (droppedDims.test (i))
3108
+ continue ;
3109
+ targetStrides.push_back (nonReducedStrides[i]);
3110
+ targetShape.push_back (nonReducedType.getDimSize (i));
3111
+ }
3112
+
3113
+ return MemRefType::get (targetShape, nonReducedType.getElementType (),
3114
+ StridedLayoutAttr::get (nonReducedType.getContext (),
3115
+ offset, targetStrides),
3116
+ nonReducedType.getMemorySpace ());
3115
3117
}
3116
3118
};
3117
3119
0 commit comments