-
Notifications
You must be signed in to change notification settings - Fork 527
/
Copy pathmodel_base.py
47 lines (36 loc) · 1.34 KB
/
model_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
import torch
class EagerModelBase(ABC):
"""
Abstract base class for eager mode models.
This abstract class defines the interface that eager mode model classes should adhere to.
Eager mode models inherit from this class to ensure consistent behavior and structure.
"""
@abstractmethod
def __init__(self):
"""
Constructor for EagerModelBase.
This initializer may be overridden in derived classes to provide additional setup if needed.
"""
pass
@abstractmethod
def get_eager_model(self) -> torch.nn.Module:
"""
Abstract method to return an eager PyTorch model instance.
Returns:
nn.Module: An instance of a PyTorch model, suitable for eager execution.
"""
raise NotImplementedError("get_eager_model")
@abstractmethod
def get_example_inputs(self):
"""
Abstract method to provide example inputs for the model.
Returns:
Any: Example inputs that can be used for testing and tracing.
"""
raise NotImplementedError("get_example_inputs")