-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
/
Copy pathkernels.py
2410 lines (1989 loc) · 83.3 KB
/
kernels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""A set of kernels that can be combined by operators and used in Gaussian processes."""
# Kernels for Gaussian process regression and classification.
#
# The kernels in this module allow kernel-engineering, i.e., they can be
# combined via the "+" and "*" operators or be exponentiated with a scalar
# via "**". These sum and product expressions can also contain scalar values,
# which are automatically converted to a constant kernel.
#
# All kernels allow (analytic) gradient-based hyperparameter optimization.
# The space of hyperparameters can be specified by giving lower und upper
# boundaries for the value of each hyperparameter (the search space is thus
# rectangular). Instead of specifying bounds, hyperparameters can also be
# declared to be "fixed", which causes these hyperparameters to be excluded from
# optimization.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
# Note: this module is strongly inspired by the kernel module of the george
# package.
import math
import warnings
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from inspect import signature
import numpy as np
from scipy.spatial.distance import cdist, pdist, squareform
from scipy.special import gamma, kv
from ..base import clone
from ..exceptions import ConvergenceWarning
from ..metrics.pairwise import pairwise_kernels
from ..utils.validation import _num_samples
def _check_length_scale(X, length_scale):
length_scale = np.squeeze(length_scale).astype(float)
if np.ndim(length_scale) > 1:
raise ValueError("length_scale cannot be of dimension greater than 1")
if np.ndim(length_scale) == 1 and X.shape[1] != length_scale.shape[0]:
raise ValueError(
"Anisotropic kernel must have the same number of "
"dimensions as data (%d!=%d)" % (length_scale.shape[0], X.shape[1])
)
return length_scale
class Hyperparameter(
namedtuple(
"Hyperparameter", ("name", "value_type", "bounds", "n_elements", "fixed")
)
):
"""A kernel hyperparameter's specification in form of a namedtuple.
.. versionadded:: 0.18
Attributes
----------
name : str
The name of the hyperparameter. Note that a kernel using a
hyperparameter with name "x" must have the attributes self.x and
self.x_bounds
value_type : str
The type of the hyperparameter. Currently, only "numeric"
hyperparameters are supported.
bounds : pair of floats >= 0 or "fixed"
The lower and upper bound on the parameter. If n_elements>1, a pair
of 1d array with n_elements each may be given alternatively. If
the string "fixed" is passed as bounds, the hyperparameter's value
cannot be changed.
n_elements : int, default=1
The number of elements of the hyperparameter value. Defaults to 1,
which corresponds to a scalar hyperparameter. n_elements > 1
corresponds to a hyperparameter which is vector-valued,
such as, e.g., anisotropic length-scales.
fixed : bool, default=None
Whether the value of this hyperparameter is fixed, i.e., cannot be
changed during hyperparameter tuning. If None is passed, the "fixed" is
derived based on the given bounds.
Examples
--------
>>> from sklearn.gaussian_process.kernels import ConstantKernel
>>> from sklearn.datasets import make_friedman2
>>> from sklearn.gaussian_process import GaussianProcessRegressor
>>> from sklearn.gaussian_process.kernels import Hyperparameter
>>> X, y = make_friedman2(n_samples=50, noise=0, random_state=0)
>>> kernel = ConstantKernel(constant_value=1.0,
... constant_value_bounds=(0.0, 10.0))
We can access each hyperparameter:
>>> for hyperparameter in kernel.hyperparameters:
... print(hyperparameter)
Hyperparameter(name='constant_value', value_type='numeric',
bounds=array([[ 0., 10.]]), n_elements=1, fixed=False)
>>> params = kernel.get_params()
>>> for key in sorted(params): print(f"{key} : {params[key]}")
constant_value : 1.0
constant_value_bounds : (0.0, 10.0)
"""
# A raw namedtuple is very memory efficient as it packs the attributes
# in a struct to get rid of the __dict__ of attributes in particular it
# does not copy the string for the keys on each instance.
# By deriving a namedtuple class just to introduce the __init__ method we
# would also reintroduce the __dict__ on the instance. By telling the
# Python interpreter that this subclass uses static __slots__ instead of
# dynamic attributes. Furthermore we don't need any additional slot in the
# subclass so we set __slots__ to the empty tuple.
__slots__ = ()
def __new__(cls, name, value_type, bounds, n_elements=1, fixed=None):
if not isinstance(bounds, str) or bounds != "fixed":
bounds = np.atleast_2d(bounds)
if n_elements > 1: # vector-valued parameter
if bounds.shape[0] == 1:
bounds = np.repeat(bounds, n_elements, 0)
elif bounds.shape[0] != n_elements:
raise ValueError(
"Bounds on %s should have either 1 or "
"%d dimensions. Given are %d"
% (name, n_elements, bounds.shape[0])
)
if fixed is None:
fixed = isinstance(bounds, str) and bounds == "fixed"
return super(Hyperparameter, cls).__new__(
cls, name, value_type, bounds, n_elements, fixed
)
# This is mainly a testing utility to check that two hyperparameters
# are equal.
def __eq__(self, other):
return (
self.name == other.name
and self.value_type == other.value_type
and np.all(self.bounds == other.bounds)
and self.n_elements == other.n_elements
and self.fixed == other.fixed
)
class Kernel(metaclass=ABCMeta):
"""Base class for all kernels.
.. versionadded:: 0.18
Examples
--------
>>> from sklearn.gaussian_process.kernels import Kernel, RBF
>>> import numpy as np
>>> class CustomKernel(Kernel):
... def __init__(self, length_scale=1.0):
... self.length_scale = length_scale
... def __call__(self, X, Y=None):
... if Y is None:
... Y = X
... return np.inner(X, X if Y is None else Y) ** 2
... def diag(self, X):
... return np.ones(X.shape[0])
... def is_stationary(self):
... return True
>>> kernel = CustomKernel(length_scale=2.0)
>>> X = np.array([[1, 2], [3, 4]])
>>> print(kernel(X))
[[ 25 121]
[121 625]]
"""
def get_params(self, deep=True):
"""Get parameters of this kernel.
Parameters
----------
deep : bool, default=True
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
params = dict()
# introspect the constructor arguments to find the model parameters
# to represent
cls = self.__class__
init = getattr(cls.__init__, "deprecated_original", cls.__init__)
init_sign = signature(init)
args, varargs = [], []
for parameter in init_sign.parameters.values():
if parameter.kind != parameter.VAR_KEYWORD and parameter.name != "self":
args.append(parameter.name)
if parameter.kind == parameter.VAR_POSITIONAL:
varargs.append(parameter.name)
if len(varargs) != 0:
raise RuntimeError(
"scikit-learn kernels should always "
"specify their parameters in the signature"
" of their __init__ (no varargs)."
" %s doesn't follow this convention." % (cls,)
)
for arg in args:
params[arg] = getattr(self, arg)
return params
def set_params(self, **params):
"""Set the parameters of this kernel.
The method works on simple kernels as well as on nested kernels.
The latter have parameters of the form ``<component>__<parameter>``
so that it's possible to update each component of a nested object.
Returns
-------
self
"""
if not params:
# Simple optimisation to gain speed (inspect is slow)
return self
valid_params = self.get_params(deep=True)
for key, value in params.items():
split = key.split("__", 1)
if len(split) > 1:
# nested objects case
name, sub_name = split
if name not in valid_params:
raise ValueError(
"Invalid parameter %s for kernel %s. "
"Check the list of available parameters "
"with `kernel.get_params().keys()`." % (name, self)
)
sub_object = valid_params[name]
sub_object.set_params(**{sub_name: value})
else:
# simple objects case
if key not in valid_params:
raise ValueError(
"Invalid parameter %s for kernel %s. "
"Check the list of available parameters "
"with `kernel.get_params().keys()`."
% (key, self.__class__.__name__)
)
setattr(self, key, value)
return self
def clone_with_theta(self, theta):
"""Returns a clone of self with given hyperparameters theta.
Parameters
----------
theta : ndarray of shape (n_dims,)
The hyperparameters
"""
cloned = clone(self)
cloned.theta = theta
return cloned
@property
def n_dims(self):
"""Returns the number of non-fixed hyperparameters of the kernel."""
return self.theta.shape[0]
@property
def hyperparameters(self):
"""Returns a list of all hyperparameter specifications."""
r = [
getattr(self, attr)
for attr in dir(self)
if attr.startswith("hyperparameter_")
]
return r
@property
def theta(self):
"""Returns the (flattened, log-transformed) non-fixed hyperparameters.
Note that theta are typically the log-transformed values of the
kernel's hyperparameters as this representation of the search space
is more amenable for hyperparameter search, as hyperparameters like
length-scales naturally live on a log-scale.
Returns
-------
theta : ndarray of shape (n_dims,)
The non-fixed, log-transformed hyperparameters of the kernel
"""
theta = []
params = self.get_params()
for hyperparameter in self.hyperparameters:
if not hyperparameter.fixed:
theta.append(params[hyperparameter.name])
if len(theta) > 0:
return np.log(np.hstack(theta))
else:
return np.array([])
@theta.setter
def theta(self, theta):
"""Sets the (flattened, log-transformed) non-fixed hyperparameters.
Parameters
----------
theta : ndarray of shape (n_dims,)
The non-fixed, log-transformed hyperparameters of the kernel
"""
params = self.get_params()
i = 0
for hyperparameter in self.hyperparameters:
if hyperparameter.fixed:
continue
if hyperparameter.n_elements > 1:
# vector-valued parameter
params[hyperparameter.name] = np.exp(
theta[i : i + hyperparameter.n_elements]
)
i += hyperparameter.n_elements
else:
params[hyperparameter.name] = np.exp(theta[i])
i += 1
if i != len(theta):
raise ValueError(
"theta has not the correct number of entries."
" Should be %d; given are %d" % (i, len(theta))
)
self.set_params(**params)
@property
def bounds(self):
"""Returns the log-transformed bounds on the theta.
Returns
-------
bounds : ndarray of shape (n_dims, 2)
The log-transformed bounds on the kernel's hyperparameters theta
"""
bounds = [
hyperparameter.bounds
for hyperparameter in self.hyperparameters
if not hyperparameter.fixed
]
if len(bounds) > 0:
return np.log(np.vstack(bounds))
else:
return np.array([])
def __add__(self, b):
if not isinstance(b, Kernel):
return Sum(self, ConstantKernel(b))
return Sum(self, b)
def __radd__(self, b):
if not isinstance(b, Kernel):
return Sum(ConstantKernel(b), self)
return Sum(b, self)
def __mul__(self, b):
if not isinstance(b, Kernel):
return Product(self, ConstantKernel(b))
return Product(self, b)
def __rmul__(self, b):
if not isinstance(b, Kernel):
return Product(ConstantKernel(b), self)
return Product(b, self)
def __pow__(self, b):
return Exponentiation(self, b)
def __eq__(self, b):
if type(self) != type(b):
return False
params_a = self.get_params()
params_b = b.get_params()
for key in set(list(params_a.keys()) + list(params_b.keys())):
if np.any(params_a.get(key, None) != params_b.get(key, None)):
return False
return True
def __repr__(self):
return "{0}({1})".format(
self.__class__.__name__, ", ".join(map("{0:.3g}".format, self.theta))
)
@abstractmethod
def __call__(self, X, Y=None, eval_gradient=False):
"""Evaluate the kernel."""
@abstractmethod
def diag(self, X):
"""Returns the diagonal of the kernel k(X, X).
The result of this method is identical to np.diag(self(X)); however,
it can be evaluated more efficiently since only the diagonal is
evaluated.
Parameters
----------
X : array-like of shape (n_samples,)
Left argument of the returned kernel k(X, Y)
Returns
-------
K_diag : ndarray of shape (n_samples_X,)
Diagonal of kernel k(X, X)
"""
@abstractmethod
def is_stationary(self):
"""Returns whether the kernel is stationary."""
@property
def requires_vector_input(self):
"""Returns whether the kernel is defined on fixed-length feature
vectors or generic objects. Defaults to True for backward
compatibility."""
return True
def _check_bounds_params(self):
"""Called after fitting to warn if bounds may have been too tight."""
list_close = np.isclose(self.bounds, np.atleast_2d(self.theta).T)
idx = 0
for hyp in self.hyperparameters:
if hyp.fixed:
continue
for dim in range(hyp.n_elements):
if list_close[idx, 0]:
warnings.warn(
"The optimal value found for "
"dimension %s of parameter %s is "
"close to the specified lower "
"bound %s. Decreasing the bound and"
" calling fit again may find a "
"better value." % (dim, hyp.name, hyp.bounds[dim][0]),
ConvergenceWarning,
)
elif list_close[idx, 1]:
warnings.warn(
"The optimal value found for "
"dimension %s of parameter %s is "
"close to the specified upper "
"bound %s. Increasing the bound and"
" calling fit again may find a "
"better value." % (dim, hyp.name, hyp.bounds[dim][1]),
ConvergenceWarning,
)
idx += 1
class NormalizedKernelMixin:
"""Mixin for kernels which are normalized: k(X, X)=1.
.. versionadded:: 0.18
"""
def diag(self, X):
"""Returns the diagonal of the kernel k(X, X).
The result of this method is identical to np.diag(self(X)); however,
it can be evaluated more efficiently since only the diagonal is
evaluated.
Parameters
----------
X : ndarray of shape (n_samples_X, n_features)
Left argument of the returned kernel k(X, Y)
Returns
-------
K_diag : ndarray of shape (n_samples_X,)
Diagonal of kernel k(X, X)
"""
return np.ones(X.shape[0])
class StationaryKernelMixin:
"""Mixin for kernels which are stationary: k(X, Y)= f(X-Y).
.. versionadded:: 0.18
"""
def is_stationary(self):
"""Returns whether the kernel is stationary."""
return True
class GenericKernelMixin:
"""Mixin for kernels which operate on generic objects such as variable-
length sequences, trees, and graphs.
.. versionadded:: 0.22
"""
@property
def requires_vector_input(self):
"""Whether the kernel works only on fixed-length feature vectors."""
return False
class CompoundKernel(Kernel):
"""Kernel which is composed of a set of other kernels.
.. versionadded:: 0.18
Parameters
----------
kernels : list of Kernels
The other kernels
Examples
--------
>>> from sklearn.gaussian_process.kernels import WhiteKernel
>>> from sklearn.gaussian_process.kernels import RBF
>>> from sklearn.gaussian_process.kernels import CompoundKernel
>>> kernel = CompoundKernel(
... [WhiteKernel(noise_level=3.0), RBF(length_scale=2.0)])
>>> print(kernel.bounds)
[[-11.51292546 11.51292546]
[-11.51292546 11.51292546]]
>>> print(kernel.n_dims)
2
>>> print(kernel.theta)
[1.09861229 0.69314718]
"""
def __init__(self, kernels):
self.kernels = kernels
def get_params(self, deep=True):
"""Get parameters of this kernel.
Parameters
----------
deep : bool, default=True
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
return dict(kernels=self.kernels)
@property
def theta(self):
"""Returns the (flattened, log-transformed) non-fixed hyperparameters.
Note that theta are typically the log-transformed values of the
kernel's hyperparameters as this representation of the search space
is more amenable for hyperparameter search, as hyperparameters like
length-scales naturally live on a log-scale.
Returns
-------
theta : ndarray of shape (n_dims,)
The non-fixed, log-transformed hyperparameters of the kernel
"""
return np.hstack([kernel.theta for kernel in self.kernels])
@theta.setter
def theta(self, theta):
"""Sets the (flattened, log-transformed) non-fixed hyperparameters.
Parameters
----------
theta : array of shape (n_dims,)
The non-fixed, log-transformed hyperparameters of the kernel
"""
k_dims = self.k1.n_dims
for i, kernel in enumerate(self.kernels):
kernel.theta = theta[i * k_dims : (i + 1) * k_dims]
@property
def bounds(self):
"""Returns the log-transformed bounds on the theta.
Returns
-------
bounds : array of shape (n_dims, 2)
The log-transformed bounds on the kernel's hyperparameters theta
"""
return np.vstack([kernel.bounds for kernel in self.kernels])
def __call__(self, X, Y=None, eval_gradient=False):
"""Return the kernel k(X, Y) and optionally its gradient.
Note that this compound kernel returns the results of all simple kernel
stacked along an additional axis.
Parameters
----------
X : array-like of shape (n_samples_X, n_features) or list of object, \
default=None
Left argument of the returned kernel k(X, Y)
Y : array-like of shape (n_samples_X, n_features) or list of object, \
default=None
Right argument of the returned kernel k(X, Y). If None, k(X, X)
is evaluated instead.
eval_gradient : bool, default=False
Determines whether the gradient with respect to the log of the
kernel hyperparameter is computed.
Returns
-------
K : ndarray of shape (n_samples_X, n_samples_Y, n_kernels)
Kernel k(X, Y)
K_gradient : ndarray of shape \
(n_samples_X, n_samples_X, n_dims, n_kernels), optional
The gradient of the kernel k(X, X) with respect to the log of the
hyperparameter of the kernel. Only returned when `eval_gradient`
is True.
"""
if eval_gradient:
K = []
K_grad = []
for kernel in self.kernels:
K_single, K_grad_single = kernel(X, Y, eval_gradient)
K.append(K_single)
K_grad.append(K_grad_single[..., np.newaxis])
return np.dstack(K), np.concatenate(K_grad, 3)
else:
return np.dstack([kernel(X, Y, eval_gradient) for kernel in self.kernels])
def __eq__(self, b):
if type(self) != type(b) or len(self.kernels) != len(b.kernels):
return False
return np.all(
[self.kernels[i] == b.kernels[i] for i in range(len(self.kernels))]
)
def is_stationary(self):
"""Returns whether the kernel is stationary."""
return np.all([kernel.is_stationary() for kernel in self.kernels])
@property
def requires_vector_input(self):
"""Returns whether the kernel is defined on discrete structures."""
return np.any([kernel.requires_vector_input for kernel in self.kernels])
def diag(self, X):
"""Returns the diagonal of the kernel k(X, X).
The result of this method is identical to `np.diag(self(X))`; however,
it can be evaluated more efficiently since only the diagonal is
evaluated.
Parameters
----------
X : array-like of shape (n_samples_X, n_features) or list of object
Argument to the kernel.
Returns
-------
K_diag : ndarray of shape (n_samples_X, n_kernels)
Diagonal of kernel k(X, X)
"""
return np.vstack([kernel.diag(X) for kernel in self.kernels]).T
class KernelOperator(Kernel):
"""Base class for all kernel operators.
.. versionadded:: 0.18
"""
def __init__(self, k1, k2):
self.k1 = k1
self.k2 = k2
def get_params(self, deep=True):
"""Get parameters of this kernel.
Parameters
----------
deep : bool, default=True
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
params = dict(k1=self.k1, k2=self.k2)
if deep:
deep_items = self.k1.get_params().items()
params.update(("k1__" + k, val) for k, val in deep_items)
deep_items = self.k2.get_params().items()
params.update(("k2__" + k, val) for k, val in deep_items)
return params
@property
def hyperparameters(self):
"""Returns a list of all hyperparameter."""
r = [
Hyperparameter(
"k1__" + hyperparameter.name,
hyperparameter.value_type,
hyperparameter.bounds,
hyperparameter.n_elements,
)
for hyperparameter in self.k1.hyperparameters
]
for hyperparameter in self.k2.hyperparameters:
r.append(
Hyperparameter(
"k2__" + hyperparameter.name,
hyperparameter.value_type,
hyperparameter.bounds,
hyperparameter.n_elements,
)
)
return r
@property
def theta(self):
"""Returns the (flattened, log-transformed) non-fixed hyperparameters.
Note that theta are typically the log-transformed values of the
kernel's hyperparameters as this representation of the search space
is more amenable for hyperparameter search, as hyperparameters like
length-scales naturally live on a log-scale.
Returns
-------
theta : ndarray of shape (n_dims,)
The non-fixed, log-transformed hyperparameters of the kernel
"""
return np.append(self.k1.theta, self.k2.theta)
@theta.setter
def theta(self, theta):
"""Sets the (flattened, log-transformed) non-fixed hyperparameters.
Parameters
----------
theta : ndarray of shape (n_dims,)
The non-fixed, log-transformed hyperparameters of the kernel
"""
k1_dims = self.k1.n_dims
self.k1.theta = theta[:k1_dims]
self.k2.theta = theta[k1_dims:]
@property
def bounds(self):
"""Returns the log-transformed bounds on the theta.
Returns
-------
bounds : ndarray of shape (n_dims, 2)
The log-transformed bounds on the kernel's hyperparameters theta
"""
if self.k1.bounds.size == 0:
return self.k2.bounds
if self.k2.bounds.size == 0:
return self.k1.bounds
return np.vstack((self.k1.bounds, self.k2.bounds))
def __eq__(self, b):
if type(self) != type(b):
return False
return (self.k1 == b.k1 and self.k2 == b.k2) or (
self.k1 == b.k2 and self.k2 == b.k1
)
def is_stationary(self):
"""Returns whether the kernel is stationary."""
return self.k1.is_stationary() and self.k2.is_stationary()
@property
def requires_vector_input(self):
"""Returns whether the kernel is stationary."""
return self.k1.requires_vector_input or self.k2.requires_vector_input
class Sum(KernelOperator):
"""The `Sum` kernel takes two kernels :math:`k_1` and :math:`k_2`
and combines them via
.. math::
k_{sum}(X, Y) = k_1(X, Y) + k_2(X, Y)
Note that the `__add__` magic method is overridden, so
`Sum(RBF(), RBF())` is equivalent to using the + operator
with `RBF() + RBF()`.
Read more in the :ref:`User Guide <gp_kernels>`.
.. versionadded:: 0.18
Parameters
----------
k1 : Kernel
The first base-kernel of the sum-kernel
k2 : Kernel
The second base-kernel of the sum-kernel
Examples
--------
>>> from sklearn.datasets import make_friedman2
>>> from sklearn.gaussian_process import GaussianProcessRegressor
>>> from sklearn.gaussian_process.kernels import RBF, Sum, ConstantKernel
>>> X, y = make_friedman2(n_samples=500, noise=0, random_state=0)
>>> kernel = Sum(ConstantKernel(2), RBF())
>>> gpr = GaussianProcessRegressor(kernel=kernel,
... random_state=0).fit(X, y)
>>> gpr.score(X, y)
1.0
>>> kernel
1.41**2 + RBF(length_scale=1)
"""
def __call__(self, X, Y=None, eval_gradient=False):
"""Return the kernel k(X, Y) and optionally its gradient.
Parameters
----------
X : array-like of shape (n_samples_X, n_features) or list of object
Left argument of the returned kernel k(X, Y)
Y : array-like of shape (n_samples_X, n_features) or list of object,\
default=None
Right argument of the returned kernel k(X, Y). If None, k(X, X)
is evaluated instead.
eval_gradient : bool, default=False
Determines whether the gradient with respect to the log of
the kernel hyperparameter is computed.
Returns
-------
K : ndarray of shape (n_samples_X, n_samples_Y)
Kernel k(X, Y)
K_gradient : ndarray of shape (n_samples_X, n_samples_X, n_dims),\
optional
The gradient of the kernel k(X, X) with respect to the log of the
hyperparameter of the kernel. Only returned when `eval_gradient`
is True.
"""
if eval_gradient:
K1, K1_gradient = self.k1(X, Y, eval_gradient=True)
K2, K2_gradient = self.k2(X, Y, eval_gradient=True)
return K1 + K2, np.dstack((K1_gradient, K2_gradient))
else:
return self.k1(X, Y) + self.k2(X, Y)
def diag(self, X):
"""Returns the diagonal of the kernel k(X, X).
The result of this method is identical to `np.diag(self(X))`; however,
it can be evaluated more efficiently since only the diagonal is
evaluated.
Parameters
----------
X : array-like of shape (n_samples_X, n_features) or list of object
Argument to the kernel.
Returns
-------
K_diag : ndarray of shape (n_samples_X,)
Diagonal of kernel k(X, X)
"""
return self.k1.diag(X) + self.k2.diag(X)
def __repr__(self):
return "{0} + {1}".format(self.k1, self.k2)
class Product(KernelOperator):
"""The `Product` kernel takes two kernels :math:`k_1` and :math:`k_2`
and combines them via
.. math::
k_{prod}(X, Y) = k_1(X, Y) * k_2(X, Y)
Note that the `__mul__` magic method is overridden, so
`Product(RBF(), RBF())` is equivalent to using the * operator
with `RBF() * RBF()`.
Read more in the :ref:`User Guide <gp_kernels>`.
.. versionadded:: 0.18
Parameters
----------
k1 : Kernel
The first base-kernel of the product-kernel
k2 : Kernel
The second base-kernel of the product-kernel
Examples
--------
>>> from sklearn.datasets import make_friedman2
>>> from sklearn.gaussian_process import GaussianProcessRegressor
>>> from sklearn.gaussian_process.kernels import (RBF, Product,
... ConstantKernel)
>>> X, y = make_friedman2(n_samples=500, noise=0, random_state=0)
>>> kernel = Product(ConstantKernel(2), RBF())
>>> gpr = GaussianProcessRegressor(kernel=kernel,
... random_state=0).fit(X, y)
>>> gpr.score(X, y)
1.0
>>> kernel
1.41**2 * RBF(length_scale=1)
"""
def __call__(self, X, Y=None, eval_gradient=False):
"""Return the kernel k(X, Y) and optionally its gradient.
Parameters
----------
X : array-like of shape (n_samples_X, n_features) or list of object
Left argument of the returned kernel k(X, Y)
Y : array-like of shape (n_samples_Y, n_features) or list of object,\
default=None
Right argument of the returned kernel k(X, Y). If None, k(X, X)
is evaluated instead.
eval_gradient : bool, default=False
Determines whether the gradient with respect to the log of
the kernel hyperparameter is computed.
Returns
-------
K : ndarray of shape (n_samples_X, n_samples_Y)
Kernel k(X, Y)
K_gradient : ndarray of shape (n_samples_X, n_samples_X, n_dims), \
optional
The gradient of the kernel k(X, X) with respect to the log of the
hyperparameter of the kernel. Only returned when `eval_gradient`
is True.
"""
if eval_gradient:
K1, K1_gradient = self.k1(X, Y, eval_gradient=True)
K2, K2_gradient = self.k2(X, Y, eval_gradient=True)
return K1 * K2, np.dstack(
(K1_gradient * K2[:, :, np.newaxis], K2_gradient * K1[:, :, np.newaxis])
)
else:
return self.k1(X, Y) * self.k2(X, Y)
def diag(self, X):
"""Returns the diagonal of the kernel k(X, X).
The result of this method is identical to np.diag(self(X)); however,
it can be evaluated more efficiently since only the diagonal is
evaluated.
Parameters
----------
X : array-like of shape (n_samples_X, n_features) or list of object
Argument to the kernel.
Returns
-------
K_diag : ndarray of shape (n_samples_X,)
Diagonal of kernel k(X, X)
"""
return self.k1.diag(X) * self.k2.diag(X)
def __repr__(self):
return "{0} * {1}".format(self.k1, self.k2)
class Exponentiation(Kernel):
"""The Exponentiation kernel takes one base kernel and a scalar parameter
:math:`p` and combines them via
.. math::
k_{exp}(X, Y) = k(X, Y) ^p
Note that the `__pow__` magic method is overridden, so