-
Notifications
You must be signed in to change notification settings - Fork 527
/
Copy pathmodel_factory.py
60 lines (50 loc) · 2.11 KB
/
model_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from typing import Any, Dict, Tuple
import torch
class EagerModelFactory:
"""
A factory class for dynamically creating instances of classes implementing EagerModelBase.
"""
@staticmethod
def create_model(
module_name, model_class_name, **kwargs
) -> Tuple[torch.nn.Module, Tuple[Any], Dict[str, Any], Any]:
"""
Create an instance of a model class that implements EagerModelBase and retrieve related data.
Args:
module_name (str): The name of the module containing the model class.
model_class_name (str): The name of the model class to create an instance of.
Returns:
Tuple[nn.Module, Any]: A tuple containing the eager PyTorch model instance and example inputs,
and any dynamic shape information for those inputs.
Raises:
ValueError: If the provided model class is not found in the module.
"""
package_prefix = "executorch." if not os.getcwd().endswith("executorch") else ""
module = importlib.import_module(
f"{package_prefix}examples.models.{module_name}"
)
if hasattr(module, model_class_name):
model_class = getattr(module, model_class_name)
model = model_class(**kwargs)
example_kwarg_inputs = None
dynamic_shapes = None
if hasattr(model, "get_example_kwarg_inputs"):
example_kwarg_inputs = model.get_example_kwarg_inputs()
if hasattr(model, "get_dynamic_shapes"):
dynamic_shapes = model.get_dynamic_shapes()
return (
model.get_eager_model(),
model.get_example_inputs(),
example_kwarg_inputs,
dynamic_shapes,
)
raise ValueError(
f"Model class '{model_class_name}' not found in module '{module_name}'."
)