Skip to content

[MLIR][Linalg] pack, unpack to take memref inputs #129036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
4d523ad
draft
ita9naiwa Feb 27, 2025
4f2dbf4
draft
ita9naiwa Feb 27, 2025
226230c
draft
ita9naiwa Feb 27, 2025
0c184df
init
ita9naiwa Feb 28, 2025
19201c6
lint
ita9naiwa Feb 28, 2025
b99b920
lint
ita9naiwa Feb 28, 2025
be6a119
add
ita9naiwa Feb 28, 2025
eee8805
remove tensor casting
ita9naiwa Mar 1, 2025
c5b3c39
add test
ita9naiwa Mar 1, 2025
714e4c4
Merge branch 'main' into ita9naiwa/pack-memref
ita9naiwa Mar 16, 2025
a5d01df
fix upon review
ita9naiwa Mar 16, 2025
2480616
lint
ita9naiwa Mar 23, 2025
0421e72
Merge branch 'main' into ita9naiwa/pack-memref
ita9naiwa Mar 23, 2025
7b92a4e
format fix
ita9naiwa Mar 24, 2025
6dc08ae
revert changes
ita9naiwa Mar 25, 2025
cf7be57
revert changes
ita9naiwa Mar 25, 2025
4e2f00d
nit
ita9naiwa Mar 25, 2025
ee7a42a
fix upon review: Add getEffects for PackOp and UnPackOp
ita9naiwa Mar 27, 2025
5b95ee8
make clang-format happy
ita9naiwa Mar 27, 2025
8b5ac5a
make clang-format happy
ita9naiwa Mar 27, 2025
c955d21
wrap getEffects function
ita9naiwa Mar 27, 2025
4bedf40
Merge branch 'main' into ita9naiwa/pack-memref
ita9naiwa Mar 27, 2025
276069d
fix upon review
ita9naiwa Mar 29, 2025
790e974
bail out transforms using PackOp, UnPackOp
ita9naiwa Mar 30, 2025
820e40b
fix build error
ita9naiwa Mar 30, 2025
43a64b9
fix build error
ita9naiwa Mar 30, 2025
a3bba60
Merge branch 'main' into ita9naiwa/pack-memref
ita9naiwa Apr 2, 2025
486c62b
add invalid pack/unpack cases
ita9naiwa Apr 2, 2025
ca889b5
fix roundtrip test
ita9naiwa Apr 2, 2025
ce910b9
fix upon review
ita9naiwa Apr 2, 2025
6a501bd
fix upon review
ita9naiwa Apr 2, 2025
535e796
.
ita9naiwa Apr 2, 2025
4cbbb80
Merge branch 'main' into ita9naiwa/pack-memref
ita9naiwa Apr 6, 2025
17ad838
fix upon review
ita9naiwa Apr 13, 2025
2aca3fd
Update mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
ita9naiwa Apr 20, 2025
2541aa2
Revert unnecessary cosmetic changes
ita9naiwa Jul 15, 2025
1488cb9
Address HanHan review feedback: disable canonicalization for memref p…
ita9naiwa Jul 15, 2025
d10cb1d
Merge branch 'main' into ita9naiwa/pack-memref
ita9naiwa Jul 15, 2025
d16448a
Apply clang-format to code changes
ita9naiwa Jul 15, 2025
825f11b
fix upon review
ita9naiwa Aug 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
ConditionallySpeculatable, NoMemoryEffect,
ConditionallySpeculatable, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
TypesMatchWith<"result type matches type of dest",
OptionalTypesMatchWith<"result type matches type of dest",
"dest", "result",
"$_self">])> {

code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
RankedTensorType getSourceType() {
return ::llvm::cast<RankedTensorType>(getSource().getType()); };
RankedTensorType getDestType() {
return ::llvm::cast<RankedTensorType>(getDest().getType()); };
ShapedType getSourceType() {
return ::llvm::cast<ShapedType>(getSource().getType()); };
ShapedType getDestType() {
return ::llvm::cast<ShapedType>(getDest().getType()); };

MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }

Expand Down Expand Up @@ -168,23 +168,16 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Note: Only tiled dimensions can be padded.
```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
let arguments = (ins AnyShaped:$source,
AnyShaped:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $dest attr-dict `:` type($source) `->` type($dest)
}];
let results = (outs Optional<AnyRankedTensor>:$result);

let hasCustomAssemblyFormat = 1;

let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
Expand All @@ -206,7 +199,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static RankedTensorType inferPackedType(RankedTensorType sourceType,
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

// Method to get the `MemRefType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static MemRefType inferPackedMemRefType(MemRefType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

// Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic.
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

Expand Down Expand Up @@ -317,21 +322,15 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
// Outer Dims: 9x3x8 Inner Dims: 4x2
```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
let arguments = (ins AnyShaped:$source,
AnyShaped:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $dest attr-dict `:` type($source) `->` type($dest)
}];
let results = (outs Optional<AnyRankedTensor>:$result);

let hasCustomAssemblyFormat = 1;

let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,8 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
omp::FlushOp, omp::MapBoundsOp,
omp::ThreadprivateOp>::value) {
if (isa<MemRefType>(originalOperand.getType())) {
// TODO: Support memref type in variable operands
if (isa<MemRefType>(originalOperand.getType()))
return rewriter.notifyMatchFailure(op, "memref is not supported yet");
}
}
convertedOperands.push_back(convertedOperand);
}
Expand Down
Loading