|
| 1 | +# %% |
| 2 | +# |
| 3 | +# This benchmark compares the speed of PCA solvers on datasets of different |
| 4 | +# sizes in order to determine the best solver to select by default via the |
| 5 | +# "auto" heuristic. |
| 6 | +# |
| 7 | +# Note: we do not control for the accuracy of the solvers: we assume that all |
| 8 | +# solvers yield transformed data with similar explained variance. This |
| 9 | +# assumption is generally true, except for the randomized solver that might |
| 10 | +# require more power iterations. |
| 11 | +# |
| 12 | +# We generate synthetic data with dimensions that are useful to plot: |
| 13 | +# - time vs n_samples for a fixed n_features and, |
| 14 | +# - time vs n_features for a fixed n_samples for a fixed n_features. |
| 15 | +import itertools |
| 16 | +from math import log10 |
| 17 | +from time import perf_counter |
| 18 | + |
| 19 | +import matplotlib.pyplot as plt |
| 20 | +import numpy as np |
| 21 | +import pandas as pd |
| 22 | + |
| 23 | +from sklearn import config_context |
| 24 | +from sklearn.decomposition import PCA |
| 25 | + |
| 26 | +REF_DIMS = [100, 1000, 10_000] |
| 27 | +data_shapes = [] |
| 28 | +for ref_dim in REF_DIMS: |
| 29 | + data_shapes.extend([(ref_dim, 10**i) for i in range(1, 8 - int(log10(ref_dim)))]) |
| 30 | + data_shapes.extend( |
| 31 | + [(ref_dim, 3 * 10**i) for i in range(1, 8 - int(log10(ref_dim)))] |
| 32 | + ) |
| 33 | + data_shapes.extend([(10**i, ref_dim) for i in range(1, 8 - int(log10(ref_dim)))]) |
| 34 | + data_shapes.extend( |
| 35 | + [(3 * 10**i, ref_dim) for i in range(1, 8 - int(log10(ref_dim)))] |
| 36 | + ) |
| 37 | + |
| 38 | +# Remove duplicates: |
| 39 | +data_shapes = sorted(set(data_shapes)) |
| 40 | + |
| 41 | +print("Generating test datasets...") |
| 42 | +rng = np.random.default_rng(0) |
| 43 | +datasets = [rng.normal(size=shape) for shape in data_shapes] |
| 44 | + |
| 45 | + |
| 46 | +# %% |
| 47 | +def measure_one(data, n_components, solver, method_name="fit"): |
| 48 | + print( |
| 49 | + f"Benchmarking {solver=!r}, {n_components=}, {method_name=!r} on data with" |
| 50 | + f" shape {data.shape}" |
| 51 | + ) |
| 52 | + pca = PCA(n_components=n_components, svd_solver=solver, random_state=0) |
| 53 | + timings = [] |
| 54 | + elapsed = 0 |
| 55 | + method = getattr(pca, method_name) |
| 56 | + with config_context(assume_finite=True): |
| 57 | + while elapsed < 0.5: |
| 58 | + tic = perf_counter() |
| 59 | + method(data) |
| 60 | + duration = perf_counter() - tic |
| 61 | + timings.append(duration) |
| 62 | + elapsed += duration |
| 63 | + return np.median(timings) |
| 64 | + |
| 65 | + |
| 66 | +SOLVERS = ["full", "covariance_eigh", "arpack", "randomized", "auto"] |
| 67 | +measurements = [] |
| 68 | +for data, n_components, method_name in itertools.product( |
| 69 | + datasets, [2, 50], ["fit", "fit_transform"] |
| 70 | +): |
| 71 | + if n_components >= min(data.shape): |
| 72 | + continue |
| 73 | + for solver in SOLVERS: |
| 74 | + if solver == "covariance_eigh" and data.shape[1] > 5000: |
| 75 | + # Too much memory and too slow. |
| 76 | + continue |
| 77 | + if solver in ["arpack", "full"] and log10(data.size) > 7: |
| 78 | + # Too slow, in particular for the full solver. |
| 79 | + continue |
| 80 | + time = measure_one(data, n_components, solver, method_name=method_name) |
| 81 | + measurements.append( |
| 82 | + { |
| 83 | + "n_components": n_components, |
| 84 | + "n_samples": data.shape[0], |
| 85 | + "n_features": data.shape[1], |
| 86 | + "time": time, |
| 87 | + "solver": solver, |
| 88 | + "method_name": method_name, |
| 89 | + } |
| 90 | + ) |
| 91 | +measurements = pd.DataFrame(measurements) |
| 92 | +measurements.to_csv("bench_pca_solvers.csv", index=False) |
| 93 | + |
| 94 | +# %% |
| 95 | +all_method_names = measurements["method_name"].unique() |
| 96 | +all_n_components = measurements["n_components"].unique() |
| 97 | + |
| 98 | +for method_name in all_method_names: |
| 99 | + fig, axes = plt.subplots( |
| 100 | + figsize=(16, 16), |
| 101 | + nrows=len(REF_DIMS), |
| 102 | + ncols=len(all_n_components), |
| 103 | + sharey=True, |
| 104 | + constrained_layout=True, |
| 105 | + ) |
| 106 | + fig.suptitle(f"Benchmarks for PCA.{method_name}, varying n_samples", fontsize=16) |
| 107 | + |
| 108 | + for row_idx, ref_dim in enumerate(REF_DIMS): |
| 109 | + for n_components, ax in zip(all_n_components, axes[row_idx]): |
| 110 | + for solver in SOLVERS: |
| 111 | + if solver == "auto": |
| 112 | + style_kwargs = dict(linewidth=2, color="black", style="--") |
| 113 | + else: |
| 114 | + style_kwargs = dict(style="o-") |
| 115 | + ax.set( |
| 116 | + title=f"n_components={n_components}, n_features={ref_dim}", |
| 117 | + ylabel="time (s)", |
| 118 | + ) |
| 119 | + measurements.query( |
| 120 | + "n_components == @n_components and n_features == @ref_dim" |
| 121 | + " and solver == @solver and method_name == @method_name" |
| 122 | + ).plot.line( |
| 123 | + x="n_samples", |
| 124 | + y="time", |
| 125 | + label=solver, |
| 126 | + logx=True, |
| 127 | + logy=True, |
| 128 | + ax=ax, |
| 129 | + **style_kwargs, |
| 130 | + ) |
| 131 | +# %% |
| 132 | +for method_name in all_method_names: |
| 133 | + fig, axes = plt.subplots( |
| 134 | + figsize=(16, 16), |
| 135 | + nrows=len(REF_DIMS), |
| 136 | + ncols=len(all_n_components), |
| 137 | + sharey=True, |
| 138 | + ) |
| 139 | + fig.suptitle(f"Benchmarks for PCA.{method_name}, varying n_features", fontsize=16) |
| 140 | + |
| 141 | + for row_idx, ref_dim in enumerate(REF_DIMS): |
| 142 | + for n_components, ax in zip(all_n_components, axes[row_idx]): |
| 143 | + for solver in SOLVERS: |
| 144 | + if solver == "auto": |
| 145 | + style_kwargs = dict(linewidth=2, color="black", style="--") |
| 146 | + else: |
| 147 | + style_kwargs = dict(style="o-") |
| 148 | + ax.set( |
| 149 | + title=f"n_components={n_components}, n_samples={ref_dim}", |
| 150 | + ylabel="time (s)", |
| 151 | + ) |
| 152 | + measurements.query( |
| 153 | + "n_components == @n_components and n_samples == @ref_dim " |
| 154 | + " and solver == @solver and method_name == @method_name" |
| 155 | + ).plot.line( |
| 156 | + x="n_features", |
| 157 | + y="time", |
| 158 | + label=solver, |
| 159 | + logx=True, |
| 160 | + logy=True, |
| 161 | + ax=ax, |
| 162 | + **style_kwargs, |
| 163 | + ) |
| 164 | + |
| 165 | +# %% |
0 commit comments