Coverage for chebpy/core/utilities.py: 100%
154 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 10:30 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 10:30 +0000
1"""Utility functions and classes for the ChebPy package.
3This module provides various utility functions and classes used throughout the ChebPy
4package, including interval operations, domain manipulations, and tolerance functions.
5It defines the core data structures for representing and manipulating intervals and domains.
6"""
8from collections import OrderedDict
9from collections.abc import Iterable
11import numpy as np
13from .decorators import cast_other
14from .exceptions import (
15 IntervalGap,
16 IntervalOverlap,
17 IntervalValues,
18 InvalidDomain,
19 NotSubdomain,
20 SupportMismatch,
21)
22from .settings import _preferences as prefs
25def htol() -> float:
26 """Return the horizontal tolerance used for interval comparisons.
28 Returns:
29 float: 5 times the machine epsilon from preferences.
30 """
31 return 5 * prefs.eps
34class Interval(np.ndarray):
35 """Utility class to implement Interval logic.
37 The purpose of this class is to both enforce certain properties of domain
38 components such as having exactly two monotonically increasing elements and
39 also to implement the functionality of mapping to and from the unit interval.
41 Attributes:
42 formap: Maps y in [-1,1] to x in [a,b]
43 invmap: Maps x in [a,b] to y in [-1,1]
44 drvmap: Derivative mapping from y in [-1,1] to x in [a,b]
46 Note:
47 Currently only implemented for finite a and b.
48 The __call__ method evaluates self.formap since this is the most
49 frequently used mapping operation.
50 """
52 def __new__(cls, a: float = -1.0, b: float = 1.0) -> "Interval":
53 """Create a new Interval instance.
55 Args:
56 a (float, optional): Left endpoint of the interval. Defaults to -1.0.
57 b (float, optional): Right endpoint of the interval. Defaults to 1.0.
59 Raises:
60 IntervalValues: If a >= b.
62 Returns:
63 Interval: A new Interval instance.
64 """
65 if a >= b:
66 raise IntervalValues
67 return np.asarray((a, b), dtype=float).view(cls)
69 def formap(self, y: float | np.ndarray) -> float | np.ndarray:
70 """Map from the reference interval [-1,1] to this interval [a,b].
72 Args:
73 y (float or numpy.ndarray): Points in the reference interval [-1,1].
75 Returns:
76 float or numpy.ndarray: Corresponding points in the interval [a,b].
77 """
78 a, b = self
79 return 0.5 * b * (y + 1.0) + 0.5 * a * (1.0 - y)
81 def invmap(self, x: float | np.ndarray) -> float | np.ndarray:
82 """Map from this interval [a,b] to the reference interval [-1,1].
84 Args:
85 x (float or numpy.ndarray): Points in the interval [a,b].
87 Returns:
88 float or numpy.ndarray: Corresponding points in the reference interval [-1,1].
89 """
90 a, b = self
91 return (2.0 * x - a - b) / (b - a)
93 def drvmap(self, y: float | np.ndarray) -> float | np.ndarray:
94 """Compute the derivative of the forward map.
96 Args:
97 y (float or numpy.ndarray): Points in the reference interval [-1,1].
99 Returns:
100 float or numpy.ndarray: Derivative values at the corresponding points.
101 """
102 a, b = self # pragma: no cover
103 return 0.0 * y + 0.5 * (b - a) # pragma: no cover
105 def __eq__(self, other: "Interval") -> bool:
106 """Check if two intervals are equal.
108 Args:
109 other (Interval): Another interval to compare with.
111 Returns:
112 bool: True if the intervals have the same endpoints, False otherwise.
113 """
114 (a, b), (x, y) = self, other
115 return (a == x) & (y == b)
117 def __ne__(self, other: "Interval") -> bool:
118 """Check if two intervals are not equal.
120 Args:
121 other (Interval): Another interval to compare with.
123 Returns:
124 bool: True if the intervals have different endpoints, False otherwise.
125 """
126 return not self == other
128 def __call__(self, y: float | np.ndarray) -> float | np.ndarray:
129 """Map points from [-1,1] to this interval (shorthand for formap).
131 Args:
132 y (float or numpy.ndarray): Points in the reference interval [-1,1].
134 Returns:
135 float or numpy.ndarray: Corresponding points in the interval [a,b].
136 """
137 return self.formap(y)
139 def __contains__(self, other: "Interval") -> bool:
140 """Check if another interval is contained within this interval.
142 Args:
143 other (Interval): Another interval to check.
145 Returns:
146 bool: True if other is contained within this interval, False otherwise.
147 """
148 (a, b), (x, y) = self, other
149 return (a <= x) & (y <= b)
151 def isinterior(self, x: float | np.ndarray) -> bool | np.ndarray:
152 """Check if points are strictly in the interior of the interval.
154 Args:
155 x (float or numpy.ndarray): Points to check.
157 Returns:
158 bool or numpy.ndarray: Boolean array indicating which points are in the interior.
159 """
160 a, b = self
161 return np.logical_and(a < x, x < b)
163 @property
164 def hscale(self) -> float:
165 """Calculate the horizontal scale factor of the interval.
167 Returns:
168 float: The horizontal scale factor.
169 """
170 a, b = self
171 h = max(infnorm(self), 1)
172 h_factor = b - a # if interval == domain: scale hscale back to 1
173 hscale = max(h / h_factor, 1) # else: hscale < 1
174 return hscale
177def _merge_duplicates(arr: np.ndarray, tols: np.ndarray) -> np.ndarray:
178 """Remove duplicate entries from an input array within specified tolerances.
180 This function works from left to right, keeping the first occurrence of
181 values that are within tolerance of each other.
183 Args:
184 arr (numpy.ndarray): Input array to remove duplicates from.
185 tols (numpy.ndarray): Array of tolerance values for each pair of adjacent elements.
186 Should have length one less than arr.
188 Returns:
189 numpy.ndarray: Array with duplicates removed.
191 Note:
192 Pathological cases may cause issues since this method works by using
193 consecutive differences. It might be better to take an average (median?),
194 rather than the left-hand value.
195 """
196 idx = np.append(np.abs(np.diff(arr)) > tols[:-1], True)
197 return arr[idx]
200class Domain(np.ndarray):
201 """Numpy ndarray with additional Chebfun-specific domain logic.
203 A Domain represents a collection of breakpoints that define a piecewise domain.
204 It provides methods for manipulating and comparing domains, as well as
205 generating intervals between adjacent breakpoints.
207 Attributes:
208 intervals: Generator yielding Interval objects between adjacent breakpoints.
209 support: First and last breakpoints of the domain.
210 """
212 def __new__(cls, breakpoints):
213 """Create a new Domain instance.
215 Args:
216 breakpoints (array-like): Collection of monotonically increasing breakpoints.
217 Must have at least 2 elements.
219 Raises:
220 InvalidDomain: If breakpoints has fewer than 2 elements or is not monotonically increasing.
222 Returns:
223 Domain: A new Domain instance.
224 """
225 bpts = np.asarray(breakpoints, dtype=float)
226 if bpts.size == 0:
227 return bpts.view(cls)
228 elif bpts.size < 2 or np.any(np.diff(bpts) <= 0):
229 raise InvalidDomain
230 else:
231 return bpts.view(cls)
233 def __contains__(self, other: "Domain") -> bool:
234 """Check whether one domain object is a subdomain of another (within tolerance).
236 Args:
237 other (Domain): Another domain to check.
239 Returns:
240 bool: True if other is contained within this domain (within tolerance), False otherwise.
241 """
242 a, b = self.support
243 x, y = other.support
244 bounds = np.array([1 - htol(), 1 + htol()])
245 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds)
246 return (lbnd <= x) & (y <= rbnd)
248 @classmethod
249 def from_chebfun(cls, chebfun):
250 """Initialize a Domain object from a Chebfun.
252 Args:
253 chebfun: A Chebfun object with breakpoints.
255 Returns:
256 Domain: A new Domain instance with the same breakpoints as the Chebfun.
257 """
258 return cls(chebfun.breakpoints)
260 @property
261 def intervals(self) -> Iterable[Interval]:
262 """Generate Interval objects between adjacent breakpoints.
264 Yields:
265 Interval: Interval objects for each pair of adjacent breakpoints.
266 """
267 for a, b in zip(self[:-1], self[1:]):
268 yield Interval(a, b)
270 @property
271 def support(self) -> Interval:
272 """Get the first and last breakpoints of the domain.
274 Returns:
275 numpy.ndarray: Array containing the first and last breakpoints.
276 """
277 return self[[0, -1]]
279 @cast_other
280 def union(self, other: "Domain") -> "Domain":
281 """Create a union of two domain objects with matching support.
283 Args:
284 other (Domain): Another domain to union with.
286 Raises:
287 SupportMismatch: If the supports of the two domains don't match within tolerance.
289 Returns:
290 Domain: A new Domain containing all breakpoints from both domains.
291 """
292 dspt = np.abs(self.support - other.support)
293 tolerance = np.maximum(htol(), htol() * np.abs(self.support))
294 if np.any(dspt > tolerance):
295 raise SupportMismatch
296 return self.merge(other)
298 def merge(self, other: "Domain") -> "Domain":
299 """Merge two domain objects without checking if they have the same support.
301 Args:
302 other (Domain): Another domain to merge with.
304 Returns:
305 Domain: A new Domain containing all breakpoints from both domains.
306 """
307 all_bpts = np.append(self, other)
308 new_bpts = np.unique(all_bpts)
309 mergetol = np.maximum(htol(), htol() * np.abs(new_bpts))
310 mgd_bpts = _merge_duplicates(new_bpts, mergetol)
311 return self.__class__(mgd_bpts)
313 @cast_other
314 def restrict(self, other: "Domain") -> "Domain":
315 """Truncate self to the support of other, retaining any interior breakpoints.
317 Args:
318 other (Domain): Domain to restrict to.
320 Raises:
321 NotSubdomain: If other is not a subdomain of self.
323 Returns:
324 Domain: A new Domain with breakpoints from self restricted to other's support.
325 """
326 if other not in self:
327 raise NotSubdomain
328 dom = self.merge(other)
329 a, b = other.support
330 bounds = np.array([1 - htol(), 1 + htol()])
331 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds)
332 new = dom[(lbnd <= dom) & (dom <= rbnd)]
333 return self.__class__(new)
335 def breakpoints_in(self, other: "Domain") -> np.ndarray:
336 """Check which breakpoints are in another domain within tolerance.
338 Args:
339 other (Domain): Domain to check against.
341 Returns:
342 numpy.ndarray: Boolean array of size equal to self where True indicates
343 that the breakpoint is in other within the specified tolerance.
344 """
345 out = np.empty(self.size, dtype=bool)
346 window = np.array([1 - htol(), 1 + htol()])
347 # TODO: is there way to vectorise this?
348 for idx, bpt in enumerate(self):
349 lbnd, rbnd = np.sort(bpt * window)
350 lbnd = -htol() if np.abs(lbnd) < htol() else lbnd
351 rbnd = +htol() if np.abs(rbnd) < htol() else rbnd
352 isin = (lbnd <= other) & (other <= rbnd)
353 out[idx] = np.any(isin)
354 return out
356 def __eq__(self, other: "Domain") -> bool:
357 """Test for pointwise equality (within a tolerance) of two Domain objects.
359 Args:
360 other (Domain): Another domain to compare with.
362 Returns:
363 bool: True if domains have the same size and all breakpoints match within tolerance.
364 """
365 if self.size != other.size:
366 return False
367 else:
368 dbpt = np.abs(self - other)
369 tolerance = np.maximum(htol(), htol() * np.abs(self))
370 return bool(np.all(dbpt <= tolerance)) # cast back to bool
372 def __ne__(self, other: "Domain") -> bool:
373 """Test for inequality of two Domain objects.
375 Args:
376 other (Domain): Another domain to compare with.
378 Returns:
379 bool: True if domains differ in size or any breakpoints don't match within tolerance.
380 """
381 return not self == other
384def _sortindex(intervals: list[Interval]) -> np.ndarray:
385 """Return an index determining the ordering of interval objects.
387 This helper function checks that the intervals:
388 1. Do not overlap
389 2. Represent a complete partition of the broader approximation domain
391 Args:
392 intervals (array-like): Array of Interval objects to sort.
394 Returns:
395 numpy.ndarray: Index array for sorting the intervals.
397 Raises:
398 IntervalOverlap: If any intervals overlap.
399 IntervalGap: If there are gaps between intervals.
400 """
401 # sort by the left endpoint Interval values
402 subintervals = np.array([x for x in intervals])
403 leftbreakpts = np.array([s[0] for s in subintervals])
404 idx = leftbreakpts.argsort()
406 # check domain consistency
407 srt = subintervals[idx]
408 x = srt.flatten()[1:-1]
409 d = x[1::2] - x[::2]
410 if (d < 0).any():
411 raise IntervalOverlap
412 if (d > 0).any():
413 raise IntervalGap
415 return idx
418def check_funs(funs: list) -> np.ndarray:
419 """Return an array of sorted funs with validation checks.
421 This function checks that the provided funs do not overlap or have gaps
422 between their intervals. The actual checks are performed in _sortindex.
424 Args:
425 funs (array-like): Array of function objects with interval attributes.
427 Returns:
428 numpy.ndarray: Sorted array of funs.
430 Raises:
431 IntervalOverlap: If any function intervals overlap.
432 IntervalGap: If there are gaps between function intervals.
433 """
434 funs = np.array(funs)
435 if funs.size == 0:
436 sortedfuns = np.array([])
437 else:
438 intervals = (fun.interval for fun in funs)
439 idx = _sortindex(intervals)
440 sortedfuns = funs[idx]
441 return sortedfuns
444def compute_breakdata(funs: np.ndarray) -> OrderedDict:
445 """Define function values at breakpoints by averaging left and right limits.
447 This function computes values at breakpoints by averaging the left and right
448 limits of adjacent functions. It is typically called after check_funs(),
449 which ensures that the domain is fully partitioned and non-overlapping.
451 Args:
452 funs (numpy.ndarray): Array of function objects with support and endvalues attributes.
454 Returns:
455 OrderedDict: Dictionary mapping breakpoints to function values.
456 """
457 if funs.size == 0:
458 return OrderedDict()
459 else:
460 points = np.array([fun.support for fun in funs])
461 values = np.array([fun.endvalues for fun in funs])
462 points = points.flatten()
463 values = values.flatten()
464 xl, xr = points[0], points[-1]
465 yl, yr = values[0], values[-1]
466 xx, yy = points[1:-1], values[1:-1]
467 x = 0.5 * (xx[::2] + xx[1::2])
468 y = 0.5 * (yy[::2] + yy[1::2])
469 xout = np.append(np.append(xl, x), xr)
470 yout = np.append(np.append(yl, y), yr)
471 return OrderedDict(zip(xout, yout))
474def generate_funs(domain: Domain | list | None, bndfun_constructor: callable, kwds: dict = {}) -> list:
475 """Generate a collection of function objects over a domain.
477 This method is used by several of the Chebfun classmethod constructors to
478 generate a collection of function objects over the specified domain.
480 Args:
481 domain (array-like or None): Domain breakpoints. If None, uses default domain from preferences.
482 bndfun_constructor (callable): Constructor function for creating function objects.
483 kwds (dict, optional): Additional keyword arguments to pass to the constructor. Defaults to {}.
485 Returns:
486 list: List of function objects covering the domain.
487 """
488 domain = Domain(domain if domain is not None else prefs.domain)
489 funs = []
490 for interval in domain.intervals:
491 kwds = {**kwds, **{"interval": interval}}
492 funs.append(bndfun_constructor(**kwds))
493 return funs
496def infnorm(vals: np.ndarray) -> float:
497 """Calculate the infinity norm of an array.
499 Args:
500 vals (array-like): Input array.
502 Returns:
503 float: The infinity norm (maximum absolute value) of the input.
504 """
505 return np.linalg.norm(vals, np.inf)
508def coerce_list(x: object) -> list | Iterable:
509 """Convert a non-iterable object to a list containing that object.
511 If the input is already an iterable (except strings), it is returned unchanged.
512 Strings are treated as non-iterables and wrapped in a list.
514 Args:
515 x: Input object to coerce to a list if necessary.
517 Returns:
518 list or iterable: The input wrapped in a list if it was not an iterable,
519 or the original input if it was already an iterable (except strings).
520 """
521 if not isinstance(x, Iterable) or isinstance(x, str): # pragma: no cover
522 x = [x]
523 return x