Coverage for src / chebpy / utilities.py: 100%
154 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-20 05:55 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-20 05:55 +0000
1"""Utility functions and classes for the ChebPy package.
3This module provides various utility functions and classes used throughout the ChebPy
4package, including interval operations, domain manipulations, and tolerance functions.
5It defines the core data structures for representing and manipulating intervals and domains.
6"""
8from collections import OrderedDict
9from collections.abc import Iterable
11import numpy as np
13from .decorators import cast_other
14from .exceptions import (
15 IntervalGap,
16 IntervalOverlap,
17 IntervalValues,
18 InvalidDomain,
19 NotSubdomain,
20 SupportMismatch,
21)
22from .settings import _preferences as prefs
25def htol() -> float:
26 """Return the horizontal tolerance used for interval comparisons.
28 Returns:
29 float: 5 times the machine epsilon from preferences.
30 """
31 return 5 * prefs.eps
34class Interval(np.ndarray):
35 """Utility class to implement Interval logic.
37 The purpose of this class is to both enforce certain properties of domain
38 components such as having exactly two monotonically increasing elements and
39 also to implement the functionality of mapping to and from the unit interval.
41 Attributes:
42 formap: Maps y in [-1,1] to x in [a,b]
43 invmap: Maps x in [a,b] to y in [-1,1]
44 drvmap: Derivative mapping from y in [-1,1] to x in [a,b]
46 Note:
47 Currently only implemented for finite a and b.
48 The __call__ method evaluates self.formap since this is the most
49 frequently used mapping operation.
50 """
52 def __new__(cls, a: float = -1.0, b: float = 1.0) -> "Interval":
53 """Create a new Interval instance.
55 Args:
56 a (float, optional): Left endpoint of the interval. Defaults to -1.0.
57 b (float, optional): Right endpoint of the interval. Defaults to 1.0.
59 Raises:
60 IntervalValues: If a >= b.
62 Returns:
63 Interval: A new Interval instance.
65 Examples:
66 >>> import numpy as np
67 >>> interval = Interval(-1, 1)
68 >>> interval.tolist()
69 [-1.0, 1.0]
70 >>> float(interval.formap(0))
71 0.0
72 """
73 if a >= b:
74 raise IntervalValues
75 return np.asarray((a, b), dtype=float).view(cls)
77 def formap(self, y: float | np.ndarray) -> float | np.ndarray:
78 """Map from the reference interval [-1,1] to this interval [a,b].
80 Args:
81 y (float or numpy.ndarray): Points in the reference interval [-1,1].
83 Returns:
84 float or numpy.ndarray: Corresponding points in the interval [a,b].
85 """
86 a, b = self
87 return 0.5 * b * (y + 1.0) + 0.5 * a * (1.0 - y)
89 def invmap(self, x: float | np.ndarray) -> float | np.ndarray:
90 """Map from this interval [a,b] to the reference interval [-1,1].
92 Args:
93 x (float or numpy.ndarray): Points in the interval [a,b].
95 Returns:
96 float or numpy.ndarray: Corresponding points in the reference interval [-1,1].
97 """
98 a, b = self
99 return (2.0 * x - a - b) / (b - a)
101 def drvmap(self, y: float | np.ndarray) -> float | np.ndarray:
102 """Compute the derivative of the forward map.
104 Args:
105 y (float or numpy.ndarray): Points in the reference interval [-1,1].
107 Returns:
108 float or numpy.ndarray: Derivative values at the corresponding points.
109 """
110 a, b = self # pragma: no cover
111 return 0.0 * y + 0.5 * (b - a) # pragma: no cover
113 def __eq__(self, other: "Interval") -> bool:
114 """Check if two intervals are equal.
116 Args:
117 other (Interval): Another interval to compare with.
119 Returns:
120 bool: True if the intervals have the same endpoints, False otherwise.
121 """
122 (a, b), (x, y) = self, other
123 return (a == x) & (y == b)
125 def __ne__(self, other: "Interval") -> bool:
126 """Check if two intervals are not equal.
128 Args:
129 other (Interval): Another interval to compare with.
131 Returns:
132 bool: True if the intervals have different endpoints, False otherwise.
133 """
134 return not self == other
136 def __call__(self, y: float | np.ndarray) -> float | np.ndarray:
137 """Map points from [-1,1] to this interval (shorthand for formap).
139 Args:
140 y (float or numpy.ndarray): Points in the reference interval [-1,1].
142 Returns:
143 float or numpy.ndarray: Corresponding points in the interval [a,b].
144 """
145 return self.formap(y)
147 def __contains__(self, other: "Interval") -> bool:
148 """Check if another interval is contained within this interval.
150 Args:
151 other (Interval): Another interval to check.
153 Returns:
154 bool: True if other is contained within this interval, False otherwise.
155 """
156 (a, b), (x, y) = self, other
157 return (a <= x) & (y <= b)
159 def isinterior(self, x: float | np.ndarray) -> bool | np.ndarray:
160 """Check if points are strictly in the interior of the interval.
162 Args:
163 x (float or numpy.ndarray): Points to check.
165 Returns:
166 bool or numpy.ndarray: Boolean array indicating which points are in the interior.
167 """
168 a, b = self
169 return np.logical_and(a < x, x < b)
171 @property
172 def hscale(self) -> float:
173 """Calculate the horizontal scale factor of the interval.
175 Returns:
176 float: The horizontal scale factor.
177 """
178 a, b = self
179 h = max(infnorm(self), 1)
180 h_factor = b - a # if interval == domain: scale hscale back to 1
181 hscale = max(h / h_factor, 1) # else: hscale < 1
182 return hscale
185def _merge_duplicates(arr: np.ndarray, tols: np.ndarray) -> np.ndarray:
186 """Remove duplicate entries from an input array within specified tolerances.
188 This function works from left to right, keeping the first occurrence of
189 values that are within tolerance of each other.
191 Args:
192 arr (numpy.ndarray): Input array to remove duplicates from.
193 tols (numpy.ndarray): Array of tolerance values for each pair of adjacent elements.
194 Should have length one less than arr.
196 Returns:
197 numpy.ndarray: Array with duplicates removed.
199 Note:
200 Pathological cases may cause issues since this method works by using
201 consecutive differences. It might be better to take an average (median?),
202 rather than the left-hand value.
203 """
204 idx = np.append(np.abs(np.diff(arr)) > tols[:-1], True)
205 return arr[idx]
208class Domain(np.ndarray):
209 """Numpy ndarray with additional Chebfun-specific domain logic.
211 A Domain represents a collection of breakpoints that define a piecewise domain.
212 It provides methods for manipulating and comparing domains, as well as
213 generating intervals between adjacent breakpoints.
215 Attributes:
216 intervals: Generator yielding Interval objects between adjacent breakpoints.
217 support: First and last breakpoints of the domain.
218 """
220 def __new__(cls, breakpoints):
221 """Create a new Domain instance.
223 Args:
224 breakpoints (array-like): Collection of monotonically increasing breakpoints.
225 Must have at least 2 elements.
227 Raises:
228 InvalidDomain: If breakpoints has fewer than 2 elements or is not monotonically increasing.
230 Returns:
231 Domain: A new Domain instance.
232 """
233 bpts = np.asarray(breakpoints, dtype=float)
234 if bpts.size == 0:
235 return bpts.view(cls)
236 elif bpts.size < 2 or np.any(np.diff(bpts) <= 0):
237 raise InvalidDomain
238 else:
239 return bpts.view(cls)
241 def __contains__(self, other: "Domain") -> bool:
242 """Check whether one domain object is a subdomain of another (within tolerance).
244 Args:
245 other (Domain): Another domain to check.
247 Returns:
248 bool: True if other is contained within this domain (within tolerance), False otherwise.
249 """
250 a, b = self.support
251 x, y = other.support
252 bounds = np.array([1 - htol(), 1 + htol()])
253 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds)
254 return (lbnd <= x) & (y <= rbnd)
256 @classmethod
257 def from_chebfun(cls, chebfun):
258 """Initialize a Domain object from a Chebfun.
260 Args:
261 chebfun: A Chebfun object with breakpoints.
263 Returns:
264 Domain: A new Domain instance with the same breakpoints as the Chebfun.
265 """
266 return cls(chebfun.breakpoints)
268 @property
269 def intervals(self) -> Iterable[Interval]:
270 """Generate Interval objects between adjacent breakpoints.
272 Yields:
273 Interval: Interval objects for each pair of adjacent breakpoints.
274 """
275 for a, b in zip(self[:-1], self[1:]):
276 yield Interval(a, b)
278 @property
279 def support(self) -> Interval:
280 """Get the first and last breakpoints of the domain.
282 Returns:
283 numpy.ndarray: Array containing the first and last breakpoints.
284 """
285 return self[[0, -1]]
287 @cast_other
288 def union(self, other: "Domain") -> "Domain":
289 """Create a union of two domain objects with matching support.
291 Args:
292 other (Domain): Another domain to union with.
294 Raises:
295 SupportMismatch: If the supports of the two domains don't match within tolerance.
297 Returns:
298 Domain: A new Domain containing all breakpoints from both domains.
299 """
300 dspt = np.abs(self.support - other.support)
301 tolerance = np.maximum(htol(), htol() * np.abs(self.support))
302 if np.any(dspt > tolerance):
303 raise SupportMismatch
304 return self.merge(other)
306 def merge(self, other: "Domain") -> "Domain":
307 """Merge two domain objects without checking if they have the same support.
309 Args:
310 other (Domain): Another domain to merge with.
312 Returns:
313 Domain: A new Domain containing all breakpoints from both domains.
314 """
315 all_bpts = np.append(self, other)
316 new_bpts = np.unique(all_bpts)
317 mergetol = np.maximum(htol(), htol() * np.abs(new_bpts))
318 mgd_bpts = _merge_duplicates(new_bpts, mergetol)
319 return self.__class__(mgd_bpts)
321 @cast_other
322 def restrict(self, other: "Domain") -> "Domain":
323 """Truncate self to the support of other, retaining any interior breakpoints.
325 Args:
326 other (Domain): Domain to restrict to.
328 Raises:
329 NotSubdomain: If other is not a subdomain of self.
331 Returns:
332 Domain: A new Domain with breakpoints from self restricted to other's support.
333 """
334 if other not in self:
335 raise NotSubdomain
336 dom = self.merge(other)
337 a, b = other.support
338 bounds = np.array([1 - htol(), 1 + htol()])
339 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds)
340 new = dom[(lbnd <= dom) & (dom <= rbnd)]
341 return self.__class__(new)
343 def breakpoints_in(self, other: "Domain") -> np.ndarray:
344 """Check which breakpoints are in another domain within tolerance.
346 Args:
347 other (Domain): Domain to check against.
349 Returns:
350 numpy.ndarray: Boolean array of size equal to self where True indicates
351 that the breakpoint is in other within the specified tolerance.
352 """
353 out = np.empty(self.size, dtype=bool)
354 window = np.array([1 - htol(), 1 + htol()])
355 # TODO: is there way to vectorise this?
356 for idx, bpt in enumerate(self):
357 lbnd, rbnd = np.sort(bpt * window)
358 lbnd = -htol() if np.abs(lbnd) < htol() else lbnd
359 rbnd = +htol() if np.abs(rbnd) < htol() else rbnd
360 isin = (lbnd <= other) & (other <= rbnd)
361 out[idx] = np.any(isin)
362 return out
364 def __eq__(self, other: "Domain") -> bool:
365 """Test for pointwise equality (within a tolerance) of two Domain objects.
367 Args:
368 other (Domain): Another domain to compare with.
370 Returns:
371 bool: True if domains have the same size and all breakpoints match within tolerance.
372 """
373 if self.size != other.size:
374 return False
375 else:
376 dbpt = np.abs(self - other)
377 tolerance = np.maximum(htol(), htol() * np.abs(self))
378 return bool(np.all(dbpt <= tolerance)) # cast back to bool
380 def __ne__(self, other: "Domain") -> bool:
381 """Test for inequality of two Domain objects.
383 Args:
384 other (Domain): Another domain to compare with.
386 Returns:
387 bool: True if domains differ in size or any breakpoints don't match within tolerance.
388 """
389 return not self == other
392def _sortindex(intervals: list[Interval]) -> np.ndarray:
393 """Return an index determining the ordering of interval objects.
395 This helper function checks that the intervals:
396 1. Do not overlap
397 2. Represent a complete partition of the broader approximation domain
399 Args:
400 intervals (array-like): Array of Interval objects to sort.
402 Returns:
403 numpy.ndarray: Index array for sorting the intervals.
405 Raises:
406 IntervalOverlap: If any intervals overlap.
407 IntervalGap: If there are gaps between intervals.
408 """
409 # sort by the left endpoint Interval values
410 subintervals = np.array([x for x in intervals])
411 leftbreakpts = np.array([s[0] for s in subintervals])
412 idx = leftbreakpts.argsort()
414 # check domain consistency
415 srt = subintervals[idx]
416 x = srt.flatten()[1:-1]
417 d = x[1::2] - x[::2]
418 if (d < 0).any():
419 raise IntervalOverlap
420 if (d > 0).any():
421 raise IntervalGap
423 return idx
426def check_funs(funs: list) -> np.ndarray:
427 """Return an array of sorted funs with validation checks.
429 This function checks that the provided funs do not overlap or have gaps
430 between their intervals. The actual checks are performed in _sortindex.
432 Args:
433 funs (array-like): Array of function objects with interval attributes.
435 Returns:
436 numpy.ndarray: Sorted array of funs.
438 Raises:
439 IntervalOverlap: If any function intervals overlap.
440 IntervalGap: If there are gaps between function intervals.
441 """
442 funs = np.array(funs)
443 if funs.size == 0:
444 sortedfuns = np.array([])
445 else:
446 intervals = (fun.interval for fun in funs)
447 idx = _sortindex(intervals)
448 sortedfuns = funs[idx]
449 return sortedfuns
452def compute_breakdata(funs: np.ndarray) -> OrderedDict:
453 """Define function values at breakpoints by averaging left and right limits.
455 This function computes values at breakpoints by averaging the left and right
456 limits of adjacent functions. It is typically called after check_funs(),
457 which ensures that the domain is fully partitioned and non-overlapping.
459 Args:
460 funs (numpy.ndarray): Array of function objects with support and endvalues attributes.
462 Returns:
463 OrderedDict: Dictionary mapping breakpoints to function values.
464 """
465 if funs.size == 0:
466 return OrderedDict()
467 else:
468 points = np.array([fun.support for fun in funs])
469 values = np.array([fun.endvalues for fun in funs])
470 points = points.flatten()
471 values = values.flatten()
472 xl, xr = points[0], points[-1]
473 yl, yr = values[0], values[-1]
474 xx, yy = points[1:-1], values[1:-1]
475 x = 0.5 * (xx[::2] + xx[1::2])
476 y = 0.5 * (yy[::2] + yy[1::2])
477 xout = np.append(np.append(xl, x), xr)
478 yout = np.append(np.append(yl, y), yr)
479 return OrderedDict(zip(xout, yout))
482def generate_funs(domain: Domain | list | None, bndfun_constructor: callable, kwds: dict = {}) -> list:
483 """Generate a collection of function objects over a domain.
485 This method is used by several of the Chebfun classmethod constructors to
486 generate a collection of function objects over the specified domain.
488 Args:
489 domain (array-like or None): Domain breakpoints. If None, uses default domain from preferences.
490 bndfun_constructor (callable): Constructor function for creating function objects.
491 kwds (dict, optional): Additional keyword arguments to pass to the constructor. Defaults to {}.
493 Returns:
494 list: List of function objects covering the domain.
495 """
496 domain = Domain(domain if domain is not None else prefs.domain)
497 funs = []
498 for interval in domain.intervals:
499 kwds = {**kwds, **{"interval": interval}}
500 funs.append(bndfun_constructor(**kwds))
501 return funs
504def infnorm(vals: np.ndarray) -> float:
505 """Calculate the infinity norm of an array.
507 Args:
508 vals (array-like): Input array.
510 Returns:
511 float: The infinity norm (maximum absolute value) of the input.
512 """
513 return np.linalg.norm(vals, np.inf)
516def coerce_list(x: object) -> list | Iterable:
517 """Convert a non-iterable object to a list containing that object.
519 If the input is already an iterable (except strings), it is returned unchanged.
520 Strings are treated as non-iterables and wrapped in a list.
522 Args:
523 x: Input object to coerce to a list if necessary.
525 Returns:
526 list or iterable: The input wrapped in a list if it was not an iterable,
527 or the original input if it was already an iterable (except strings).
528 """
529 if not isinstance(x, Iterable) or isinstance(x, str): # pragma: no cover
530 x = [x]
531 return x