Skip to content

Commit cc7e24c

Browse files
matthias-springertru
authored andcommitted
[mlir] Fix crash when adding nested dialect extensions
A dialect extension can add additional dialect extensions in its `apply` function. This used to crash when the vector of `extensions` was internally reallocated while it is being iterated over. Differential Revision: https://reviews.llvm.org/D158838
1 parent 94f348b commit cc7e24c

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

mlir/lib/IR/Dialect.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
125125
MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
126126
for (auto *dialect : ctx->getLoadedDialects()) {
127127
#ifndef NDEBUG
128-
dialect->handleUseOfUndefinedPromisedInterface(interfaceKind, interfaceName);
128+
dialect->handleUseOfUndefinedPromisedInterface(interfaceKind,
129+
interfaceName);
129130
#endif
130131
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
131132
interfaces.insert(interface);
@@ -243,8 +244,9 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {
243244
extension.apply(ctx, requiredDialects);
244245
};
245246

246-
for (const auto &extension : extensions)
247-
applyExtension(*extension);
247+
// Note: Additional extensions may be added while applying an extension.
248+
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
249+
applyExtension(*extensions[i]);
248250
}
249251

250252
void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
@@ -264,8 +266,9 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
264266
extension.apply(ctx, requiredDialects);
265267
};
266268

267-
for (const auto &extension : extensions)
268-
applyExtension(*extension);
269+
// Note: Additional extensions may be added while applying an extension.
270+
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
271+
applyExtension(*extensions[i]);
269272
}
270273

271274
bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {

mlir/unittests/IR/DialectTest.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,50 @@ TEST(Dialect, RepeatedDelayedRegistration) {
136136
EXPECT_TRUE(testDialectInterface != nullptr);
137137
}
138138

139+
namespace {
140+
/// A dummy extension that increases a counter when being applied and
141+
/// recursively adds additional extensions.
142+
struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
143+
DummyExtension(int *counter, int numRecursive)
144+
: DialectExtension(), counter(counter), numRecursive(numRecursive) {}
145+
146+
void apply(MLIRContext *ctx, TestDialect *dialect) const final {
147+
++(*counter);
148+
DialectRegistry nestedRegistry;
149+
for (int i = 0; i < numRecursive; ++i)
150+
nestedRegistry.addExtension(
151+
std::make_unique<DummyExtension>(counter, /*numRecursive=*/0));
152+
// Adding additional extensions may trigger a reallocation of the
153+
// `extensions` vector in the dialect registry.
154+
ctx->appendDialectRegistry(nestedRegistry);
155+
}
156+
157+
private:
158+
int *counter;
159+
int numRecursive;
160+
};
161+
} // namespace
162+
163+
TEST(Dialect, NestedDialectExtension) {
164+
DialectRegistry registry;
165+
registry.insert<TestDialect>();
166+
167+
// Add an extension that adds 100 more extensions.
168+
int counter1 = 0;
169+
registry.addExtension(std::make_unique<DummyExtension>(&counter1, 100));
170+
// Add one more extension. This should not crash.
171+
int counter2 = 0;
172+
registry.addExtension(std::make_unique<DummyExtension>(&counter2, 0));
173+
174+
// Load dialect and apply extensions.
175+
MLIRContext context(registry);
176+
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
177+
ASSERT_TRUE(testDialect != nullptr);
178+
179+
// Extensions may be applied multiple times. Make sure that each expected
180+
// extension was applied at least once.
181+
EXPECT_GE(counter1, 101);
182+
EXPECT_GE(counter2, 1);
183+
}
184+
139185
} // namespace

0 commit comments

Comments
 (0)