Coverage for src / chebpy / classicfun.py: 100%
144 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:33 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:33 +0000
1"""Implementation of the Classicfun class for functions on arbitrary intervals.
3This module provides the Classicfun class, which represents functions on arbitrary intervals
4by mapping them to a standard domain [-1, 1] and using a Onefun representation.
5"""
7from abc import ABC
8from typing import Any
10import matplotlib.pyplot as plt
11import numpy as np
13from .chebtech import Chebtech
14from .decorators import self_empty
15from .exceptions import IntervalMismatch, NotSubinterval
16from .fun import Fun
17from .plotting import plotfun
18from .settings import _preferences as prefs
19from .utilities import Interval
21techdict = {
22 "Chebtech": Chebtech,
23}
26class Classicfun(Fun, ABC):
27 """Abstract base class for functions defined on arbitrary intervals using a mapped representation.
29 This class implements the Fun interface for functions defined on arbitrary intervals
30 by mapping them to a standard domain [-1, 1] and using a Onefun representation
31 (such as Chebtech) on that standard domain.
33 The Classicfun class serves as a base class for specific implementations like Bndfun.
34 It handles the mapping between the arbitrary interval and the standard domain,
35 delegating the actual function representation to the underlying Onefun object.
36 """
38 # --------------------------
39 # alternative constructors
40 # --------------------------
41 @classmethod
42 def initempty(cls) -> "Classicfun":
43 """Initialize an empty function.
45 This constructor creates an empty function representation, which is
46 useful as a placeholder or for special cases. The interval has no
47 relevance to the emptiness status of a Classicfun, so we arbitrarily
48 set it to be the default interval [-1, 1].
50 Returns:
51 Classicfun: A new empty instance.
52 """
53 interval = Interval()
54 onefun = techdict[prefs.tech].initempty(interval=interval)
55 return cls(onefun, interval)
57 @classmethod
58 def initconst(cls, c: Any, interval: Any) -> "Classicfun":
59 """Initialize a constant function.
61 This constructor creates a function that represents a constant value
62 on the specified interval.
64 Args:
65 c: The constant value.
66 interval: The interval on which to define the function.
68 Returns:
69 Classicfun: A new instance representing the constant function f(x) = c.
70 """
71 onefun = techdict[prefs.tech].initconst(c, interval=interval)
72 return cls(onefun, interval)
74 @classmethod
75 def initidentity(cls, interval: Any) -> "Classicfun":
76 """Initialize the identity function f(x) = x.
78 This constructor creates a function that represents f(x) = x
79 on the specified interval.
81 Args:
82 interval: The interval on which to define the identity function.
84 Returns:
85 Classicfun: A new instance representing the identity function.
86 """
87 onefun = techdict[prefs.tech].initvalues(np.asarray(interval), interval=interval)
88 return cls(onefun, interval)
90 @classmethod
91 def initfun_adaptive(cls, f: Any, interval: Any) -> "Classicfun":
92 """Initialize from a callable function using adaptive sampling.
94 This constructor determines the appropriate number of points needed to
95 represent the function to the specified tolerance using an adaptive algorithm.
97 Args:
98 f (callable): The function to be approximated.
99 interval: The interval on which to define the function.
101 Returns:
102 Classicfun: A new instance representing the function f.
103 """
104 onefun = techdict[prefs.tech].initfun(lambda y: f(interval(y)), interval=interval)
105 return cls(onefun, interval)
107 @classmethod
108 def initfun_fixedlen(cls, f: Any, interval: Any, n: int) -> "Classicfun":
109 """Initialize from a callable function using a fixed number of points.
111 This constructor uses a specified number of points to represent the function,
112 rather than determining the number adaptively.
114 Args:
115 f (callable): The function to be approximated.
116 interval: The interval on which to define the function.
117 n (int): The number of points to use.
119 Returns:
120 Classicfun: A new instance representing the function f.
121 """
122 onefun = techdict[prefs.tech].initfun(lambda y: f(interval(y)), n, interval=interval)
123 return cls(onefun, interval)
125 # -------------------
126 # 'private' methods
127 # -------------------
128 def __call__(self, x: Any, how: str = "clenshaw") -> Any:
129 """Evaluate the function at points x.
131 This method evaluates the function at the specified points by mapping them
132 to the standard domain [-1, 1] and evaluating the underlying onefun.
134 Args:
135 x (float or array-like): Points at which to evaluate the function.
136 how (str, optional): Method to use for evaluation. Defaults to "clenshaw".
138 Returns:
139 float or array-like: The value(s) of the function at the specified point(s).
140 Returns a scalar if x is a scalar, otherwise an array of the same size as x.
141 """
142 y = self.interval.invmap(x)
143 return self.onefun(y, how)
145 def __init__(self, onefun: Any, interval: Any) -> None:
146 """Initialize a new Classicfun instance.
148 This method initializes a new function representation on the specified interval
149 using the provided onefun object for the standard domain representation.
151 Args:
152 onefun: The Onefun object representing the function on [-1, 1].
153 interval: The Interval object defining the domain of the function.
154 """
155 self.onefun = onefun
156 self._interval = interval
158 def __repr__(self) -> str: # pragma: no cover
159 """Return a string representation of the function.
161 This method returns a string representation of the function that includes
162 the class name, support interval, and size.
164 Returns:
165 str: A string representation of the function.
166 """
167 out = "{0}([{2}, {3}], {1})".format(self.__class__.__name__, self.size, *self.support)
168 return out
170 # ------------
171 # properties
172 # ------------
173 @property
174 def coeffs(self) -> Any:
175 """Get the coefficients of the function representation.
177 This property returns the coefficients used in the function representation,
178 delegating to the underlying onefun object.
180 Returns:
181 array-like: The coefficients of the function representation.
182 """
183 return self.onefun.coeffs
185 @property
186 def endvalues(self) -> Any:
187 """Get the values of the function at the endpoints of its interval.
189 This property evaluates the function at the endpoints of its interval
190 of definition.
192 Returns:
193 numpy.ndarray: Array containing the function values at the endpoints
194 of the interval [a, b].
195 """
196 return self.__call__(self.support)
198 @property
199 def interval(self) -> Any:
200 """Get the interval on which this function is defined.
202 This property returns the interval object representing the domain
203 of definition for this function.
205 Returns:
206 Interval: The interval on which this function is defined.
207 """
208 return self._interval
210 @property
211 def isconst(self) -> Any:
212 """Check if this function represents a constant.
214 This property determines whether the function is constant (i.e., f(x) = c
215 for some constant c) over its interval of definition, delegating to the
216 underlying onefun object.
218 Returns:
219 bool: True if the function is constant, False otherwise.
220 """
221 return self.onefun.isconst
223 @property
224 def iscomplex(self) -> Any:
225 """Check if this function has complex values.
227 This property determines whether the function has complex values or is
228 purely real-valued, delegating to the underlying onefun object.
230 Returns:
231 bool: True if the function has complex values, False otherwise.
232 """
233 return self.onefun.iscomplex
235 @property
236 def isempty(self) -> Any:
237 """Check if this function is empty.
239 This property determines whether the function is empty, which is a special
240 state used as a placeholder or for special cases, delegating to the
241 underlying onefun object.
243 Returns:
244 bool: True if the function is empty, False otherwise.
245 """
246 return self.onefun.isempty
248 @property
249 def size(self) -> Any:
250 """Get the size of the function representation.
252 This property returns the number of coefficients or other measure of the
253 complexity of the function representation, delegating to the underlying
254 onefun object.
256 Returns:
257 int: The size of the function representation.
258 """
259 return self.onefun.size
261 @property
262 def support(self) -> Any:
263 """Get the support interval of this function.
265 This property returns the interval on which this function is defined,
266 represented as a numpy array with two elements [a, b].
268 Returns:
269 numpy.ndarray: Array containing the endpoints of the interval.
270 """
271 return np.asarray(self.interval)
273 @property
274 def vscale(self) -> Any:
275 """Get the vertical scale of the function.
277 This property returns a measure of the range of function values, typically
278 the maximum absolute value of the function on its interval of definition,
279 delegating to the underlying onefun object.
281 Returns:
282 float: The vertical scale of the function.
283 """
284 return self.onefun.vscale
286 # -----------
287 # utilities
288 # -----------
290 def imag(self) -> "Classicfun":
291 """Get the imaginary part of this function.
293 This method returns a new function representing the imaginary part of this function.
294 If this function is real-valued, returns a zero function.
296 Returns:
297 Classicfun: A new function representing the imaginary part of this function.
298 """
299 if self.iscomplex:
300 return self.__class__(self.onefun.imag(), self.interval)
301 else:
302 return self.initconst(0, interval=self.interval)
304 def real(self) -> "Classicfun":
305 """Get the real part of this function.
307 This method returns a new function representing the real part of this function.
308 If this function is already real-valued, returns this function.
310 Returns:
311 Classicfun: A new function representing the real part of this function.
312 """
313 if self.iscomplex:
314 return self.__class__(self.onefun.real(), self.interval)
315 else:
316 return self
318 def restrict(self, subinterval: Any) -> "Classicfun":
319 """Restrict this function to a subinterval.
321 This method creates a new function that is the restriction of this function
322 to the specified subinterval. The output is formed using a fixed length
323 construction with the same number of degrees of freedom as the original function.
325 Args:
326 subinterval (array-like): The subinterval to which this function should be restricted.
327 Must be contained within the original interval of definition.
329 Returns:
330 Classicfun: A new function representing the restriction of this function to the subinterval.
332 Raises:
333 NotSubinterval: If the subinterval is not contained within the original interval.
334 """
335 if subinterval not in self.interval: # pragma: no cover
336 raise NotSubinterval(self.interval, subinterval)
337 if self.interval == subinterval:
338 return self
339 else:
340 return self.__class__.initfun_fixedlen(self, subinterval, self.size)
342 def translate(self, c: float) -> "Classicfun":
343 """Translate this function by a constant c.
345 This method creates a new function g(x) = f(x-c), which is the original
346 function translated horizontally by c.
348 Args:
349 c (float): The amount by which to translate the function.
351 Returns:
352 Classicfun: A new function representing g(x) = f(x-c).
353 """
354 return self.__class__(self.onefun, self.interval + c)
356 # -------------
357 # rootfinding
358 # -------------
359 def roots(self) -> Any:
360 """Find the roots (zeros) of the function on its interval of definition.
362 This method computes the points where the function equals zero
363 within its interval of definition by finding the roots of the
364 underlying onefun and mapping them to the function's interval.
366 Returns:
367 numpy.ndarray: An array of the roots of the function in its interval of definition,
368 sorted in ascending order.
369 """
370 uroots = self.onefun.roots()
371 return self.interval(uroots)
373 # ----------
374 # calculus
375 # ----------
376 def cumsum(self) -> "Classicfun":
377 """Compute the indefinite integral of the function.
379 This method calculates the indefinite integral (antiderivative) of the function,
380 with the constant of integration chosen so that the indefinite integral
381 evaluates to 0 at the left endpoint of the interval.
383 Returns:
384 Classicfun: A new function representing the indefinite integral of this function.
385 """
386 a, b = self.support
387 onefun = 0.5 * (b - a) * self.onefun.cumsum()
388 return self.__class__(onefun, self.interval)
390 def diff(self) -> "Classicfun":
391 """Compute the derivative of the function.
393 This method calculates the derivative of the function with respect to x,
394 applying the chain rule to account for the mapping between the standard
395 domain [-1, 1] and the function's interval.
397 Returns:
398 Classicfun: A new function representing the derivative of this function.
399 """
400 a, b = self.support
401 onefun = 2.0 / (b - a) * self.onefun.diff()
402 return self.__class__(onefun, self.interval)
404 def sum(self) -> Any:
405 """Compute the definite integral of the function over its interval of definition.
407 This method calculates the definite integral of the function
408 over its interval of definition [a, b], applying the appropriate
409 scaling factor to account for the mapping from [-1, 1].
411 Returns:
412 float or complex: The definite integral of the function over its interval of definition.
413 """
414 a, b = self.support
415 return 0.5 * (b - a) * self.onefun.sum()
417 # ----------
418 # plotting
419 # ----------
420 def plot(self, ax: Any = None, **kwds: Any) -> Any:
421 """Plot the function over its interval of definition.
423 This method plots the function over its interval of definition using matplotlib.
424 For complex-valued functions, it plots the real part against the imaginary part.
426 Args:
427 ax (matplotlib.axes.Axes, optional): The axes on which to plot. If None,
428 a new axes will be created. Defaults to None.
429 **kwds: Additional keyword arguments to pass to matplotlib's plot function.
431 Returns:
432 matplotlib.axes.Axes: The axes on which the plot was created.
433 """
434 return plotfun(self, self.support, ax=ax, **kwds)
437# ----------------------------------------------------------------
438# methods that execute the corresponding onefun method as is
439# ----------------------------------------------------------------
441methods_onefun_other = ("values", "plotcoeffs")
444def add_utility(methodname: str) -> None:
445 """Add a utility method to the Classicfun class.
447 This function creates a method that delegates to the corresponding method
448 of the underlying onefun object and adds it to the Classicfun class.
450 Args:
451 methodname (str): The name of the method to add.
453 Note:
454 The created method will have the same name and signature as the
455 corresponding method in the onefun object.
456 """
458 def method(self: Any, *args: Any, **kwds: Any) -> Any:
459 """Delegate to the corresponding method of the underlying onefun object.
461 This method calls the same-named method on the underlying onefun object
462 and returns its result.
464 Args:
465 self (Classicfun): The Classicfun object.
466 *args: Variable length argument list to pass to the onefun method.
467 **kwds: Arbitrary keyword arguments to pass to the onefun method.
469 Returns:
470 The return value from the corresponding onefun method.
471 """
472 return getattr(self.onefun, methodname)(*args, **kwds)
474 method.__name__ = methodname
475 method.__doc__ = method.__doc__
476 setattr(Classicfun, methodname, method)
479for methodname in methods_onefun_other:
480 if methodname[:4] == "plot" and plt is None: # pragma: no cover
481 continue
482 add_utility(methodname)
485# -----------------------------------------------------------------------
486# unary operators and zero-argument utlity methods returning a onefun
487# -----------------------------------------------------------------------
489methods_onefun_zeroargs = ("__pos__", "__neg__", "copy", "simplify")
492def add_zero_arg_op(methodname: str) -> None:
493 """Add a zero-argument operation method to the Classicfun class.
495 This function creates a method that delegates to the corresponding method
496 of the underlying onefun object and wraps the result in a new Classicfun
497 instance with the same interval.
499 Args:
500 methodname (str): The name of the method to add.
502 Note:
503 The created method will have the same name and signature as the
504 corresponding method in the onefun object, but will return a Classicfun
505 instance instead of an onefun instance.
506 """
508 def method(self: Any, *args: Any, **kwds: Any) -> Any:
509 """Apply a zero-argument operation and return a new Classicfun.
511 This method calls the same-named method on the underlying onefun object
512 and wraps the result in a new Classicfun instance with the same interval.
514 Args:
515 self (Classicfun): The Classicfun object.
516 *args: Variable length argument list to pass to the onefun method.
517 **kwds: Arbitrary keyword arguments to pass to the onefun method.
519 Returns:
520 Classicfun: A new Classicfun instance with the result of the operation.
521 """
522 onefun = getattr(self.onefun, methodname)(*args, **kwds)
523 return self.__class__(onefun, self.interval)
525 method.__name__ = methodname
526 method.__doc__ = method.__doc__
527 setattr(Classicfun, methodname, method)
530for methodname in methods_onefun_zeroargs:
531 add_zero_arg_op(methodname)
533# -----------------------------------------
534# binary operators returning a onefun
535# -----------------------------------------
537# ToDo: change these to operator module methods
538methods_onefun_binary = (
539 "__add__",
540 "__div__",
541 "__mul__",
542 "__pow__",
543 "__radd__",
544 "__rdiv__",
545 "__rmul__",
546 "__rpow__",
547 "__rsub__",
548 "__rtruediv__",
549 "__sub__",
550 "__truediv__",
551)
554def add_binary_op(methodname: str) -> None:
555 """Add a binary operation method to the Classicfun class.
557 This function creates a method that implements a binary operation between
558 two Classicfun objects or between a Classicfun and a scalar. It delegates
559 to the corresponding method of the underlying onefun object and wraps the
560 result in a new Classicfun instance with the same interval.
562 Args:
563 methodname (str): The name of the binary operation method to add.
565 Note:
566 The created method will check that both Classicfun objects have the
567 same interval before performing the operation. If one operand is not
568 a Classicfun, it will be passed directly to the onefun method.
569 """
571 @self_empty()
572 def method(self: Any, f: Any, *args: Any, **kwds: Any) -> Any:
573 """Apply a binary operation and return a new Classicfun.
575 This method implements a binary operation between this Classicfun and
576 another object (either another Classicfun or a scalar). It delegates
577 to the corresponding method of the underlying onefun object and wraps
578 the result in a new Classicfun instance with the same interval.
580 Args:
581 self (Classicfun): The Classicfun object.
582 f (Classicfun or scalar): The second operand of the binary operation.
583 *args: Variable length argument list to pass to the onefun method.
584 **kwds: Arbitrary keyword arguments to pass to the onefun method.
586 Returns:
587 Classicfun: A new Classicfun instance with the result of the operation.
589 Raises:
590 IntervalMismatch: If f is a Classicfun with a different interval.
591 """
592 cls = self.__class__
593 if isinstance(f, cls):
594 # TODO: as in ChebTech, is a decorator apporach here better?
595 if f.isempty:
596 return f.copy()
597 g = f.onefun
598 # raise Exception if intervals are not consistent
599 if self.interval != f.interval: # pragma: no cover
600 raise IntervalMismatch(self.interval, f.interval)
601 else:
602 # let the lower level classes raise any other exceptions
603 g = f
604 onefun = getattr(self.onefun, methodname)(g, *args, **kwds)
605 return cls(onefun, self.interval)
607 method.__name__ = methodname
608 method.__doc__ = method.__doc__
609 setattr(Classicfun, methodname, method)
612for methodname in methods_onefun_binary:
613 add_binary_op(methodname)
615# ---------------------------
616# numpy universal functions
617# ---------------------------
620def add_ufunc(op: Any) -> None:
621 """Add a NumPy universal function method to the Classicfun class.
623 This function creates a method that applies a NumPy universal function (ufunc)
624 to the values of a Classicfun and returns a new Classicfun representing the result.
626 Args:
627 op (callable): The NumPy universal function to apply.
629 Note:
630 The created method will have the same name as the NumPy function
631 and will take no arguments other than self.
632 """
634 @self_empty()
635 def method(self: Any) -> Any:
636 """Apply a NumPy universal function to this function.
638 This method applies a NumPy universal function (ufunc) to the values
639 of this function and returns a new function representing the result.
641 Returns:
642 Classicfun: A new function representing op(f(x)).
643 """
644 return self.__class__.initfun_adaptive(lambda x: op(self(x)), self.interval)
646 name = op.__name__
647 method.__name__ = name
648 method.__doc__ = method.__doc__
649 setattr(Classicfun, name, method)
652ufuncs = (
653 np.absolute,
654 np.arccos,
655 np.arccosh,
656 np.arcsin,
657 np.arcsinh,
658 np.arctan,
659 np.arctanh,
660 np.ceil,
661 np.cos,
662 np.cosh,
663 np.exp,
664 np.exp2,
665 np.expm1,
666 np.floor,
667 np.log,
668 np.log2,
669 np.log10,
670 np.log1p,
671 np.sign,
672 np.sinh,
673 np.sin,
674 np.tan,
675 np.tanh,
676 np.sqrt,
677)
679for op in ufuncs:
680 add_ufunc(op)