-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathtrain.py
111 lines (97 loc) · 2.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
"""Script to train with flat or hierarchical approaches."""
import argparse
import pickle
import sys
from argparse import Namespace
from joblib import parallel_backend
from omegaconf import DictConfig, OmegaConf
from data import load_dataframe, flatten_labels
from tune import configure_pipeline
def parse_args(args: list) -> Namespace:
"""
Parse a list of arguments.
Parameters
----------
args : list
Arguments to parse.
Returns
-------
_ : Namespace
Parsed arguments.
"""
parser = argparse.ArgumentParser(description="Train classifier")
parser.add_argument(
"--n-jobs",
type=int,
required=True,
help="Number of jobs to run training in parallel",
)
parser.add_argument(
"--x-train",
type=str,
required=True,
help="Input CSV file with training features",
)
parser.add_argument(
"--y-train",
type=str,
required=True,
help="Input CSV file with training labels",
)
parser.add_argument(
"--trained-model",
type=str,
required=True,
help="Path to store trained model",
)
parser.add_argument(
"--classifier",
type=str,
required=True,
help="Algorithm used for fitting, e.g., logistic_regression or random_forest",
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Model used for training, e.g., flat, lcpl, lcpn or lcppn",
)
parser.add_argument(
"--best-parameters",
type=str,
required=True,
help="Path to optuna's tuned parameters",
)
return parser.parse_args(args)
def load_parameters(yml: str) -> DictConfig:
"""
Load parameters from a YAML file.
Parameters
----------
yml : str
Path to YAML file containing tuned parameters.
Returns
-------
cfg : DictConfig
Dictionary containing all configuration information.
"""
cfg = OmegaConf.load(yml)
return cfg["best_params"]
def train() -> None: # pragma: no cover
"""Train with flat or hierarchical approaches."""
args = parse_args(sys.argv[1:])
x_train = load_dataframe(args.x_train).squeeze()
y_train = load_dataframe(args.y_train)
if args.model == "flat":
y_train = flatten_labels(y_train)
best_params = load_parameters(args.best_parameters)
best_params.model = args.model
best_params.classifier = args.classifier
best_params.n_jobs = args.n_jobs
pipeline = configure_pipeline(best_params)
with parallel_backend("threading", n_jobs=args.n_jobs):
pipeline.fit(x_train, y_train)
pickle.dump(pipeline, open(args.trained_model, "wb"))
if __name__ == "__main__":
train() # pragma: no cover