Skip to content

Commit 31464fc

Browse files
committed
Pushing the docs to 1.5/ for branch: 1.5.X, commit 156ef141f3b270edb06c8ae9af37c55253c0aabe
1 parent 6006a40 commit 31464fc

File tree

2,734 files changed

+92729
-87709
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

2,734 files changed

+92729
-87709
lines changed

1.5/.buildinfo

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Sphinx build info version 1
22
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
3-
config: d11dba821463e12eaf4819e95c5fa796
3+
config: 77a81589eb4374f7aa6937c4aeda73cb
44
tags: 645f666f9bcd5a90fca523b33c5a78b7
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

1.5/_downloads/1bcb2039afa126da41f1cea42b4a5866/plot_gpr_prior_posterior.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def plot_gpr_samples(gpr_model, n_samples, ax):
127127
)
128128

129129
# %%
130-
# Rational Quadradtic kernel
131-
# ..........................
130+
# Rational Quadratic kernel
131+
# .........................
132132
from sklearn.gaussian_process.kernels import RationalQuadratic
133133

134134
kernel = 1.0 * RationalQuadratic(length_scale=1.0, alpha=0.1, alpha_bounds=(1e-5, 1e15))
@@ -201,7 +201,7 @@ def plot_gpr_samples(gpr_model, n_samples, ax):
201201
kernel = ConstantKernel(0.1, (0.01, 10.0)) * (
202202
DotProduct(sigma_0=1.0, sigma_0_bounds=(0.1, 10.0)) ** 2
203203
)
204-
gpr = GaussianProcessRegressor(kernel=kernel, random_state=0)
204+
gpr = GaussianProcessRegressor(kernel=kernel, random_state=0, normalize_y=True)
205205

206206
fig, axs = plt.subplots(nrows=2, sharex=True, sharey=True, figsize=(10, 8))
207207

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

1.5/_downloads/21a6ff17ef2837fe1cd49e63223a368d/plot_unveil_tree_structure.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@
6868
# - ``weighted_n_node_samples[i]``: the weighted number of training samples
6969
# reaching node ``i``
7070
# - ``value[i, j, k]``: the summary of the training samples that reached node i for
71-
# output j and class k (for regression tree, class is set to 1).
71+
# output j and class k (for regression tree, class is set to 1). See below
72+
# for more information about ``value``.
7273
#
7374
# Using the arrays, we can traverse the tree structure to compute various
7475
# properties. Below, we will compute the depth of each node and whether or not
@@ -108,7 +109,7 @@
108109
if is_leaves[i]:
109110
print(
110111
"{space}node={node} is a leaf node with value={value}.".format(
111-
space=node_depth[i] * "\t", node=i, value=values[i]
112+
space=node_depth[i] * "\t", node=i, value=np.around(values[i], 3)
112113
)
113114
)
114115
else:
@@ -122,24 +123,36 @@
122123
feature=feature[i],
123124
threshold=threshold[i],
124125
right=children_right[i],
125-
value=values[i],
126+
value=np.around(values[i], 3),
126127
)
127128
)
128129

129130
# %%
130131
# What is the values array used here?
131132
# -----------------------------------
132133
# The `tree_.value` array is a 3D array of shape
133-
# [``n_nodes``, ``n_classes``, ``n_outputs``] which provides the count of samples
134-
# reaching a node for each class and for each output. Each node has a ``value``
135-
# array which is the number of weighted samples reaching this
136-
# node for each output and class.
134+
# [``n_nodes``, ``n_classes``, ``n_outputs``] which provides the proportion of samples
135+
# reaching a node for each class and for each output.
136+
# Each node has a ``value`` array which is the proportion of weighted samples reaching
137+
# this node for each output and class with respect to the parent node.
138+
#
139+
# One could convert this to the absolute weighted number of samples reaching a node,
140+
# by multiplying this number by `tree_.weighted_n_node_samples[node_idx]` for the
141+
# given node. Note sample weights are not used in this example, so the weighted
142+
# number of samples is the number of samples reaching the node because each sample
143+
# has a weight of 1 by default.
137144
#
138145
# For example, in the above tree built on the iris dataset, the root node has
139-
# ``value = [37, 34, 41]``, indicating there are 37 samples
146+
# ``value = [0.33, 0.304, 0.366]`` indicating there are 33% of class 0 samples,
147+
# 30.4% of class 1 samples, and 36.6% of class 2 samples at the root node. One can
148+
# convert this to the absolute number of samples by multiplying by the number of
149+
# samples reaching the root node, which is `tree_.weighted_n_node_samples[0]`.
150+
# Then the root node has ``value = [37, 34, 41]``, indicating there are 37 samples
140151
# of class 0, 34 samples of class 1, and 41 samples of class 2 at the root node.
152+
#
141153
# Traversing the tree, the samples are split and as a result, the ``value`` array
142-
# reaching each node changes. The left child of the root node has ``value = [37, 0, 0]``
154+
# reaching each node changes. The left child of the root node has ``value = [1., 0, 0]``
155+
# (or ``value = [37, 0, 0]`` when converted to the absolute number of samples)
143156
# because all 37 samples in the left child node are from class 0.
144157
#
145158
# Note: In this example, `n_outputs=1`, but the tree classifier can also handle
@@ -148,8 +161,10 @@
148161

149162
##############################################################################
150163
# We can compare the above output to the plot of the decision tree.
164+
# Here, we show the proportions of samples of each class that reach each
165+
# node corresponding to the actual elements of `tree_.value` array.
151166

152-
tree.plot_tree(clf)
167+
tree.plot_tree(clf, proportion=True)
153168
plt.show()
154169

155170
##############################################################################
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

1.5/_downloads/437df39fcde24ead7b91917f2133a53c/plot_regression.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,46 @@
66
Demonstrate the resolution of a regression problem
77
using a k-Nearest Neighbor and the interpolation of the
88
target using both barycenter and constant weights.
9-
109
"""
1110

1211
# Author: Alexandre Gramfort <[email protected]>
1312
# Fabian Pedregosa <[email protected]>
1413
#
1514
# License: BSD 3 clause (C) INRIA
1615

17-
1816
# %%
1917
# Generate sample data
2018
# --------------------
19+
# Here we generate a few data points to use to train the model. We also generate
20+
# data in the whole range of the training data to visualize how the model would
21+
# react in that whole region.
2122
import matplotlib.pyplot as plt
2223
import numpy as np
2324

2425
from sklearn import neighbors
2526

26-
np.random.seed(0)
27-
X = np.sort(5 * np.random.rand(40, 1), axis=0)
28-
T = np.linspace(0, 5, 500)[:, np.newaxis]
29-
y = np.sin(X).ravel()
27+
rng = np.random.RandomState(0)
28+
X_train = np.sort(5 * rng.rand(40, 1), axis=0)
29+
X_test = np.linspace(0, 5, 500)[:, np.newaxis]
30+
y = np.sin(X_train).ravel()
3031

3132
# Add noise to targets
3233
y[::5] += 1 * (0.5 - np.random.rand(8))
3334

3435
# %%
3536
# Fit regression model
3637
# --------------------
38+
# Here we train a model and visualize how `uniform` and `distance`
39+
# weights in prediction effect predicted values.
3740
n_neighbors = 5
3841

3942
for i, weights in enumerate(["uniform", "distance"]):
4043
knn = neighbors.KNeighborsRegressor(n_neighbors, weights=weights)
41-
y_ = knn.fit(X, y).predict(T)
44+
y_ = knn.fit(X_train, y).predict(X_test)
4245

4346
plt.subplot(2, 1, i + 1)
44-
plt.scatter(X, y, color="darkorange", label="data")
45-
plt.plot(T, y_, color="navy", label="prediction")
47+
plt.scatter(X_train, y, color="darkorange", label="data")
48+
plt.plot(X_test, y_, color="navy", label="prediction")
4649
plt.axis("tight")
4750
plt.legend()
4851
plt.title("KNeighborsRegressor (k = %i, weights = '%s')" % (n_neighbors, weights))
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

1.5/_downloads/4941b506cc56c9cec00d40992e2a4207/plot_permutation_importance_multicollinear.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
},
2323
"outputs": [],
2424
"source": [
25-
"from sklearn.inspection import permutation_importance\n\n\ndef plot_permutation_importance(clf, X, y, ax):\n result = permutation_importance(clf, X, y, n_repeats=10, random_state=42, n_jobs=2)\n perm_sorted_idx = result.importances_mean.argsort()\n\n ax.boxplot(\n result.importances[perm_sorted_idx].T,\n vert=False,\n labels=X.columns[perm_sorted_idx],\n )\n ax.axvline(x=0, color=\"k\", linestyle=\"--\")\n return ax"
25+
"import matplotlib\n\nfrom sklearn.inspection import permutation_importance\nfrom sklearn.utils.fixes import parse_version\n\n\ndef plot_permutation_importance(clf, X, y, ax):\n result = permutation_importance(clf, X, y, n_repeats=10, random_state=42, n_jobs=2)\n perm_sorted_idx = result.importances_mean.argsort()\n\n # `labels` argument in boxplot is deprecated in matplotlib 3.9 and has been\n # renamed to `tick_labels`. The following code handles this, but as a\n # scikit-learn user you probably can write simpler code by using `labels=...`\n # (matplotlib < 3.9) or `tick_labels=...` (matplotlib >= 3.9).\n tick_labels_parameter_name = (\n \"tick_labels\"\n if parse_version(matplotlib.__version__) >= parse_version(\"3.9\")\n else \"labels\"\n )\n tick_labels_dict = {tick_labels_parameter_name: X.columns[perm_sorted_idx]}\n ax.boxplot(result.importances[perm_sorted_idx].T, vert=False, **tick_labels_dict)\n ax.axvline(x=0, color=\"k\", linestyle=\"--\")\n return ax"
2626
]
2727
},
2828
{
@@ -58,7 +58,7 @@
5858
},
5959
"outputs": [],
6060
"source": [
61-
"import matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nmdi_importances = pd.Series(clf.feature_importances_, index=X_train.columns)\ntree_importance_sorted_idx = np.argsort(clf.feature_importances_)\ntree_indices = np.arange(0, len(clf.feature_importances_)) + 0.5\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))\nmdi_importances.sort_values().plot.barh(ax=ax1)\nax1.set_xlabel(\"Gini importance\")\nplot_permutation_importance(clf, X_train, y_train, ax2)\nax2.set_xlabel(\"Decrease in accuracy score\")\nfig.suptitle(\n \"Impurity-based vs. permutation importances on multicollinear features (train set)\"\n)\n_ = fig.tight_layout()"
61+
"import matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nmdi_importances = pd.Series(clf.feature_importances_, index=X_train.columns)\ntree_importance_sorted_idx = np.argsort(clf.feature_importances_)\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))\nmdi_importances.sort_values().plot.barh(ax=ax1)\nax1.set_xlabel(\"Gini importance\")\nplot_permutation_importance(clf, X_train, y_train, ax2)\nax2.set_xlabel(\"Decrease in accuracy score\")\nfig.suptitle(\n \"Impurity-based vs. permutation importances on multicollinear features (train set)\"\n)\n_ = fig.tight_layout()"
6262
]
6363
},
6464
{
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

1.5/_downloads/4ee88a807e060ca374ab95e0d8d819ed/plot_ica_vs_pca.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
},
5252
"outputs": [],
5353
"source": [
54-
"import matplotlib.pyplot as plt\n\n\ndef plot_samples(S, axis_list=None):\n plt.scatter(\n S[:, 0], S[:, 1], s=2, marker=\"o\", zorder=10, color=\"steelblue\", alpha=0.5\n )\n if axis_list is not None:\n for axis, color, label in axis_list:\n axis /= axis.std()\n x_axis, y_axis = axis\n plt.quiver(\n (0, 0),\n (0, 0),\n x_axis,\n y_axis,\n zorder=11,\n width=0.01,\n scale=6,\n color=color,\n label=label,\n )\n\n plt.hlines(0, -3, 3)\n plt.vlines(0, -3, 3)\n plt.xlim(-3, 3)\n plt.ylim(-3, 3)\n plt.xlabel(\"x\")\n plt.ylabel(\"y\")\n\n\nplt.figure()\nplt.subplot(2, 2, 1)\nplot_samples(S / S.std())\nplt.title(\"True Independent Sources\")\n\naxis_list = [(pca.components_.T, \"orange\", \"PCA\"), (ica.mixing_, \"red\", \"ICA\")]\nplt.subplot(2, 2, 2)\nplot_samples(X / np.std(X), axis_list=axis_list)\nlegend = plt.legend(loc=\"lower right\")\nlegend.set_zorder(100)\n\nplt.title(\"Observations\")\n\nplt.subplot(2, 2, 3)\nplot_samples(S_pca_ / np.std(S_pca_, axis=0))\nplt.title(\"PCA recovered signals\")\n\nplt.subplot(2, 2, 4)\nplot_samples(S_ica_ / np.std(S_ica_))\nplt.title(\"ICA recovered signals\")\n\nplt.subplots_adjust(0.09, 0.04, 0.94, 0.94, 0.26, 0.36)\nplt.tight_layout()\nplt.show()"
54+
"import matplotlib.pyplot as plt\n\n\ndef plot_samples(S, axis_list=None):\n plt.scatter(\n S[:, 0], S[:, 1], s=2, marker=\"o\", zorder=10, color=\"steelblue\", alpha=0.5\n )\n if axis_list is not None:\n for axis, color, label in axis_list:\n x_axis, y_axis = axis / axis.std()\n plt.quiver(\n (0, 0),\n (0, 0),\n x_axis,\n y_axis,\n zorder=11,\n width=0.01,\n scale=6,\n color=color,\n label=label,\n )\n\n plt.hlines(0, -5, 5, color=\"black\", linewidth=0.5)\n plt.vlines(0, -3, 3, color=\"black\", linewidth=0.5)\n plt.xlim(-5, 5)\n plt.ylim(-3, 3)\n plt.gca().set_aspect(\"equal\")\n plt.xlabel(\"x\")\n plt.ylabel(\"y\")\n\n\nplt.figure()\nplt.subplot(2, 2, 1)\nplot_samples(S / S.std())\nplt.title(\"True Independent Sources\")\n\naxis_list = [(pca.components_.T, \"orange\", \"PCA\"), (ica.mixing_, \"red\", \"ICA\")]\nplt.subplot(2, 2, 2)\nplot_samples(X / np.std(X), axis_list=axis_list)\nlegend = plt.legend(loc=\"upper left\")\nlegend.set_zorder(100)\n\nplt.title(\"Observations\")\n\nplt.subplot(2, 2, 3)\nplot_samples(S_pca_ / np.std(S_pca_))\nplt.title(\"PCA recovered signals\")\n\nplt.subplot(2, 2, 4)\nplot_samples(S_ica_ / np.std(S_ica_))\nplt.title(\"ICA recovered signals\")\n\nplt.subplots_adjust(0.09, 0.04, 0.94, 0.94, 0.26, 0.36)\nplt.tight_layout()\nplt.show()"
5555
]
5656
}
5757
],
Binary file not shown.
Binary file not shown.

1.5/_downloads/50040ae12dd16e7d2e79135d7793c17e/plot_release_highlights_0_22_0.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
# `plot_confusion_matrix`. Read more about this new API in the
3535
# :ref:`User Guide <visualizations>`.
3636

37+
import matplotlib
3738
import matplotlib.pyplot as plt
3839

3940
from sklearn.datasets import make_classification
@@ -43,6 +44,7 @@
4344
from sklearn.metrics import RocCurveDisplay
4445
from sklearn.model_selection import train_test_split
4546
from sklearn.svm import SVC
47+
from sklearn.utils.fixes import parse_version
4648

4749
X, y = make_classification(random_state=0)
4850
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
@@ -117,9 +119,18 @@
117119

118120
fig, ax = plt.subplots()
119121
sorted_idx = result.importances_mean.argsort()
120-
ax.boxplot(
121-
result.importances[sorted_idx].T, vert=False, labels=feature_names[sorted_idx]
122+
123+
# `labels` argument in boxplot is deprecated in matplotlib 3.9 and has been
124+
# renamed to `tick_labels`. The following code handles this, but as a
125+
# scikit-learn user you probably can write simpler code by using `labels=...`
126+
# (matplotlib < 3.9) or `tick_labels=...` (matplotlib >= 3.9).
127+
tick_labels_parameter_name = (
128+
"tick_labels"
129+
if parse_version(matplotlib.__version__) >= parse_version("3.9")
130+
else "labels"
122131
)
132+
tick_labels_dict = {tick_labels_parameter_name: feature_names[sorted_idx]}
133+
ax.boxplot(result.importances[sorted_idx].T, vert=False, **tick_labels_dict)
123134
ax.set_title("Permutation Importance of each feature")
124135
ax.set_ylabel("Features")
125136
fig.tight_layout()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

1.5/_downloads/69878e8e2864920aa874c5a68cecf1d3/plot_species_distribution_modeling.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
1818
The two species are:
1919
20-
- `"Bradypus variegatus"
21-
<http://www.iucnredlist.org/details/3038/0>`_ ,
22-
the Brown-throated Sloth.
20+
- `Bradypus variegatus
21+
<http://www.iucnredlist.org/details/3038/0>`_,
22+
the brown-throated sloth.
2323
24-
- `"Microryzomys minutus"
25-
<http://www.iucnredlist.org/details/13408/0>`_ ,
26-
also known as the Forest Small Rice Rat, a rodent that lives in Peru,
24+
- `Microryzomys minutus
25+
<http://www.iucnredlist.org/details/13408/0>`_,
26+
also known as the forest small rice rat, a rodent that lives in Peru,
2727
Colombia, Ecuador, Peru, and Venezuela.
2828
2929
References
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

1.5/_downloads/756be88c4ccd4c7bba02ab13f0d3258a/plot_permutation_importance_multicollinear.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,27 @@
2626
# ------------------------------------------------------
2727
#
2828
# First, we define a function to ease the plotting:
29+
import matplotlib
30+
2931
from sklearn.inspection import permutation_importance
32+
from sklearn.utils.fixes import parse_version
3033

3134

3235
def plot_permutation_importance(clf, X, y, ax):
3336
result = permutation_importance(clf, X, y, n_repeats=10, random_state=42, n_jobs=2)
3437
perm_sorted_idx = result.importances_mean.argsort()
3538

36-
ax.boxplot(
37-
result.importances[perm_sorted_idx].T,
38-
vert=False,
39-
labels=X.columns[perm_sorted_idx],
39+
# `labels` argument in boxplot is deprecated in matplotlib 3.9 and has been
40+
# renamed to `tick_labels`. The following code handles this, but as a
41+
# scikit-learn user you probably can write simpler code by using `labels=...`
42+
# (matplotlib < 3.9) or `tick_labels=...` (matplotlib >= 3.9).
43+
tick_labels_parameter_name = (
44+
"tick_labels"
45+
if parse_version(matplotlib.__version__) >= parse_version("3.9")
46+
else "labels"
4047
)
48+
tick_labels_dict = {tick_labels_parameter_name: X.columns[perm_sorted_idx]}
49+
ax.boxplot(result.importances[perm_sorted_idx].T, vert=False, **tick_labels_dict)
4150
ax.axvline(x=0, color="k", linestyle="--")
4251
return ax
4352

@@ -66,7 +75,6 @@ def plot_permutation_importance(clf, X, y, ax):
6675

6776
mdi_importances = pd.Series(clf.feature_importances_, index=X_train.columns)
6877
tree_importance_sorted_idx = np.argsort(clf.feature_importances_)
69-
tree_indices = np.arange(0, len(clf.feature_importances_)) + 0.5
7078

7179
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
7280
mdi_importances.sort_values().plot.barh(ax=ax1)

1.5/_downloads/75a08bb798ae7156529a808a0e08e7b4/plot_gpr_prior_posterior.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
"cell_type": "markdown",
8888
"metadata": {},
8989
"source": [
90-
"### Rational Quadradtic kernel\n\n"
90+
"### Rational Quadratic kernel\n\n"
9191
]
9292
},
9393
{
@@ -156,7 +156,7 @@
156156
},
157157
"outputs": [],
158158
"source": [
159-
"from sklearn.gaussian_process.kernels import ConstantKernel, DotProduct\n\nkernel = ConstantKernel(0.1, (0.01, 10.0)) * (\n DotProduct(sigma_0=1.0, sigma_0_bounds=(0.1, 10.0)) ** 2\n)\ngpr = GaussianProcessRegressor(kernel=kernel, random_state=0)\n\nfig, axs = plt.subplots(nrows=2, sharex=True, sharey=True, figsize=(10, 8))\n\n# plot prior\nplot_gpr_samples(gpr, n_samples=n_samples, ax=axs[0])\naxs[0].set_title(\"Samples from prior distribution\")\n\n# plot posterior\ngpr.fit(X_train, y_train)\nplot_gpr_samples(gpr, n_samples=n_samples, ax=axs[1])\naxs[1].scatter(X_train[:, 0], y_train, color=\"red\", zorder=10, label=\"Observations\")\naxs[1].legend(bbox_to_anchor=(1.05, 1.5), loc=\"upper left\")\naxs[1].set_title(\"Samples from posterior distribution\")\n\nfig.suptitle(\"Dot-product kernel\", fontsize=18)\nplt.tight_layout()"
159+
"from sklearn.gaussian_process.kernels import ConstantKernel, DotProduct\n\nkernel = ConstantKernel(0.1, (0.01, 10.0)) * (\n DotProduct(sigma_0=1.0, sigma_0_bounds=(0.1, 10.0)) ** 2\n)\ngpr = GaussianProcessRegressor(kernel=kernel, random_state=0, normalize_y=True)\n\nfig, axs = plt.subplots(nrows=2, sharex=True, sharey=True, figsize=(10, 8))\n\n# plot prior\nplot_gpr_samples(gpr, n_samples=n_samples, ax=axs[0])\naxs[0].set_title(\"Samples from prior distribution\")\n\n# plot posterior\ngpr.fit(X_train, y_train)\nplot_gpr_samples(gpr, n_samples=n_samples, ax=axs[1])\naxs[1].scatter(X_train[:, 0], y_train, color=\"red\", zorder=10, label=\"Observations\")\naxs[1].legend(bbox_to_anchor=(1.05, 1.5), loc=\"upper left\")\naxs[1].set_title(\"Samples from posterior distribution\")\n\nfig.suptitle(\"Dot-product kernel\", fontsize=18)\nplt.tight_layout()"
160160
]
161161
},
162162
{

0 commit comments

Comments
 (0)