I’ve been looking MLIR operations which have result groups (i.e., return multiple values, which are then given a single name that we can index).
E.g., here @my_func
returns two values, and the second one is used as an argument to @double
:
module {
func.func @main() {
%0:2 = call @my_func() : () -> (i64, i64)
%1 = call @double(%0#1) : (i64) -> (i64)
return
}
func.func @my_func() -> (i64, i64) {
%0 = arith.constant 42 : i64
%1 = arith.constant 69 : i64
return %0, %1 : i64, i64
}
func.func @double(%arg0: i64) -> (i64) {
%0 = arith.addi %arg0, %arg0 : i64
return %0 : i64
}
}
However, when analysing the operations using the Python MLIR bindings, I’m finding it tricky to extract the index of the result group being used as an argument to @double
(i.e., #1).
For my needs, knowing which index is being used is important, as it helps me analyse the data flow in the program.
If I get the operands of the @double
operation with op.operands
, this is of type OpOperandList
.
If I loop through it, I get Value
objects, in this example Value(%0:2 = func.call @my_func() : () -> (i64, i64))
This Value
object doesn’t seem to encode the index of the result group being used as an argument.
Standalone example
#!/usr/bin/env python3
import os
import json
from mlir.ir import Context, Module, Location, InsertionPoint
from mlir.dialects import func
from collections import defaultdict
class MLIRFunctionExtractor:
def init(self, input_mlir_file):
self.input_mlir_file = input_mlir_file
def extract_functions(self):
with open(self.input_mlir_file, "r") as file:
mlir_source = file.read()
with Context() as ctx:
module = Module.parse(mlir_source)
# Walk through all functions and save each to a separate file
self.handle_op(module.operation)
@staticmethod
def extract_ssa_name(value):
value_str = str(value)
print("hey value", value_str)
if not value_str.startswith("Value("):
raise RuntimeError(
f"Value does not start with 'Value(': {value_str}, type: ", type(value)
)
if "=" in value_str:
"""Normal SSA value"""
ssa_name = value_str.split("=")[0].strip().split("(")[-1].strip()
return ssa_name
if "<block argument>" in value_str:
# format: "Value(<block argument> of type 'tensor<1x4x6xf32>' at index: 6)"
# we can use the index as the SSA name, %arg{idx}
ssa_name = f"%arg{value_str.split(':')[-1].strip().strip(')')}"
return ssa_name
raise RuntimeError(f"Could not extract SSA name from value: {value_str}")
def handle_op(self, operation):
if isinstance(operation, func.FuncOp):
if str(operation.attributes["sym_name"]) == '"main"':
self.get_inputs(operation)
# Recursively walk through all operations
for region in operation.regions:
for block in region.blocks:
for nested_op in block.operations:
self.handle_op(nested_op)
def get_inputs(self, operation):
for region in operation.regions:
for block in region.blocks:
for op in block.operations:
if isinstance(op, func.CallOp):
op_name = str(op.attributes["callee"])
operands = [self.extract_ssa_name(arg) for arg in op.operands]
operands = [str(arg) for arg in op.operands]
print(operands)
if "double" in op_name:
print("hey op[0]", op.operands[0], type(op.operands[0]))
print(type(op.operands))
# print(op.operands.operand_number)
for operand in op.operands:
print(operand)
# print(operand.index)
print(type(operand))
# print("operand_number", operand.operand_number)
# print(type(operand))
print()
if name == “main”:
extractor = MLIRFunctionExtractor(“result_group.mlir”)
extractor.extract_functions()