Coverage for src / chebpy / chebfun.py: 100%
481 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:33 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:33 +0000
1"""Implementation of the Chebfun class for piecewise function approximation.
3This module provides the Chebfun class, which is the main user-facing class in the
4ChebPy package. It represents functions using piecewise polynomial approximations
5on arbitrary intervals, allowing for operations such as integration, differentiation,
6root-finding, and more.
8The Chebfun class is inspired by the MATLAB package of the same name and provides
9similar functionality for working with functions rather than numbers.
10"""
12from __future__ import annotations
14import operator
15from collections.abc import Callable, Iterator
16from typing import Any
18import matplotlib.pyplot as plt
19import numpy as np
20from matplotlib.axes import Axes
22from .algorithms import _conv_legendre, cheb2leg, leg2cheb
23from .bndfun import Bndfun
24from .chebtech import Chebtech
25from .decorators import cache, cast_arg_to_chebfun, float_argument, self_empty
26from .exceptions import BadFunLengthArgument, SupportMismatch
27from .plotting import plotfun
28from .settings import _preferences as prefs
29from .utilities import Domain, Interval, check_funs, compute_breakdata, generate_funs
32class Chebfun:
33 """Main class for representing and manipulating functions in ChebPy.
35 The Chebfun class represents functions using piecewise polynomial approximations
36 on arbitrary intervals. It provides a comprehensive set of operations for working
37 with these function representations, including:
39 - Function evaluation at arbitrary points
40 - Algebraic operations (addition, multiplication, etc.)
41 - Calculus operations (differentiation, integration, etc.)
42 - Rootfinding
43 - Plotting
45 Chebfun objects can be created from callable functions, constant values, or
46 directly from function pieces. The class supports both adaptive and fixed-length
47 approximations, allowing for efficient representation of functions with varying
48 complexity across different intervals.
50 Attributes:
51 funs (numpy.ndarray): Array of function pieces that make up the Chebfun.
52 breakdata (OrderedDict): Mapping of breakpoints to function values.
53 transposed (bool): Flag indicating if the Chebfun is transposed.
54 """
56 def __init__(self, funs: Any) -> None:
57 """Initialize a Chebfun object.
59 Args:
60 funs (list): List of function objects to be included in the Chebfun.
61 These will be checked and sorted using check_funs.
62 """
63 self.funs = check_funs(funs)
64 self.breakdata = compute_breakdata(self.funs)
65 self.transposed = False
67 @classmethod
68 def initempty(cls) -> Chebfun:
69 """Initialize an empty Chebfun.
71 Returns:
72 Chebfun: An empty Chebfun object with no functions.
74 Examples:
75 >>> f = Chebfun.initempty()
76 >>> f.isempty
77 True
78 """
79 return cls([])
81 @classmethod
82 def initidentity(cls, domain: Any = None) -> Chebfun:
83 """Initialize a Chebfun representing the identity function f(x) = x.
85 Args:
86 domain (array-like, optional): Domain on which to define the identity function.
87 If None, uses the default domain from preferences.
89 Returns:
90 Chebfun: A Chebfun object representing the identity function on the specified domain.
92 Examples:
93 >>> import numpy as np
94 >>> x = Chebfun.initidentity([-1, 1])
95 >>> float(x(0.5))
96 0.5
97 >>> np.allclose(x([0, 0.5, 1]), [0, 0.5, 1])
98 True
99 """
100 return cls(generate_funs(domain, Bndfun.initidentity))
102 @classmethod
103 def initconst(cls, c: Any, domain: Any = None) -> Chebfun:
104 """Initialize a Chebfun representing a constant function f(x) = c.
106 Args:
107 c (float or complex): The constant value.
108 domain (array-like, optional): Domain on which to define the constant function.
109 If None, uses the default domain from preferences.
111 Returns:
112 Chebfun: A Chebfun object representing the constant function on the specified domain.
114 Examples:
115 >>> import numpy as np
116 >>> f = Chebfun.initconst(3.14, [-1, 1])
117 >>> float(f(0))
118 3.14
119 >>> float(f(0.5))
120 3.14
121 >>> np.allclose(f([0, 0.5, 1]), [3.14, 3.14, 3.14])
122 True
123 """
124 return cls(generate_funs(domain, Bndfun.initconst, {"c": c}))
126 @classmethod
127 def initfun_adaptive(cls, f: Callable[..., Any], domain: Any = None) -> Chebfun:
128 """Initialize a Chebfun by adaptively sampling a function.
130 This method determines the appropriate number of points needed to represent
131 the function to the specified tolerance using an adaptive algorithm.
133 Args:
134 f (callable): The function to be approximated.
135 domain (array-like, optional): Domain on which to define the function.
136 If None, uses the default domain from preferences.
138 Returns:
139 Chebfun: A Chebfun object representing the function on the specified domain.
141 Examples:
142 >>> import numpy as np
143 >>> f = Chebfun.initfun_adaptive(lambda x: np.sin(x), [-np.pi, np.pi])
144 >>> bool(abs(f(0)) < 1e-10)
145 True
146 >>> bool(abs(f(np.pi/2) - 1) < 1e-10)
147 True
148 """
149 return cls(generate_funs(domain, Bndfun.initfun_adaptive, {"f": f}))
151 @classmethod
152 def initfun_fixedlen(cls, f: Callable[..., Any], n: Any, domain: Any = None) -> Chebfun:
153 """Initialize a Chebfun with a fixed number of points.
155 This method uses a specified number of points to represent the function,
156 rather than determining the number adaptively.
158 Args:
159 f (callable): The function to be approximated.
160 n (int or array-like): Number of points to use. If a single value, uses the same
161 number for each interval. If an array, must have one fewer elements than
162 the size of the domain.
163 domain (array-like, optional): Domain on which to define the function.
164 If None, uses the default domain from preferences.
166 Returns:
167 Chebfun: A Chebfun object representing the function on the specified domain.
169 Raises:
170 BadFunLengthArgument: If n is an array and its size doesn't match domain.size - 1.
171 """
172 nn = np.array(n)
173 if nn.size < 2:
174 funs = generate_funs(domain, Bndfun.initfun_fixedlen, {"f": f, "n": n})
175 else:
176 domain = Domain(domain if domain is not None else prefs.domain)
177 if not nn.size == domain.size - 1:
178 raise BadFunLengthArgument
179 funs = []
180 for interval, length in zip(domain.intervals, nn, strict=False):
181 funs.append(Bndfun.initfun_fixedlen(f, interval, length))
182 return cls(funs)
184 @classmethod
185 def initfun(cls, f: Callable[..., Any], domain: Any = None, n: Any = None) -> Chebfun:
186 """Initialize a Chebfun from a function.
188 This is a general-purpose constructor that delegates to either initfun_adaptive
189 or initfun_fixedlen based on whether n is provided.
191 Args:
192 f (callable): The function to be approximated.
193 domain (array-like, optional): Domain on which to define the function.
194 If None, uses the default domain from preferences.
195 n (int or array-like, optional): Number of points to use. If None, determines
196 the number adaptively. If provided, uses a fixed number of points.
198 Returns:
199 Chebfun: A Chebfun object representing the function on the specified domain.
200 """
201 if n is None:
202 return cls.initfun_adaptive(f, domain)
203 else:
204 return cls.initfun_fixedlen(f, n, domain)
206 # --------------------
207 # operator overloads
208 # --------------------
209 def __add__(self, f: Any) -> Any:
210 """Add a Chebfun with another Chebfun or a scalar.
212 Args:
213 f (Chebfun or scalar): The object to add to this Chebfun.
215 Returns:
216 Chebfun: A new Chebfun representing the sum.
217 """
218 return self._apply_binop(f, operator.add)
220 @self_empty(np.array([]))
221 @float_argument
222 def __call__(self, x: Any) -> Any:
223 """Evaluate the Chebfun at points x.
225 This method evaluates the Chebfun at the specified points. It handles interior
226 points, breakpoints, and points outside the domain appropriately.
228 Args:
229 x (float or array-like): Points at which to evaluate the Chebfun.
231 Returns:
232 float or numpy.ndarray: The value(s) of the Chebfun at the specified point(s).
233 Returns a scalar if x is a scalar, otherwise an array of the same size as x.
234 """
235 # initialise output
236 dtype = complex if self.iscomplex else float
237 out = np.full(x.size, np.nan, dtype=dtype)
239 # evaluate a fun when x is an interior point
240 for fun in self:
241 idx = fun.interval.isinterior(x)
242 out[idx] = fun(x[idx])
244 # evaluate the breakpoint data for x at a breakpoint
245 breakpoints = self.breakpoints
246 for break_point in breakpoints:
247 out[x == break_point] = self.breakdata[break_point]
249 # first and last funs used to evaluate outside of the chebfun domain
250 lpts, rpts = x < breakpoints[0], x > breakpoints[-1]
251 out[lpts] = self.funs[0](x[lpts])
252 out[rpts] = self.funs[-1](x[rpts])
253 return out
255 def __iter__(self) -> Iterator[Any]:
256 """Return an iterator over the functions in this Chebfun.
258 Returns:
259 iterator: An iterator over the functions (funs) in this Chebfun.
260 """
261 return self.funs.__iter__()
263 def __len__(self) -> int:
264 """Return the total number of coefficients across all funs.
266 Returns:
267 int: The sum of sizes of all constituent funs.
268 """
269 return sum(f.size for f in self.funs)
271 def __eq__(self, other: object) -> bool:
272 """Test for equality between two Chebfun objects.
274 Two Chebfun objects are considered equal if they have the same domain
275 and their function values are equal (within tolerance) at a set of test points.
277 Args:
278 other (object): The object to compare with this Chebfun.
280 Returns:
281 bool: True if the objects are equal, False otherwise.
282 """
283 if not isinstance(other, self.__class__):
284 return False
286 # Check if both are empty
287 if self.isempty and other.isempty:
288 return True
290 # Check if domains are equal
291 if self.domain != other.domain:
292 return False
294 # Check function values at test points
295 xx = np.linspace(self.support[0], self.support[1], 100)
296 tol = 1e2 * max(self.vscale, other.vscale) * prefs.eps
297 return bool(np.all(np.abs(self(xx) - other(xx)) <= tol))
299 def __mul__(self, f: Any) -> Any:
300 """Multiply a Chebfun with another Chebfun or a scalar.
302 Args:
303 f (Chebfun or scalar): The object to multiply with this Chebfun.
305 Returns:
306 Chebfun: A new Chebfun representing the product.
307 """
308 return self._apply_binop(f, operator.mul)
310 def __neg__(self) -> Chebfun:
311 """Return the negative of this Chebfun.
313 Returns:
314 Chebfun: A new Chebfun representing -f(x).
315 """
316 return self.__class__(-self.funs)
318 def __pos__(self) -> Chebfun:
319 """Return the positive of this Chebfun (which is the Chebfun itself).
321 Returns:
322 Chebfun: This Chebfun object (unchanged).
323 """
324 return self
326 def __abs__(self) -> Chebfun:
327 """Return the absolute value of this Chebfun.
329 Returns:
330 Chebfun: A new Chebfun representing |f(x)|.
331 """
332 abs_funs = []
333 for fun in self.funs:
334 abs_funs.append(fun.absolute())
335 return self.__class__(abs_funs)
337 def __pow__(self, f: Any) -> Any:
338 """Raise this Chebfun to a power.
340 Args:
341 f (Chebfun or scalar): The exponent to which this Chebfun is raised.
343 Returns:
344 Chebfun: A new Chebfun representing self^f.
345 """
346 return self._apply_binop(f, operator.pow)
348 def __rtruediv__(self, c: Any) -> Chebfun:
349 """Divide a scalar by this Chebfun.
351 This method is called when a scalar is divided by a Chebfun, i.e., c / self.
353 Args:
354 c (scalar): The scalar numerator.
356 Returns:
357 Chebfun: A new Chebfun representing c / self.
359 Note:
360 This is executed when truediv(f, self) fails, which is to say whenever c
361 is not a Chebfun. We proceed on the assumption f is a scalar.
362 """
364 def constfun(cheb: Any, const: Any) -> Any:
365 return 0.0 * cheb + const
367 def make_divfun(fun: Any) -> Callable[..., Any]:
368 return lambda x: constfun(x, c) / fun(x)
370 newfuns = [fun.initfun_adaptive(make_divfun(fun), fun.interval) for fun in self]
371 return self.__class__(newfuns)
373 @self_empty("Chebfun<empty>")
374 def __repr__(self) -> str:
375 """Return a string representation of the Chebfun.
377 This method returns a detailed string representation of the Chebfun,
378 including information about its domain, intervals, and endpoint values.
380 Returns:
381 str: A string representation of the Chebfun.
382 """
383 rowcol = "row" if self.transposed else "column"
384 numpcs = self.funs.size
385 plural = "" if numpcs == 1 else "s"
386 header = f"Chebfun {rowcol} ({numpcs} smooth piece{plural})\n"
387 domain_info = f"domain: {self.support}\n"
388 toprow = " interval length endpoint values\n"
389 tmplat = "[{:8.2g},{:8.2g}] {:6} {:8.2g} {:8.2g}\n"
390 rowdta = ""
391 for fun in self:
392 endpts = fun.support
393 xl, xr = endpts
394 fl, fr = fun(endpts)
395 row = tmplat.format(xl, xr, fun.size, fl, fr)
396 rowdta += row
397 btmrow = f"vertical scale = {self.vscale:3.2g}"
398 btmxtr = "" if numpcs == 1 else f" total length = {sum([f.size for f in self])}"
399 return header + domain_info + toprow + rowdta + btmrow + btmxtr
401 def __rsub__(self, f: Any) -> Any:
402 """Subtract this Chebfun from another object.
404 This method is called when another object is subtracted by this Chebfun,
405 i.e., f - self.
407 Args:
408 f (Chebfun or scalar): The object from which to subtract this Chebfun.
410 Returns:
411 Chebfun: A new Chebfun representing f - self.
412 """
413 return -(self - f)
415 @cast_arg_to_chebfun
416 def __rpow__(self, f: Any) -> Any:
417 """Raise another object to the power of this Chebfun.
419 This method is called when another object is raised to the power of this Chebfun,
420 i.e., f ** self.
422 Args:
423 f (Chebfun or scalar): The base to be raised to the power of this Chebfun.
425 Returns:
426 Chebfun: A new Chebfun representing f ** self.
427 """
428 return f**self
430 def __truediv__(self, f: Any) -> Any:
431 """Divide this Chebfun by another object.
433 Args:
434 f (Chebfun or scalar): The divisor.
436 Returns:
437 Chebfun: A new Chebfun representing self / f.
438 """
439 return self._apply_binop(f, operator.truediv)
441 __rmul__ = __mul__
442 __div__ = __truediv__
443 __rdiv__ = __rtruediv__
444 __radd__ = __add__
446 def __str__(self) -> str:
447 """Return a concise string representation of the Chebfun.
449 This method returns a brief string representation of the Chebfun,
450 showing its orientation, number of pieces, total size, and domain.
452 Returns:
453 str: A concise string representation of the Chebfun.
454 """
455 rowcol = "row" if self.transposed else "col"
456 domain_str = f"domain {self.support}" if not self.isempty else "empty"
457 out = f"<Chebfun-{rowcol},{self.funs.size},{sum([f.size for f in self])}, {domain_str}>\n"
458 return out
460 def __sub__(self, f: Any) -> Any:
461 """Subtract another object from this Chebfun.
463 Args:
464 f (Chebfun or scalar): The object to subtract from this Chebfun.
466 Returns:
467 Chebfun: A new Chebfun representing self - f.
468 """
469 return self._apply_binop(f, operator.sub)
471 # ------------------
472 # internal helpers
473 # ------------------
474 @self_empty()
475 def _apply_binop(self, f: Any, op: Callable[..., Any]) -> Any:
476 """Apply a binary operation between this Chebfun and another object.
478 This is a funnel method used in the implementation of Chebfun binary
479 operators. The high-level idea is to first break each chebfun into a
480 series of pieces corresponding to the union of the domains of each
481 before applying the supplied binary operator and simplifying. In the
482 case of the second argument being a scalar we don't need to do the
483 simplify step, since at the Tech-level these operations are defined
484 such that there is no change in the number of coefficients.
486 Args:
487 f (Chebfun or scalar): The second operand of the binary operation.
488 op (callable): The binary operation to apply (e.g., operator.add).
490 Returns:
491 Chebfun: A new Chebfun resulting from applying the binary operation.
492 """
493 if hasattr(f, "isempty") and f.isempty:
494 return f
495 if np.isscalar(f):
496 chbfn1 = self
497 chbfn2 = f * np.ones(self.funs.size)
498 simplify = False
499 else:
500 newdom = self.domain.union(f.domain)
501 chbfn1 = self._break(newdom)
502 chbfn2 = f._break(newdom)
503 simplify = True
504 newfuns = []
505 for fun1, fun2 in zip(chbfn1, chbfn2, strict=False):
506 newfun = op(fun1, fun2)
507 if simplify:
508 newfun = newfun.simplify()
509 newfuns.append(newfun)
510 return self.__class__(newfuns)
512 def _break(self, targetdomain: Domain) -> Chebfun:
513 """Resample this Chebfun to a new domain.
515 This method resamples the Chebfun to the supplied Domain object. It is
516 intended as a private method since one will typically need to have
517 called either Domain.union(f) or Domain.merge(f) prior to calling this method.
519 Args:
520 targetdomain (Domain): The domain to which this Chebfun should be resampled.
522 Returns:
523 Chebfun: A new Chebfun resampled to the target domain.
524 """
525 newfuns = []
526 subintervals = iter(targetdomain.intervals)
527 interval = next(subintervals) # next(..) for Python2/3 compatibility
528 for fun in self:
529 while interval in fun.interval:
530 newfun = fun.restrict(interval)
531 newfuns.append(newfun)
532 try:
533 interval = next(subintervals)
534 except StopIteration:
535 break
536 return self.__class__(newfuns)
538 # ------------
539 # properties
540 # ------------
541 @property
542 def breakpoints(self) -> np.ndarray:
543 """Get the breakpoints of this Chebfun.
545 Breakpoints are the points where the Chebfun transitions from one piece to another.
547 Returns:
548 numpy.ndarray: Array of breakpoints.
549 """
550 return np.array(list(self.breakdata.keys()))
552 @property
553 @self_empty(Domain([]))
554 def domain(self) -> Domain:
555 """Get the domain of this Chebfun.
557 Returns:
558 Domain: A Domain object corresponding to this Chebfun.
559 """
560 return Domain.from_chebfun(self)
562 @domain.setter
563 def domain(self, new_domain: Any) -> None:
564 """Set the domain of the Chebfun by restricting to the new domain.
566 Args:
567 new_domain (array-like): The new domain to which this Chebfun should be restricted.
568 """
569 self.restrict_(new_domain)
571 @property
572 @self_empty(Domain([]))
573 def support(self) -> Any:
574 """Get the support interval of this Chebfun.
576 The support is the interval between the first and last breakpoints.
578 Returns:
579 numpy.ndarray: Array containing the first and last breakpoints.
580 """
581 return self.domain.support
583 @property
584 @self_empty(0.0)
585 def hscale(self) -> float:
586 """Get the horizontal scale of this Chebfun.
588 The horizontal scale is the maximum absolute value of the support interval.
590 Returns:
591 float: The horizontal scale.
592 """
593 return float(np.abs(self.support).max())
595 @property
596 @self_empty(False)
597 def iscomplex(self) -> bool:
598 """Check if this Chebfun has complex values.
600 Returns:
601 bool: True if any of the functions in this Chebfun have complex values,
602 False otherwise.
603 """
604 return any(fun.iscomplex for fun in self)
606 @property
607 @self_empty(False)
608 def isconst(self) -> bool:
609 """Check if this Chebfun represents a constant function.
611 A Chebfun is constant if all of its pieces are constant with the same value.
613 Returns:
614 bool: True if this Chebfun represents a constant function, False otherwise.
616 Note:
617 TODO: find an abstract way of referencing funs[0].coeffs[0]
618 """
619 c = self.funs[0].coeffs[0]
620 return all(fun.isconst and fun.coeffs[0] == c for fun in self)
622 @property
623 def isempty(self) -> bool:
624 """Check if this Chebfun is empty.
626 An empty Chebfun contains no functions.
628 Returns:
629 bool: True if this Chebfun is empty, False otherwise.
630 """
631 return self.funs.size == 0
633 @property
634 @self_empty(0.0)
635 def vscale(self) -> Any:
636 """Get the vertical scale of this Chebfun.
638 The vertical scale is the maximum of the vertical scales of all pieces.
640 Returns:
641 float: The vertical scale.
642 """
643 return np.max([fun.vscale for fun in self])
645 @property
646 @self_empty()
647 def x(self) -> Chebfun:
648 """Get the identity function on the support of this Chebfun.
650 This property returns a new Chebfun representing the identity function f(x) = x
651 defined on the same support as this Chebfun.
653 Returns:
654 Chebfun: A Chebfun representing the identity function on the support of this Chebfun.
655 """
656 return self.__class__.initidentity(self.support)
658 # -----------
659 # utilities
660 # ----------
662 def imag(self) -> Chebfun:
663 """Get the imaginary part of this Chebfun.
665 Returns:
666 Chebfun: A new Chebfun representing the imaginary part of this Chebfun.
667 If this Chebfun is real-valued, returns a zero Chebfun.
668 """
669 if self.iscomplex:
670 return self.__class__([fun.imag() for fun in self])
671 else:
672 return self.initconst(0, domain=self.domain)
674 def real(self) -> Chebfun:
675 """Get the real part of this Chebfun.
677 Returns:
678 Chebfun: A new Chebfun representing the real part of this Chebfun.
679 If this Chebfun is already real-valued, returns this Chebfun.
680 """
681 if self.iscomplex:
682 return self.__class__([fun.real() for fun in self])
683 else:
684 return self
686 def copy(self) -> Chebfun:
687 """Create a deep copy of this Chebfun.
689 Returns:
690 Chebfun: A new Chebfun that is a deep copy of this Chebfun.
691 """
692 return self.__class__([fun.copy() for fun in self])
694 @self_empty()
695 def _restrict(self, subinterval: Any) -> Chebfun:
696 """Restrict a Chebfun to a subinterval, without simplifying.
698 This is an internal method that restricts the Chebfun to a subinterval
699 without performing simplification.
701 Args:
702 subinterval (array-like): The subinterval to which this Chebfun should be restricted.
704 Returns:
705 Chebfun: A new Chebfun restricted to the specified subinterval, without simplification.
706 """
707 newdom = self.domain.restrict(Domain(subinterval))
708 return self._break(newdom)
710 def restrict(self, subinterval: Any) -> Any:
711 """Restrict a Chebfun to a subinterval.
713 This method creates a new Chebfun that is restricted to the specified subinterval
714 and simplifies the result.
716 Args:
717 subinterval (array-like): The subinterval to which this Chebfun should be restricted.
719 Returns:
720 Chebfun: A new Chebfun restricted to the specified subinterval.
721 """
722 return self._restrict(subinterval).simplify()
724 @self_empty()
725 def restrict_(self, subinterval: Any) -> Chebfun:
726 """Restrict a Chebfun to a subinterval, modifying the object in place.
728 This method modifies the current Chebfun by restricting it to the specified
729 subinterval and simplifying the result.
731 Args:
732 subinterval (array-like): The subinterval to which this Chebfun should be restricted.
734 Returns:
735 Chebfun: The modified Chebfun (self).
736 """
737 restricted = self._restrict(subinterval).simplify()
738 self.funs = restricted.funs
739 self.breakdata = compute_breakdata(self.funs)
740 return self
742 @cache
743 @self_empty(np.array([]))
744 def roots(self, merge: Any = None) -> np.ndarray:
745 """Compute the roots of a Chebfun.
747 This method finds the values x for which f(x) = 0, by computing the roots
748 of each piece of the Chebfun and combining them.
750 Args:
751 merge (bool, optional): Whether to merge roots at breakpoints. If None,
752 uses the value from preferences. Defaults to None.
754 Returns:
755 numpy.ndarray: Array of roots sorted in ascending order.
757 Examples:
758 >>> import numpy as np
759 >>> f = Chebfun.initfun_adaptive(lambda x: x**2 - 1, [-2, 2])
760 >>> roots = f.roots()
761 >>> len(roots)
762 2
763 >>> np.allclose(sorted(roots), [-1, 1])
764 True
765 """
766 merge = merge if merge is not None else prefs.mergeroots
767 allrts = []
768 prvrts = np.array([])
769 htol = 1e2 * self.hscale * prefs.eps
770 for fun in self:
771 rts = fun.roots()
772 # ignore first root if equal to the last root of previous fun
773 # TODO: there could be multiple roots at breakpoints
774 if prvrts.size > 0 and rts.size > 0 and merge and abs(prvrts[-1] - rts[0]) <= htol:
775 rts = rts[1:]
776 allrts.append(rts)
777 prvrts = rts
778 return np.concatenate(list(allrts))
780 @self_empty()
781 def simplify(self) -> Chebfun:
782 """Simplify each fun in the chebfun."""
783 return self.__class__([fun.simplify() for fun in self])
785 def translate(self, c: Any) -> Chebfun:
786 """Translate a chebfun by c, i.e., return f(x-c)."""
787 return self.__class__([x.translate(c) for x in self])
789 # ----------
790 # calculus
791 # ----------
792 def cumsum(self) -> Chebfun:
793 """Compute the indefinite integral (antiderivative) of the Chebfun.
795 This method computes the indefinite integral of the Chebfun, with the
796 constant of integration chosen so that the indefinite integral evaluates
797 to 0 at the left endpoint of the domain. For piecewise functions, constants
798 are added to ensure continuity across the pieces.
800 Returns:
801 Chebfun: A new Chebfun representing the indefinite integral of this Chebfun.
803 Examples:
804 >>> import numpy as np
805 >>> f = Chebfun.initconst(1.0, [-1, 1])
806 >>> F = f.cumsum()
807 >>> bool(abs(F(-1)) < 1e-10)
808 True
809 >>> bool(abs(F(1) - 2.0) < 1e-10)
810 True
811 """
812 newfuns = []
813 prevfun = None
814 for fun in self:
815 integral = fun.cumsum()
816 if prevfun:
817 # enforce continuity by adding the function value
818 # at the right endpoint of the previous fun
819 _, fb = prevfun.endvalues
820 integral = integral + fb
821 newfuns.append(integral)
822 prevfun = integral
823 return self.__class__(newfuns)
825 def diff(self, n: int = 1) -> Chebfun:
826 """Compute the derivative of the Chebfun.
828 This method calculates the nth derivative of the Chebfun with respect to x.
829 It creates a new Chebfun where each piece is the derivative of the
830 corresponding piece in the original Chebfun.
832 Args:
833 n: Order of differentiation (default: 1). Must be non-negative integer.
835 Returns:
836 Chebfun: A new Chebfun representing the nth derivative of this Chebfun.
838 Examples:
839 >>> from chebpy import chebfun
840 >>> f = chebfun(lambda x: x**3)
841 >>> df1 = f.diff() # first derivative: 3*x**2
842 >>> df2 = f.diff(2) # second derivative: 6*x
843 >>> df3 = f.diff(3) # third derivative: 6
844 >>> bool(abs(df1(0.5) - 0.75) < 1e-10)
845 True
846 >>> bool(abs(df2(0.5) - 3.0) < 1e-10)
847 True
848 >>> bool(abs(df3(0.5) - 6.0) < 1e-10)
849 True
850 """
851 if not isinstance(n, int):
852 raise TypeError(n)
853 if n == 0:
854 return self
855 if n < 0:
856 raise ValueError(n)
858 result = self
859 for _ in range(n):
860 dfuns = np.array([fun.diff() for fun in result])
861 result = self.__class__(dfuns)
862 return result
864 def conv(self, g: Chebfun) -> Chebfun:
865 """Compute the convolution of this Chebfun with g.
867 Computes h(x) = (f ★ g)(x) = ∫ f(t) g(x-t) dt, where domain(f) is
868 [a, b] and domain(g) is [c, d]. The result is a piecewise Chebfun on
869 [a + c, b + d] whose breakpoints are the pairwise sums of the
870 breakpoints of f and g.
872 Both f and g may be piecewise (contain an arbitrary number of funs).
874 When both inputs are single-piece with equal-width domains, the fast
875 Hale-Townsend Legendre convolution algorithm is used. Otherwise, each
876 output sub-interval is constructed adaptively using Gauss-Legendre
877 quadrature.
879 The algorithm is based on:
880 N. Hale and A. Townsend, "An algorithm for the convolution of
881 Legendre series", SIAM J. Sci. Comput., 36(3), A1207-A1220, 2014.
883 Args:
884 g (Chebfun): A Chebfun (single-piece or piecewise).
886 Returns:
887 Chebfun: A piecewise Chebfun on [a + c, b + d] representing
888 (f ★ g).
890 Examples:
891 >>> import numpy as np
892 >>> from chebpy import chebfun
893 >>> f = chebfun(lambda x: np.ones_like(x), [-1, 1])
894 >>> h = f.conv(f)
895 >>> bool(abs(h(0.0) - 2.0) < 1e-10)
896 True
897 >>> bool(abs(h(-1.0) - 1.0) < 1e-10)
898 True
899 >>> bool(abs(h(1.0) - 1.0) < 1e-10)
900 True
901 """
902 if self.isempty or g.isempty:
903 return self.__class__.initempty()
905 # Fast path: both single-piece with equal-width domains
906 if self.funs.size == 1 and g.funs.size == 1:
907 f_fun, g_fun = self.funs[0], g.funs[0]
908 f_w = float(f_fun.support[1]) - float(f_fun.support[0])
909 g_w = float(g_fun.support[1]) - float(g_fun.support[0])
910 if np.isclose(f_w, g_w):
911 return self._conv_equal_width_pair(f_fun, g_fun)
913 # General piecewise convolution
914 return self._conv_piecewise(g)
916 def _conv_equal_width_pair(self, f_fun: Any, g_fun: Any) -> Chebfun:
917 """Convolve two single Bndfuns of equal width using the fast algorithm.
919 Uses the Hale-Townsend Legendre convolution. The two funs may be on
920 different intervals as long as they have the same width.
921 """
922 a = float(f_fun.support[0])
923 b = float(f_fun.support[1])
924 c = float(g_fun.support[0])
925 d = float(g_fun.support[1])
927 h = (b - a) / 2.0 # half-width (same for both funs)
929 leg_f = cheb2leg(f_fun.coeffs)
930 leg_g = cheb2leg(g_fun.coeffs)
932 gamma_left, gamma_right = _conv_legendre(leg_f, leg_g)
934 gamma_left = h * gamma_left
935 gamma_right = h * gamma_right
937 cheb_left = leg2cheb(gamma_left)
938 cheb_right = leg2cheb(gamma_right)
940 mid = (a + b + c + d) / 2.0
941 left_interval = Interval(a + c, mid)
942 right_interval = Interval(mid, b + d)
944 left_fun = Bndfun(Chebtech(cheb_left), left_interval)
945 right_fun = Bndfun(Chebtech(cheb_right), right_interval)
947 return self.__class__([left_fun, right_fun])
949 def _conv_piecewise(self, g: Chebfun) -> Chebfun:
950 """General piecewise convolution via Gauss-Legendre quadrature.
952 The breakpoints of the result are the sorted, unique pairwise sums of
953 the breakpoints of self and g. On each sub-interval the convolution
954 integral is smooth, so we construct it adaptively.
955 """
956 f_breaks = self.breakpoints
957 g_breaks = g.breakpoints
958 f_a, f_b = float(f_breaks[0]), float(f_breaks[-1])
959 g_c, g_d = float(g_breaks[0]), float(g_breaks[-1])
961 # Output breakpoints: all pairwise sums, uniquified and coalesced
962 out_breaks = np.unique(np.add.outer(f_breaks, g_breaks).ravel())
963 hscl = max(abs(out_breaks[0]), abs(out_breaks[-1]), 1.0)
964 tol = 10.0 * np.finfo(float).eps * hscl
965 mask = np.concatenate(([True], np.diff(out_breaks) > tol))
966 out_breaks = out_breaks[mask]
968 # Quadrature order: sufficient for exact integration of polynomial
969 # integrand on each smooth sub-interval
970 max_deg = max(fun.size for fun in self.funs) + max(fun.size for fun in g.funs)
971 n_quad = max(int(np.ceil((max_deg + 1) / 2)), 16)
972 quad_nodes, quad_weights = np.polynomial.legendre.leggauss(n_quad)
974 # Pre-convert breakpoints to plain float lists for the inner loop
975 f_bps = [float(bp) for bp in f_breaks]
976 g_bps = [float(bp) for bp in g_breaks]
978 def conv_eval(x: np.ndarray) -> np.ndarray:
979 """Evaluate (self ★ g)(x) via Gauss-Legendre quadrature."""
980 x = np.atleast_1d(np.asarray(x, dtype=float))
981 result = np.zeros(x.shape)
982 for idx in range(x.size):
983 xi = x[idx]
984 t_lo = max(f_a, xi - g_d)
985 t_hi = min(f_b, xi - g_c)
986 if t_hi <= t_lo:
987 continue
988 # Break integration at breakpoints of f and shifted breakpoints
989 # of g so the integrand is polynomial on each sub-interval.
990 inner = [t_lo, t_hi]
991 for bp in f_bps:
992 if t_lo < bp < t_hi:
993 inner.append(bp)
994 for bp in g_bps:
995 shifted = xi - bp
996 if t_lo < shifted < t_hi:
997 inner.append(shifted)
998 inner = sorted(set(inner))
1000 total = 0.0
1001 for j in range(len(inner) - 1):
1002 a_int, b_int = inner[j], inner[j + 1]
1003 hw = (b_int - a_int) / 2.0
1004 mid = (a_int + b_int) / 2.0
1005 nodes = hw * quad_nodes + mid
1006 wts = hw * quad_weights
1007 total += np.dot(wts, self(nodes) * g(xi - nodes))
1008 result[idx] = total
1009 return result
1011 # Build a Bndfun on each output sub-interval
1012 funs_list = []
1013 for i in range(len(out_breaks) - 1):
1014 interval = Interval(out_breaks[i], out_breaks[i + 1])
1015 fun = Bndfun.initfun_adaptive(conv_eval, interval)
1016 funs_list.append(fun)
1018 return self.__class__(funs_list)
1020 def sum(self) -> Any:
1021 """Compute the definite integral of the Chebfun over its domain.
1023 This method calculates the definite integral of the Chebfun over its
1024 entire domain of definition by summing the definite integrals of each
1025 piece.
1027 Returns:
1028 float or complex: The definite integral of the Chebfun over its domain.
1030 Examples:
1031 >>> import numpy as np
1032 >>> f = Chebfun.initfun_adaptive(lambda x: x**2, [-1, 1])
1033 >>> bool(abs(f.sum() - 2.0/3.0) < 1e-10)
1034 True
1035 >>> g = Chebfun.initconst(1.0, [-1, 1])
1036 >>> bool(abs(g.sum() - 2.0) < 1e-10)
1037 True
1038 """
1039 return np.sum([fun.sum() for fun in self])
1041 def dot(self, f: Any) -> Any:
1042 """Compute the dot product of this Chebfun with another function.
1044 This method calculates the inner product (dot product) of this Chebfun
1045 with another function f by multiplying them pointwise and then integrating
1046 the result over the domain.
1048 Args:
1049 f (Chebfun or scalar): The function or scalar to compute the dot product with.
1050 If not a Chebfun, it will be converted to one.
1052 Returns:
1053 float or complex: The dot product of this Chebfun with f.
1054 """
1055 return (self * f).sum()
1057 def norm(self, p: Any = 2) -> Any:
1058 """Compute the Lp norm of the Chebfun over its domain.
1060 This method calculates the Lp norm of the Chebfun. The L2 norm is the
1061 default and is computed as sqrt(integral(|f|^2)). For p=inf, returns
1062 the maximum absolute value by checking critical points (extrema).
1064 Args:
1065 p (int or float): The norm type. Supported values are 1, 2, positive
1066 integers/floats, or np.inf. Defaults to 2 (L2 norm).
1068 Returns:
1069 float: The Lp norm of the Chebfun.
1071 Examples:
1072 >>> from chebpy import chebfun
1073 >>> import numpy as np
1074 >>> f = chebfun(lambda x: x**2, [-1, 1])
1075 >>> np.allclose(f.norm(), 0.6324555320336759) # L2 norm
1076 True
1077 >>> np.allclose(f.norm(np.inf), 1.0) # Maximum absolute value
1078 True
1079 """
1080 if p == 2:
1081 # L2 norm: sqrt(integral(|f|^2))
1082 return np.sqrt(self.dot(self))
1083 elif p == np.inf:
1084 # L-infinity norm: max|f(x)|
1085 df = self.diff()
1086 critical_pts = df.roots()
1087 # Add endpoints
1088 endpoints = np.array([self.domain[0], self.domain[-1]])
1089 # Combine all test points
1090 test_pts = np.concatenate([critical_pts, endpoints])
1091 # Evaluate and find max
1092 vals = np.abs(self(test_pts))
1093 return np.max(vals)
1094 elif p == 1:
1095 # L1 norm: integral(|f|)
1096 return self.absolute().sum()
1097 elif p > 0:
1098 # General Lp norm: (integral(|f|^p))^(1/p)
1099 f_abs = self.absolute()
1100 f_pow_p = f_abs**p
1101 integral = f_pow_p.sum()
1102 return integral ** (1.0 / p)
1103 else:
1104 raise ValueError(p)
1106 # ----------
1107 # utilities
1108 # ----------
1109 @self_empty()
1110 def absolute(self) -> Chebfun:
1111 """Absolute value of a Chebfun."""
1112 newdom = self.domain.merge(self.roots())
1113 funs = [x.absolute() for x in self._break(newdom)]
1114 return self.__class__(funs)
1116 abs = absolute
1118 @self_empty()
1119 def sign(self) -> Chebfun:
1120 """Sign function of a Chebfun.
1122 Computes the piecewise sign of a Chebfun by finding its roots
1123 and splitting the domain at those points, then creating constant
1124 pieces with the appropriate sign values.
1126 Returns:
1127 Chebfun: A new Chebfun representing sign(f(x)).
1128 """
1129 roots = self.roots()
1130 newdom = self.domain.merge(roots)
1131 funs = []
1132 for fun in self._break(newdom):
1133 mid = fun.support[0] + 0.5 * (fun.support[-1] - fun.support[0])
1134 s = float(np.sign(float(self(mid))))
1135 funs.append(Bndfun.initconst(s, fun.interval))
1136 result = self.__class__(funs)
1137 # Set breakdata: at roots sign is 0, elsewhere use sign of function
1138 htol = max(1e2 * self.hscale * prefs.eps, prefs.eps)
1139 for bp in result.breakpoints:
1140 if roots.size > 0 and np.any(np.abs(bp - roots) <= htol):
1141 result.breakdata[bp] = 0.0
1142 else:
1143 result.breakdata[bp] = float(np.sign(float(self(bp))))
1144 return result
1146 @self_empty()
1147 def ceil(self) -> Chebfun:
1148 """Ceiling function of a Chebfun.
1150 Computes the piecewise ceiling of a Chebfun by finding where
1151 the function crosses integer values and splitting the domain
1152 at those points, then creating constant pieces with the
1153 appropriate ceiling values.
1155 Returns:
1156 Chebfun: A new Chebfun representing ceil(f(x)).
1157 """
1158 crossings = self._integer_crossings()
1159 newdom = self.domain.merge(crossings)
1160 funs = []
1161 for fun in self._break(newdom):
1162 mid = fun.support[0] + 0.5 * (fun.support[-1] - fun.support[0])
1163 c = float(np.ceil(float(self(mid))))
1164 funs.append(Bndfun.initconst(c, fun.interval))
1165 result = self.__class__(funs)
1166 for bp in result.breakpoints:
1167 result.breakdata[bp] = float(np.ceil(float(self(bp))))
1168 return result
1170 @self_empty()
1171 def floor(self) -> Chebfun:
1172 """Floor function of a Chebfun.
1174 Computes the piecewise floor of a Chebfun by finding where
1175 the function crosses integer values and splitting the domain
1176 at those points, then creating constant pieces with the
1177 appropriate floor values.
1179 Returns:
1180 Chebfun: A new Chebfun representing floor(f(x)).
1181 """
1182 crossings = self._integer_crossings()
1183 newdom = self.domain.merge(crossings)
1184 funs = []
1185 for fun in self._break(newdom):
1186 mid = fun.support[0] + 0.5 * (fun.support[-1] - fun.support[0])
1187 c = float(np.floor(float(self(mid))))
1188 funs.append(Bndfun.initconst(c, fun.interval))
1189 result = self.__class__(funs)
1190 for bp in result.breakpoints:
1191 result.breakdata[bp] = float(np.floor(float(self(bp))))
1192 return result
1194 def _integer_crossings(self) -> np.ndarray:
1195 """Find where this Chebfun crosses integer values.
1197 This helper method identifies all points in the domain where the
1198 Chebfun value equals an integer, by finding roots of (self - n)
1199 for each integer n in the range of the function.
1201 Returns:
1202 numpy.ndarray: Array of x-values where the function crosses integers.
1203 """
1204 all_values = np.concatenate([fun.values() for fun in self])
1205 lo = int(np.floor(np.min(all_values)))
1206 hi = int(np.ceil(np.max(all_values)))
1207 crossings = []
1208 for n in range(lo, hi + 1):
1209 shifted = self - n
1210 crossings.extend(shifted.roots().tolist())
1211 return np.array(crossings)
1213 @self_empty()
1214 @cast_arg_to_chebfun
1215 def maximum(self, other: Any) -> Any:
1216 """Pointwise maximum of self and another chebfun."""
1217 return self._maximum_minimum(other, operator.ge)
1219 @self_empty()
1220 @cast_arg_to_chebfun
1221 def minimum(self, other: Any) -> Any:
1222 """Pointwise mimimum of self and another chebfun."""
1223 return self._maximum_minimum(other, operator.lt)
1225 def _maximum_minimum(self, other: Chebfun, comparator: Callable[..., bool]) -> Any:
1226 """Method for computing the pointwise maximum/minimum of two Chebfuns.
1228 This internal method implements the algorithm for computing the pointwise
1229 maximum or minimum of two Chebfun objects, based on the provided comparator.
1230 It is used by the maximum() and minimum() methods.
1232 Args:
1233 other (Chebfun): Another Chebfun to compare with this one.
1234 comparator (callable): A function that compares two values and returns
1235 a boolean. For maximum, this is operator.ge (>=), and for minimum,
1236 this is operator.lt (<).
1238 Returns:
1239 Chebfun: A new Chebfun representing the pointwise maximum or minimum.
1240 """
1241 # Handle empty Chebfuns
1242 if self.isempty or other.isempty:
1243 return self.__class__.initempty()
1245 # Find the intersection of domains
1246 try:
1247 # Try to use union if supports match
1248 newdom = self.domain.union(other.domain)
1249 except SupportMismatch:
1250 # If supports don't match, find the intersection
1251 a_min, a_max = self.support
1252 b_min, b_max = other.support
1254 # Calculate intersection
1255 c_min = max(a_min, b_min)
1256 c_max = min(a_max, b_max)
1258 # If there's no intersection, return empty
1259 if c_min >= c_max:
1260 return self.__class__.initempty()
1262 # Restrict both functions to the intersection
1263 self_restricted = self.restrict([c_min, c_max])
1264 other_restricted = other.restrict([c_min, c_max])
1266 # Recursively call with the restricted functions
1267 return self_restricted._maximum_minimum(other_restricted, comparator)
1269 # Continue with the original algorithm
1270 roots = (self - other).roots()
1271 newdom = newdom.merge(roots)
1272 switch = newdom.support.merge(roots)
1274 # Handle the case where switch is empty
1275 if switch.size == 0: # pragma: no cover
1276 return self.__class__.initempty()
1278 keys = 0.5 * ((-1) ** np.arange(switch.size - 1) + 1)
1279 if switch.size > 0 and comparator(other(switch[0]), self(switch[0])):
1280 keys = 1 - keys
1281 funs = np.array([])
1282 for interval, use_self in zip(switch.intervals, keys, strict=False):
1283 subdom = newdom.restrict(interval)
1284 subfun = self.restrict(subdom) if use_self else other.restrict(subdom)
1285 funs = np.append(funs, subfun.funs)
1286 return self.__class__(funs)
1288 # ----------
1289 # plotting
1290 # ----------
1291 def plot(self, ax: Axes | None = None, **kwds: Any) -> Any:
1292 """Plot the Chebfun over its domain.
1294 This method plots the Chebfun over its domain using matplotlib.
1295 For complex-valued Chebfuns, it plots the real part against the imaginary part.
1297 Args:
1298 ax (matplotlib.axes.Axes, optional): The axes on which to plot. If None,
1299 a new axes will be created. Defaults to None.
1300 **kwds: Additional keyword arguments to pass to matplotlib's plot function.
1302 Returns:
1303 matplotlib.axes.Axes: The axes on which the plot was created.
1304 """
1305 return plotfun(self, self.support, ax=ax, **kwds)
1307 def plotcoeffs(self, ax: Axes | None = None, **kwds: Any) -> Axes:
1308 """Plot the coefficients of the Chebfun on a semilogy scale.
1310 This method plots the absolute values of the coefficients for each piece
1311 of the Chebfun on a semilogy scale, which is useful for visualizing the
1312 decay of coefficients in the Chebyshev series.
1314 Args:
1315 ax (matplotlib.axes.Axes, optional): The axes on which to plot. If None,
1316 a new axes will be created. Defaults to None.
1317 **kwds: Additional keyword arguments to pass to matplotlib's semilogy function.
1319 Returns:
1320 matplotlib.axes.Axes: The axes on which the plot was created.
1321 """
1322 ax = ax or plt.gca()
1323 for fun in self:
1324 fun.plotcoeffs(ax=ax, **kwds)
1325 return ax
1328# ---------
1329# ufuncs
1330# ---------
1331def add_ufunc(op: Callable[..., Any]) -> None:
1332 """Add a NumPy universal function method to the Chebfun class.
1334 This function creates a method that applies a NumPy universal function (ufunc)
1335 to each piece of a Chebfun and returns a new Chebfun representing the result.
1337 Args:
1338 op (callable): The NumPy universal function to apply.
1340 Note:
1341 The created method will have the same name as the NumPy function
1342 and will take no arguments other than self.
1343 """
1345 @self_empty()
1346 def method(self: Chebfun) -> Chebfun:
1347 """Apply a NumPy universal function to this Chebfun.
1349 This method applies a NumPy universal function (ufunc) to each piece
1350 of this Chebfun and returns a new Chebfun representing the result.
1352 Args:
1353 self (Chebfun): The Chebfun object to which the function is applied.
1355 Returns:
1356 Chebfun: A new Chebfun representing op(f(x)).
1357 """
1358 return self.__class__([op(fun) for fun in self])
1360 name = op.__name__ # type: ignore[attr-defined]
1361 method.__name__ = name
1362 method.__doc__ = method.__doc__
1363 setattr(Chebfun, name, method)
1366ufuncs = (
1367 np.arccos,
1368 np.arccosh,
1369 np.arcsin,
1370 np.arcsinh,
1371 np.arctan,
1372 np.arctanh,
1373 np.cos,
1374 np.cosh,
1375 np.exp,
1376 np.exp2,
1377 np.expm1,
1378 np.log,
1379 np.log2,
1380 np.log10,
1381 np.log1p,
1382 np.sinh,
1383 np.sin,
1384 np.tan,
1385 np.tanh,
1386 np.sqrt,
1387)
1389for op in ufuncs:
1390 add_ufunc(op)