Coverage for src / chebpy / gpr.py: 100%
178 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 07:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 07:22 +0000
1"""Gaussian process regression with Chebfun representations.
3Implements Gaussian process regression (GPR) following the algorithm described
4in Rasmussen & Williams, *Gaussian Processes for Machine Learning*, MIT Press,
52006, and the MATLAB Chebfun ``gpr.m`` by The University of Oxford and The
6Chebfun Developers.
8The posterior mean, variance, and (optionally) random samples from the
9posterior are all returned as Chebfun / Quasimatrix objects so that they can
10be manipulated with the full ChebPy toolkit (differentiation, integration,
11rootfinding, etc.).
13Reference:
14 C. E. Rasmussen & C. K. I. Williams, "Gaussian Processes for Machine
15 Learning", MIT Press, 2006.
16"""
18from __future__ import annotations
20from collections.abc import Callable
21from dataclasses import dataclass, field
23import numpy as np
24from numpy.typing import ArrayLike
26from .algorithms import chebpts2
27from .chebfun import Chebfun
28from .quasimatrix import Quasimatrix
29from .settings import _preferences as prefs
32# ---------------------------------------------------------------------------
33# Options container
34# ---------------------------------------------------------------------------
35@dataclass
36class _GPROptions:
37 """Parsed options for a GPR call."""
39 sigma: float = 1.0
40 sigma_given: bool = False
41 length_scale: float = 0.0
42 noise: float = 0.0
43 domain: np.ndarray = field(default_factory=lambda: np.array([-1.0, 1.0]))
44 trig: bool = False
45 n_samples: int = 0
48# ---------------------------------------------------------------------------
49# Kernel helpers
50# ---------------------------------------------------------------------------
53def _kernel_matrix(
54 x1: np.ndarray,
55 x2: np.ndarray,
56 opts: _GPROptions,
57) -> np.ndarray:
58 """Evaluate the covariance kernel k(x1_i, x2_j) for all pairs."""
59 r = x1[:, None] - x2[None, :]
60 if opts.trig:
61 period = opts.domain[1] - opts.domain[0]
62 return opts.sigma**2 * np.exp(-2.0 / opts.length_scale**2 * np.sin(np.pi / period * r) ** 2)
63 return opts.sigma**2 * np.exp(-0.5 / opts.length_scale**2 * r**2)
66def _log_marginal_likelihood(
67 length_scale: float | np.ndarray,
68 x: np.ndarray,
69 y: np.ndarray,
70 opts: _GPROptions,
71) -> float | np.ndarray:
72 """Negative log marginal likelihood (eq. 2.30 in Rasmussen & Williams).
74 Accepts scalar or array *length_scale* so that it can be wrapped as a
75 Chebfun for optimisation.
76 """
77 scalar_input = np.ndim(length_scale) == 0
78 ls = np.atleast_1d(np.asarray(length_scale, dtype=float))
79 n = len(x)
80 rx = x[:, None] - x[None, :]
81 result = np.empty_like(ls)
83 for idx in np.ndindex(ls.shape):
84 l_val = ls[idx]
85 if opts.trig:
86 period = opts.domain[1] - opts.domain[0]
87 cov_mat = opts.sigma**2 * np.exp(-2.0 / l_val**2 * np.sin(np.pi / period * rx) ** 2)
88 else:
89 cov_mat = opts.sigma**2 * np.exp(-0.5 / l_val**2 * rx**2)
91 if opts.noise != 0:
92 cov_mat += opts.noise**2 * np.eye(n)
93 else:
94 cov_mat += 1e-15 * n * opts.sigma**2 * np.eye(n)
96 chol_l = np.linalg.cholesky(cov_mat)
97 alpha = np.linalg.solve(chol_l.T, np.linalg.solve(chol_l, y))
98 lml = -0.5 * y @ alpha - np.sum(np.log(np.diag(chol_l))) - 0.5 * n * np.log(2 * np.pi)
99 result[idx] = lml
101 return float(result.item()) if scalar_input else result.ravel()
104# ---------------------------------------------------------------------------
105# Length-scale selection via max log marginal likelihood
106# ---------------------------------------------------------------------------
109def _select_length_scale(x: np.ndarray, y: np.ndarray, opts: _GPROptions) -> float:
110 """Choose the length-scale that maximises the log marginal likelihood."""
111 n = len(x)
112 dom_size = opts.domain[1] - opts.domain[0]
114 if opts.trig:
115 lo, hi = 1.0 / (2 * n), 10.0
116 else:
117 lo, hi = dom_size / (2 * np.pi * n), 10.0 / np.pi * dom_size
119 # Heuristic: shrink the right end of the search domain if the lml is
120 # monotonically decreasing (mirrors the MATLAB implementation).
121 f1 = float(_log_marginal_likelihood(lo, x, y, opts))
122 f2 = float(_log_marginal_likelihood(hi, x, y, opts))
123 while f1 > f2 and hi / lo > 1 + 1e-4:
124 new_bound = lo + (hi - lo) / 10.0
125 f_new = float(_log_marginal_likelihood(new_bound, x, y, opts))
126 if f_new > f1:
127 break
128 hi = new_bound
129 f2 = f_new
131 # Maximise using golden-section search (negated to find the max).
132 return _golden_section_max(lambda ls: float(_log_marginal_likelihood(ls, x, y, opts)), lo, hi)
135def _golden_section_max(f: Callable[[float], float], a: float, b: float, tol: float = 1e-6) -> float:
136 """Golden-section search for the scalar argmax of *f* on [a, b]."""
137 gr = (np.sqrt(5.0) + 1.0) / 2.0
138 c = b - (b - a) / gr
139 d = a + (b - a) / gr
140 while abs(b - a) > tol * (abs(a) + abs(b)):
141 if f(c) > f(d):
142 b = d
143 else:
144 a = c
145 c = b - (b - a) / gr
146 d = a + (b - a) / gr
147 return 0.5 * (a + b)
150# ---------------------------------------------------------------------------
151# Public API
152# ---------------------------------------------------------------------------
155def _parse_inputs(
156 x: ArrayLike,
157 y: ArrayLike,
158 *,
159 sigma: float | None,
160 noise: float,
161 trig: bool,
162 n_samples: int,
163) -> tuple[np.ndarray, np.ndarray, _GPROptions, float]:
164 """Validate inputs and build the initial options container.
166 Returns ``(x_arr, y_arr, opts, scaling_factor)``.
167 """
168 x_arr = np.asarray(x, dtype=float).ravel()
169 y_arr = np.asarray(y, dtype=float).ravel()
170 if x_arr.shape != y_arr.shape:
171 msg = "x and y must have the same length."
172 raise ValueError(msg)
174 opts = _GPROptions(trig=trig, noise=noise, n_samples=n_samples)
176 scaling_factor = 1.0
177 if sigma is not None:
178 opts.sigma = sigma
179 opts.sigma_given = True
180 else:
181 if len(y_arr) > 0:
182 scaling_factor = float(np.max(np.abs(y_arr)))
183 opts.sigma_given = False
184 opts.sigma = scaling_factor
186 return x_arr, y_arr, opts, scaling_factor
189def _infer_domain(
190 x_arr: np.ndarray,
191 opts: _GPROptions,
192 domain: tuple[float, float] | list[float] | np.ndarray | None,
193) -> None:
194 """Set ``opts.domain`` from *domain* or from the observation locations."""
195 if domain is not None:
196 opts.domain = np.asarray(domain, dtype=float)
197 elif len(x_arr) == 0:
198 opts.domain = np.array([-1.0, 1.0])
199 elif len(x_arr) == 1:
200 opts.domain = np.array([x_arr[0] - 1, x_arr[0] + 1])
201 elif opts.trig:
202 span = float(np.max(x_arr) - np.min(x_arr))
203 opts.domain = np.array([float(np.min(x_arr)), float(np.max(x_arr)) + 0.1 * span])
204 else:
205 opts.domain = np.array([float(np.min(x_arr)), float(np.max(x_arr))])
208def _infer_length_scale(
209 x_arr: np.ndarray,
210 y_arr: np.ndarray,
211 opts: _GPROptions,
212 scaling_factor: float,
213 length_scale: float | None,
214) -> None:
215 """Set ``opts.length_scale`` — user-supplied or auto-selected."""
216 if length_scale is not None:
217 opts.length_scale = length_scale
218 return
220 if len(x_arr) == 0:
221 opts.length_scale = 1.0
222 return
224 y_n = y_arr / scaling_factor if scaling_factor != 0 else y_arr
226 if not opts.sigma_given:
227 tmp = _GPROptions(
228 sigma=1.0,
229 sigma_given=True,
230 noise=opts.noise / scaling_factor if scaling_factor != 0 else opts.noise,
231 domain=opts.domain,
232 trig=opts.trig,
233 )
234 y_opt = y_n
235 else:
236 tmp = _GPROptions(
237 sigma=opts.sigma,
238 sigma_given=True,
239 noise=opts.noise,
240 domain=opts.domain,
241 trig=opts.trig,
242 )
243 y_opt = y_arr
245 opts.length_scale = _select_length_scale(x_arr, y_opt, tmp)
248def _posterior_chebfuns(
249 x_arr: np.ndarray,
250 y_arr: np.ndarray,
251 opts: _GPROptions,
252 scaling_factor: float,
253 n_samples: int,
254) -> tuple[Chebfun, Chebfun] | tuple[Chebfun, Chebfun, Quasimatrix]:
255 """Compute posterior mean, variance, and optional samples as Chebfuns."""
256 n = len(x_arr)
257 cov_mat = _kernel_matrix(x_arr, x_arr, opts)
258 if opts.noise == 0:
259 cov_mat += 1e-15 * scaling_factor**2 * n * np.eye(n)
260 else:
261 cov_mat += opts.noise**2 * np.eye(n)
263 chol_l = np.linalg.cholesky(cov_mat)
264 alpha = np.linalg.solve(chol_l.T, np.linalg.solve(chol_l, y_arr))
266 # Sample grid: Chebyshev points for the default tech, equispaced points
267 # for the periodic (Trigtech) case. Using the right grid here is critical
268 # because the constructed Chebfun pieces are then built via
269 # ``Chebfun.initfun_fixedlen`` whose underlying tech evaluates at exactly
270 # this grid.
271 sample_size = min(20 * n, 2000)
272 if opts.trig:
273 # n equispaced points on [a, b) — matches Trigtech._trigpts mapped to domain
274 x_sample = opts.domain[0] + (opts.domain[1] - opts.domain[0]) * np.arange(sample_size) / sample_size
275 else:
276 t = chebpts2(sample_size)
277 x_sample = 0.5 * (opts.domain[1] - opts.domain[0]) * t + 0.5 * (opts.domain[0] + opts.domain[1])
279 in_x = np.isin(x_sample, x_arr)
281 k_star = _kernel_matrix(x_sample, x_arr, opts)
282 if opts.noise:
283 k_star += opts.noise**2 * (np.abs(x_sample[:, None] - x_arr[None, :]) == 0)
285 # Posterior mean
286 mean_vals = k_star @ alpha
288 # Posterior variance
289 k_ss = _kernel_matrix(x_sample, x_sample, opts)
290 if opts.noise:
291 k_ss += opts.noise**2 * np.diag(in_x.astype(float))
293 v = np.linalg.solve(chol_l, k_star.T)
294 var_diag = np.diag(k_ss) - np.sum(v**2, axis=0)
295 var_diag = np.maximum(var_diag, 0.0)
297 # Build Chebfuns under the appropriate tech. For ``trig=True`` this
298 # produces Trigtech-backed pieces so that downstream calculus is performed
299 # in Fourier space.
300 tech_name = "Trigtech" if opts.trig else prefs.tech
301 with prefs:
302 prefs.tech = tech_name
303 f_mean = Chebfun.initfun_fixedlen(lambda _z: mean_vals, sample_size, opts.domain)
304 f_var = Chebfun.initfun_fixedlen(lambda _z: var_diag, sample_size, opts.domain)
306 if n_samples <= 0:
307 return f_mean, f_var
309 # Posterior samples
310 cov_post = k_ss - v.T @ v
311 cov_post = 0.5 * (cov_post + cov_post.T)
312 cov_post += 1e-12 * scaling_factor**2 * n * np.eye(sample_size)
313 chol_s = np.linalg.cholesky(cov_post)
315 draws = mean_vals[:, None] + chol_s @ np.random.randn(sample_size, n_samples)
316 cols: list[Chebfun] = []
317 for j in range(n_samples):
318 cols.append(
319 Chebfun.initfun_fixedlen(
320 lambda _z, _j=j: draws[:, _j],
321 sample_size,
322 opts.domain,
323 )
324 )
325 return f_mean, f_var, Quasimatrix(cols)
328def gpr(
329 x: ArrayLike,
330 y: ArrayLike,
331 *,
332 domain: tuple[float, float] | list[float] | np.ndarray | None = None,
333 sigma: float | None = None,
334 length_scale: float | None = None,
335 noise: float = 0.0,
336 trig: bool = False,
337 n_samples: int = 0,
338) -> tuple[Chebfun, Chebfun] | tuple[Chebfun, Chebfun, Quasimatrix]:
339 """Gaussian process regression returning Chebfun objects.
341 Given observations ``(x, y)`` of a latent function, compute the posterior
342 mean and variance of a Gaussian process with zero prior mean and a squared
343 exponential kernel::
345 k(x, x') = sigma**2 * exp(-0.5 / L**2 * (x - x')**2)
347 When ``trig=True`` a periodic variant is used instead::
349 k(x, x') = sigma**2 * exp(-2 / L**2 * sin(pi * (x - x') / P)**2)
351 where *P* is the period (length of the domain).
353 Args:
354 x: Observation locations (1-D array-like).
355 y: Observation values (same length as *x*).
356 domain: Domain ``[a, b]`` for the output Chebfuns. Defaults to
357 ``[min(x), max(x)]`` (or slightly extended for ``trig``).
358 sigma: Signal variance of the kernel. Defaults to ``max(|y|)``.
359 length_scale: Length-scale *L* of the kernel. If ``None``, it is
360 chosen to maximise the log marginal likelihood.
361 noise: Standard deviation of i.i.d. Gaussian observation noise.
362 The kernel diagonal is augmented by ``noise**2``.
363 trig: If ``True``, use a periodic squared-exponential kernel.
364 n_samples: Number of independent posterior samples to draw. When
365 positive, a :class:`Quasimatrix` with *n_samples* columns is
366 returned as the third element of the output tuple.
368 Returns:
369 ``(f_mean, f_var)`` — posterior mean and variance as Chebfun objects.
370 If ``n_samples > 0``, returns ``(f_mean, f_var, samples)`` where
371 *samples* is a Quasimatrix whose columns are independent draws from
372 the posterior.
374 Raises:
375 ValueError: If *x* and *y* have different lengths or are empty.
377 Examples:
378 >>> import numpy as np
379 >>> from chebpy.gpr import gpr
380 >>> rng = np.random.default_rng(1)
381 >>> x = -2 + 4 * rng.random(10)
382 >>> y = np.sin(np.exp(x))
383 >>> f_mean, f_var = gpr(x, y, domain=[-2, 2])
385 Reference:
386 C. E. Rasmussen & C. K. I. Williams, "Gaussian Processes for Machine
387 Learning", MIT Press, 2006.
388 """
389 x_arr, y_arr, opts, scaling_factor = _parse_inputs(
390 x,
391 y,
392 sigma=sigma,
393 noise=noise,
394 trig=trig,
395 n_samples=n_samples,
396 )
397 _infer_domain(x_arr, opts, domain)
398 _infer_length_scale(x_arr, y_arr, opts, scaling_factor, length_scale)
400 # No data → return prior
401 if len(x_arr) == 0:
402 f_mean = Chebfun.initconst(0.0, opts.domain)
403 f_var = Chebfun.initconst(opts.sigma**2, opts.domain)
404 if n_samples > 0:
405 return f_mean, f_var, _prior_samples(opts, scaling_factor, n_samples)
406 return f_mean, f_var
408 return _posterior_chebfuns(x_arr, y_arr, opts, scaling_factor, n_samples)
411def _prior_samples(
412 opts: _GPROptions,
413 scaling_factor: float,
414 n_samples: int,
415) -> Quasimatrix:
416 """Draw samples from the GP prior (no observations)."""
417 sample_size = 1000
418 if opts.trig:
419 x_sample = np.linspace(opts.domain[0], opts.domain[1], sample_size)
420 else:
421 t = chebpts2(sample_size)
422 x_sample = 0.5 * (opts.domain[1] - opts.domain[0]) * t + 0.5 * (opts.domain[0] + opts.domain[1])
424 k_ss = _kernel_matrix(x_sample, x_sample, opts)
425 k_ss += 1e-12 * scaling_factor**2 * np.eye(sample_size)
426 chol_s = np.linalg.cholesky(k_ss)
428 f_mean_vals = np.zeros(sample_size)
429 draws = f_mean_vals[:, None] + chol_s @ np.random.randn(sample_size, n_samples)
431 cols: list[Chebfun] = []
432 for j in range(n_samples):
433 cols.append(
434 Chebfun.initfun_fixedlen(
435 lambda _z, _j=j: draws[:, _j],
436 sample_size,
437 opts.domain,
438 )
439 )
440 return Quasimatrix(cols)