forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmanifold.py
34 lines (22 loc) · 820 Bytes
/
manifold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from sklearn.manifold import TSNE
from .common import Benchmark, Estimator
from .datasets import _digits_dataset
class TSNEBenchmark(Estimator, Benchmark):
"""
Benchmarks for t-SNE.
"""
param_names = ["method"]
params = (["exact", "barnes_hut"],)
def setup_cache(self):
super().setup_cache()
def make_data(self, params):
(method,) = params
n_samples = 500 if method == "exact" else None
return _digits_dataset(n_samples=n_samples)
def make_estimator(self, params):
(method,) = params
estimator = TSNE(random_state=0, method=method)
return estimator
def make_scorers(self):
self.train_scorer = lambda _, __: self.estimator.kl_divergence_
self.test_scorer = lambda _, __: self.estimator.kl_divergence_