Coverage for src / chebpy / utilities.py: 98%

190 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 07:22 +0000

1"""Utility functions and classes for the ChebPy package. 

2 

3This module provides various utility functions and classes used throughout the ChebPy 

4package, including interval operations, domain manipulations, and tolerance functions. 

5It defines the core data structures for representing and manipulating intervals and domains. 

6""" 

7 

8import itertools 

9from collections import OrderedDict 

10from collections.abc import Callable, Iterable 

11from typing import Any, Protocol, runtime_checkable 

12 

13import numpy as np 

14 

15from .decorators import cast_other 

16from .exceptions import ( 

17 IntervalGap, 

18 IntervalOverlap, 

19 IntervalValues, 

20 InvalidDomain, 

21 NotSubdomain, 

22 SupportMismatch, 

23) 

24from .settings import _preferences as prefs 

25 

26 

27def htol() -> float: 

28 """Return the horizontal tolerance used for interval comparisons. 

29 

30 Returns: 

31 float: 5 times the machine epsilon from preferences. 

32 """ 

33 return 5 * prefs.eps # type: ignore[return-value] 

34 

35 

36@runtime_checkable 

37class IntervalMap(Protocol): 

38 """Structural protocol for a bijective map between [-1, 1] and [a, b]. 

39 

40 Any object exposing the three mapping methods below can be used as the 

41 ``_interval`` of a :class:`~chebpy.classicfun.Classicfun`. The canonical 

42 affine implementation is :class:`Interval`; non-affine implementations 

43 (e.g. endpoint-clustering exponential transforms for representing 

44 functions with branch-type endpoint singularities) are introduced by 

45 follow-up plans without disturbing this contract. 

46 

47 Required methods: 

48 formap(y): Map ``y`` from the reference interval [-1, 1] to ``x`` in [a, b]. 

49 invmap(x): Map ``x`` from [a, b] back to ``y`` in [-1, 1]. 

50 drvmap(y): Derivative dx/dy of ``formap`` evaluated at ``y``. 

51 

52 Notes: 

53 - ``formap`` and ``invmap`` must be mutual inverses on their domains. 

54 - ``drvmap(y)`` must be strictly positive on (-1, 1) so the map is 

55 orientation-preserving. It may vanish at ``y = ±1`` for non-affine 

56 maps with endpoint clustering. 

57 - The protocol is purely structural; no runtime registration is 

58 required. ``isinstance(obj, IntervalMap)`` is supported via 

59 ``runtime_checkable`` for opt-in duck-type assertions only. 

60 """ 

61 

62 def formap(self, y: float | np.ndarray) -> Any: 

63 """Map ``y`` from the reference interval [-1, 1] to ``x`` in [a, b].""" 

64 ... 

65 

66 def invmap(self, x: float | np.ndarray) -> Any: 

67 """Map ``x`` from [a, b] back to ``y`` in [-1, 1].""" 

68 ... 

69 

70 def drvmap(self, y: float | np.ndarray) -> Any: 

71 """Return the derivative ``dx/dy`` of :meth:`formap` at ``y``.""" 

72 ... 

73 

74 

75class Interval(np.ndarray): 

76 """Utility class to implement Interval logic. 

77 

78 The purpose of this class is to both enforce certain properties of domain 

79 components such as having exactly two monotonically increasing elements and 

80 also to implement the functionality of mapping to and from the unit interval. 

81 

82 ``Interval`` is the canonical affine implementer of the 

83 :class:`IntervalMap` protocol: ``formap`` is a linear bijection between 

84 [-1, 1] and [a, b] and ``drvmap`` is the constant ``(b - a) / 2``. 

85 

86 Attributes: 

87 formap: Maps y in [-1,1] to x in [a,b] 

88 invmap: Maps x in [a,b] to y in [-1,1] 

89 drvmap: Derivative mapping from y in [-1,1] to x in [a,b] 

90 

91 Note: 

92 Currently only implemented for finite a and b. 

93 The __call__ method evaluates self.formap since this is the most 

94 frequently used mapping operation. 

95 """ 

96 

97 def __new__(cls, a: float = -1.0, b: float = 1.0) -> "Interval": 

98 """Create a new Interval instance. 

99 

100 Args: 

101 a (float, optional): Left endpoint of the interval. Defaults to -1.0. 

102 b (float, optional): Right endpoint of the interval. Defaults to 1.0. 

103 

104 Raises: 

105 IntervalValues: If a >= b. 

106 

107 Returns: 

108 Interval: A new Interval instance. 

109 

110 Examples: 

111 >>> import numpy as np 

112 >>> interval = Interval(-1, 1) 

113 >>> interval.tolist() 

114 [-1.0, 1.0] 

115 >>> float(interval.formap(0)) 

116 0.0 

117 """ 

118 if a >= b: 

119 raise IntervalValues 

120 return np.asarray((a, b), dtype=float).view(cls) # type: ignore[return-value] 

121 

122 def formap(self, y: float | np.ndarray) -> Any: 

123 """Map from the reference interval [-1,1] to this interval [a,b]. 

124 

125 Args: 

126 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

127 

128 Returns: 

129 float or numpy.ndarray: Corresponding points in the interval [a,b]. 

130 """ 

131 a, b = self 

132 return 0.5 * b * (y + 1.0) + 0.5 * a * (1.0 - y) 

133 

134 def invmap(self, x: float | np.ndarray) -> Any: 

135 """Map from this interval [a,b] to the reference interval [-1,1]. 

136 

137 Args: 

138 x (float or numpy.ndarray): Points in the interval [a,b]. 

139 

140 Returns: 

141 float or numpy.ndarray: Corresponding points in the reference interval [-1,1]. 

142 """ 

143 a, b = self 

144 return (2.0 * x - a - b) / (b - a) 

145 

146 def drvmap(self, y: float | np.ndarray) -> Any: 

147 """Compute the derivative of the forward map. 

148 

149 Args: 

150 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

151 

152 Returns: 

153 float or numpy.ndarray: Derivative values at the corresponding points. 

154 """ 

155 a, b = self # pragma: no cover 

156 return 0.0 * y + 0.5 * (b - a) # pragma: no cover 

157 

158 def __eq__(self, other: object) -> bool: 

159 """Check if two intervals are equal. 

160 

161 Args: 

162 other: Another interval to compare with. 

163 

164 Returns: 

165 bool: True if the intervals have the same endpoints, False otherwise. 

166 """ 

167 if not isinstance(other, Interval): 

168 return NotImplemented 

169 (a, b), (x, y) = self, other 

170 return bool((a == x) & (y == b)) 

171 

172 def __ne__(self, other: object) -> bool: 

173 """Check if two intervals are not equal. 

174 

175 Args: 

176 other: Another interval to compare with. 

177 

178 Returns: 

179 bool: True if the intervals have different endpoints, False otherwise. 

180 """ 

181 if not isinstance(other, Interval): 

182 return NotImplemented 

183 return not self == other 

184 

185 def __call__(self, y: float | np.ndarray) -> Any: 

186 """Map points from [-1,1] to this interval (shorthand for formap). 

187 

188 Args: 

189 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

190 

191 Returns: 

192 float or numpy.ndarray: Corresponding points in the interval [a,b]. 

193 """ 

194 return self.formap(y) 

195 

196 def __contains__(self, other: object) -> bool: 

197 """Check if another interval is contained within this interval. 

198 

199 Args: 

200 other (Interval): Another interval to check. 

201 

202 Returns: 

203 bool: True if other is contained within this interval, False otherwise. 

204 """ 

205 other_interval: Interval = other 

206 (a, b), (x, y) = self, other_interval 

207 return bool((a <= x) & (y <= b)) 

208 

209 def isinterior(self, x: float | np.ndarray) -> Any: 

210 """Check if points are strictly in the interior of the interval. 

211 

212 Args: 

213 x (float or numpy.ndarray): Points to check. 

214 

215 Returns: 

216 bool or numpy.ndarray: Boolean array indicating which points are in the interior. 

217 """ 

218 a, b = self 

219 return np.logical_and(a < x, x < b) 

220 

221 @property 

222 def hscale(self) -> float: 

223 """Calculate the horizontal scale factor of the interval. 

224 

225 Returns: 

226 float: The horizontal scale factor. 

227 """ 

228 a, b = self 

229 h = max(infnorm(self), 1) 

230 h_factor = b - a # if interval == domain: scale hscale back to 1 

231 result = max(h / h_factor, 1) # else: hscale < 1 

232 return float(result) 

233 

234 

235def _merge_duplicates(arr: np.ndarray, tols: np.ndarray) -> np.ndarray: 

236 """Remove duplicate entries from an input array within specified tolerances. 

237 

238 This function works from left to right, keeping the first occurrence of 

239 values that are within tolerance of each other. 

240 

241 Args: 

242 arr (numpy.ndarray): Input array to remove duplicates from. 

243 tols (numpy.ndarray): Array of tolerance values for each pair of adjacent elements. 

244 Should have length one less than arr. 

245 

246 Returns: 

247 numpy.ndarray: Array with duplicates removed. 

248 

249 Note: 

250 Pathological cases may cause issues since this method works by using 

251 consecutive differences. It might be better to take an average (median?), 

252 rather than the left-hand value. 

253 """ 

254 idx = np.append(np.abs(np.diff(arr)) > tols[:-1], True) 

255 return np.asarray(arr[idx]) 

256 

257 

258class Domain(np.ndarray): 

259 """Numpy ndarray with additional Chebfun-specific domain logic. 

260 

261 A Domain represents a collection of breakpoints that define a piecewise domain. 

262 It provides methods for manipulating and comparing domains, as well as 

263 generating intervals between adjacent breakpoints. 

264 

265 Attributes: 

266 intervals: Generator yielding Interval objects between adjacent breakpoints. 

267 support: First and last breakpoints of the domain. 

268 """ 

269 

270 def __new__(cls, breakpoints: Any) -> "Domain": 

271 """Create a new Domain instance. 

272 

273 Args: 

274 breakpoints (array-like): Collection of monotonically increasing breakpoints. 

275 Must have at least 2 elements. The outermost breakpoints may be 

276 ``-np.inf`` / ``+np.inf``; interior breakpoints must be finite. 

277 

278 Raises: 

279 InvalidDomain: If breakpoints has fewer than 2 elements, is not 

280 monotonically increasing, or contains non-finite values at 

281 interior positions. 

282 

283 Returns: 

284 Domain: A new Domain instance. 

285 """ 

286 bpts = np.asarray(breakpoints, dtype=float) 

287 if bpts.size == 0: 

288 return bpts.view(cls) # type: ignore[return-value] 

289 if bpts.size < 2 or np.any(np.diff(bpts) <= 0): 

290 raise InvalidDomain 

291 # Interior breakpoints must be finite; only the outermost two may be infinite. 

292 if bpts.size > 2 and not np.all(np.isfinite(bpts[1:-1])): 

293 raise InvalidDomain 

294 # NaN is never permitted anywhere. 

295 if np.any(np.isnan(bpts)): 

296 raise InvalidDomain 

297 return bpts.view(cls) # type: ignore[return-value] 

298 

299 def __contains__(self, other: object) -> bool: 

300 """Check whether one domain object is a subdomain of another (within tolerance). 

301 

302 Args: 

303 other (Domain): Another domain to check. 

304 

305 Returns: 

306 bool: True if other is contained within this domain (within tolerance), False otherwise. 

307 """ 

308 other_domain: Domain = other 

309 a, b = self.support 

310 x, y = other_domain.support 

311 bounds = np.array([1 - htol(), 1 + htol()]) 

312 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds) 

313 return bool((lbnd <= x) & (y <= rbnd)) 

314 

315 @classmethod 

316 def from_chebfun(cls, chebfun: Any) -> "Domain": 

317 """Initialize a Domain object from a Chebfun. 

318 

319 Args: 

320 chebfun: A Chebfun object with breakpoints. 

321 

322 Returns: 

323 Domain: A new Domain instance with the same breakpoints as the Chebfun. 

324 """ 

325 return cls(chebfun.breakpoints) 

326 

327 @property 

328 def intervals(self) -> Iterable[Interval]: 

329 """Generate Interval objects between adjacent breakpoints. 

330 

331 Yields: 

332 Interval: Interval objects for each pair of adjacent breakpoints. 

333 """ 

334 for a, b in itertools.pairwise(self): 

335 yield Interval(a, b) 

336 

337 @property 

338 def support(self) -> np.ndarray: 

339 """Get the first and last breakpoints of the domain. 

340 

341 Returns: 

342 numpy.ndarray: Array containing the first and last breakpoints. 

343 """ 

344 return self[[0, -1]] 

345 

346 @cast_other 

347 def union(self, other: "Domain") -> "Domain": 

348 """Create a union of two domain objects with matching support. 

349 

350 Args: 

351 other (Domain): Another domain to union with. 

352 

353 Raises: 

354 SupportMismatch: If the supports of the two domains don't match within tolerance. 

355 

356 Returns: 

357 Domain: A new Domain containing all breakpoints from both domains. 

358 """ 

359 dspt = np.abs(self.support - other.support) 

360 tolerance = np.maximum(htol(), htol() * np.abs(self.support)) 

361 if np.any(dspt > tolerance): 

362 raise SupportMismatch 

363 return self.merge(other) 

364 

365 def merge(self, other: "Domain") -> "Domain": 

366 """Merge two domain objects without checking if they have the same support. 

367 

368 Args: 

369 other (Domain): Another domain to merge with. 

370 

371 Returns: 

372 Domain: A new Domain containing all breakpoints from both domains. 

373 """ 

374 all_bpts = np.append(self, other) 

375 new_bpts = np.unique(all_bpts) 

376 mergetol = np.maximum(htol(), htol() * np.abs(new_bpts)) 

377 mgd_bpts = _merge_duplicates(new_bpts, mergetol) 

378 return self.__class__(mgd_bpts) 

379 

380 @cast_other 

381 def restrict(self, other: "Domain") -> "Domain": 

382 """Truncate self to the support of other, retaining any interior breakpoints. 

383 

384 Args: 

385 other (Domain): Domain to restrict to. 

386 

387 Raises: 

388 NotSubdomain: If other is not a subdomain of self. 

389 

390 Returns: 

391 Domain: A new Domain with breakpoints from self restricted to other's support. 

392 """ 

393 if other not in self: 

394 raise NotSubdomain 

395 dom = self.merge(other) 

396 a, b = other.support 

397 bounds = np.array([1 - htol(), 1 + htol()]) 

398 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds) 

399 new = dom[(lbnd <= dom) & (dom <= rbnd)] 

400 return self.__class__(new) 

401 

402 def breakpoints_in(self, other: "Domain") -> np.ndarray: 

403 """Check which breakpoints are in another domain within tolerance. 

404 

405 Args: 

406 other (Domain): Domain to check against. 

407 

408 Returns: 

409 numpy.ndarray: Boolean array of size equal to self where True indicates 

410 that the breakpoint is in other within the specified tolerance. 

411 """ 

412 out = np.empty(self.size, dtype=bool) 

413 window = np.array([1 - htol(), 1 + htol()]) 

414 # TODO: is there way to vectorise this? 

415 for idx, bpt in enumerate(self): 

416 lbnd, rbnd = np.sort(bpt * window) 

417 lbnd = -htol() if np.abs(lbnd) < htol() else lbnd 

418 rbnd = +htol() if np.abs(rbnd) < htol() else rbnd 

419 isin = (lbnd <= other) & (other <= rbnd) 

420 out[idx] = np.any(isin) 

421 return out 

422 

423 def __eq__(self, other: object) -> bool: 

424 """Test for pointwise equality (within a tolerance) of two Domain objects. 

425 

426 Args: 

427 other: Another domain to compare with. 

428 

429 Returns: 

430 bool: True if domains have the same size and all breakpoints match within tolerance. 

431 """ 

432 if not isinstance(other, Domain): 

433 # Try to convert array-like objects to Domain for comparison 

434 try: 

435 other = Domain(other) 

436 except Exception: 

437 return NotImplemented 

438 if self.size != other.size: 

439 return False 

440 else: 

441 dbpt = np.abs(self - other) 

442 tolerance = np.maximum(htol(), htol() * np.abs(self)) 

443 return bool(np.all(dbpt <= tolerance)) # cast back to bool 

444 

445 def __ne__(self, other: object) -> bool: 

446 """Test for inequality of two Domain objects. 

447 

448 Args: 

449 other: Another domain to compare with. 

450 

451 Returns: 

452 bool: True if domains differ in size or any breakpoints don't match within tolerance. 

453 """ 

454 if not isinstance(other, Domain): 

455 return NotImplemented 

456 return not self == other 

457 

458 

459def _sortindex(intervals: Iterable[Interval]) -> np.ndarray: 

460 """Return an index determining the ordering of interval objects. 

461 

462 This helper function checks that the intervals: 

463 1. Do not overlap 

464 2. Represent a complete partition of the broader approximation domain 

465 

466 Args: 

467 intervals (array-like): Array of Interval objects to sort. 

468 

469 Returns: 

470 numpy.ndarray: Index array for sorting the intervals. 

471 

472 Raises: 

473 IntervalOverlap: If any intervals overlap. 

474 IntervalGap: If there are gaps between intervals. 

475 """ 

476 # sort by the left endpoint Interval values 

477 subintervals = np.array(list(intervals)) 

478 leftbreakpts = np.array([s[0] for s in subintervals]) 

479 idx = leftbreakpts.argsort() 

480 

481 # check domain consistency 

482 srt = subintervals[idx] 

483 x = srt.flatten()[1:-1] 

484 d = x[1::2] - x[::2] 

485 if (d < 0).any(): 

486 raise IntervalOverlap 

487 if (d > 0).any(): 

488 raise IntervalGap 

489 

490 return idx 

491 

492 

493def check_funs(funs: Any) -> np.ndarray: 

494 """Return an array of sorted funs with validation checks. 

495 

496 This function checks that the provided funs do not overlap or have gaps 

497 between their intervals. The actual checks are performed in _sortindex. 

498 

499 Args: 

500 funs (array-like): Array of function objects with interval attributes. 

501 

502 Returns: 

503 numpy.ndarray: Sorted array of funs. 

504 

505 Raises: 

506 IntervalOverlap: If any function intervals overlap. 

507 IntervalGap: If there are gaps between function intervals. 

508 """ 

509 funs = np.array(funs) 

510 if funs.size == 0: 

511 sortedfuns = np.array([]) 

512 else: 

513 # Use ``support`` (logical interval) rather than ``interval`` (storage) 

514 # so CompactFun pieces are validated against their unbounded user-facing 

515 # endpoints rather than their finite numerical-support storage. 

516 intervals = (fun.support for fun in funs) 

517 idx = _sortindex(intervals) 

518 sortedfuns = funs[idx] 

519 return sortedfuns 

520 

521 

522def compute_breakdata(funs: np.ndarray) -> OrderedDict[float, Any]: 

523 """Define function values at breakpoints by averaging left and right limits. 

524 

525 This function computes values at breakpoints by averaging the left and right 

526 limits of adjacent functions. It is typically called after check_funs(), 

527 which ensures that the domain is fully partitioned and non-overlapping. 

528 

529 Args: 

530 funs (numpy.ndarray): Array of function objects with support and endvalues attributes. 

531 

532 Returns: 

533 OrderedDict: Dictionary mapping breakpoints to function values. 

534 """ 

535 if funs.size == 0: 

536 return OrderedDict() 

537 else: 

538 points = np.array([fun.support for fun in funs]) 

539 values = np.array([fun.endvalues for fun in funs]) 

540 points = points.flatten() 

541 values = values.flatten() 

542 xl, xr = points[0], points[-1] 

543 yl, yr = values[0], values[-1] 

544 xx, yy = points[1:-1], values[1:-1] 

545 x = 0.5 * (xx[::2] + xx[1::2]) 

546 y = 0.5 * (yy[::2] + yy[1::2]) 

547 xout = np.append(np.append(xl, x), xr) 

548 yout = np.append(np.append(yl, y), yr) 

549 return OrderedDict(zip(xout, yout, strict=False)) 

550 

551 

552def generate_funs( 

553 domain: Domain | list[float] | None, bndfun_constructor: Callable[..., Any], kwds: dict[str, Any] | None = None 

554) -> list[Any]: 

555 """Generate a collection of function objects over a domain. 

556 

557 This method is used by several of the Chebfun classmethod constructors to 

558 generate a collection of function objects over the specified domain. For 

559 pieces with finite endpoints the supplied ``bndfun_constructor`` is used; 

560 for pieces with one or both endpoints at ``±inf`` the corresponding 

561 classmethod on :class:`CompactFun` is invoked instead, dispatched by 

562 method name. 

563 

564 Args: 

565 domain (array-like or None): Domain breakpoints. If None, uses default domain from preferences. 

566 The outermost breakpoints may be ``±inf``; interior breakpoints must be finite. 

567 bndfun_constructor (callable): Constructor function for creating function objects on 

568 finite intervals (typically a :class:`Bndfun` classmethod). 

569 kwds (dict, optional): Additional keyword arguments to pass to the constructor. Defaults to {}. 

570 

571 Returns: 

572 list: List of function objects covering the domain. 

573 """ 

574 if kwds is None: 

575 kwds = {} 

576 domain = Domain(domain if domain is not None else prefs.domain) 

577 # Local import avoids a circular dependency with chebpy.compactfun, which 

578 # imports from utilities for Interval / Domain. 

579 from .compactfun import CompactFun 

580 

581 method_name = getattr(bndfun_constructor, "__name__", None) 

582 compact_constructor: Callable[..., Any] | None = ( 

583 getattr(CompactFun, method_name) if method_name is not None and hasattr(CompactFun, method_name) else None 

584 ) 

585 

586 funs = [] 

587 for a, b in itertools.pairwise(domain): 

588 a_f, b_f = float(a), float(b) 

589 if np.isfinite(a_f) and np.isfinite(b_f): 

590 interval = Interval(a_f, b_f) 

591 ctor = bndfun_constructor 

592 else: 

593 if compact_constructor is None: 

594 raise InvalidDomain 

595 interval = (a_f, b_f) # CompactFun classmethods accept (a, b) tuples with ±inf 

596 ctor = compact_constructor 

597 funs.append(ctor(**{**kwds, "interval": interval})) 

598 return funs 

599 

600 

601def infnorm(vals: np.ndarray) -> float: 

602 """Calculate the infinity norm of an array. 

603 

604 Args: 

605 vals (array-like): Input array. 

606 

607 Returns: 

608 float: The infinity norm (maximum absolute value) of the input. 

609 """ 

610 return float(np.linalg.norm(vals, np.inf)) 

611 

612 

613def coerce_list(x: object) -> list[Any] | Iterable[Any]: 

614 """Convert a non-iterable object to a list containing that object. 

615 

616 If the input is already an iterable (except strings), it is returned unchanged. 

617 Strings are treated as non-iterables and wrapped in a list. 

618 

619 Args: 

620 x: Input object to coerce to a list if necessary. 

621 

622 Returns: 

623 list or iterable: The input wrapped in a list if it was not an iterable, 

624 or the original input if it was already an iterable (except strings). 

625 """ 

626 if not isinstance(x, Iterable) or isinstance(x, str): # pragma: no cover 

627 x = [x] 

628 return x