Coverage for src / chebpy / utilities.py: 100%
171 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:33 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:33 +0000
1"""Utility functions and classes for the ChebPy package.
3This module provides various utility functions and classes used throughout the ChebPy
4package, including interval operations, domain manipulations, and tolerance functions.
5It defines the core data structures for representing and manipulating intervals and domains.
6"""
8import itertools
9from collections import OrderedDict
10from collections.abc import Callable, Iterable
11from typing import Any
13import numpy as np
15from .decorators import cast_other
16from .exceptions import (
17 IntervalGap,
18 IntervalOverlap,
19 IntervalValues,
20 InvalidDomain,
21 NotSubdomain,
22 SupportMismatch,
23)
24from .settings import _preferences as prefs
27def htol() -> float:
28 """Return the horizontal tolerance used for interval comparisons.
30 Returns:
31 float: 5 times the machine epsilon from preferences.
32 """
33 return 5 * prefs.eps # type: ignore[return-value]
36class Interval(np.ndarray):
37 """Utility class to implement Interval logic.
39 The purpose of this class is to both enforce certain properties of domain
40 components such as having exactly two monotonically increasing elements and
41 also to implement the functionality of mapping to and from the unit interval.
43 Attributes:
44 formap: Maps y in [-1,1] to x in [a,b]
45 invmap: Maps x in [a,b] to y in [-1,1]
46 drvmap: Derivative mapping from y in [-1,1] to x in [a,b]
48 Note:
49 Currently only implemented for finite a and b.
50 The __call__ method evaluates self.formap since this is the most
51 frequently used mapping operation.
52 """
54 def __new__(cls, a: float = -1.0, b: float = 1.0) -> "Interval":
55 """Create a new Interval instance.
57 Args:
58 a (float, optional): Left endpoint of the interval. Defaults to -1.0.
59 b (float, optional): Right endpoint of the interval. Defaults to 1.0.
61 Raises:
62 IntervalValues: If a >= b.
64 Returns:
65 Interval: A new Interval instance.
67 Examples:
68 >>> import numpy as np
69 >>> interval = Interval(-1, 1)
70 >>> interval.tolist()
71 [-1.0, 1.0]
72 >>> float(interval.formap(0))
73 0.0
74 """
75 if a >= b:
76 raise IntervalValues
77 return np.asarray((a, b), dtype=float).view(cls) # type: ignore[return-value]
79 def formap(self, y: float | np.ndarray) -> Any:
80 """Map from the reference interval [-1,1] to this interval [a,b].
82 Args:
83 y (float or numpy.ndarray): Points in the reference interval [-1,1].
85 Returns:
86 float or numpy.ndarray: Corresponding points in the interval [a,b].
87 """
88 a, b = self
89 return 0.5 * b * (y + 1.0) + 0.5 * a * (1.0 - y)
91 def invmap(self, x: float | np.ndarray) -> Any:
92 """Map from this interval [a,b] to the reference interval [-1,1].
94 Args:
95 x (float or numpy.ndarray): Points in the interval [a,b].
97 Returns:
98 float or numpy.ndarray: Corresponding points in the reference interval [-1,1].
99 """
100 a, b = self
101 return (2.0 * x - a - b) / (b - a)
103 def drvmap(self, y: float | np.ndarray) -> Any:
104 """Compute the derivative of the forward map.
106 Args:
107 y (float or numpy.ndarray): Points in the reference interval [-1,1].
109 Returns:
110 float or numpy.ndarray: Derivative values at the corresponding points.
111 """
112 a, b = self # pragma: no cover
113 return 0.0 * y + 0.5 * (b - a) # pragma: no cover
115 def __eq__(self, other: object) -> bool:
116 """Check if two intervals are equal.
118 Args:
119 other: Another interval to compare with.
121 Returns:
122 bool: True if the intervals have the same endpoints, False otherwise.
123 """
124 if not isinstance(other, Interval):
125 return NotImplemented
126 (a, b), (x, y) = self, other
127 return bool((a == x) & (y == b))
129 def __ne__(self, other: object) -> bool:
130 """Check if two intervals are not equal.
132 Args:
133 other: Another interval to compare with.
135 Returns:
136 bool: True if the intervals have different endpoints, False otherwise.
137 """
138 if not isinstance(other, Interval):
139 return NotImplemented
140 return not self == other
142 def __call__(self, y: float | np.ndarray) -> Any:
143 """Map points from [-1,1] to this interval (shorthand for formap).
145 Args:
146 y (float or numpy.ndarray): Points in the reference interval [-1,1].
148 Returns:
149 float or numpy.ndarray: Corresponding points in the interval [a,b].
150 """
151 return self.formap(y)
153 def __contains__(self, other: object) -> bool:
154 """Check if another interval is contained within this interval.
156 Args:
157 other (Interval): Another interval to check.
159 Returns:
160 bool: True if other is contained within this interval, False otherwise.
161 """
162 other_interval: Interval = other
163 (a, b), (x, y) = self, other_interval
164 return bool((a <= x) & (y <= b))
166 def isinterior(self, x: float | np.ndarray) -> Any:
167 """Check if points are strictly in the interior of the interval.
169 Args:
170 x (float or numpy.ndarray): Points to check.
172 Returns:
173 bool or numpy.ndarray: Boolean array indicating which points are in the interior.
174 """
175 a, b = self
176 return np.logical_and(a < x, x < b)
178 @property
179 def hscale(self) -> float:
180 """Calculate the horizontal scale factor of the interval.
182 Returns:
183 float: The horizontal scale factor.
184 """
185 a, b = self
186 h = max(infnorm(self), 1)
187 h_factor = b - a # if interval == domain: scale hscale back to 1
188 result = max(h / h_factor, 1) # else: hscale < 1
189 return float(result)
192def _merge_duplicates(arr: np.ndarray, tols: np.ndarray) -> np.ndarray:
193 """Remove duplicate entries from an input array within specified tolerances.
195 This function works from left to right, keeping the first occurrence of
196 values that are within tolerance of each other.
198 Args:
199 arr (numpy.ndarray): Input array to remove duplicates from.
200 tols (numpy.ndarray): Array of tolerance values for each pair of adjacent elements.
201 Should have length one less than arr.
203 Returns:
204 numpy.ndarray: Array with duplicates removed.
206 Note:
207 Pathological cases may cause issues since this method works by using
208 consecutive differences. It might be better to take an average (median?),
209 rather than the left-hand value.
210 """
211 idx = np.append(np.abs(np.diff(arr)) > tols[:-1], True)
212 return np.asarray(arr[idx])
215class Domain(np.ndarray):
216 """Numpy ndarray with additional Chebfun-specific domain logic.
218 A Domain represents a collection of breakpoints that define a piecewise domain.
219 It provides methods for manipulating and comparing domains, as well as
220 generating intervals between adjacent breakpoints.
222 Attributes:
223 intervals: Generator yielding Interval objects between adjacent breakpoints.
224 support: First and last breakpoints of the domain.
225 """
227 def __new__(cls, breakpoints: Any) -> "Domain":
228 """Create a new Domain instance.
230 Args:
231 breakpoints (array-like): Collection of monotonically increasing breakpoints.
232 Must have at least 2 elements.
234 Raises:
235 InvalidDomain: If breakpoints has fewer than 2 elements or is not monotonically increasing.
237 Returns:
238 Domain: A new Domain instance.
239 """
240 bpts = np.asarray(breakpoints, dtype=float)
241 if bpts.size == 0:
242 return bpts.view(cls) # type: ignore[return-value]
243 elif bpts.size < 2 or np.any(np.diff(bpts) <= 0):
244 raise InvalidDomain
245 else:
246 return bpts.view(cls) # type: ignore[return-value]
248 def __contains__(self, other: object) -> bool:
249 """Check whether one domain object is a subdomain of another (within tolerance).
251 Args:
252 other (Domain): Another domain to check.
254 Returns:
255 bool: True if other is contained within this domain (within tolerance), False otherwise.
256 """
257 other_domain: Domain = other
258 a, b = self.support
259 x, y = other_domain.support
260 bounds = np.array([1 - htol(), 1 + htol()])
261 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds)
262 return bool((lbnd <= x) & (y <= rbnd))
264 @classmethod
265 def from_chebfun(cls, chebfun: Any) -> "Domain":
266 """Initialize a Domain object from a Chebfun.
268 Args:
269 chebfun: A Chebfun object with breakpoints.
271 Returns:
272 Domain: A new Domain instance with the same breakpoints as the Chebfun.
273 """
274 return cls(chebfun.breakpoints)
276 @property
277 def intervals(self) -> Iterable[Interval]:
278 """Generate Interval objects between adjacent breakpoints.
280 Yields:
281 Interval: Interval objects for each pair of adjacent breakpoints.
282 """
283 for a, b in itertools.pairwise(self):
284 yield Interval(a, b)
286 @property
287 def support(self) -> np.ndarray:
288 """Get the first and last breakpoints of the domain.
290 Returns:
291 numpy.ndarray: Array containing the first and last breakpoints.
292 """
293 return self[[0, -1]]
295 @cast_other
296 def union(self, other: "Domain") -> "Domain":
297 """Create a union of two domain objects with matching support.
299 Args:
300 other (Domain): Another domain to union with.
302 Raises:
303 SupportMismatch: If the supports of the two domains don't match within tolerance.
305 Returns:
306 Domain: A new Domain containing all breakpoints from both domains.
307 """
308 dspt = np.abs(self.support - other.support)
309 tolerance = np.maximum(htol(), htol() * np.abs(self.support))
310 if np.any(dspt > tolerance):
311 raise SupportMismatch
312 return self.merge(other)
314 def merge(self, other: "Domain") -> "Domain":
315 """Merge two domain objects without checking if they have the same support.
317 Args:
318 other (Domain): Another domain to merge with.
320 Returns:
321 Domain: A new Domain containing all breakpoints from both domains.
322 """
323 all_bpts = np.append(self, other)
324 new_bpts = np.unique(all_bpts)
325 mergetol = np.maximum(htol(), htol() * np.abs(new_bpts))
326 mgd_bpts = _merge_duplicates(new_bpts, mergetol)
327 return self.__class__(mgd_bpts)
329 @cast_other
330 def restrict(self, other: "Domain") -> "Domain":
331 """Truncate self to the support of other, retaining any interior breakpoints.
333 Args:
334 other (Domain): Domain to restrict to.
336 Raises:
337 NotSubdomain: If other is not a subdomain of self.
339 Returns:
340 Domain: A new Domain with breakpoints from self restricted to other's support.
341 """
342 if other not in self:
343 raise NotSubdomain
344 dom = self.merge(other)
345 a, b = other.support
346 bounds = np.array([1 - htol(), 1 + htol()])
347 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds)
348 new = dom[(lbnd <= dom) & (dom <= rbnd)]
349 return self.__class__(new)
351 def breakpoints_in(self, other: "Domain") -> np.ndarray:
352 """Check which breakpoints are in another domain within tolerance.
354 Args:
355 other (Domain): Domain to check against.
357 Returns:
358 numpy.ndarray: Boolean array of size equal to self where True indicates
359 that the breakpoint is in other within the specified tolerance.
360 """
361 out = np.empty(self.size, dtype=bool)
362 window = np.array([1 - htol(), 1 + htol()])
363 # TODO: is there way to vectorise this?
364 for idx, bpt in enumerate(self):
365 lbnd, rbnd = np.sort(bpt * window)
366 lbnd = -htol() if np.abs(lbnd) < htol() else lbnd
367 rbnd = +htol() if np.abs(rbnd) < htol() else rbnd
368 isin = (lbnd <= other) & (other <= rbnd)
369 out[idx] = np.any(isin)
370 return out
372 def __eq__(self, other: object) -> bool:
373 """Test for pointwise equality (within a tolerance) of two Domain objects.
375 Args:
376 other: Another domain to compare with.
378 Returns:
379 bool: True if domains have the same size and all breakpoints match within tolerance.
380 """
381 if not isinstance(other, Domain):
382 # Try to convert array-like objects to Domain for comparison
383 try:
384 other = Domain(other)
385 except Exception:
386 return NotImplemented
387 if self.size != other.size:
388 return False
389 else:
390 dbpt = np.abs(self - other)
391 tolerance = np.maximum(htol(), htol() * np.abs(self))
392 return bool(np.all(dbpt <= tolerance)) # cast back to bool
394 def __ne__(self, other: object) -> bool:
395 """Test for inequality of two Domain objects.
397 Args:
398 other: Another domain to compare with.
400 Returns:
401 bool: True if domains differ in size or any breakpoints don't match within tolerance.
402 """
403 if not isinstance(other, Domain):
404 return NotImplemented
405 return not self == other
408def _sortindex(intervals: Iterable[Interval]) -> np.ndarray:
409 """Return an index determining the ordering of interval objects.
411 This helper function checks that the intervals:
412 1. Do not overlap
413 2. Represent a complete partition of the broader approximation domain
415 Args:
416 intervals (array-like): Array of Interval objects to sort.
418 Returns:
419 numpy.ndarray: Index array for sorting the intervals.
421 Raises:
422 IntervalOverlap: If any intervals overlap.
423 IntervalGap: If there are gaps between intervals.
424 """
425 # sort by the left endpoint Interval values
426 subintervals = np.array(list(intervals))
427 leftbreakpts = np.array([s[0] for s in subintervals])
428 idx = leftbreakpts.argsort()
430 # check domain consistency
431 srt = subintervals[idx]
432 x = srt.flatten()[1:-1]
433 d = x[1::2] - x[::2]
434 if (d < 0).any():
435 raise IntervalOverlap
436 if (d > 0).any():
437 raise IntervalGap
439 return idx
442def check_funs(funs: Any) -> np.ndarray:
443 """Return an array of sorted funs with validation checks.
445 This function checks that the provided funs do not overlap or have gaps
446 between their intervals. The actual checks are performed in _sortindex.
448 Args:
449 funs (array-like): Array of function objects with interval attributes.
451 Returns:
452 numpy.ndarray: Sorted array of funs.
454 Raises:
455 IntervalOverlap: If any function intervals overlap.
456 IntervalGap: If there are gaps between function intervals.
457 """
458 funs = np.array(funs)
459 if funs.size == 0:
460 sortedfuns = np.array([])
461 else:
462 intervals = (fun.interval for fun in funs)
463 idx = _sortindex(intervals)
464 sortedfuns = funs[idx]
465 return sortedfuns
468def compute_breakdata(funs: np.ndarray) -> OrderedDict[float, Any]:
469 """Define function values at breakpoints by averaging left and right limits.
471 This function computes values at breakpoints by averaging the left and right
472 limits of adjacent functions. It is typically called after check_funs(),
473 which ensures that the domain is fully partitioned and non-overlapping.
475 Args:
476 funs (numpy.ndarray): Array of function objects with support and endvalues attributes.
478 Returns:
479 OrderedDict: Dictionary mapping breakpoints to function values.
480 """
481 if funs.size == 0:
482 return OrderedDict()
483 else:
484 points = np.array([fun.support for fun in funs])
485 values = np.array([fun.endvalues for fun in funs])
486 points = points.flatten()
487 values = values.flatten()
488 xl, xr = points[0], points[-1]
489 yl, yr = values[0], values[-1]
490 xx, yy = points[1:-1], values[1:-1]
491 x = 0.5 * (xx[::2] + xx[1::2])
492 y = 0.5 * (yy[::2] + yy[1::2])
493 xout = np.append(np.append(xl, x), xr)
494 yout = np.append(np.append(yl, y), yr)
495 return OrderedDict(zip(xout, yout, strict=False))
498def generate_funs(
499 domain: Domain | list[float] | None, bndfun_constructor: Callable[..., Any], kwds: dict[str, Any] | None = None
500) -> list[Any]:
501 """Generate a collection of function objects over a domain.
503 This method is used by several of the Chebfun classmethod constructors to
504 generate a collection of function objects over the specified domain.
506 Args:
507 domain (array-like or None): Domain breakpoints. If None, uses default domain from preferences.
508 bndfun_constructor (callable): Constructor function for creating function objects.
509 kwds (dict, optional): Additional keyword arguments to pass to the constructor. Defaults to {}.
511 Returns:
512 list: List of function objects covering the domain.
513 """
514 if kwds is None:
515 kwds = {}
516 domain = Domain(domain if domain is not None else prefs.domain)
517 funs = []
518 for interval in domain.intervals:
519 kwds = {**kwds, **{"interval": interval}}
520 funs.append(bndfun_constructor(**kwds))
521 return funs
524def infnorm(vals: np.ndarray) -> float:
525 """Calculate the infinity norm of an array.
527 Args:
528 vals (array-like): Input array.
530 Returns:
531 float: The infinity norm (maximum absolute value) of the input.
532 """
533 return float(np.linalg.norm(vals, np.inf))
536def coerce_list(x: object) -> list[Any] | Iterable[Any]:
537 """Convert a non-iterable object to a list containing that object.
539 If the input is already an iterable (except strings), it is returned unchanged.
540 Strings are treated as non-iterables and wrapped in a list.
542 Args:
543 x: Input object to coerce to a list if necessary.
545 Returns:
546 list or iterable: The input wrapped in a list if it was not an iterable,
547 or the original input if it was already an iterable (except strings).
548 """
549 if not isinstance(x, Iterable) or isinstance(x, str): # pragma: no cover
550 x = [x]
551 return x