-
Notifications
You must be signed in to change notification settings - Fork 524
/
Copy pathgen_ops_def.py
96 lines (84 loc) · 3.05 KB
/
gen_ops_def.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Generates a template `functions.yaml` from a model binary. Ignoring all custom ops
import argparse
import os
import sys
from typing import Any, List
import torch
import yaml
from executorch.codegen.tools.yaml_util import BlankLineDumper
from executorch.exir._serialize import _deserialize_pte_binary
from executorch.exir.schema import Operator
def get_operators(model_file: str) -> List[Operator]:
print("Processing model file: ", model_file)
with open(model_file, "rb") as f:
flatbuffer = f.read()
program = _deserialize_pte_binary(flatbuffer)
print(f"Program loaded from model file: {model_file}")
operators = program.execution_plan[0].operators
return operators
def dump_yaml(model_file: str, output_file: str) -> None:
ops = get_operators(model_file)
m = [] # type: ignore[var-annotated]
for op in ops:
if op.name.startswith("aten::"):
schemas = torch._C._jit_get_schemas_for_operator(op.name)
m.extend(filter(lambda s: s.overload_name == op.overload, schemas))
else:
print(f"Warning: not generating template for {op.name}")
output = []
for s in m:
print(str(s))
name = s.name.replace("aten::", "torch::executor::")
output.append(
{
"func": str(s),
"variants": "function",
"dispatch": {
"CPU": f"{name}_{s.overload_name}",
},
}
)
with open(output_file, "w") as f:
yaml.dump(
output,
f,
Dumper=BlankLineDumper,
default_flow_style=False,
sort_keys=False,
width=1000,
)
def main(args: List[Any]) -> None:
"""This binary generates a template functions.yaml which will be consumed by ExecuTorch codegen.
It reads the model file, deserialize it and dumps all the operators into a new functions.yaml.
The generated file contains placeholder kernels, it needs to be updated with proper kernel names.
"""
parser = argparse.ArgumentParser(
description="Generate operator list from a model file"
)
parser.add_argument(
"--output_path",
help=("The path to the output yaml file (functions.yaml)"),
required=True,
)
parser.add_argument(
"--model_file_path",
help=("Path to an executorch program"),
required=False,
)
options = parser.parse_args(args)
assert options.model_file_path, "Need to provide a model_file_path."
assert os.path.isfile(
options.model_file_path
), "The value for --model_file_path needs to be a valid file."
dump_yaml(
model_file=options.model_file_path,
output_file=options.output_path if options.output_path else "./functions.yaml",
)
if __name__ == "__main__":
main(sys.argv[1:])