Coverage for chebpy/core/chebtech.py: 99%
248 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 10:30 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 10:30 +0000
1"""Implementation of Chebyshev polynomial technology for function approximation.
3This module provides the Chebtech class, which is an abstract base class for
4representing functions using Chebyshev polynomial expansions. It serves as the
5foundation for the Chebtech class, which uses Chebyshev points of the second kind.
7The Chebtech classes implement core functionality for working with Chebyshev
8expansions, including:
9- Function evaluation using Clenshaw's algorithm or barycentric interpolation
10- Algebraic operations (addition, multiplication, etc.)
11- Calculus operations (differentiation, integration, etc.)
12- Rootfinding
13- Plotting
15These classes are primarily used internally by higher-level classes like Bndfun
16and Chebfun, rather than being used directly by end users.
17"""
19from abc import ABC
21import matplotlib.pyplot as plt
22import numpy as np
24from .algorithms import (
25 adaptive,
26 bary,
27 barywts2,
28 chebpts2,
29 clenshaw,
30 coeffmult,
31 coeffs2vals2,
32 newtonroots,
33 rootsunit,
34 standard_chop,
35 vals2coeffs2,
36)
37from .decorators import self_empty
38from .plotting import plotfun, plotfuncoeffs
39from .settings import _preferences as prefs
40from .smoothfun import Smoothfun
41from .utilities import Interval, coerce_list
44class Chebtech(Smoothfun, ABC):
45 """Abstract base class serving as the template for Chebtech1 and Chebtech subclasses.
47 Chebtech objects always work with first-kind coefficients, so much
48 of the core operational functionality is defined this level.
50 The user will rarely work with these classes directly so we make
51 several assumptions regarding input data types.
52 """
54 @classmethod
55 def initconst(cls, c, *, interval=None):
56 """Initialise a Chebtech from a constant c."""
57 if not np.isscalar(c):
58 raise ValueError(c)
59 if isinstance(c, int):
60 c = float(c)
61 return cls(np.array([c]), interval=interval)
63 @classmethod
64 def initempty(cls, *, interval=None):
65 """Initialise an empty Chebtech."""
66 return cls(np.array([]), interval=interval)
68 @classmethod
69 def initidentity(cls, *, interval=None):
70 """Chebtech representation of f(x) = x on [-1,1]."""
71 return cls(np.array([0, 1]), interval=interval)
73 @classmethod
74 def initfun(cls, fun, n=None, *, interval=None):
75 """Convenience constructor to automatically select the adaptive or fixedlen constructor.
77 This constructor automatically selects between the adaptive or fixed-length
78 constructor based on the input arguments passed.
79 """
80 if n is None:
81 return cls.initfun_adaptive(fun, interval=interval)
82 else:
83 return cls.initfun_fixedlen(fun, n, interval=interval)
85 @classmethod
86 def initfun_fixedlen(cls, fun, n, *, interval=None):
87 """Initialise a Chebtech from the callable fun using n degrees of freedom.
89 This constructor creates a Chebtech representation of the function using
90 a fixed number of degrees of freedom specified by n.
91 """
92 points = cls._chebpts(n)
93 values = fun(points)
94 coeffs = vals2coeffs2(values)
95 return cls(coeffs, interval=interval)
97 @classmethod
98 def initfun_adaptive(cls, fun, *, interval=None):
99 """Initialise a Chebtech from the callable fun utilising the adaptive constructor.
101 This constructor uses an adaptive algorithm to determine the appropriate
102 number of degrees of freedom needed to represent the function.
103 """
104 interval = interval if interval is not None else prefs.domain
105 interval = Interval(*interval)
106 coeffs = adaptive(cls, fun, hscale=interval.hscale)
107 return cls(coeffs, interval=interval)
109 @classmethod
110 def initvalues(cls, values, *, interval=None):
111 """Initialise a Chebtech from an array of values at Chebyshev points."""
112 return cls(cls._vals2coeffs(values), interval=interval)
114 def __init__(self, coeffs, interval=None):
115 """Initialize a Chebtech object.
117 This method initializes a new Chebtech object with the given coefficients
118 and interval. If no interval is provided, the default interval from
119 preferences is used.
121 Args:
122 coeffs (array-like): The coefficients of the Chebyshev series.
123 interval (array-like, optional): The interval on which the function
124 is defined. Defaults to None, which uses the default interval
125 from preferences.
126 """
127 interval = interval if interval is not None else prefs.domain
128 self._coeffs = np.array(coeffs)
129 self._interval = Interval(*interval)
131 def __call__(self, x, how="clenshaw"):
132 """Evaluate the Chebtech at the given points.
134 Args:
135 x: Points at which to evaluate the Chebtech.
136 how (str, optional): Method to use for evaluation. Either "clenshaw" or "bary".
137 Defaults to "clenshaw".
139 Returns:
140 The values of the Chebtech at the given points.
142 Raises:
143 ValueError: If the specified method is not supported.
144 """
145 method = {
146 "clenshaw": self.__call__clenshaw,
147 "bary": self.__call__bary,
148 }
149 try:
150 return method[how](x)
151 except KeyError:
152 raise ValueError(how)
154 def __call__clenshaw(self, x):
155 return clenshaw(x, self.coeffs)
157 def __call__bary(self, x):
158 fk = self.values()
159 xk = self._chebpts(fk.size)
160 vk = self._barywts(fk.size)
161 return bary(x, fk, xk, vk)
163 def __repr__(self): # pragma: no cover
164 """Return a string representation of the Chebtech.
166 Returns:
167 str: A string representation of the Chebtech.
168 """
169 out = f"<{self.__class__.__name__}{{{self.size}}}>"
170 return out
172 # ------------
173 # properties
174 # ------------
175 @property
176 def coeffs(self):
177 """Chebyshev expansion coefficients in the T_k basis."""
178 return self._coeffs
180 @property
181 def interval(self):
182 """Interval that Chebtech is mapped to."""
183 return self._interval
185 @property
186 def size(self):
187 """Return the size of the object."""
188 return self.coeffs.size
190 @property
191 def isempty(self):
192 """Return True if the Chebtech is empty."""
193 return self.size == 0
195 @property
196 def iscomplex(self):
197 """Determine whether the underlying onefun is complex or real valued."""
198 return self._coeffs.dtype == complex
200 @property
201 def isconst(self):
202 """Return True if the Chebtech represents a constant."""
203 return self.size == 1
205 @property
206 @self_empty(0.0)
207 def vscale(self):
208 """Estimate the vertical scale of a Chebtech."""
209 return np.abs(coerce_list(self.values())).max()
211 # -----------
212 # utilities
213 # -----------
214 def copy(self):
215 """Return a deep copy of the Chebtech."""
216 return self.__class__(self.coeffs.copy(), interval=self.interval.copy())
218 def imag(self):
219 """Return the imaginary part of the Chebtech.
221 Returns:
222 Chebtech: A new Chebtech representing the imaginary part of this Chebtech.
223 If this Chebtech is real-valued, returns a zero Chebtech.
224 """
225 if self.iscomplex:
226 return self.__class__(np.imag(self.coeffs), self.interval)
227 else:
228 return self.initconst(0, interval=self.interval)
230 def prolong(self, n):
231 """Return a Chebtech of length n.
233 Obtained either by truncating if n < self.size or zero-padding if n > self.size.
234 In all cases a deep copy is returned.
235 """
236 m = self.size
237 ak = self.coeffs
238 cls = self.__class__
239 if n - m < 0:
240 out = cls(ak[:n].copy(), interval=self.interval)
241 elif n - m > 0:
242 out = cls(np.append(ak, np.zeros(n - m)), interval=self.interval)
243 else:
244 out = self.copy()
245 return out
247 def real(self):
248 """Return the real part of the Chebtech.
250 Returns:
251 Chebtech: A new Chebtech representing the real part of this Chebtech.
252 If this Chebtech is already real-valued, returns self.
253 """
254 if self.iscomplex:
255 return self.__class__(np.real(self.coeffs), self.interval)
256 else:
257 return self
259 def simplify(self):
260 """Call standard_chop on the coefficients of self.
262 Returns a Chebtech comprised of a copy of the truncated coefficients.
263 """
264 # coefficients
265 oldlen = len(self.coeffs)
266 longself = self.prolong(max(17, oldlen))
267 cfs = longself.coeffs
268 # scale (decrease) tolerance by hscale
269 tol = prefs.eps * max(self.interval.hscale, 1)
270 # chop
271 npts = standard_chop(cfs, tol=tol)
272 npts = min(oldlen, npts)
273 # construct
274 return self.__class__(cfs[:npts].copy(), interval=self.interval)
276 def values(self):
277 """Function values at Chebyshev points."""
278 return coeffs2vals2(self.coeffs)
280 # ---------
281 # algebra
282 # ---------
283 @self_empty()
284 def __add__(self, f):
285 """Add a scalar or another Chebtech to this Chebtech.
287 Args:
288 f: A scalar or another Chebtech to add to this Chebtech.
290 Returns:
291 Chebtech: A new Chebtech representing the sum.
292 """
293 cls = self.__class__
294 if np.isscalar(f):
295 if np.iscomplexobj(f):
296 dtype = complex
297 else:
298 dtype = self.coeffs.dtype
299 cfs = np.array(self.coeffs, dtype=dtype)
300 cfs[0] += f
301 return cls(cfs, interval=self.interval)
302 else:
303 # TODO: is a more general decorator approach better here?
304 # TODO: for constant Chebtech, convert to constant and call __add__ again
305 if f.isempty:
306 return f.copy()
307 g = self
308 n, m = g.size, f.size
309 if n < m:
310 g = g.prolong(m)
311 elif m < n:
312 f = f.prolong(n)
313 cfs = f.coeffs + g.coeffs
315 # check for zero output
316 eps = prefs.eps
317 tol = 0.5 * eps * max([f.vscale, g.vscale])
318 if all(abs(cfs) < tol):
319 return cls.initconst(0.0, interval=self.interval)
320 else:
321 return cls(cfs, interval=self.interval)
323 @self_empty()
324 def __div__(self, f):
325 """Divide this Chebtech by a scalar or another Chebtech.
327 Args:
328 f: A scalar or another Chebtech to divide this Chebtech by.
330 Returns:
331 Chebtech: A new Chebtech representing the quotient.
332 """
333 cls = self.__class__
334 if np.isscalar(f):
335 cfs = 1.0 / f * self.coeffs
336 return cls(cfs, interval=self.interval)
337 else:
338 # TODO: review with reference to __add__
339 if f.isempty:
340 return f.copy()
341 return cls.initfun_adaptive(lambda x: self(x) / f(x), interval=self.interval)
343 __truediv__ = __div__
345 @self_empty()
346 def __mul__(self, g):
347 """Multiply this Chebtech by a scalar or another Chebtech.
349 Args:
350 g: A scalar or another Chebtech to multiply this Chebtech by.
352 Returns:
353 Chebtech: A new Chebtech representing the product.
354 """
355 cls = self.__class__
356 if np.isscalar(g):
357 cfs = g * self.coeffs
358 return cls(cfs, interval=self.interval)
359 else:
360 # TODO: review with reference to __add__
361 if g.isempty:
362 return g.copy()
363 f = self
364 n = f.size + g.size - 1
365 f = f.prolong(n)
366 g = g.prolong(n)
367 cfs = coeffmult(f.coeffs, g.coeffs)
368 out = cls(cfs, interval=self.interval)
369 return out
371 def __neg__(self):
372 """Return the negative of this Chebtech.
374 Returns:
375 Chebtech: A new Chebtech representing the negative of this Chebtech.
376 """
377 coeffs = -self.coeffs
378 return self.__class__(coeffs, interval=self.interval)
380 def __pos__(self):
381 """Return this Chebtech (unary positive).
383 Returns:
384 Chebtech: This Chebtech (self).
385 """
386 return self
388 @self_empty()
389 def __pow__(self, f):
390 """Raise this Chebtech to a power.
392 Args:
393 f: The exponent, which can be a scalar or another Chebtech.
395 Returns:
396 Chebtech: A new Chebtech representing this Chebtech raised to the power f.
397 """
399 def powfun(fn, x):
400 if np.isscalar(fn):
401 return fn
402 else:
403 return fn(x)
405 return self.__class__.initfun_adaptive(lambda x: np.power(self(x), powfun(f, x)), interval=self.interval)
407 def __rdiv__(self, f):
408 """Divide a scalar by this Chebtech.
410 This is called when f / self is executed and f is not a Chebtech.
412 Args:
413 f: A scalar to be divided by this Chebtech.
415 Returns:
416 Chebtech: A new Chebtech representing f divided by this Chebtech.
417 """
419 # Executed when __div__(f, self) fails, which is to say whenever f
420 # is not a Chebtech. We proceeed on the assumption f is a scalar.
421 def constfun(x):
422 return 0.0 * x + f
424 return self.__class__.initfun_adaptive(lambda x: constfun(x) / self(x), interval=self.interval)
426 __radd__ = __add__
428 def __rsub__(self, f):
429 """Subtract this Chebtech from a scalar.
431 This is called when f - self is executed and f is not a Chebtech.
433 Args:
434 f: A scalar from which to subtract this Chebtech.
436 Returns:
437 Chebtech: A new Chebtech representing f minus this Chebtech.
438 """
439 return -(self - f)
441 @self_empty()
442 def __rpow__(self, f):
443 """Raise a scalar to the power of this Chebtech.
445 This is called when f ** self is executed and f is not a Chebtech.
447 Args:
448 f: A scalar to be raised to the power of this Chebtech.
450 Returns:
451 Chebtech: A new Chebtech representing f raised to the power of this Chebtech.
452 """
453 return self.__class__.initfun_adaptive(lambda x: np.power(f, self(x)), interval=self.interval)
455 __rtruediv__ = __rdiv__
456 __rmul__ = __mul__
458 def __sub__(self, f):
459 """Subtract a scalar or another Chebtech from this Chebtech.
461 Args:
462 f: A scalar or another Chebtech to subtract from this Chebtech.
464 Returns:
465 Chebtech: A new Chebtech representing the difference.
466 """
467 return self + (-f)
469 # -------
470 # roots
471 # -------
472 def roots(self, sort=None):
473 """Compute the roots of the Chebtech on [-1,1].
475 Uses the coefficients in the associated Chebyshev series approximation.
476 """
477 sort = sort if sort is not None else prefs.sortroots
478 rts = rootsunit(self.coeffs)
479 rts = newtonroots(self, rts)
480 # fix problems with newton for roots that are numerically very close
481 rts = np.clip(rts, -1, 1) # if newton roots are just outside [-1,1]
482 rts = rts if not sort else np.sort(rts)
483 return rts
485 # ----------
486 # calculus
487 # ----------
488 # Note that function returns 0 for an empty Chebtech object; this is
489 # consistent with numpy, which returns zero for the sum of an empty array
490 @self_empty(resultif=0.0)
491 def sum(self):
492 """Definite integral of a Chebtech on the interval [-1,1]."""
493 if self.isconst:
494 out = 2.0 * self(0.0)
495 else:
496 ak = self.coeffs.copy()
497 ak[1::2] = 0
498 kk = np.arange(2, ak.size)
499 ii = np.append([2, 0], 2 / (1 - kk**2))
500 out = (ak * ii).sum()
501 return out
503 @self_empty()
504 def cumsum(self):
505 """Return a Chebtech object representing the indefinite integral.
507 Computes the indefinite integral of a Chebtech on the interval [-1,1].
508 The constant term is chosen such that F(-1) = 0.
509 """
510 n = self.size
511 ak = np.append(self.coeffs, [0, 0])
512 bk = np.zeros(n + 1, dtype=self.coeffs.dtype)
513 rk = np.arange(2, n + 1)
514 bk[2:] = 0.5 * (ak[1:n] - ak[3:]) / rk
515 bk[1] = ak[0] - 0.5 * ak[2]
516 vk = np.ones(n)
517 vk[1::2] = -1
518 bk[0] = (vk * bk[1:]).sum()
519 out = self.__class__(bk, interval=self.interval)
520 return out
522 @self_empty()
523 def diff(self):
524 """Return a Chebtech object representing the derivative.
526 Computes the derivative of a Chebtech on the interval [-1,1].
527 """
528 if self.isconst:
529 out = self.__class__(np.array([0.0]), interval=self.interval)
530 else:
531 n = self.size
532 ak = self.coeffs
533 zk = np.zeros(n - 1, dtype=self.coeffs.dtype)
534 wk = 2 * np.arange(1, n)
535 vk = wk * ak[1:]
536 zk[-1::-2] = vk[-1::-2].cumsum()
537 zk[-2::-2] = vk[-2::-2].cumsum()
538 zk[0] = 0.5 * zk[0]
539 out = self.__class__(zk, interval=self.interval)
540 return out
542 @staticmethod
543 def _chebpts(n):
544 """Return n Chebyshev points of the second-kind."""
545 return chebpts2(n)
547 @staticmethod
548 def _barywts(n):
549 """Barycentric weights for Chebyshev points of 2nd kind."""
550 return barywts2(n)
552 @staticmethod
553 def _vals2coeffs(vals):
554 """Map function values at Chebyshev points of 2nd kind.
556 Converts values at Chebyshev points of 2nd kind to first-kind Chebyshev polynomial coefficients.
557 """
558 return vals2coeffs2(vals)
560 @staticmethod
561 def _coeffs2vals(coeffs):
562 """Map first-kind Chebyshev polynomial coefficients.
564 Converts first-kind Chebyshev polynomial coefficients to function values at Chebyshev points of 2nd kind.
565 """
566 return coeffs2vals2(coeffs)
568 # ----------
569 # plotting
570 # ----------
571 def plot(self, ax=None, **kwargs):
572 """Plot the Chebtech on the interval [-1, 1].
574 Args:
575 ax (matplotlib.axes.Axes, optional): The axes on which to plot. Defaults to None.
576 **kwargs: Additional keyword arguments to pass to the plot function.
578 Returns:
579 matplotlib.lines.Line2D: The line object created by the plot.
580 """
581 return plotfun(self, (-1, 1), ax=ax, **kwargs)
583 def plotcoeffs(self, ax=None, **kwargs):
584 """Plot the absolute values of the Chebyshev coefficients.
586 Args:
587 ax (matplotlib.axes.Axes, optional): The axes on which to plot. Defaults to None.
588 **kwargs: Additional keyword arguments to pass to the plot function.
590 Returns:
591 matplotlib.lines.Line2D: The line object created by the plot.
592 """
593 ax = ax or plt.gca()
594 return plotfuncoeffs(abs(self.coeffs), ax=ax, **kwargs)