Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit 5efa8b0

Browse files
authored
Refactor tests to use pytest features (#193)
* refactor tests to use pytest features * run CI * hotfix * hotfix * Update test_penalty.py * Update test_sag.py * Update test_sag.py * Update test_sag.py * address review comments
1 parent 3afcb4a commit 5efa8b0

14 files changed

+744
-655
lines changed

lightning/impl/tests/conftest.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
import scipy.sparse as sp
3+
4+
from sklearn.datasets import load_iris
5+
6+
from lightning.impl.datasets.samples_generator import make_classification
7+
8+
9+
@pytest.fixture(scope="module")
10+
def train_data():
11+
iris = load_iris()
12+
return iris.data, iris.target
13+
14+
15+
@pytest.fixture(scope="module")
16+
def bin_train_data(train_data):
17+
X, y = train_data
18+
X_bin = X[y <= 1]
19+
y_bin = y[y <= 1] * 2 - 1
20+
return X_bin, y_bin
21+
22+
23+
@pytest.fixture(scope="module")
24+
def bin_dense_train_data():
25+
bin_dense, bin_target = make_classification(n_samples=200, n_features=100,
26+
n_informative=5,
27+
n_classes=2, random_state=0)
28+
return bin_dense, bin_target
29+
30+
31+
@pytest.fixture(scope="module")
32+
def bin_sparse_train_data(bin_dense_train_data):
33+
bin_dense, bin_target = bin_dense_train_data
34+
bin_csr = sp.csr_matrix(bin_dense)
35+
return bin_csr, bin_target
36+
37+
38+
@pytest.fixture(scope="module")
39+
def mult_dense_train_data():
40+
mult_dense, mult_target = make_classification(n_samples=300, n_features=100,
41+
n_informative=5,
42+
n_classes=3, random_state=0)
43+
return mult_dense, mult_target
44+
45+
46+
@pytest.fixture(scope="module")
47+
def mult_sparse_train_data(mult_dense_train_data):
48+
mult_dense, mult_target = mult_dense_train_data
49+
mult_sparse = sp.csr_matrix(mult_dense)
50+
return mult_sparse, mult_target

lightning/impl/tests/test_adagrad.py

+22-21
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,63 @@
11
import numpy as np
2-
3-
from sklearn.datasets import load_iris
2+
import pytest
43

54
from lightning.classification import AdaGradClassifier
65
from lightning.regression import AdaGradRegressor
76
from lightning.impl.adagrad_fast import _proj_elastic_all
87
from lightning.impl.tests.utils import check_predict_proba
98

10-
iris = load_iris()
11-
X, y = iris.data, iris.target
12-
13-
X_bin = X[y <= 1]
14-
y_bin = y[y <= 1] * 2 - 1
15-
169

17-
def test_adagrad_elastic_hinge():
10+
def test_adagrad_elastic_hinge(bin_train_data):
11+
X_bin, y_bin = bin_train_data
1812
clf = AdaGradClassifier(alpha=0.5, l1_ratio=0.85, n_iter=10, random_state=0)
1913
clf.fit(X_bin, y_bin)
2014
assert not hasattr(clf, "predict_proba")
2115
assert clf.score(X_bin, y_bin) == 1.0
2216

2317

24-
def test_adagrad_elastic_smooth_hinge():
18+
def test_adagrad_elastic_smooth_hinge(bin_train_data):
19+
X_bin, y_bin = bin_train_data
2520
clf = AdaGradClassifier(alpha=0.5, l1_ratio=0.85, loss="smooth_hinge",
2621
n_iter=10, random_state=0)
2722
clf.fit(X_bin, y_bin)
2823
assert not hasattr(clf, "predict_proba")
2924
assert clf.score(X_bin, y_bin) == 1.0
3025

3126

32-
def test_adagrad_elastic_log():
27+
def test_adagrad_elastic_log(bin_train_data):
28+
X_bin, y_bin = bin_train_data
3329
clf = AdaGradClassifier(alpha=0.1, l1_ratio=0.85, loss="log", n_iter=10,
3430
random_state=0)
3531
clf.fit(X_bin, y_bin)
3632
assert clf.score(X_bin, y_bin) == 1.0
3733
check_predict_proba(clf, X_bin)
3834

3935

40-
def test_adagrad_hinge_multiclass():
36+
def test_adagrad_hinge_multiclass(train_data):
37+
X, y = train_data
4138
clf = AdaGradClassifier(alpha=1e-2, n_iter=100, loss="hinge", random_state=0)
4239
clf.fit(X, y)
4340
assert not hasattr(clf, "predict_proba")
4441
np.testing.assert_almost_equal(clf.score(X, y), 0.940, 3)
4542

4643

47-
def test_adagrad_classes_binary():
44+
def test_adagrad_classes_binary(bin_train_data):
45+
X_bin, y_bin = bin_train_data
4846
clf = AdaGradClassifier()
4947
assert not hasattr(clf, 'classes_')
5048
clf.fit(X_bin, y_bin)
5149
assert list(clf.classes_) == [-1, 1]
5250

5351

54-
def test_adagrad_classes_multiclass():
52+
def test_adagrad_classes_multiclass(train_data):
53+
X, y = train_data
5554
clf = AdaGradClassifier()
5655
assert not hasattr(clf, 'classes_')
5756
clf.fit(X, y)
5857
assert list(clf.classes_) == [0, 1, 2]
5958

6059

61-
def test_adagrad_callback():
60+
def test_adagrad_callback(bin_train_data):
6261
class Callback(object):
6362

6463
def __init__(self, X, y):
@@ -74,16 +73,18 @@ def __call__(self, clf, t):
7473
score = clf.score(self.X, self.y)
7574
self.acc.append(score)
7675

76+
X_bin, y_bin = bin_train_data
7777
cb = Callback(X_bin, y_bin)
7878
clf = AdaGradClassifier(alpha=0.5, l1_ratio=0.85, n_iter=10,
7979
callback=cb, random_state=0)
8080
clf.fit(X_bin, y_bin)
8181
assert cb.acc[-1] == 1.0
8282

8383

84-
def test_adagrad_regression():
85-
for loss in ("squared", "absolute"):
86-
reg = AdaGradRegressor(loss=loss)
87-
reg.fit(X_bin, y_bin)
88-
y_pred = np.sign(reg.predict(X_bin))
89-
assert np.mean(y_bin == y_pred) == 1.0
84+
@pytest.mark.parametrize("loss", ["squared", "absolute"])
85+
def test_adagrad_regression(loss, bin_train_data):
86+
X_bin, y_bin = bin_train_data
87+
reg = AdaGradRegressor(loss=loss)
88+
reg.fit(X_bin, y_bin)
89+
y_pred = np.sign(reg.predict(X_bin))
90+
assert np.mean(y_bin == y_pred) == 1.0

lightning/impl/tests/test_dataset.py

+49-22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pickle
2+
23
import numpy as np
4+
import pytest
35
import scipy.sparse as sp
46

57
from sklearn.datasets import make_classification
@@ -10,26 +12,40 @@
1012
from lightning.impl.dataset_fast import CSRDataset
1113
from lightning.impl.dataset_fast import CSCDataset
1214

13-
# Create test datasets.
14-
X, _ = make_classification(n_samples=20, n_features=100,
15-
n_informative=5, n_classes=2, random_state=0)
16-
X2, _ = make_classification(n_samples=10, n_features=100,
17-
n_informative=5, n_classes=2, random_state=0)
1815

19-
# Sparsify datasets.
20-
X[X < 0.3] = 0
16+
@pytest.fixture(scope="module")
17+
def test_data():
18+
X, _ = make_classification(n_samples=20, n_features=100,
19+
n_informative=5, n_classes=2, random_state=0)
20+
X2, _ = make_classification(n_samples=10, n_features=100,
21+
n_informative=5, n_classes=2, random_state=0)
22+
23+
# Sparsify datasets.
24+
X[X < 0.3] = 0
2125

22-
X_csr = sp.csr_matrix(X)
23-
X_csc = sp.csc_matrix(X)
26+
X_csr = sp.csr_matrix(X)
27+
X_csc = sp.csc_matrix(X)
2428

25-
rs = check_random_state(0)
26-
cds = ContiguousDataset(X)
27-
fds = FortranDataset(np.asfortranarray(X))
28-
csr_ds = CSRDataset(X_csr)
29-
csc_ds = CSCDataset(X_csc)
29+
rs = check_random_state(0)
30+
cds = ContiguousDataset(X)
31+
fds = FortranDataset(np.asfortranarray(X))
32+
csr_ds = CSRDataset(X_csr)
33+
csc_ds = CSCDataset(X_csc)
3034

35+
return {
36+
"X": X,
37+
"X_csr": X_csr,
38+
"X_csc": X_csc,
39+
"contiguous_dataset": cds,
40+
"fortran_dataset": fds,
41+
"dataset_csr": csr_ds,
42+
"dataset_csc": csc_ds
43+
}
3144

32-
def test_contiguous_get_row():
45+
46+
def test_contiguous_get_row(test_data):
47+
X = test_data["X"]
48+
cds = test_data["contiguous_dataset"]
3349
ind = np.arange(X.shape[1])
3450
for i in range(X.shape[0]):
3551
indices, data, n_nz = cds.get_row(i)
@@ -38,15 +54,19 @@ def test_contiguous_get_row():
3854
assert n_nz == X.shape[1]
3955

4056

41-
def test_csr_get_row():
57+
def test_csr_get_row(test_data):
58+
X = test_data["X"]
59+
csr_ds = test_data["dataset_csr"]
4260
for i in range(X.shape[0]):
4361
indices, data, n_nz = csr_ds.get_row(i)
4462
for jj in range(n_nz):
4563
j = indices[jj]
4664
assert X[i, j] == data[jj]
4765

4866

49-
def test_fortran_get_column():
67+
def test_fortran_get_column(test_data):
68+
X = test_data["X"]
69+
fds = test_data["fortran_dataset"]
5070
ind = np.arange(X.shape[0])
5171
for j in range(X.shape[1]):
5272
indices, data, n_nz = fds.get_column(j)
@@ -55,18 +75,25 @@ def test_fortran_get_column():
5575
assert n_nz == X.shape[0]
5676

5777

58-
def test_csc_get_column():
78+
def test_csc_get_column(test_data):
79+
X = test_data["X"]
80+
csc_ds = test_data["dataset_csc"]
5981
for j in range(X.shape[1]):
6082
indices, data, n_nz = csc_ds.get_column(j)
6183
for ii in range(n_nz):
6284
i = indices[ii]
6385
assert X[i, j] == data[ii]
6486

6587

66-
def test_picklable_datasets():
67-
"""Test that the datasets are picklable."""
68-
69-
for dataset in [cds, csr_ds, fds, csc_ds]:
88+
def test_picklable_datasets(test_data):
89+
# Test that the datasets are picklable.
90+
X = test_data["X"]
91+
for dataset in [
92+
test_data["contiguous_dataset"],
93+
test_data["dataset_csr"],
94+
test_data["fortran_dataset"],
95+
test_data["dataset_csc"]
96+
]:
7097
pds = pickle.dumps(dataset)
7198
dataset = pickle.loads(pds)
7299
assert dataset.get_n_samples() == X.shape[0]

0 commit comments

Comments
 (0)