MLIR Python Bindings: Recreate a Module

It’s great to see how far along the MLIR python bindings are.

As a basic example on constructing new modules, I’m building some Python code which reads in an MLIR file, and tries to recreates it.

For example, with this as input:

func.func @test_addi(%arg0 : i64, %arg1 : i64) -> i64 {
  %0 = arith.constant 42 : i64
  %1 = arith.constant 69 : i64
  %2 = arith.addi %0, %1 : i64
  return %2 : i64
}

I’ve been able to get:

"builtin.module"() ({
  "builtin.module"() ({
    "\22test_addi\22"() ({
      %0 = "arith.constant"() <{value = 42 : i64}> : () -> i64
      %1 = "arith.constant"() <{value = 69 : i64}> : () -> i64
      %2 = "arith.addi"(%0, %1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
      "func.return"(%2) : (i64) -> ()
    }) {function_type = () -> i64, sym_name = "test_addi"} : () -> ()
  }) : () -> ()
}) : () -> ()

In addition, if I pass arguments to my FuncOp, e.g.,

func.func @test_addi(%arg0 : i64, %arg1 : i64) -> i64 {
  %0 = arith.addi %arg0, %arg1 : i64
  return %0 : i64
}

Then I need to add handling for the BlockArguments, since I believe that the FuncOp arguments are actually Block arguments, which gives me:

"builtin.module"() ({
  "builtin.module"() ({
    "\22test_addi\22"() ({
    ^bb0(%arg0: i64, %arg1: i64):
      %0 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
      "func.return"(%0) : (i64) -> ()
    }) {function_type = (i64, i64) -> i64, sym_name = "test_addi"} : () -> ()
  }) : () -> ()
}) : () -> ()

Note how:

  1. There are two levels of modules, and
  2. Everything is in the generic form

I’ll share my code below, but for now let’s say I’m recreating my ops with:

recreated_op = Operation.create(
                op_name,
                results=results,
                operands=mapped_operands,
                attributes=attributes,
                successors=[],
                regions=len(operation.regions),
                loc=location,
)

I’ve noticed that the docs says that:

Operations can also be constructed using the generic class and based on the canonical string name of the operation using Operation.create
This form is discouraged from use and is intended for generic operation processing.

So in principle I could use the actual constructors (though this would require a lot of if-else statements to determine the correct constructor):

if isinstance(operation, func.FuncOp):
    recreated_op = func.FuncOp(op_name, (mapped_operands, results))
elif isinstance(operation, func.ReturnOp):
    recreated_op = func.ReturnOp(mapped_operands)
...

This gives me a better formatted FuncOp, but everything still appears to be in the generic form:

"builtin.module"() ({
  "builtin.module"() ({
    "func.func"() <{function_type = () -> (), sym_name = "\22test_addi\22"}> ({
    ^bb0(%arg0: i64, %arg1: i64):
      %0 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
      "func.return"(%0) : (i64) -> ()
    }) : () -> ()
  }) : () -> ()
}) : () -> ()

My full code is below, but is there anything I’m missing here that could avoid the generic form, or more MLIR Python examples out there?

Initial Python code
from mlir.ir import (
    Context,
    Module,
    Operation,
    InsertionPoint,
    Location,
    Block,
    BlockArgument,
    Type,
)
from mlir.dialects import func, arith


class MLIRRecreator:
    def __init__(self, input_mlir_file, output_mlir_file):
        self.input_mlir_file = input_mlir_file
        self.output_mlir_file = output_mlir_file
        self.ssa_map = {}
        self.operand_map = {}

    def recreate_mlir(self):
        with open(self.input_mlir_file, "r") as file:
            mlir_source = file.read()

        with Context() as ctx:
            ctx.allow_unregistered_dialects = True
            original_module = Module.parse(mlir_source)

            with Location.unknown(ctx) as default_loc:
                recreated_module = Module.create()
                with InsertionPoint.at_block_begin(recreated_module.body):
                    self.recreate_operation(
                        original_module.operation, default_location=default_loc
                    )

                print("Recreated module: ", recreated_module)
                with open(self.output_mlir_file, "w") as file:
                    file.write(str(recreated_module))

    @staticmethod
    def extract_ssa_name(value):
        value_str = str(value)
        if value_str.startswith("Value(") and "=" in value_str:
            ssa_name = value_str.split("=")[0].strip().split("(")[-1].strip()
            return ssa_name
        return None

    def recreate_operation(self, operation, default_location=None):
        attributes = {
            named_attr.name: named_attr.attr for named_attr in operation.attributes
        }
        op_name = (
            str(operation.name)
            if not isinstance(operation.name, str)
            else operation.name
        )
        location = getattr(operation, "loc", default_location)

        # Map operands based on previously created SSA values
        mapped_operands = [self.operand_map.get(op) for op in operation.operands]
        results = [res.type for res in operation.results]

        # Create the new operation
        if isinstance(operation, func.FuncOp):
            recreated_op = func.FuncOp(op_name, (mapped_operands, results))
        elif isinstance(operation, func.ReturnOp):
            recreated_op = func.ReturnOp(mapped_operands)
        # elif type(operation) == arith.AddIOp:
        #     # recreated_op = arith.AddIOp(mapped_operands, results[0])
        #     recreated_op = arith.AddIOp(mapped_operands, results)
        else:
            recreated_op = Operation.create(
                op_name,
                results=results,
                operands=mapped_operands,
                attributes=attributes,
                successors=[],
                regions=len(operation.regions),
                loc=location,
            )

        # Update the SSA map with the results from the new operation
        for orig_result, recreated_result in zip(
            operation.results, recreated_op.results
        ):
            original_ssa_name = self.extract_ssa_name(orig_result)
            if original_ssa_name:
                self.ssa_map[original_ssa_name] = recreated_result
                self.operand_map[orig_result] = recreated_result

        # Continue recreating regions and blocks as before
        for orig_region, recreated_region in zip(
            operation.regions, recreated_op.regions
        ):
            self.recreate_region(
                orig_region, recreated_region, default_location=default_location
            )

    def recreate_region(self, original_region, recreated_region, default_location=None):

        for original_block in original_region.blocks:

            # Create a new block at the start of the recreated region
            # and map any block arguments it may need
            recreated_block = Block.create_at_start(recreated_region)

            with InsertionPoint(recreated_block):
                # Map any block arguments that are needed
                for block_arg_orig in original_block.arguments:
                    block_arg_new = recreated_block.add_argument(
                        Type(block_arg_orig.type), default_location
                    )
                    self.operand_map[block_arg_orig] = block_arg_new
                for operation in original_block.operations:
                    self.recreate_operation(
                        operation,
                        default_location=default_location,
                    )


if __name__ == "__main__":
    input_mlir_path = "input.mlir"
    output_mlir_path = "recreated.mlir"
    recreator = MLIRRecreator(input_mlir_path, output_mlir_path)
    recreator.recreate_mlir()
    print(f"MLIR has been recreated at {output_mlir_path}")
    # Load and print the recreated MLIR
    with open(output_mlir_path, "r") as file:
        print(file.read())

Quick answer to your question (I didn’t spot the bug in your code yet): if the module doesn’t verify due to some constraint violation, it prints generic. Call verify() on the module operation to have it run the verifier and throw a real error message. That will probably point at what is going wrong.

(Edit: I see the problem: your function_type does not match your entry block arguments)

Great, thanks! Calling .verify() will really help with my work here.

I was able to get a non-generic output now, by adding a special case to handle FuncOps, but fortunately it seems that’s the only special case I need:

module {
  module {
    func.func @test_addi(%arg0: i64, %arg1: i64) -> i64 {
      %0 = arith.addi %arg0, %arg1 : i64
      return %0 : i64
    }
  }
}
Updated `recreate_operation` function, with `verify()` after every operation creation.
    def recreate_operation(self, operation, default_location=None):
        attributes = {
            named_attr.name: named_attr.attr for named_attr in operation.attributes
        }
        op_name = (
            str(operation.name)
            if not isinstance(operation.name, str)
            else operation.name
        )
        location = getattr(operation, "loc", default_location)

        # Map operands based on previously created SSA values
        mapped_operands = [self.operand_map.get(op) for op in operation.operands]
        results = [res.type for res in operation.results]

         if isinstance(operation, func.FuncOp):
            op_name = op_name.strip('"')
            recreated_op = func.FuncOp(
                name=op_name,
                type=operation.type,
            )
        else:
            recreated_op = Operation.create(
                op_name,
                results=results,
                operands=mapped_operands,
                attributes=attributes,
                successors=[],
                regions=len(operation.regions),
                loc=location,
            )

        # Update the SSA map with the results from the new operation
        for orig_result, recreated_result in zip(
            operation.results, recreated_op.results
        ):
            original_ssa_name = self.extract_ssa_name(orig_result)
            if original_ssa_name:
                self.ssa_map[original_ssa_name] = recreated_result
                self.operand_map[orig_result] = recreated_result

        # Continue recreating regions and blocks as before
        for orig_region, recreated_region in zip(
            operation.regions, recreated_op.regions
        ):
            self.recreate_region(
                orig_region, recreated_region, default_location=default_location
            )

        recreated_op.verify()

I still have this double module, but I can live with that for now.