Coverage for src / chebpy / chebtech.py: 100%
252 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:33 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:33 +0000
1"""Implementation of Chebyshev polynomial technology for function approximation.
3This module provides the Chebtech class, which is an abstract base class for
4representing functions using Chebyshev polynomial expansions. It serves as the
5foundation for the Chebtech class, which uses Chebyshev points of the second kind.
7The Chebtech classes implement core functionality for working with Chebyshev
8expansions, including:
9- Function evaluation using Clenshaw's algorithm or barycentric interpolation
10- Algebraic operations (addition, multiplication, etc.)
11- Calculus operations (differentiation, integration, etc.)
12- Rootfinding
13- Plotting
15These classes are primarily used internally by higher-level classes like Bndfun
16and Chebfun, rather than being used directly by end users.
17"""
19from abc import ABC
20from collections.abc import Callable
21from typing import Any
23import matplotlib.pyplot as plt
24import numpy as np
26from .algorithms import (
27 adaptive,
28 bary,
29 barywts2,
30 chebpts2,
31 clenshaw,
32 coeffmult,
33 coeffs2vals2,
34 newtonroots,
35 rootsunit,
36 standard_chop,
37 vals2coeffs2,
38)
39from .decorators import self_empty
40from .plotting import plotfun, plotfuncoeffs
41from .settings import _preferences as prefs
42from .smoothfun import Smoothfun
43from .utilities import Interval, coerce_list
46class Chebtech(Smoothfun, ABC):
47 """Abstract base class serving as the template for Chebtech1 and Chebtech subclasses.
49 Chebtech objects always work with first-kind coefficients, so much
50 of the core operational functionality is defined this level.
52 The user will rarely work with these classes directly so we make
53 several assumptions regarding input data types.
54 """
56 @classmethod
57 def initconst(cls, c: Any = None, *, interval: Any = None) -> Any:
58 """Initialise a Chebtech from a constant c."""
59 if not np.isscalar(c):
60 raise ValueError(c)
61 if isinstance(c, int):
62 c = float(c)
63 return cls(np.array([c]), interval=interval)
65 @classmethod
66 def initempty(cls, *, interval: Any = None) -> "Chebtech":
67 """Initialise an empty Chebtech."""
68 return cls(np.array([]), interval=interval)
70 @classmethod
71 def initidentity(cls, *, interval: Any = None) -> "Chebtech":
72 """Chebtech representation of f(x) = x on [-1,1]."""
73 return cls(np.array([0, 1]), interval=interval)
75 @classmethod
76 def initfun(cls, fun: Any = None, n: Any = None, *, interval: Any = None) -> Any:
77 """Convenience constructor to automatically select the adaptive or fixedlen constructor.
79 This constructor automatically selects between the adaptive or fixed-length
80 constructor based on the input arguments passed.
81 """
82 if n is None:
83 return cls.initfun_adaptive(fun, interval=interval)
84 else:
85 return cls.initfun_fixedlen(fun, n, interval=interval)
87 @classmethod
88 def initfun_fixedlen(cls, fun: Any = None, n: Any = None, *, interval: Any = None) -> Any:
89 """Initialise a Chebtech from the callable fun using n degrees of freedom.
91 This constructor creates a Chebtech representation of the function using
92 a fixed number of degrees of freedom specified by n.
93 """
94 if n is None:
95 raise ValueError("n must be specified for fixed-length initialization") # noqa: TRY003
96 points = cls._chebpts(int(n))
97 values = fun(points)
98 coeffs = vals2coeffs2(values)
99 return cls(coeffs, interval=interval)
101 @classmethod
102 def initfun_adaptive(cls, fun: Any = None, *, interval: Any = None) -> Any:
103 """Initialise a Chebtech from the callable fun utilising the adaptive constructor.
105 This constructor uses an adaptive algorithm to determine the appropriate
106 number of degrees of freedom needed to represent the function.
107 """
108 interval = interval if interval is not None else prefs.domain
109 interval = Interval(*interval)
110 coeffs = adaptive(cls, fun, hscale=interval.hscale)
111 return cls(coeffs, interval=interval)
113 @classmethod
114 def initvalues(cls, values: Any = None, *, interval: Any = None) -> Any:
115 """Initialise a Chebtech from an array of values at Chebyshev points."""
116 return cls(cls._vals2coeffs(values), interval=interval)
118 def __init__(self, coeffs: Any, interval: Any = None) -> None:
119 """Initialize a Chebtech object.
121 This method initializes a new Chebtech object with the given coefficients
122 and interval. If no interval is provided, the default interval from
123 preferences is used.
125 Args:
126 coeffs (array-like): The coefficients of the Chebyshev series.
127 interval (array-like, optional): The interval on which the function
128 is defined. Defaults to None, which uses the default interval
129 from preferences.
130 """
131 interval = interval if interval is not None else prefs.domain
132 self._coeffs = np.array(coeffs)
133 self._interval = Interval(*interval)
135 def __call__(self, x: Any, how: str = "clenshaw") -> Any:
136 """Evaluate the Chebtech at the given points.
138 Args:
139 x: Points at which to evaluate the Chebtech.
140 how (str, optional): Method to use for evaluation. Either "clenshaw" or "bary".
141 Defaults to "clenshaw".
143 Returns:
144 The values of the Chebtech at the given points.
146 Raises:
147 ValueError: If the specified method is not supported.
148 """
149 method: dict[str, Callable[[Any], Any]] = {
150 "clenshaw": self.__call__clenshaw,
151 "bary": self.__call__bary,
152 }
153 try:
154 return method[how](x)
155 except KeyError as err:
156 raise ValueError(how) from err
158 def __call__clenshaw(self, x: Any) -> Any:
159 return clenshaw(x, self.coeffs)
161 def __call__bary(self, x: Any) -> Any:
162 fk = self.values()
163 xk = self._chebpts(fk.size)
164 vk = self._barywts(fk.size)
165 return bary(x, fk, xk, vk)
167 def __repr__(self) -> str: # pragma: no cover
168 """Return a string representation of the Chebtech.
170 Returns:
171 str: A string representation of the Chebtech.
172 """
173 out = f"<{self.__class__.__name__}{{{self.size}}}>"
174 return out
176 # ------------
177 # properties
178 # ------------
179 @property
180 def coeffs(self) -> np.ndarray:
181 """Chebyshev expansion coefficients in the T_k basis."""
182 return self._coeffs
184 @property
185 def interval(self) -> Interval:
186 """Interval that Chebtech is mapped to."""
187 return self._interval
189 @property
190 def size(self) -> int:
191 """Return the size of the object."""
192 return self.coeffs.size
194 @property
195 def isempty(self) -> bool:
196 """Return True if the Chebtech is empty."""
197 return self.size == 0
199 @property
200 def iscomplex(self) -> bool:
201 """Determine whether the underlying onefun is complex or real valued."""
202 return self._coeffs.dtype == complex
204 @property
205 def isconst(self) -> bool:
206 """Return True if the Chebtech represents a constant."""
207 return self.size == 1
209 @property
210 @self_empty(0.0)
211 def vscale(self) -> float:
212 """Estimate the vertical scale of a Chebtech."""
213 return float(np.abs(np.asarray(coerce_list(self.values()))).max())
215 # -----------
216 # utilities
217 # -----------
218 def copy(self) -> "Chebtech":
219 """Return a deep copy of the Chebtech."""
220 return self.__class__(self.coeffs.copy(), interval=self.interval.copy())
222 def imag(self) -> Any:
223 """Return the imaginary part of the Chebtech.
225 Returns:
226 Chebtech: A new Chebtech representing the imaginary part of this Chebtech.
227 If this Chebtech is real-valued, returns a zero Chebtech.
228 """
229 if self.iscomplex:
230 return self.__class__(np.imag(self.coeffs), self.interval)
231 else:
232 return self.initconst(0, interval=self.interval)
234 def prolong(self, n: int) -> "Chebtech":
235 """Return a Chebtech of length n.
237 Obtained either by truncating if n < self.size or zero-padding if n > self.size.
238 In all cases a deep copy is returned.
239 """
240 m = self.size
241 ak = self.coeffs
242 cls = self.__class__
243 if n - m < 0:
244 out = cls(ak[:n].copy(), interval=self.interval)
245 elif n - m > 0:
246 out = cls(np.append(ak, np.zeros(n - m)), interval=self.interval)
247 else:
248 out = self.copy()
249 return out
251 def real(self) -> "Chebtech":
252 """Return the real part of the Chebtech.
254 Returns:
255 Chebtech: A new Chebtech representing the real part of this Chebtech.
256 If this Chebtech is already real-valued, returns self.
257 """
258 if self.iscomplex:
259 return self.__class__(np.real(self.coeffs), self.interval)
260 else:
261 return self
263 def simplify(self) -> "Chebtech":
264 """Call standard_chop on the coefficients of self.
266 Returns a Chebtech comprised of a copy of the truncated coefficients.
267 """
268 # coefficients
269 oldlen = len(self.coeffs)
270 longself = self.prolong(max(17, oldlen))
271 cfs = longself.coeffs
272 # scale (decrease) tolerance by hscale
273 tol = prefs.eps * max(self.interval.hscale, 1)
274 # chop
275 npts = standard_chop(cfs, tol=tol)
276 npts = min(oldlen, npts)
277 # construct
278 return self.__class__(cfs[:npts].copy(), interval=self.interval)
280 def values(self) -> np.ndarray:
281 """Function values at Chebyshev points."""
282 return coeffs2vals2(self.coeffs)
284 # ---------
285 # algebra
286 # ---------
287 @self_empty()
288 def __add__(self, f: Any) -> Any:
289 """Add a scalar or another Chebtech to this Chebtech.
291 Args:
292 f: A scalar or another Chebtech to add to this Chebtech.
294 Returns:
295 Chebtech: A new Chebtech representing the sum.
296 """
297 cls = self.__class__
298 if np.isscalar(f):
299 if np.iscomplexobj(f):
300 dtype: Any = complex
301 else:
302 dtype = self.coeffs.dtype
303 cfs = np.array(self.coeffs, dtype=dtype)
304 cfs[0] += f
305 return cls(cfs, interval=self.interval)
306 else:
307 # TODO: is a more general decorator approach better here?
308 # TODO: for constant Chebtech, convert to constant and call __add__ again
309 if f.isempty:
310 return f.copy()
311 g = self
312 n, m = g.size, f.size
313 if n < m:
314 g = g.prolong(m)
315 elif m < n:
316 f = f.prolong(n)
317 cfs = f.coeffs + g.coeffs
319 # check for zero output
320 eps = prefs.eps
321 tol = 0.5 * eps * max([f.vscale, g.vscale])
322 if all(abs(cfs) < tol):
323 return cls.initconst(0.0, interval=self.interval)
324 else:
325 return cls(cfs, interval=self.interval)
327 @self_empty()
328 def __div__(self, f: Any) -> Any:
329 """Divide this Chebtech by a scalar or another Chebtech.
331 Args:
332 f: A scalar or another Chebtech to divide this Chebtech by.
334 Returns:
335 Chebtech: A new Chebtech representing the quotient.
336 """
337 cls = self.__class__
338 if np.isscalar(f):
339 cfs = 1.0 / np.asarray(f) * self.coeffs
340 return cls(cfs, interval=self.interval)
341 else:
342 # TODO: review with reference to __add__
343 if f.isempty:
344 return f.copy()
345 return cls.initfun_adaptive(lambda x: self(x) / f(x), interval=self.interval)
347 __truediv__ = __div__
349 @self_empty()
350 def __mul__(self, g: Any) -> Any:
351 """Multiply this Chebtech by a scalar or another Chebtech.
353 Args:
354 g: A scalar or another Chebtech to multiply this Chebtech by.
356 Returns:
357 Chebtech: A new Chebtech representing the product.
358 """
359 cls = self.__class__
360 if np.isscalar(g):
361 cfs = g * self.coeffs
362 return cls(cfs, interval=self.interval)
363 else:
364 # TODO: review with reference to __add__
365 if g.isempty:
366 return g.copy()
367 f = self
368 n = f.size + g.size - 1
369 f = f.prolong(n)
370 g = g.prolong(n)
371 cfs = coeffmult(f.coeffs, g.coeffs)
372 out = cls(cfs, interval=self.interval)
373 return out
375 def __neg__(self) -> "Chebtech":
376 """Return the negative of this Chebtech.
378 Returns:
379 Chebtech: A new Chebtech representing the negative of this Chebtech.
380 """
381 coeffs = -self.coeffs
382 return self.__class__(coeffs, interval=self.interval)
384 def __pos__(self) -> "Chebtech":
385 """Return this Chebtech (unary positive).
387 Returns:
388 Chebtech: This Chebtech (self).
389 """
390 return self
392 @self_empty()
393 def __pow__(self, f: Any) -> Any:
394 """Raise this Chebtech to a power.
396 Args:
397 f: The exponent, which can be a scalar or another Chebtech.
399 Returns:
400 Chebtech: A new Chebtech representing this Chebtech raised to the power f.
401 """
403 def powfun(fn: Any, x: Any) -> Any:
404 if np.isscalar(fn):
405 return fn
406 else:
407 return fn(x)
409 return self.__class__.initfun_adaptive(lambda x: np.power(self(x), powfun(f, x)), interval=self.interval)
411 def __rdiv__(self, f: Any) -> Any:
412 """Divide a scalar by this Chebtech.
414 This is called when f / self is executed and f is not a Chebtech.
416 Args:
417 f: A scalar to be divided by this Chebtech.
419 Returns:
420 Chebtech: A new Chebtech representing f divided by this Chebtech.
421 """
423 # Executed when __div__(f, self) fails, which is to say whenever f
424 # is not a Chebtech. We proceeed on the assumption f is a scalar.
425 def constfun(x: Any) -> Any:
426 return 0.0 * x + f
428 return self.__class__.initfun_adaptive(lambda x: constfun(x) / self(x), interval=self.interval)
430 __radd__ = __add__
432 def __rsub__(self, f: Any) -> Any:
433 """Subtract this Chebtech from a scalar.
435 This is called when f - self is executed and f is not a Chebtech.
437 Args:
438 f: A scalar from which to subtract this Chebtech.
440 Returns:
441 Chebtech: A new Chebtech representing f minus this Chebtech.
442 """
443 return -(self - f)
445 @self_empty()
446 def __rpow__(self, f: Any) -> Any:
447 """Raise a scalar to the power of this Chebtech.
449 This is called when f ** self is executed and f is not a Chebtech.
451 Args:
452 f: A scalar to be raised to the power of this Chebtech.
454 Returns:
455 Chebtech: A new Chebtech representing f raised to the power of this Chebtech.
456 """
457 return self.__class__.initfun_adaptive(lambda x: np.power(f, self(x)), interval=self.interval)
459 __rtruediv__ = __rdiv__
460 __rmul__ = __mul__
462 def __sub__(self, f: Any) -> Any:
463 """Subtract a scalar or another Chebtech from this Chebtech.
465 Args:
466 f: A scalar or another Chebtech to subtract from this Chebtech.
468 Returns:
469 Chebtech: A new Chebtech representing the difference.
470 """
471 return self + (-f)
473 # -------
474 # roots
475 # -------
476 def roots(self, sort: bool | None = None) -> np.ndarray:
477 """Compute the roots of the Chebtech on [-1,1].
479 Uses the coefficients in the associated Chebyshev series approximation.
480 """
481 sort = sort if sort is not None else prefs.sortroots
482 rts = rootsunit(self.coeffs)
483 rts = newtonroots(self, rts)
484 # fix problems with newton for roots that are numerically very close
485 rts = np.clip(rts, -1, 1) # if newton roots are just outside [-1,1]
486 rts = rts if not sort else np.sort(rts)
487 return rts
489 # ----------
490 # calculus
491 # ----------
492 # Note that function returns 0 for an empty Chebtech object; this is
493 # consistent with numpy, which returns zero for the sum of an empty array
494 @self_empty(resultif=0.0)
495 def sum(self) -> Any:
496 """Definite integral of a Chebtech on the interval [-1,1]."""
497 if self.isconst:
498 out = 2.0 * self(0.0)
499 else:
500 ak = self.coeffs.copy()
501 ak[1::2] = 0
502 kk = np.arange(2, ak.size)
503 ii = np.append([2, 0], 2 / (1 - kk**2))
504 out = (ak * ii).sum()
505 return out
507 @self_empty()
508 def cumsum(self) -> "Chebtech":
509 """Return a Chebtech object representing the indefinite integral.
511 Computes the indefinite integral of a Chebtech on the interval [-1,1].
512 The constant term is chosen such that F(-1) = 0.
513 """
514 n = self.size
515 ak = np.append(self.coeffs, [0, 0])
516 bk = np.zeros(n + 1, dtype=self.coeffs.dtype)
517 rk = np.arange(2, n + 1)
518 bk[2:] = 0.5 * (ak[1:n] - ak[3:]) / rk
519 bk[1] = ak[0] - 0.5 * ak[2]
520 vk = np.ones(n)
521 vk[1::2] = -1
522 bk[0] = (vk * bk[1:]).sum()
523 out = self.__class__(bk, interval=self.interval)
524 return out
526 @self_empty()
527 def diff(self) -> "Chebtech":
528 """Return a Chebtech object representing the derivative.
530 Computes the derivative of a Chebtech on the interval [-1,1].
531 """
532 if self.isconst:
533 out = self.__class__(np.array([0.0]), interval=self.interval)
534 else:
535 n = self.size
536 ak = self.coeffs
537 zk = np.zeros(n - 1, dtype=self.coeffs.dtype)
538 wk = 2 * np.arange(1, n)
539 vk = wk * ak[1:]
540 zk[-1::-2] = vk[-1::-2].cumsum()
541 zk[-2::-2] = vk[-2::-2].cumsum()
542 zk[0] = 0.5 * zk[0]
543 out = self.__class__(zk, interval=self.interval)
544 return out
546 @staticmethod
547 def _chebpts(n: int) -> np.ndarray:
548 """Return n Chebyshev points of the second-kind."""
549 return chebpts2(n)
551 @staticmethod
552 def _barywts(n: int) -> np.ndarray:
553 """Barycentric weights for Chebyshev points of 2nd kind."""
554 return barywts2(n)
556 @staticmethod
557 def _vals2coeffs(vals: Any) -> np.ndarray:
558 """Map function values at Chebyshev points of 2nd kind.
560 Converts values at Chebyshev points of 2nd kind to first-kind Chebyshev polynomial coefficients.
561 """
562 return vals2coeffs2(vals)
564 @staticmethod
565 def _coeffs2vals(coeffs: Any) -> np.ndarray:
566 """Map first-kind Chebyshev polynomial coefficients.
568 Converts first-kind Chebyshev polynomial coefficients to function values at Chebyshev points of 2nd kind.
569 """
570 return coeffs2vals2(coeffs)
572 # ----------
573 # plotting
574 # ----------
575 def plot(self, ax: Any = None, **kwargs: Any) -> Any:
576 """Plot the Chebtech on the interval [-1, 1].
578 Args:
579 ax (matplotlib.axes.Axes, optional): The axes on which to plot. Defaults to None.
580 **kwargs: Additional keyword arguments to pass to the plot function.
582 Returns:
583 matplotlib.lines.Line2D: The line object created by the plot.
584 """
585 return plotfun(self, (-1, 1), ax=ax, **kwargs)
587 def plotcoeffs(self, ax: Any = None, **kwargs: Any) -> Any:
588 """Plot the absolute values of the Chebyshev coefficients.
590 Args:
591 ax (matplotlib.axes.Axes, optional): The axes on which to plot. Defaults to None.
592 **kwargs: Additional keyword arguments to pass to the plot function.
594 Returns:
595 matplotlib.lines.Line2D: The line object created by the plot.
596 """
597 ax = ax or plt.gca()
598 return plotfuncoeffs(abs(self.coeffs), ax=ax, **kwargs)