-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrnn_dbscan_big.py
67 lines (50 loc) · 2.12 KB
/
rnn_dbscan_big.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
=======================================================
Demo of RnnDBSCAN clustering algorithm on large dataset
=======================================================
Tests RnnDBSCAN on a large dataset. Requires pandas.
"""
import numpy as np
from joblib import Memory
from sklearn import metrics
from sklearn.datasets import fetch_openml
from sklearn_ann.cluster.rnn_dbscan import simple_rnn_dbscan_pipeline
# #############################################################################
# Generate sample data
def fetch_mnist():
print("Downloading mnist_784")
mnist = fetch_openml("mnist_784")
return mnist.data / 255, mnist.target
memory = Memory("./mnist")
X, y = memory.cache(fetch_mnist)()
def run_rnn_dbscan(neighbor_transformer, n_neighbors, **kwargs):
# #############################################################################
# Compute RnnDBSCAN
pipeline = simple_rnn_dbscan_pipeline(neighbor_transformer, n_neighbors, **kwargs)
labels = pipeline.fit_predict(X)
db = pipeline.named_steps["rnndbscan"]
core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
core_samples_mask[db.core_sample_indices_] = True
# Number of clusters in labels, ignoring noise if present.
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
n_noise_ = list(labels).count(-1)
print(f"""\
Estimated number of clusters: {n_clusters_}
Estimated number of noise points: {n_noise_}
Homogeneity: {metrics.homogeneity_score(y, labels):0.3f}
Completeness: {metrics.completeness_score(y, labels):0.3f}
V-measure: {metrics.v_measure_score(y, labels):0.3f}
Adjusted Rand Index: {metrics.adjusted_rand_score(y, labels):0.3f}
Adjusted Mutual Information: {metrics.adjusted_mutual_info_score(y, labels):0.3f}
Silhouette Coefficient: {metrics.silhouette_score(X, labels):0.3f}\
""")
if __name__ == "__main__":
import code
print("""\
Now you can import your chosen transformer_cls and run:
run_rnn_dbscan(transformer_cls, n_neighbors, **params)
e.g.
from sklearn_ann.kneighbors.pynndescent import PyNNDescentTransformer
run_rnn_dbscan(PyNNDescentTransformer, 10)\
""")
code.interact(local=locals())