Coverage for src / chebpy / utilities.py: 98%
190 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 07:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 07:22 +0000
1"""Utility functions and classes for the ChebPy package.
3This module provides various utility functions and classes used throughout the ChebPy
4package, including interval operations, domain manipulations, and tolerance functions.
5It defines the core data structures for representing and manipulating intervals and domains.
6"""
8import itertools
9from collections import OrderedDict
10from collections.abc import Callable, Iterable
11from typing import Any, Protocol, runtime_checkable
13import numpy as np
15from .decorators import cast_other
16from .exceptions import (
17 IntervalGap,
18 IntervalOverlap,
19 IntervalValues,
20 InvalidDomain,
21 NotSubdomain,
22 SupportMismatch,
23)
24from .settings import _preferences as prefs
27def htol() -> float:
28 """Return the horizontal tolerance used for interval comparisons.
30 Returns:
31 float: 5 times the machine epsilon from preferences.
32 """
33 return 5 * prefs.eps # type: ignore[return-value]
36@runtime_checkable
37class IntervalMap(Protocol):
38 """Structural protocol for a bijective map between [-1, 1] and [a, b].
40 Any object exposing the three mapping methods below can be used as the
41 ``_interval`` of a :class:`~chebpy.classicfun.Classicfun`. The canonical
42 affine implementation is :class:`Interval`; non-affine implementations
43 (e.g. endpoint-clustering exponential transforms for representing
44 functions with branch-type endpoint singularities) are introduced by
45 follow-up plans without disturbing this contract.
47 Required methods:
48 formap(y): Map ``y`` from the reference interval [-1, 1] to ``x`` in [a, b].
49 invmap(x): Map ``x`` from [a, b] back to ``y`` in [-1, 1].
50 drvmap(y): Derivative dx/dy of ``formap`` evaluated at ``y``.
52 Notes:
53 - ``formap`` and ``invmap`` must be mutual inverses on their domains.
54 - ``drvmap(y)`` must be strictly positive on (-1, 1) so the map is
55 orientation-preserving. It may vanish at ``y = ±1`` for non-affine
56 maps with endpoint clustering.
57 - The protocol is purely structural; no runtime registration is
58 required. ``isinstance(obj, IntervalMap)`` is supported via
59 ``runtime_checkable`` for opt-in duck-type assertions only.
60 """
62 def formap(self, y: float | np.ndarray) -> Any:
63 """Map ``y`` from the reference interval [-1, 1] to ``x`` in [a, b]."""
64 ...
66 def invmap(self, x: float | np.ndarray) -> Any:
67 """Map ``x`` from [a, b] back to ``y`` in [-1, 1]."""
68 ...
70 def drvmap(self, y: float | np.ndarray) -> Any:
71 """Return the derivative ``dx/dy`` of :meth:`formap` at ``y``."""
72 ...
75class Interval(np.ndarray):
76 """Utility class to implement Interval logic.
78 The purpose of this class is to both enforce certain properties of domain
79 components such as having exactly two monotonically increasing elements and
80 also to implement the functionality of mapping to and from the unit interval.
82 ``Interval`` is the canonical affine implementer of the
83 :class:`IntervalMap` protocol: ``formap`` is a linear bijection between
84 [-1, 1] and [a, b] and ``drvmap`` is the constant ``(b - a) / 2``.
86 Attributes:
87 formap: Maps y in [-1,1] to x in [a,b]
88 invmap: Maps x in [a,b] to y in [-1,1]
89 drvmap: Derivative mapping from y in [-1,1] to x in [a,b]
91 Note:
92 Currently only implemented for finite a and b.
93 The __call__ method evaluates self.formap since this is the most
94 frequently used mapping operation.
95 """
97 def __new__(cls, a: float = -1.0, b: float = 1.0) -> "Interval":
98 """Create a new Interval instance.
100 Args:
101 a (float, optional): Left endpoint of the interval. Defaults to -1.0.
102 b (float, optional): Right endpoint of the interval. Defaults to 1.0.
104 Raises:
105 IntervalValues: If a >= b.
107 Returns:
108 Interval: A new Interval instance.
110 Examples:
111 >>> import numpy as np
112 >>> interval = Interval(-1, 1)
113 >>> interval.tolist()
114 [-1.0, 1.0]
115 >>> float(interval.formap(0))
116 0.0
117 """
118 if a >= b:
119 raise IntervalValues
120 return np.asarray((a, b), dtype=float).view(cls) # type: ignore[return-value]
122 def formap(self, y: float | np.ndarray) -> Any:
123 """Map from the reference interval [-1,1] to this interval [a,b].
125 Args:
126 y (float or numpy.ndarray): Points in the reference interval [-1,1].
128 Returns:
129 float or numpy.ndarray: Corresponding points in the interval [a,b].
130 """
131 a, b = self
132 return 0.5 * b * (y + 1.0) + 0.5 * a * (1.0 - y)
134 def invmap(self, x: float | np.ndarray) -> Any:
135 """Map from this interval [a,b] to the reference interval [-1,1].
137 Args:
138 x (float or numpy.ndarray): Points in the interval [a,b].
140 Returns:
141 float or numpy.ndarray: Corresponding points in the reference interval [-1,1].
142 """
143 a, b = self
144 return (2.0 * x - a - b) / (b - a)
146 def drvmap(self, y: float | np.ndarray) -> Any:
147 """Compute the derivative of the forward map.
149 Args:
150 y (float or numpy.ndarray): Points in the reference interval [-1,1].
152 Returns:
153 float or numpy.ndarray: Derivative values at the corresponding points.
154 """
155 a, b = self # pragma: no cover
156 return 0.0 * y + 0.5 * (b - a) # pragma: no cover
158 def __eq__(self, other: object) -> bool:
159 """Check if two intervals are equal.
161 Args:
162 other: Another interval to compare with.
164 Returns:
165 bool: True if the intervals have the same endpoints, False otherwise.
166 """
167 if not isinstance(other, Interval):
168 return NotImplemented
169 (a, b), (x, y) = self, other
170 return bool((a == x) & (y == b))
172 def __ne__(self, other: object) -> bool:
173 """Check if two intervals are not equal.
175 Args:
176 other: Another interval to compare with.
178 Returns:
179 bool: True if the intervals have different endpoints, False otherwise.
180 """
181 if not isinstance(other, Interval):
182 return NotImplemented
183 return not self == other
185 def __call__(self, y: float | np.ndarray) -> Any:
186 """Map points from [-1,1] to this interval (shorthand for formap).
188 Args:
189 y (float or numpy.ndarray): Points in the reference interval [-1,1].
191 Returns:
192 float or numpy.ndarray: Corresponding points in the interval [a,b].
193 """
194 return self.formap(y)
196 def __contains__(self, other: object) -> bool:
197 """Check if another interval is contained within this interval.
199 Args:
200 other (Interval): Another interval to check.
202 Returns:
203 bool: True if other is contained within this interval, False otherwise.
204 """
205 other_interval: Interval = other
206 (a, b), (x, y) = self, other_interval
207 return bool((a <= x) & (y <= b))
209 def isinterior(self, x: float | np.ndarray) -> Any:
210 """Check if points are strictly in the interior of the interval.
212 Args:
213 x (float or numpy.ndarray): Points to check.
215 Returns:
216 bool or numpy.ndarray: Boolean array indicating which points are in the interior.
217 """
218 a, b = self
219 return np.logical_and(a < x, x < b)
221 @property
222 def hscale(self) -> float:
223 """Calculate the horizontal scale factor of the interval.
225 Returns:
226 float: The horizontal scale factor.
227 """
228 a, b = self
229 h = max(infnorm(self), 1)
230 h_factor = b - a # if interval == domain: scale hscale back to 1
231 result = max(h / h_factor, 1) # else: hscale < 1
232 return float(result)
235def _merge_duplicates(arr: np.ndarray, tols: np.ndarray) -> np.ndarray:
236 """Remove duplicate entries from an input array within specified tolerances.
238 This function works from left to right, keeping the first occurrence of
239 values that are within tolerance of each other.
241 Args:
242 arr (numpy.ndarray): Input array to remove duplicates from.
243 tols (numpy.ndarray): Array of tolerance values for each pair of adjacent elements.
244 Should have length one less than arr.
246 Returns:
247 numpy.ndarray: Array with duplicates removed.
249 Note:
250 Pathological cases may cause issues since this method works by using
251 consecutive differences. It might be better to take an average (median?),
252 rather than the left-hand value.
253 """
254 idx = np.append(np.abs(np.diff(arr)) > tols[:-1], True)
255 return np.asarray(arr[idx])
258class Domain(np.ndarray):
259 """Numpy ndarray with additional Chebfun-specific domain logic.
261 A Domain represents a collection of breakpoints that define a piecewise domain.
262 It provides methods for manipulating and comparing domains, as well as
263 generating intervals between adjacent breakpoints.
265 Attributes:
266 intervals: Generator yielding Interval objects between adjacent breakpoints.
267 support: First and last breakpoints of the domain.
268 """
270 def __new__(cls, breakpoints: Any) -> "Domain":
271 """Create a new Domain instance.
273 Args:
274 breakpoints (array-like): Collection of monotonically increasing breakpoints.
275 Must have at least 2 elements. The outermost breakpoints may be
276 ``-np.inf`` / ``+np.inf``; interior breakpoints must be finite.
278 Raises:
279 InvalidDomain: If breakpoints has fewer than 2 elements, is not
280 monotonically increasing, or contains non-finite values at
281 interior positions.
283 Returns:
284 Domain: A new Domain instance.
285 """
286 bpts = np.asarray(breakpoints, dtype=float)
287 if bpts.size == 0:
288 return bpts.view(cls) # type: ignore[return-value]
289 if bpts.size < 2 or np.any(np.diff(bpts) <= 0):
290 raise InvalidDomain
291 # Interior breakpoints must be finite; only the outermost two may be infinite.
292 if bpts.size > 2 and not np.all(np.isfinite(bpts[1:-1])):
293 raise InvalidDomain
294 # NaN is never permitted anywhere.
295 if np.any(np.isnan(bpts)):
296 raise InvalidDomain
297 return bpts.view(cls) # type: ignore[return-value]
299 def __contains__(self, other: object) -> bool:
300 """Check whether one domain object is a subdomain of another (within tolerance).
302 Args:
303 other (Domain): Another domain to check.
305 Returns:
306 bool: True if other is contained within this domain (within tolerance), False otherwise.
307 """
308 other_domain: Domain = other
309 a, b = self.support
310 x, y = other_domain.support
311 bounds = np.array([1 - htol(), 1 + htol()])
312 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds)
313 return bool((lbnd <= x) & (y <= rbnd))
315 @classmethod
316 def from_chebfun(cls, chebfun: Any) -> "Domain":
317 """Initialize a Domain object from a Chebfun.
319 Args:
320 chebfun: A Chebfun object with breakpoints.
322 Returns:
323 Domain: A new Domain instance with the same breakpoints as the Chebfun.
324 """
325 return cls(chebfun.breakpoints)
327 @property
328 def intervals(self) -> Iterable[Interval]:
329 """Generate Interval objects between adjacent breakpoints.
331 Yields:
332 Interval: Interval objects for each pair of adjacent breakpoints.
333 """
334 for a, b in itertools.pairwise(self):
335 yield Interval(a, b)
337 @property
338 def support(self) -> np.ndarray:
339 """Get the first and last breakpoints of the domain.
341 Returns:
342 numpy.ndarray: Array containing the first and last breakpoints.
343 """
344 return self[[0, -1]]
346 @cast_other
347 def union(self, other: "Domain") -> "Domain":
348 """Create a union of two domain objects with matching support.
350 Args:
351 other (Domain): Another domain to union with.
353 Raises:
354 SupportMismatch: If the supports of the two domains don't match within tolerance.
356 Returns:
357 Domain: A new Domain containing all breakpoints from both domains.
358 """
359 dspt = np.abs(self.support - other.support)
360 tolerance = np.maximum(htol(), htol() * np.abs(self.support))
361 if np.any(dspt > tolerance):
362 raise SupportMismatch
363 return self.merge(other)
365 def merge(self, other: "Domain") -> "Domain":
366 """Merge two domain objects without checking if they have the same support.
368 Args:
369 other (Domain): Another domain to merge with.
371 Returns:
372 Domain: A new Domain containing all breakpoints from both domains.
373 """
374 all_bpts = np.append(self, other)
375 new_bpts = np.unique(all_bpts)
376 mergetol = np.maximum(htol(), htol() * np.abs(new_bpts))
377 mgd_bpts = _merge_duplicates(new_bpts, mergetol)
378 return self.__class__(mgd_bpts)
380 @cast_other
381 def restrict(self, other: "Domain") -> "Domain":
382 """Truncate self to the support of other, retaining any interior breakpoints.
384 Args:
385 other (Domain): Domain to restrict to.
387 Raises:
388 NotSubdomain: If other is not a subdomain of self.
390 Returns:
391 Domain: A new Domain with breakpoints from self restricted to other's support.
392 """
393 if other not in self:
394 raise NotSubdomain
395 dom = self.merge(other)
396 a, b = other.support
397 bounds = np.array([1 - htol(), 1 + htol()])
398 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds)
399 new = dom[(lbnd <= dom) & (dom <= rbnd)]
400 return self.__class__(new)
402 def breakpoints_in(self, other: "Domain") -> np.ndarray:
403 """Check which breakpoints are in another domain within tolerance.
405 Args:
406 other (Domain): Domain to check against.
408 Returns:
409 numpy.ndarray: Boolean array of size equal to self where True indicates
410 that the breakpoint is in other within the specified tolerance.
411 """
412 out = np.empty(self.size, dtype=bool)
413 window = np.array([1 - htol(), 1 + htol()])
414 # TODO: is there way to vectorise this?
415 for idx, bpt in enumerate(self):
416 lbnd, rbnd = np.sort(bpt * window)
417 lbnd = -htol() if np.abs(lbnd) < htol() else lbnd
418 rbnd = +htol() if np.abs(rbnd) < htol() else rbnd
419 isin = (lbnd <= other) & (other <= rbnd)
420 out[idx] = np.any(isin)
421 return out
423 def __eq__(self, other: object) -> bool:
424 """Test for pointwise equality (within a tolerance) of two Domain objects.
426 Args:
427 other: Another domain to compare with.
429 Returns:
430 bool: True if domains have the same size and all breakpoints match within tolerance.
431 """
432 if not isinstance(other, Domain):
433 # Try to convert array-like objects to Domain for comparison
434 try:
435 other = Domain(other)
436 except Exception:
437 return NotImplemented
438 if self.size != other.size:
439 return False
440 else:
441 dbpt = np.abs(self - other)
442 tolerance = np.maximum(htol(), htol() * np.abs(self))
443 return bool(np.all(dbpt <= tolerance)) # cast back to bool
445 def __ne__(self, other: object) -> bool:
446 """Test for inequality of two Domain objects.
448 Args:
449 other: Another domain to compare with.
451 Returns:
452 bool: True if domains differ in size or any breakpoints don't match within tolerance.
453 """
454 if not isinstance(other, Domain):
455 return NotImplemented
456 return not self == other
459def _sortindex(intervals: Iterable[Interval]) -> np.ndarray:
460 """Return an index determining the ordering of interval objects.
462 This helper function checks that the intervals:
463 1. Do not overlap
464 2. Represent a complete partition of the broader approximation domain
466 Args:
467 intervals (array-like): Array of Interval objects to sort.
469 Returns:
470 numpy.ndarray: Index array for sorting the intervals.
472 Raises:
473 IntervalOverlap: If any intervals overlap.
474 IntervalGap: If there are gaps between intervals.
475 """
476 # sort by the left endpoint Interval values
477 subintervals = np.array(list(intervals))
478 leftbreakpts = np.array([s[0] for s in subintervals])
479 idx = leftbreakpts.argsort()
481 # check domain consistency
482 srt = subintervals[idx]
483 x = srt.flatten()[1:-1]
484 d = x[1::2] - x[::2]
485 if (d < 0).any():
486 raise IntervalOverlap
487 if (d > 0).any():
488 raise IntervalGap
490 return idx
493def check_funs(funs: Any) -> np.ndarray:
494 """Return an array of sorted funs with validation checks.
496 This function checks that the provided funs do not overlap or have gaps
497 between their intervals. The actual checks are performed in _sortindex.
499 Args:
500 funs (array-like): Array of function objects with interval attributes.
502 Returns:
503 numpy.ndarray: Sorted array of funs.
505 Raises:
506 IntervalOverlap: If any function intervals overlap.
507 IntervalGap: If there are gaps between function intervals.
508 """
509 funs = np.array(funs)
510 if funs.size == 0:
511 sortedfuns = np.array([])
512 else:
513 # Use ``support`` (logical interval) rather than ``interval`` (storage)
514 # so CompactFun pieces are validated against their unbounded user-facing
515 # endpoints rather than their finite numerical-support storage.
516 intervals = (fun.support for fun in funs)
517 idx = _sortindex(intervals)
518 sortedfuns = funs[idx]
519 return sortedfuns
522def compute_breakdata(funs: np.ndarray) -> OrderedDict[float, Any]:
523 """Define function values at breakpoints by averaging left and right limits.
525 This function computes values at breakpoints by averaging the left and right
526 limits of adjacent functions. It is typically called after check_funs(),
527 which ensures that the domain is fully partitioned and non-overlapping.
529 Args:
530 funs (numpy.ndarray): Array of function objects with support and endvalues attributes.
532 Returns:
533 OrderedDict: Dictionary mapping breakpoints to function values.
534 """
535 if funs.size == 0:
536 return OrderedDict()
537 else:
538 points = np.array([fun.support for fun in funs])
539 values = np.array([fun.endvalues for fun in funs])
540 points = points.flatten()
541 values = values.flatten()
542 xl, xr = points[0], points[-1]
543 yl, yr = values[0], values[-1]
544 xx, yy = points[1:-1], values[1:-1]
545 x = 0.5 * (xx[::2] + xx[1::2])
546 y = 0.5 * (yy[::2] + yy[1::2])
547 xout = np.append(np.append(xl, x), xr)
548 yout = np.append(np.append(yl, y), yr)
549 return OrderedDict(zip(xout, yout, strict=False))
552def generate_funs(
553 domain: Domain | list[float] | None, bndfun_constructor: Callable[..., Any], kwds: dict[str, Any] | None = None
554) -> list[Any]:
555 """Generate a collection of function objects over a domain.
557 This method is used by several of the Chebfun classmethod constructors to
558 generate a collection of function objects over the specified domain. For
559 pieces with finite endpoints the supplied ``bndfun_constructor`` is used;
560 for pieces with one or both endpoints at ``±inf`` the corresponding
561 classmethod on :class:`CompactFun` is invoked instead, dispatched by
562 method name.
564 Args:
565 domain (array-like or None): Domain breakpoints. If None, uses default domain from preferences.
566 The outermost breakpoints may be ``±inf``; interior breakpoints must be finite.
567 bndfun_constructor (callable): Constructor function for creating function objects on
568 finite intervals (typically a :class:`Bndfun` classmethod).
569 kwds (dict, optional): Additional keyword arguments to pass to the constructor. Defaults to {}.
571 Returns:
572 list: List of function objects covering the domain.
573 """
574 if kwds is None:
575 kwds = {}
576 domain = Domain(domain if domain is not None else prefs.domain)
577 # Local import avoids a circular dependency with chebpy.compactfun, which
578 # imports from utilities for Interval / Domain.
579 from .compactfun import CompactFun
581 method_name = getattr(bndfun_constructor, "__name__", None)
582 compact_constructor: Callable[..., Any] | None = (
583 getattr(CompactFun, method_name) if method_name is not None and hasattr(CompactFun, method_name) else None
584 )
586 funs = []
587 for a, b in itertools.pairwise(domain):
588 a_f, b_f = float(a), float(b)
589 if np.isfinite(a_f) and np.isfinite(b_f):
590 interval = Interval(a_f, b_f)
591 ctor = bndfun_constructor
592 else:
593 if compact_constructor is None:
594 raise InvalidDomain
595 interval = (a_f, b_f) # CompactFun classmethods accept (a, b) tuples with ±inf
596 ctor = compact_constructor
597 funs.append(ctor(**{**kwds, "interval": interval}))
598 return funs
601def infnorm(vals: np.ndarray) -> float:
602 """Calculate the infinity norm of an array.
604 Args:
605 vals (array-like): Input array.
607 Returns:
608 float: The infinity norm (maximum absolute value) of the input.
609 """
610 return float(np.linalg.norm(vals, np.inf))
613def coerce_list(x: object) -> list[Any] | Iterable[Any]:
614 """Convert a non-iterable object to a list containing that object.
616 If the input is already an iterable (except strings), it is returned unchanged.
617 Strings are treated as non-iterables and wrapped in a list.
619 Args:
620 x: Input object to coerce to a list if necessary.
622 Returns:
623 list or iterable: The input wrapped in a list if it was not an iterable,
624 or the original input if it was already an iterable (except strings).
625 """
626 if not isinstance(x, Iterable) or isinstance(x, str): # pragma: no cover
627 x = [x]
628 return x