Numerical Testing: Multiple Output in MLIR Python Bindings

I have implemented a VarMeanOp in MLIR that computes variance and mean as outputs. I have successfully lowered this operation to LLVM and am now testing it using MLIR Python bindings as part of the numerical validation process.

Problem:

When dealing with single-output operations, I can correctly validate the results using Python bindings by creating a memref and invoking the operation using ExecutionEngine.invoke(). This works by leveraging APIs such as:

get_ranked_memref_descriptor() 
make_nd_memref_descriptor(rank, type)

However, for multi-output operations, I am encountering issues.

What I Have Tried:
I created two separate memref descriptors for the two outputs (variance and mean).
I attempted to pass these memrefs to engine.invoke() to capture both results. But I am unable to retrieve the correct outputs when invoking the operation.

Single output (works correctly)

res_memref = make_nd_memref_descriptor(rank, element_type) engine.invoke("single_output_op", input_memref, res_memref)

Multiple outputs (not working)

var_memref = make_nd_memref_descriptor(rank, element_type) 
mean_memref = make_nd_memref_descriptor(rank, element_type) 
engine.invoke("multi_output_op", var_memref, mean_memref,input_memref)

Unable to get the correct results.

How can I correctly handle multiple outputs in ExecutionEngine.invoke() when testing multi-output operations in Python?
Is there a recommended approach to create and pass multiple memrefs for capturing multiple outputs?

​​

​​

Can you provide a minimal test?

def testTOSAVarMean(funcName, shape, dtype):
    with Context():
        command = [
            "/home/mcw/Projects/MXNET-MLIR/build/bin/MxNet-opt",
            "--pass-pipeline=builtin.module(func.func(llvm-request-c-wrappers),mxnet-to-llvm)",
            "../test/MxNetToLLVM/var_mean.mlir",
        ]
        module = subprocess.run(command, text=True, capture_output=True)
        module = Module.parse(module.stdout)
        execution_engine = ExecutionEngine(module)
        input_arg = np.random.randint(1, 100, shape).astype(dtype)

        ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_float * 2)()

        mem_out_var = ctypes.pointer(ctypes.pointer(ref_out))
        mem_out_mean = ctypes.pointer(ctypes.pointer(ref_out))

        mem_input = ctypes.pointer(
            ctypes.pointer(rt.get_ranked_memref_descriptor(input_arg))
        )
        execution_engine.invoke(funcName, mem_out_var, mem_out_mean, me mem_input)

        var = ranked_memref_to_numpy(mem_out_var.contents)
        mean = ranked_memref_to_numpy(mem_out_mean.contents)

        c3, c4 = torch.var_mean(torch.from_numpy(input_arg), dim=0, keepdim=True)

        if torch.allclose(torch.from_numpy(var), c3, equal_nan=True) and torch.allclose(
            torch.from_numpy(mean), c4, equal_nan=True
        ):
            print("True")
        else:
            print("False")


testTOSAVarMean("var_mean_f32", (4, 4), np.float32)

I am trying this, what i understand till now is that invoke() accepts multiple inputs, but it can fetch only single result.
By default it will take first argument to store the result.

Below is the implementation of ExecutionEngine.invoke()

def invoke(self, name, *ctypes_args):
        """Invoke a function with the list of ctypes arguments.
        All arguments must be pointers.
        Raise a RuntimeError if the function isn't found.
        """
        func = self.lookup(name)
        packed_args = (ctypes.c_void_p * len(ctypes_args))()
        for argNum in range(len(ctypes_args)):
            packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
        func(packed_args)

I meant a test that would run in MLIR upstream (you didn’t provide you IR, and you seem to depend on custom pass pipeline from yours).

If this is an execution engine issue, we should be able to demonstrate it with a new test here: llvm-project/mlir/test/python/execution_engine.py at main · llvm/llvm-project · GitHub

// RUN: MxNet-opt --mxnet-to-llvm %s | FileCheck %s

// CHECK-LABEL: func.func @var_mean
func.func @var_mean(%tensorA:tensor<3x3xf32>)-> (tensor<3x1xf32>,tensor<3x1xf32>) {
  %result1, %result2 = "MxNet.var_mean"( %tensorA ) {axes = array<i32 : 1>, keepdim=1:i1} : ( tensor<3x3xf32>) -> (tensor<3x1xf32>,tensor<3x1xf32>)
  return %result1, %result2 : tensor<3x1xf32>,tensor<3x1xf32>
}

LLVM IR

module {
  llvm.func @free(!llvm.ptr)
  llvm.func @malloc(i64) -> !llvm.ptr
  llvm.mlir.global private constant @__constant_3x1xf32_0(dense<3.000000e+00> : tensor<3x1xf32>) {addr_space = 0 : i32, alignment = 64 : i64} : !llvm.array<3 x array<1 x f32>>
  llvm.mlir.global private constant @__constant_3x1xf32(dense<2.000000e+00> : tensor<3x1xf32>) {addr_space = 0 : i32, alignment = 64 : i64} : !llvm.array<3 x array<1 x f32>>
  llvm.func @var_mean(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) -> !llvm.struct<(struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>)> {
    %0 = llvm.mlir.undef : !llvm.struct<(struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>)>
    %1 = llvm.mlir.constant(64 : index) : i64
    %2 = llvm.mlir.addressof @__constant_3x1xf32_0 : !llvm.ptr
    %3 = llvm.mlir.addressof @__constant_3x1xf32 : !llvm.ptr
    %4 = llvm.mlir.zero : !llvm.ptr
    %5 = llvm.mlir.constant(0.000000e+00 : f32) : f32
    %6 = llvm.mlir.constant(0 : index) : i64
    %7 = llvm.mlir.constant(3 : index) : i64
    %8 = llvm.mlir.constant(1 : index) : i64
    %9 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %10 = llvm.getelementptr %3[0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<3 x array<1 x f32>>
    %11 = llvm.getelementptr %2[0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<3 x array<1 x f32>>
    %12 = llvm.getelementptr %4[3] : (!llvm.ptr) -> !llvm.ptr, f32
    %13 = llvm.ptrtoint %12 : !llvm.ptr to i64
    %14 = llvm.add %13, %1 : i64
    %15 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr
    %16 = llvm.ptrtoint %15 : !llvm.ptr to i64
    %17 = llvm.sub %1, %8 : i64
    %18 = llvm.add %16, %17 : i64
    %19 = llvm.urem %18, %1 : i64
    %20 = llvm.sub %18, %19 : i64
    %21 = llvm.inttoptr %20 : i64 to !llvm.ptr
    llvm.br ^bb1(%6 : i64)
  ^bb1(%22: i64):  // 2 preds: ^bb0, ^bb2
    %23 = llvm.icmp "slt" %22, %7 : i64
    llvm.cond_br %23, ^bb2, ^bb3
  ^bb2:  // pred: ^bb1
    %24 = llvm.getelementptr %21[%22] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    llvm.store %5, %24 : f32, !llvm.ptr
    %25 = llvm.add %22, %8 : i64
    llvm.br ^bb1(%25 : i64)
  ^bb3:  // pred: ^bb1
    llvm.br ^bb4(%6 : i64)
  ^bb4(%26: i64):  // 2 preds: ^bb3, ^bb8
    %27 = llvm.icmp "slt" %26, %7 : i64
    llvm.cond_br %27, ^bb5, ^bb9
  ^bb5:  // pred: ^bb4
    llvm.br ^bb6(%6 : i64)
  ^bb6(%28: i64):  // 2 preds: ^bb5, ^bb7
    %29 = llvm.icmp "slt" %28, %7 : i64
    llvm.cond_br %29, ^bb7, ^bb8
  ^bb7:  // pred: ^bb6
    %30 = llvm.getelementptr %arg1[%arg2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %31 = llvm.mul %26, %arg5 : i64
    %32 = llvm.mul %28, %arg6 : i64
    %33 = llvm.add %31, %32 : i64
    %34 = llvm.getelementptr %30[%33] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %35 = llvm.load %34 : !llvm.ptr -> f32
    %36 = llvm.getelementptr %21[%26] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %37 = llvm.load %36 : !llvm.ptr -> f32
    %38 = llvm.fadd %35, %37 : f32
    llvm.store %38, %36 : f32, !llvm.ptr
    %39 = llvm.add %28, %8 : i64
    llvm.br ^bb6(%39 : i64)
  ^bb8:  // pred: ^bb6
    %40 = llvm.add %26, %8 : i64
    llvm.br ^bb4(%40 : i64)
  ^bb9:  // pred: ^bb4
    %41 = llvm.insertvalue %15, %9[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %42 = llvm.insertvalue %21, %41[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %43 = llvm.insertvalue %6, %42[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %44 = llvm.insertvalue %7, %43[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %45 = llvm.insertvalue %8, %44[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %46 = llvm.insertvalue %8, %45[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %47 = llvm.insertvalue %8, %46[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    llvm.br ^bb10(%6 : i64)
  ^bb10(%48: i64):  // 2 preds: ^bb9, ^bb14
    %49 = llvm.icmp "slt" %48, %7 : i64
    llvm.cond_br %49, ^bb11, ^bb15
  ^bb11:  // pred: ^bb10
    llvm.br ^bb12(%6 : i64)
  ^bb12(%50: i64):  // 2 preds: ^bb11, ^bb13
    %51 = llvm.icmp "slt" %50, %8 : i64
    llvm.cond_br %51, ^bb13, ^bb14
  ^bb13:  // pred: ^bb12
    %52 = llvm.add %48, %50 : i64
    %53 = llvm.getelementptr %21[%52] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %54 = llvm.load %53 : !llvm.ptr -> f32
    %55 = llvm.getelementptr %11[%52] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %56 = llvm.load %55 : !llvm.ptr -> f32
    %57 = llvm.fdiv %54, %56 : f32
    llvm.store %57, %53 : f32, !llvm.ptr
    %58 = llvm.add %50, %8 : i64
    llvm.br ^bb12(%58 : i64)
  ^bb14:  // pred: ^bb12
    %59 = llvm.add %48, %8 : i64
    llvm.br ^bb10(%59 : i64)
  ^bb15:  // pred: ^bb10
    %60 = llvm.getelementptr %4[9] : (!llvm.ptr) -> !llvm.ptr, f32
    %61 = llvm.ptrtoint %60 : !llvm.ptr to i64
    %62 = llvm.add %61, %1 : i64
    %63 = llvm.call @malloc(%62) : (i64) -> !llvm.ptr
    %64 = llvm.ptrtoint %63 : !llvm.ptr to i64
    %65 = llvm.add %64, %17 : i64
    %66 = llvm.urem %65, %1 : i64
    %67 = llvm.sub %65, %66 : i64
    %68 = llvm.inttoptr %67 : i64 to !llvm.ptr
    llvm.br ^bb16(%6 : i64)
  ^bb16(%69: i64):  // 2 preds: ^bb15, ^bb20
    %70 = llvm.icmp "slt" %69, %7 : i64
    llvm.cond_br %70, ^bb17, ^bb21
  ^bb17:  // pred: ^bb16
    llvm.br ^bb18(%6 : i64)
  ^bb18(%71: i64):  // 2 preds: ^bb17, ^bb19
    %72 = llvm.icmp "slt" %71, %7 : i64
    llvm.cond_br %72, ^bb19, ^bb20
  ^bb19:  // pred: ^bb18
    %73 = llvm.getelementptr %arg1[%arg2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %74 = llvm.mul %69, %arg5 : i64
    %75 = llvm.mul %71, %arg6 : i64
    %76 = llvm.add %74, %75 : i64
    %77 = llvm.getelementptr %73[%76] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %78 = llvm.load %77 : !llvm.ptr -> f32
    %79 = llvm.add %69, %6 : i64
    %80 = llvm.getelementptr %21[%79] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %81 = llvm.load %80 : !llvm.ptr -> f32
    %82 = llvm.fsub %78, %81 : f32
    %83 = llvm.mul %69, %7 : i64
    %84 = llvm.add %83, %71 : i64
    %85 = llvm.getelementptr %68[%84] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    llvm.store %82, %85 : f32, !llvm.ptr
    %86 = llvm.add %71, %8 : i64
    llvm.br ^bb18(%86 : i64)
  ^bb20:  // pred: ^bb18
    %87 = llvm.add %69, %8 : i64
    llvm.br ^bb16(%87 : i64)
  ^bb21:  // pred: ^bb16
    %88 = llvm.call @malloc(%62) : (i64) -> !llvm.ptr
    %89 = llvm.ptrtoint %88 : !llvm.ptr to i64
    %90 = llvm.add %89, %17 : i64
    %91 = llvm.urem %90, %1 : i64
    %92 = llvm.sub %90, %91 : i64
    %93 = llvm.inttoptr %92 : i64 to !llvm.ptr
    llvm.br ^bb22(%6 : i64)
  ^bb22(%94: i64):  // 2 preds: ^bb21, ^bb26
    %95 = llvm.icmp "slt" %94, %7 : i64
    llvm.cond_br %95, ^bb23, ^bb27
  ^bb23:  // pred: ^bb22
    llvm.br ^bb24(%6 : i64)
  ^bb24(%96: i64):  // 2 preds: ^bb23, ^bb25
    %97 = llvm.icmp "slt" %96, %7 : i64
    llvm.cond_br %97, ^bb25, ^bb26
  ^bb25:  // pred: ^bb24
    %98 = llvm.mul %94, %7 : i64
    %99 = llvm.add %98, %96 : i64
    %100 = llvm.getelementptr %68[%99] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %101 = llvm.load %100 : !llvm.ptr -> f32
    %102 = llvm.fmul %101, %101 : f32
    %103 = llvm.getelementptr %93[%99] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    llvm.store %102, %103 : f32, !llvm.ptr
    %104 = llvm.add %96, %8 : i64
    llvm.br ^bb24(%104 : i64)
  ^bb26:  // pred: ^bb24
    %105 = llvm.add %94, %8 : i64
    llvm.br ^bb22(%105 : i64)
  ^bb27:  // pred: ^bb22
    %106 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr
    %107 = llvm.ptrtoint %106 : !llvm.ptr to i64
    %108 = llvm.add %107, %17 : i64
    %109 = llvm.urem %108, %1 : i64
    %110 = llvm.sub %108, %109 : i64
    %111 = llvm.inttoptr %110 : i64 to !llvm.ptr
    llvm.br ^bb28(%6 : i64)
  ^bb28(%112: i64):  // 2 preds: ^bb27, ^bb29
    %113 = llvm.icmp "slt" %112, %7 : i64
    llvm.cond_br %113, ^bb29, ^bb30
  ^bb29:  // pred: ^bb28
    %114 = llvm.getelementptr %111[%112] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    llvm.store %5, %114 : f32, !llvm.ptr
    %115 = llvm.add %112, %8 : i64
    llvm.br ^bb28(%115 : i64)
  ^bb30:  // pred: ^bb28
    llvm.br ^bb31(%6 : i64)
  ^bb31(%116: i64):  // 2 preds: ^bb30, ^bb35
    %117 = llvm.icmp "slt" %116, %7 : i64
    llvm.cond_br %117, ^bb32, ^bb36
  ^bb32:  // pred: ^bb31
    llvm.br ^bb33(%6 : i64)
  ^bb33(%118: i64):  // 2 preds: ^bb32, ^bb34
    %119 = llvm.icmp "slt" %118, %7 : i64
    llvm.cond_br %119, ^bb34, ^bb35
  ^bb34:  // pred: ^bb33
    %120 = llvm.mul %116, %7 : i64
    %121 = llvm.add %120, %118 : i64
    %122 = llvm.getelementptr %93[%121] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %123 = llvm.load %122 : !llvm.ptr -> f32
    %124 = llvm.getelementptr %111[%116] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %125 = llvm.load %124 : !llvm.ptr -> f32
    %126 = llvm.fadd %123, %125 : f32
    llvm.store %126, %124 : f32, !llvm.ptr
    %127 = llvm.add %118, %8 : i64
    llvm.br ^bb33(%127 : i64)
  ^bb35:  // pred: ^bb33
    %128 = llvm.add %116, %8 : i64
    llvm.br ^bb31(%128 : i64)
  ^bb36:  // pred: ^bb31
    %129 = llvm.insertvalue %106, %9[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %130 = llvm.insertvalue %111, %129[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %131 = llvm.insertvalue %6, %130[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %132 = llvm.insertvalue %7, %131[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %133 = llvm.insertvalue %8, %132[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %134 = llvm.insertvalue %8, %133[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %135 = llvm.insertvalue %8, %134[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    llvm.br ^bb37(%6 : i64)
  ^bb37(%136: i64):  // 2 preds: ^bb36, ^bb41
    %137 = llvm.icmp "slt" %136, %7 : i64
    llvm.cond_br %137, ^bb38, ^bb42
  ^bb38:  // pred: ^bb37
    llvm.br ^bb39(%6 : i64)
  ^bb39(%138: i64):  // 2 preds: ^bb38, ^bb40
    %139 = llvm.icmp "slt" %138, %8 : i64
    llvm.cond_br %139, ^bb40, ^bb41
  ^bb40:  // pred: ^bb39
    %140 = llvm.add %136, %138 : i64
    %141 = llvm.getelementptr %111[%140] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %142 = llvm.load %141 : !llvm.ptr -> f32
    %143 = llvm.getelementptr %10[%140] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %144 = llvm.load %143 : !llvm.ptr -> f32
    %145 = llvm.fdiv %142, %144 : f32
    llvm.store %145, %141 : f32, !llvm.ptr
    %146 = llvm.add %138, %8 : i64
    llvm.br ^bb39(%146 : i64)
  ^bb41:  // pred: ^bb39
    %147 = llvm.add %136, %8 : i64
    llvm.br ^bb37(%147 : i64)
  ^bb42:  // pred: ^bb37
    llvm.call @free(%63) : (!llvm.ptr) -> ()
    llvm.call @free(%88) : (!llvm.ptr) -> ()
    %148 = llvm.insertvalue %135, %0[0] : !llvm.struct<(struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>)> 
    %149 = llvm.insertvalue %47, %148[1] : !llvm.struct<(struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>)> 
    llvm.return %149 : !llvm.struct<(struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>)>
  }
}

This does not seem like a minimal test to me, or something that I can test immediately in this file llvm-project/mlir/test/python/execution_engine.py at main · llvm/llvm-project · GitHub

But more importantly, the LLVM IR you provided was likely not generated from a pipeline running llvm-request-c-wrappers (like in your python snippet).

Yeah, I am keeping the llvm-c-emit-wrappers only during numerical test.

I am not able completely understand that what you are referring by minimal test, what i understand is you want the test that you can immediately test using the execution-engine.py file link you have shared.

from mlir.ir import *
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir import runtime as rt
from mlir.runtime import *

import ctypes
import torch


# pipeline for lowering to LLVM.
def lowerToLLVM(module):
    pm = PassManager.parse(
        "builtin.module(func.func(llvm-request-c-wrappers),func.func(tosa-to-linalg),tosa-to-tensor,tosa-to-arith,convert-elementwise-to-linalg,one-shot-bufferize{bufferize-function-boundaries=true}, func.func(buffer-deallocation),canonicalize,expand-strided-metadata,convert-linalg-to-loops,convert-scf-to-cf,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts,canonicalize)"
    )
    pm.run(module.operation)
    return module


def testTOSAVarMean(funcName, shape, dtype):
    with Context():
        module = Module.parse(
            r"""module {
                  func.func @var_mean_f32(%arg0: tensor<3x3xf32>) -> (tensor<3x1xf32>, tensor<3x1xf32>) {
                    %0 = "tosa.const"() <{value = dense<2.000000e+00> : tensor<3x1xf32>}> : () -> tensor<3x1xf32>
                    %1 = "tosa.const"() <{value = dense<3.000000e+00> : tensor<3x1xf32>}> : () -> tensor<3x1xf32>
                    %2 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<3x3xf32>) -> tensor<3x1xf32>
                    %3 = arith.divf %2, %1 : tensor<3x1xf32>
                    %4 = tosa.sub %arg0, %3 : (tensor<3x3xf32>, tensor<3x1xf32>) -> tensor<3x3xf32>
                    %5 = tosa.mul %4, %4 {shift = 0 : i8} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
                    %6 = tosa.reduce_sum %5 {axis = 1 : i32} : (tensor<3x3xf32>) -> tensor<3x1xf32>
                    %7 = arith.divf %6, %0 : tensor<3x1xf32>
                    return %7, %3 : tensor<3x1xf32>, tensor<3x1xf32>
                  }
                } """
        )
        execution_engine = ExecutionEngine(lowerToLLVM(module))

        input_arg = np.random.randint(1, 100, shape).astype(dtype)
        mem_input = ctypes.pointer(
            ctypes.pointer(rt.get_ranked_memref_descriptor(input_arg))
        )

        ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_float * 2)()

        mem_out_var = ctypes.pointer(ctypes.pointer(ref_out))
        mem_out_mean = ctypes.pointer(ctypes.pointer(ref_out))

        execution_engine.invoke(funcName, mem_out_var, mem_out_mean, mem_input)

        var = ranked_memref_to_numpy(mem_out_var.contents)
        mean = ranked_memref_to_numpy(mem_out_mean.contents)

        print(torch.from_numpy(var))
        print(torch.from_numpy(mean))

        c3, c4 = torch.var_mean(torch.from_numpy(input_arg), dim=0, keepdim=True)

        if torch.allclose(torch.from_numpy(var), c3) and torch.allclose(
            torch.from_numpy(mean), c4
        ):
            print("True")
        else:
            print("False")


testTOSAVarMean("var_mean_f32", (3, 3), np.float32)

I don’t remember how this works upstream but the way I do this is I generate a return_consumer which is a Python callback that’s compatible with C:

# This doesn't actually generate anything - just forward declares the callback.
def make_return_consumer(kernel_func):
    c_api_compatible_types = [
        T.memref(element_type=t.element_type) if MemRefType.isinstance(t) else t
        for t in kernel_func.function_type.value.results
    ]
    cb = FuncOp(
        f"{kernel_func.name.value}_return_consumer",
        (c_api_compatible_types, []),
        visibility="private",
    )
    cb.attributes["llvm.emit_c_interface"] = UnitAttr.get()
    cb.attributes[refback_cb_attr] = UnitAttr.get()
    return cb

Then the return consumer is registered with the ExecutionEngine like this

RETURN_RESULTS = None
def consume_return_callback(*args):
    global RETURN_RESULTS
    results = convert_returns_from_ctype(args, self.ret_types)

ee.register_runtime(
    return_func_name,
    ctype_wrapper(consume_return_callback),
)

Note, return_func_name == f"{kernel_func.name.value}_return_consumer".

Then the actual func I invoke looks like this:

@FuncOp.from_py_func(*input_types, name=f"{kernel_func.name.value}_capi_wrapper")
def wrapper(*args, **_kwargs):

    // call original kernel_func
    results = CallOp(kernel_func, list(args)).results

    if return_consumer is not None:
        c_api_compatible_results = []
        for i, a in enumerate(results):

            // cast memrefs to unranked memrefs in order to be compatible with
            // unranked_memref_to_numpy
            if MemRefType.isinstance(a.type):
                a = cast(T.memref(element_type=a.type.element_type), a)
            c_api_compatible_results.append(a)

        // return_consumer == f"{kernel_func.name.value}_return_consumer"
        CallOp(return_consumer, c_api_compatible_results)

Then just invoke as usual and RETURN_RESULTS will have your results. You can find the entire script at mlir/extras/runtime/refbackend.py and I just double checked that multiple returns do in fact work.

Minimal test for the execution engine would be something that does not involve TOSA or any complex lowering, would not do any complex math and just for example produce constant value: that is a function with a handful of lines and the minimal amount of passes to run. As well as an execution that does not involve PyTorch or anything. The kind of test that looks like that def testMemrefAdd(): in the file I sent for example.

To put it differently: if I were to fix the “bug” you reported for the execution engine, what would be the unit-test for it?

We encountered the same issue that it supports a single return value only.

If you do not want to use a callback, one solution is to follow LLVM IR Target - MLIR and un/pack explicitly from/to one large return buffer.

f = enginePtr->lookupPacked(std::string("_mlir_ciface_") + fname);
llvm::SmallVector<void *> args;
args.push_back(result_buffer);
for (auto inarg: all_input_args)
  args.push_back(inarg)
f(args.data());

Thank you, @mehdi_amini , for taking the time to respond to me.

This is the minimal test that you can paste into the file llvm-project/mlir/test/python/execution_engine.py at main · llvm/llvm-project · GitHub and run it to reproduce the issue.

def testInvokeAddExtendedUIOp(): 
    with Context(): 
        module = Module.parse( 
            r""" 
func.func @addUI(%arg0: i64, %arg1: i64) -> (i64,i1) attributes { llvm.emit_c_interface } { 
  %sum, %overflow = arith.addui_extended %arg0, %arg1 : i64, i1 
  return %sum, %overflow : i64,i1 
} 
    """ 
        ) 
        c_uint64_p = ctypes.c_uint64 * 1  # Correct type for i64 (unsigned 64-bit int) 
        arg0 = c_uint64_p(18446744073709551615)  # Max value for uint64 
        arg1 = c_uint64_p(12) 
        sum = c_uint64_p(0) 
        c_bool_p = ctypes.c_bool * 1  # Correct type for i1 (boolean) 
        overflow = c_bool_p(0)  # To store the overflow
        execution_engine = ExecutionEngine(lowerToLLVM(module)) 
        execution_engine.invoke( 
            "addUI", sum, overflow, arg0, arg1 
        )  # segmentation fault at this line 
        print(sum[0], overflow[0])   
run(testInvokeAddExtendedUIOp)

To run this script, I am using command:
python3 execution_engine.py

ERROR:
TEST: testInvokeAddExtendedUIOp
Segmentation fault (core dumped)