Get index of result group using Python bindings

I’ve been looking MLIR operations which have result groups (i.e., return multiple values, which are then given a single name that we can index).

E.g., here @my_func returns two values, and the second one is used as an argument to @double:

module {
  func.func @main() {
    %0:2 = call @my_func() : () -> (i64, i64)
    %1 = call @double(%0#1) : (i64) -> (i64)
    return
  }
  func.func @my_func() -> (i64, i64) {
    %0 = arith.constant 42 : i64
    %1 = arith.constant 69 : i64
    return %0, %1 : i64, i64
  }
  func.func @double(%arg0: i64) -> (i64) {
    %0 = arith.addi %arg0, %arg0 : i64
    return %0 : i64
  }
}

However, when analysing the operations using the Python MLIR bindings, I’m finding it tricky to extract the index of the result group being used as an argument to @double (i.e., #1).

For my needs, knowing which index is being used is important, as it helps me analyse the data flow in the program.

If I get the operands of the @double operation with op.operands, this is of type OpOperandList.
If I loop through it, I get Value objects, in this example Value(%0:2 = func.call @my_func() : () -> (i64, i64))
This Value object doesn’t seem to encode the index of the result group being used as an argument.

Standalone example

#!/usr/bin/env python3

import os
import json
from mlir.ir import Context, Module, Location, InsertionPoint
from mlir.dialects import func
from collections import defaultdict

class MLIRFunctionExtractor:
def init(self, input_mlir_file):
self.input_mlir_file = input_mlir_file

def extract_functions(self):
    with open(self.input_mlir_file, "r") as file:
        mlir_source = file.read()

    with Context() as ctx:
        module = Module.parse(mlir_source)

        # Walk through all functions and save each to a separate file
        self.handle_op(module.operation)

@staticmethod
def extract_ssa_name(value):
    value_str = str(value)
    print("hey value", value_str)
    if not value_str.startswith("Value("):
        raise RuntimeError(
            f"Value does not start with 'Value(': {value_str}, type: ", type(value)
        )
    if "=" in value_str:
        """Normal SSA value"""
        ssa_name = value_str.split("=")[0].strip().split("(")[-1].strip()
        return ssa_name
    if "<block argument>" in value_str:
        # format: "Value(<block argument> of type 'tensor<1x4x6xf32>' at index: 6)"
        # we can use the index as the SSA name, %arg{idx}
        ssa_name = f"%arg{value_str.split(':')[-1].strip().strip(')')}"
        return ssa_name
    raise RuntimeError(f"Could not extract SSA name from value: {value_str}")

def handle_op(self, operation):
    if isinstance(operation, func.FuncOp):
        if str(operation.attributes["sym_name"]) == '"main"':
            self.get_inputs(operation)

    # Recursively walk through all operations
    for region in operation.regions:
        for block in region.blocks:
            for nested_op in block.operations:
                self.handle_op(nested_op)

def get_inputs(self, operation):
    for region in operation.regions:
        for block in region.blocks:
            for op in block.operations:
                if isinstance(op, func.CallOp):
                    op_name = str(op.attributes["callee"])
                    operands = [self.extract_ssa_name(arg) for arg in op.operands]
                    operands = [str(arg) for arg in op.operands]

                    print(operands)
                    if "double" in op_name:
                        print("hey op[0]", op.operands[0], type(op.operands[0]))
                        print(type(op.operands))
                        # print(op.operands.operand_number)
                        for operand in op.operands:
                            print(operand)
                            # print(operand.index)
                            print(type(operand))
                            # print("operand_number", operand.operand_number)
                            # print(type(operand))
                    print()

if name == “main”:
extractor = MLIRFunctionExtractor(“result_group.mlir”)
extractor.extract_functions()

Looping through the OpOperandList gives you an OpOperand which has a operand_number() method. So it would be op.operands[i].operand_number.

For reference here is the stub file: llvm-project/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi at eb9c49c900f43aa79811f80847c97c6596197430 · llvm/llvm-project · GitHub

The issue is that although type(op.operands) == OpOperandList, if I get items from it, e.g., op.operands[0], those are of type Value, not OpOperand.

This means that we don’t have the operand_number variable.

I’m not sure if this is expected behaviour, or if there’s a mismatch in the Python API.

Edit

Looking through IRCore.cpp, I can see that yes, OpOperandList does return Values.

However, looking at the definition of Value, I can see if has the uses property, which returns a PyOpOperandIterator, which I can iterate over to get the OpOperand, calling operand_number.

However, this does not give me what I want. In this example, see how I call @sum with %0#1 and %0#0:

module {
  func.func @main() {
    %0:2 = call @my_func() : () -> (i64, i64)
    %1 = call @sum(%0#1, %0#0) : (i64, i64) -> (i64)
    return
  }
  func.func @my_func() -> (i64, i64) {
    %0 = arith.constant 42 : i64
    %1 = arith.constant 69 : i64
    return %0, %1 : i64, i64
  }
  func.func @sum(%arg0: i64, %arg1: i64) -> (i64) {
    %0 = arith.addi %arg0, %arg1 : i64
    return %0 : i64
  }
}

Running

for operand in op.operands:
    for use in operand.uses:
        print(
            "operand:",
            self.extract_ssa_name(operand),
            "operand number:",
            use.operand_number,
        )

I get:

operand: %0:2 operand number: 0
operand: %0:2 operand number: 1

It seems that operand number just tells us the position in the operand list, rather than the index of the result group being used.
In this case, since we are calling (%0#1, %0#0), I would expect to see 1 and 0 as the indexes, since operand 0 is %0#1 and operand 1 is %0#0.

Ah so you want the result number. Then you can cast the Value to an OpResult. So Something like this:

for operand in op.operation.operands:
    print(OpResult(operand).result_number)
1 Like