-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathplot_model_persistence.py
45 lines (35 loc) · 1.14 KB
/
plot_model_persistence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# -*- coding: utf-8 -*-
"""
=====================
Model Persistence
=====================
HiClass is fully compatible with Pickle.
Pickle can be used to easily store machine learning models on disk.
In this example, we demonstrate how to use pickle to store and load trained classifiers.
"""
import pickle
from sklearn.linear_model import LogisticRegression
from hiclass import LocalClassifierPerLevel
# Define data
X_train = [[1, 2], [3, 4], [5, 6], [7, 8]]
X_test = [[7, 8], [5, 6], [3, 4], [1, 2]]
Y_train = [
["Animal", "Mammal", "Sheep"],
["Animal", "Mammal", "Cow"],
["Animal", "Reptile", "Snake"],
["Animal", "Reptile", "Lizard"],
]
# Use Logistic Regression classifiers for every level in the hierarchy
lr = LogisticRegression()
classifier = LocalClassifierPerLevel(local_classifier=lr)
# Train local classifier per level
classifier.fit(X_train, Y_train)
# Save the model to disk
filename = "trained_model.sav"
pickle.dump(classifier, open(filename, "wb"))
# Some time in the future...
# Load the model from disk
loaded_model = pickle.load(open(filename, "rb"))
# Predict
predictions = loaded_model.predict(X_test)
print(predictions)