It’s great to see how far along the MLIR python bindings are.
As a basic example on constructing new modules, I’m building some Python code which reads in an MLIR file, and tries to recreates it.
For example, with this as input:
func.func @test_addi(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.constant 42 : i64
%1 = arith.constant 69 : i64
%2 = arith.addi %0, %1 : i64
return %2 : i64
}
I’ve been able to get:
"builtin.module"() ({
"builtin.module"() ({
"\22test_addi\22"() ({
%0 = "arith.constant"() <{value = 42 : i64}> : () -> i64
%1 = "arith.constant"() <{value = 69 : i64}> : () -> i64
%2 = "arith.addi"(%0, %1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
"func.return"(%2) : (i64) -> ()
}) {function_type = () -> i64, sym_name = "test_addi"} : () -> ()
}) : () -> ()
}) : () -> ()
In addition, if I pass arguments to my FuncOp, e.g.,
func.func @test_addi(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.addi %arg0, %arg1 : i64
return %0 : i64
}
Then I need to add handling for the BlockArguments, since I believe that the FuncOp arguments are actually Block arguments, which gives me:
"builtin.module"() ({
"builtin.module"() ({
"\22test_addi\22"() ({
^bb0(%arg0: i64, %arg1: i64):
%0 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
"func.return"(%0) : (i64) -> ()
}) {function_type = (i64, i64) -> i64, sym_name = "test_addi"} : () -> ()
}) : () -> ()
}) : () -> ()
Note how:
- There are two levels of modules, and
- Everything is in the generic form
I’ll share my code below, but for now let’s say I’m recreating my ops with:
recreated_op = Operation.create(
op_name,
results=results,
operands=mapped_operands,
attributes=attributes,
successors=[],
regions=len(operation.regions),
loc=location,
)
I’ve noticed that the docs says that:
Operations can also be constructed using the generic class and based on the canonical string name of the operation using Operation.create
This form is discouraged from use and is intended for generic operation processing.
So in principle I could use the actual constructors (though this would require a lot of if-else statements to determine the correct constructor):
if isinstance(operation, func.FuncOp):
recreated_op = func.FuncOp(op_name, (mapped_operands, results))
elif isinstance(operation, func.ReturnOp):
recreated_op = func.ReturnOp(mapped_operands)
...
This gives me a better formatted FuncOp, but everything still appears to be in the generic form:
"builtin.module"() ({
"builtin.module"() ({
"func.func"() <{function_type = () -> (), sym_name = "\22test_addi\22"}> ({
^bb0(%arg0: i64, %arg1: i64):
%0 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
"func.return"(%0) : (i64) -> ()
}) : () -> ()
}) : () -> ()
}) : () -> ()
My full code is below, but is there anything I’m missing here that could avoid the generic form, or more MLIR Python examples out there?
Initial Python code
from mlir.ir import (
Context,
Module,
Operation,
InsertionPoint,
Location,
Block,
BlockArgument,
Type,
)
from mlir.dialects import func, arith
class MLIRRecreator:
def __init__(self, input_mlir_file, output_mlir_file):
self.input_mlir_file = input_mlir_file
self.output_mlir_file = output_mlir_file
self.ssa_map = {}
self.operand_map = {}
def recreate_mlir(self):
with open(self.input_mlir_file, "r") as file:
mlir_source = file.read()
with Context() as ctx:
ctx.allow_unregistered_dialects = True
original_module = Module.parse(mlir_source)
with Location.unknown(ctx) as default_loc:
recreated_module = Module.create()
with InsertionPoint.at_block_begin(recreated_module.body):
self.recreate_operation(
original_module.operation, default_location=default_loc
)
print("Recreated module: ", recreated_module)
with open(self.output_mlir_file, "w") as file:
file.write(str(recreated_module))
@staticmethod
def extract_ssa_name(value):
value_str = str(value)
if value_str.startswith("Value(") and "=" in value_str:
ssa_name = value_str.split("=")[0].strip().split("(")[-1].strip()
return ssa_name
return None
def recreate_operation(self, operation, default_location=None):
attributes = {
named_attr.name: named_attr.attr for named_attr in operation.attributes
}
op_name = (
str(operation.name)
if not isinstance(operation.name, str)
else operation.name
)
location = getattr(operation, "loc", default_location)
# Map operands based on previously created SSA values
mapped_operands = [self.operand_map.get(op) for op in operation.operands]
results = [res.type for res in operation.results]
# Create the new operation
if isinstance(operation, func.FuncOp):
recreated_op = func.FuncOp(op_name, (mapped_operands, results))
elif isinstance(operation, func.ReturnOp):
recreated_op = func.ReturnOp(mapped_operands)
# elif type(operation) == arith.AddIOp:
# # recreated_op = arith.AddIOp(mapped_operands, results[0])
# recreated_op = arith.AddIOp(mapped_operands, results)
else:
recreated_op = Operation.create(
op_name,
results=results,
operands=mapped_operands,
attributes=attributes,
successors=[],
regions=len(operation.regions),
loc=location,
)
# Update the SSA map with the results from the new operation
for orig_result, recreated_result in zip(
operation.results, recreated_op.results
):
original_ssa_name = self.extract_ssa_name(orig_result)
if original_ssa_name:
self.ssa_map[original_ssa_name] = recreated_result
self.operand_map[orig_result] = recreated_result
# Continue recreating regions and blocks as before
for orig_region, recreated_region in zip(
operation.regions, recreated_op.regions
):
self.recreate_region(
orig_region, recreated_region, default_location=default_location
)
def recreate_region(self, original_region, recreated_region, default_location=None):
for original_block in original_region.blocks:
# Create a new block at the start of the recreated region
# and map any block arguments it may need
recreated_block = Block.create_at_start(recreated_region)
with InsertionPoint(recreated_block):
# Map any block arguments that are needed
for block_arg_orig in original_block.arguments:
block_arg_new = recreated_block.add_argument(
Type(block_arg_orig.type), default_location
)
self.operand_map[block_arg_orig] = block_arg_new
for operation in original_block.operations:
self.recreate_operation(
operation,
default_location=default_location,
)
if __name__ == "__main__":
input_mlir_path = "input.mlir"
output_mlir_path = "recreated.mlir"
recreator = MLIRRecreator(input_mlir_path, output_mlir_path)
recreator.recreate_mlir()
print(f"MLIR has been recreated at {output_mlir_path}")
# Load and print the recreated MLIR
with open(output_mlir_path, "r") as file:
print(file.read())