Coverage for src / chebpy / utilities.py: 100%

171 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-22 21:33 +0000

1"""Utility functions and classes for the ChebPy package. 

2 

3This module provides various utility functions and classes used throughout the ChebPy 

4package, including interval operations, domain manipulations, and tolerance functions. 

5It defines the core data structures for representing and manipulating intervals and domains. 

6""" 

7 

8import itertools 

9from collections import OrderedDict 

10from collections.abc import Callable, Iterable 

11from typing import Any 

12 

13import numpy as np 

14 

15from .decorators import cast_other 

16from .exceptions import ( 

17 IntervalGap, 

18 IntervalOverlap, 

19 IntervalValues, 

20 InvalidDomain, 

21 NotSubdomain, 

22 SupportMismatch, 

23) 

24from .settings import _preferences as prefs 

25 

26 

27def htol() -> float: 

28 """Return the horizontal tolerance used for interval comparisons. 

29 

30 Returns: 

31 float: 5 times the machine epsilon from preferences. 

32 """ 

33 return 5 * prefs.eps # type: ignore[return-value] 

34 

35 

36class Interval(np.ndarray): 

37 """Utility class to implement Interval logic. 

38 

39 The purpose of this class is to both enforce certain properties of domain 

40 components such as having exactly two monotonically increasing elements and 

41 also to implement the functionality of mapping to and from the unit interval. 

42 

43 Attributes: 

44 formap: Maps y in [-1,1] to x in [a,b] 

45 invmap: Maps x in [a,b] to y in [-1,1] 

46 drvmap: Derivative mapping from y in [-1,1] to x in [a,b] 

47 

48 Note: 

49 Currently only implemented for finite a and b. 

50 The __call__ method evaluates self.formap since this is the most 

51 frequently used mapping operation. 

52 """ 

53 

54 def __new__(cls, a: float = -1.0, b: float = 1.0) -> "Interval": 

55 """Create a new Interval instance. 

56 

57 Args: 

58 a (float, optional): Left endpoint of the interval. Defaults to -1.0. 

59 b (float, optional): Right endpoint of the interval. Defaults to 1.0. 

60 

61 Raises: 

62 IntervalValues: If a >= b. 

63 

64 Returns: 

65 Interval: A new Interval instance. 

66 

67 Examples: 

68 >>> import numpy as np 

69 >>> interval = Interval(-1, 1) 

70 >>> interval.tolist() 

71 [-1.0, 1.0] 

72 >>> float(interval.formap(0)) 

73 0.0 

74 """ 

75 if a >= b: 

76 raise IntervalValues 

77 return np.asarray((a, b), dtype=float).view(cls) # type: ignore[return-value] 

78 

79 def formap(self, y: float | np.ndarray) -> Any: 

80 """Map from the reference interval [-1,1] to this interval [a,b]. 

81 

82 Args: 

83 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

84 

85 Returns: 

86 float or numpy.ndarray: Corresponding points in the interval [a,b]. 

87 """ 

88 a, b = self 

89 return 0.5 * b * (y + 1.0) + 0.5 * a * (1.0 - y) 

90 

91 def invmap(self, x: float | np.ndarray) -> Any: 

92 """Map from this interval [a,b] to the reference interval [-1,1]. 

93 

94 Args: 

95 x (float or numpy.ndarray): Points in the interval [a,b]. 

96 

97 Returns: 

98 float or numpy.ndarray: Corresponding points in the reference interval [-1,1]. 

99 """ 

100 a, b = self 

101 return (2.0 * x - a - b) / (b - a) 

102 

103 def drvmap(self, y: float | np.ndarray) -> Any: 

104 """Compute the derivative of the forward map. 

105 

106 Args: 

107 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

108 

109 Returns: 

110 float or numpy.ndarray: Derivative values at the corresponding points. 

111 """ 

112 a, b = self # pragma: no cover 

113 return 0.0 * y + 0.5 * (b - a) # pragma: no cover 

114 

115 def __eq__(self, other: object) -> bool: 

116 """Check if two intervals are equal. 

117 

118 Args: 

119 other: Another interval to compare with. 

120 

121 Returns: 

122 bool: True if the intervals have the same endpoints, False otherwise. 

123 """ 

124 if not isinstance(other, Interval): 

125 return NotImplemented 

126 (a, b), (x, y) = self, other 

127 return bool((a == x) & (y == b)) 

128 

129 def __ne__(self, other: object) -> bool: 

130 """Check if two intervals are not equal. 

131 

132 Args: 

133 other: Another interval to compare with. 

134 

135 Returns: 

136 bool: True if the intervals have different endpoints, False otherwise. 

137 """ 

138 if not isinstance(other, Interval): 

139 return NotImplemented 

140 return not self == other 

141 

142 def __call__(self, y: float | np.ndarray) -> Any: 

143 """Map points from [-1,1] to this interval (shorthand for formap). 

144 

145 Args: 

146 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

147 

148 Returns: 

149 float or numpy.ndarray: Corresponding points in the interval [a,b]. 

150 """ 

151 return self.formap(y) 

152 

153 def __contains__(self, other: object) -> bool: 

154 """Check if another interval is contained within this interval. 

155 

156 Args: 

157 other (Interval): Another interval to check. 

158 

159 Returns: 

160 bool: True if other is contained within this interval, False otherwise. 

161 """ 

162 other_interval: Interval = other 

163 (a, b), (x, y) = self, other_interval 

164 return bool((a <= x) & (y <= b)) 

165 

166 def isinterior(self, x: float | np.ndarray) -> Any: 

167 """Check if points are strictly in the interior of the interval. 

168 

169 Args: 

170 x (float or numpy.ndarray): Points to check. 

171 

172 Returns: 

173 bool or numpy.ndarray: Boolean array indicating which points are in the interior. 

174 """ 

175 a, b = self 

176 return np.logical_and(a < x, x < b) 

177 

178 @property 

179 def hscale(self) -> float: 

180 """Calculate the horizontal scale factor of the interval. 

181 

182 Returns: 

183 float: The horizontal scale factor. 

184 """ 

185 a, b = self 

186 h = max(infnorm(self), 1) 

187 h_factor = b - a # if interval == domain: scale hscale back to 1 

188 result = max(h / h_factor, 1) # else: hscale < 1 

189 return float(result) 

190 

191 

192def _merge_duplicates(arr: np.ndarray, tols: np.ndarray) -> np.ndarray: 

193 """Remove duplicate entries from an input array within specified tolerances. 

194 

195 This function works from left to right, keeping the first occurrence of 

196 values that are within tolerance of each other. 

197 

198 Args: 

199 arr (numpy.ndarray): Input array to remove duplicates from. 

200 tols (numpy.ndarray): Array of tolerance values for each pair of adjacent elements. 

201 Should have length one less than arr. 

202 

203 Returns: 

204 numpy.ndarray: Array with duplicates removed. 

205 

206 Note: 

207 Pathological cases may cause issues since this method works by using 

208 consecutive differences. It might be better to take an average (median?), 

209 rather than the left-hand value. 

210 """ 

211 idx = np.append(np.abs(np.diff(arr)) > tols[:-1], True) 

212 return np.asarray(arr[idx]) 

213 

214 

215class Domain(np.ndarray): 

216 """Numpy ndarray with additional Chebfun-specific domain logic. 

217 

218 A Domain represents a collection of breakpoints that define a piecewise domain. 

219 It provides methods for manipulating and comparing domains, as well as 

220 generating intervals between adjacent breakpoints. 

221 

222 Attributes: 

223 intervals: Generator yielding Interval objects between adjacent breakpoints. 

224 support: First and last breakpoints of the domain. 

225 """ 

226 

227 def __new__(cls, breakpoints: Any) -> "Domain": 

228 """Create a new Domain instance. 

229 

230 Args: 

231 breakpoints (array-like): Collection of monotonically increasing breakpoints. 

232 Must have at least 2 elements. 

233 

234 Raises: 

235 InvalidDomain: If breakpoints has fewer than 2 elements or is not monotonically increasing. 

236 

237 Returns: 

238 Domain: A new Domain instance. 

239 """ 

240 bpts = np.asarray(breakpoints, dtype=float) 

241 if bpts.size == 0: 

242 return bpts.view(cls) # type: ignore[return-value] 

243 elif bpts.size < 2 or np.any(np.diff(bpts) <= 0): 

244 raise InvalidDomain 

245 else: 

246 return bpts.view(cls) # type: ignore[return-value] 

247 

248 def __contains__(self, other: object) -> bool: 

249 """Check whether one domain object is a subdomain of another (within tolerance). 

250 

251 Args: 

252 other (Domain): Another domain to check. 

253 

254 Returns: 

255 bool: True if other is contained within this domain (within tolerance), False otherwise. 

256 """ 

257 other_domain: Domain = other 

258 a, b = self.support 

259 x, y = other_domain.support 

260 bounds = np.array([1 - htol(), 1 + htol()]) 

261 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds) 

262 return bool((lbnd <= x) & (y <= rbnd)) 

263 

264 @classmethod 

265 def from_chebfun(cls, chebfun: Any) -> "Domain": 

266 """Initialize a Domain object from a Chebfun. 

267 

268 Args: 

269 chebfun: A Chebfun object with breakpoints. 

270 

271 Returns: 

272 Domain: A new Domain instance with the same breakpoints as the Chebfun. 

273 """ 

274 return cls(chebfun.breakpoints) 

275 

276 @property 

277 def intervals(self) -> Iterable[Interval]: 

278 """Generate Interval objects between adjacent breakpoints. 

279 

280 Yields: 

281 Interval: Interval objects for each pair of adjacent breakpoints. 

282 """ 

283 for a, b in itertools.pairwise(self): 

284 yield Interval(a, b) 

285 

286 @property 

287 def support(self) -> np.ndarray: 

288 """Get the first and last breakpoints of the domain. 

289 

290 Returns: 

291 numpy.ndarray: Array containing the first and last breakpoints. 

292 """ 

293 return self[[0, -1]] 

294 

295 @cast_other 

296 def union(self, other: "Domain") -> "Domain": 

297 """Create a union of two domain objects with matching support. 

298 

299 Args: 

300 other (Domain): Another domain to union with. 

301 

302 Raises: 

303 SupportMismatch: If the supports of the two domains don't match within tolerance. 

304 

305 Returns: 

306 Domain: A new Domain containing all breakpoints from both domains. 

307 """ 

308 dspt = np.abs(self.support - other.support) 

309 tolerance = np.maximum(htol(), htol() * np.abs(self.support)) 

310 if np.any(dspt > tolerance): 

311 raise SupportMismatch 

312 return self.merge(other) 

313 

314 def merge(self, other: "Domain") -> "Domain": 

315 """Merge two domain objects without checking if they have the same support. 

316 

317 Args: 

318 other (Domain): Another domain to merge with. 

319 

320 Returns: 

321 Domain: A new Domain containing all breakpoints from both domains. 

322 """ 

323 all_bpts = np.append(self, other) 

324 new_bpts = np.unique(all_bpts) 

325 mergetol = np.maximum(htol(), htol() * np.abs(new_bpts)) 

326 mgd_bpts = _merge_duplicates(new_bpts, mergetol) 

327 return self.__class__(mgd_bpts) 

328 

329 @cast_other 

330 def restrict(self, other: "Domain") -> "Domain": 

331 """Truncate self to the support of other, retaining any interior breakpoints. 

332 

333 Args: 

334 other (Domain): Domain to restrict to. 

335 

336 Raises: 

337 NotSubdomain: If other is not a subdomain of self. 

338 

339 Returns: 

340 Domain: A new Domain with breakpoints from self restricted to other's support. 

341 """ 

342 if other not in self: 

343 raise NotSubdomain 

344 dom = self.merge(other) 

345 a, b = other.support 

346 bounds = np.array([1 - htol(), 1 + htol()]) 

347 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds) 

348 new = dom[(lbnd <= dom) & (dom <= rbnd)] 

349 return self.__class__(new) 

350 

351 def breakpoints_in(self, other: "Domain") -> np.ndarray: 

352 """Check which breakpoints are in another domain within tolerance. 

353 

354 Args: 

355 other (Domain): Domain to check against. 

356 

357 Returns: 

358 numpy.ndarray: Boolean array of size equal to self where True indicates 

359 that the breakpoint is in other within the specified tolerance. 

360 """ 

361 out = np.empty(self.size, dtype=bool) 

362 window = np.array([1 - htol(), 1 + htol()]) 

363 # TODO: is there way to vectorise this? 

364 for idx, bpt in enumerate(self): 

365 lbnd, rbnd = np.sort(bpt * window) 

366 lbnd = -htol() if np.abs(lbnd) < htol() else lbnd 

367 rbnd = +htol() if np.abs(rbnd) < htol() else rbnd 

368 isin = (lbnd <= other) & (other <= rbnd) 

369 out[idx] = np.any(isin) 

370 return out 

371 

372 def __eq__(self, other: object) -> bool: 

373 """Test for pointwise equality (within a tolerance) of two Domain objects. 

374 

375 Args: 

376 other: Another domain to compare with. 

377 

378 Returns: 

379 bool: True if domains have the same size and all breakpoints match within tolerance. 

380 """ 

381 if not isinstance(other, Domain): 

382 # Try to convert array-like objects to Domain for comparison 

383 try: 

384 other = Domain(other) 

385 except Exception: 

386 return NotImplemented 

387 if self.size != other.size: 

388 return False 

389 else: 

390 dbpt = np.abs(self - other) 

391 tolerance = np.maximum(htol(), htol() * np.abs(self)) 

392 return bool(np.all(dbpt <= tolerance)) # cast back to bool 

393 

394 def __ne__(self, other: object) -> bool: 

395 """Test for inequality of two Domain objects. 

396 

397 Args: 

398 other: Another domain to compare with. 

399 

400 Returns: 

401 bool: True if domains differ in size or any breakpoints don't match within tolerance. 

402 """ 

403 if not isinstance(other, Domain): 

404 return NotImplemented 

405 return not self == other 

406 

407 

408def _sortindex(intervals: Iterable[Interval]) -> np.ndarray: 

409 """Return an index determining the ordering of interval objects. 

410 

411 This helper function checks that the intervals: 

412 1. Do not overlap 

413 2. Represent a complete partition of the broader approximation domain 

414 

415 Args: 

416 intervals (array-like): Array of Interval objects to sort. 

417 

418 Returns: 

419 numpy.ndarray: Index array for sorting the intervals. 

420 

421 Raises: 

422 IntervalOverlap: If any intervals overlap. 

423 IntervalGap: If there are gaps between intervals. 

424 """ 

425 # sort by the left endpoint Interval values 

426 subintervals = np.array(list(intervals)) 

427 leftbreakpts = np.array([s[0] for s in subintervals]) 

428 idx = leftbreakpts.argsort() 

429 

430 # check domain consistency 

431 srt = subintervals[idx] 

432 x = srt.flatten()[1:-1] 

433 d = x[1::2] - x[::2] 

434 if (d < 0).any(): 

435 raise IntervalOverlap 

436 if (d > 0).any(): 

437 raise IntervalGap 

438 

439 return idx 

440 

441 

442def check_funs(funs: Any) -> np.ndarray: 

443 """Return an array of sorted funs with validation checks. 

444 

445 This function checks that the provided funs do not overlap or have gaps 

446 between their intervals. The actual checks are performed in _sortindex. 

447 

448 Args: 

449 funs (array-like): Array of function objects with interval attributes. 

450 

451 Returns: 

452 numpy.ndarray: Sorted array of funs. 

453 

454 Raises: 

455 IntervalOverlap: If any function intervals overlap. 

456 IntervalGap: If there are gaps between function intervals. 

457 """ 

458 funs = np.array(funs) 

459 if funs.size == 0: 

460 sortedfuns = np.array([]) 

461 else: 

462 intervals = (fun.interval for fun in funs) 

463 idx = _sortindex(intervals) 

464 sortedfuns = funs[idx] 

465 return sortedfuns 

466 

467 

468def compute_breakdata(funs: np.ndarray) -> OrderedDict[float, Any]: 

469 """Define function values at breakpoints by averaging left and right limits. 

470 

471 This function computes values at breakpoints by averaging the left and right 

472 limits of adjacent functions. It is typically called after check_funs(), 

473 which ensures that the domain is fully partitioned and non-overlapping. 

474 

475 Args: 

476 funs (numpy.ndarray): Array of function objects with support and endvalues attributes. 

477 

478 Returns: 

479 OrderedDict: Dictionary mapping breakpoints to function values. 

480 """ 

481 if funs.size == 0: 

482 return OrderedDict() 

483 else: 

484 points = np.array([fun.support for fun in funs]) 

485 values = np.array([fun.endvalues for fun in funs]) 

486 points = points.flatten() 

487 values = values.flatten() 

488 xl, xr = points[0], points[-1] 

489 yl, yr = values[0], values[-1] 

490 xx, yy = points[1:-1], values[1:-1] 

491 x = 0.5 * (xx[::2] + xx[1::2]) 

492 y = 0.5 * (yy[::2] + yy[1::2]) 

493 xout = np.append(np.append(xl, x), xr) 

494 yout = np.append(np.append(yl, y), yr) 

495 return OrderedDict(zip(xout, yout, strict=False)) 

496 

497 

498def generate_funs( 

499 domain: Domain | list[float] | None, bndfun_constructor: Callable[..., Any], kwds: dict[str, Any] | None = None 

500) -> list[Any]: 

501 """Generate a collection of function objects over a domain. 

502 

503 This method is used by several of the Chebfun classmethod constructors to 

504 generate a collection of function objects over the specified domain. 

505 

506 Args: 

507 domain (array-like or None): Domain breakpoints. If None, uses default domain from preferences. 

508 bndfun_constructor (callable): Constructor function for creating function objects. 

509 kwds (dict, optional): Additional keyword arguments to pass to the constructor. Defaults to {}. 

510 

511 Returns: 

512 list: List of function objects covering the domain. 

513 """ 

514 if kwds is None: 

515 kwds = {} 

516 domain = Domain(domain if domain is not None else prefs.domain) 

517 funs = [] 

518 for interval in domain.intervals: 

519 kwds = {**kwds, **{"interval": interval}} 

520 funs.append(bndfun_constructor(**kwds)) 

521 return funs 

522 

523 

524def infnorm(vals: np.ndarray) -> float: 

525 """Calculate the infinity norm of an array. 

526 

527 Args: 

528 vals (array-like): Input array. 

529 

530 Returns: 

531 float: The infinity norm (maximum absolute value) of the input. 

532 """ 

533 return float(np.linalg.norm(vals, np.inf)) 

534 

535 

536def coerce_list(x: object) -> list[Any] | Iterable[Any]: 

537 """Convert a non-iterable object to a list containing that object. 

538 

539 If the input is already an iterable (except strings), it is returned unchanged. 

540 Strings are treated as non-iterables and wrapped in a list. 

541 

542 Args: 

543 x: Input object to coerce to a list if necessary. 

544 

545 Returns: 

546 list or iterable: The input wrapped in a list if it was not an iterable, 

547 or the original input if it was already an iterable (except strings). 

548 """ 

549 if not isinstance(x, Iterable) or isinstance(x, str): # pragma: no cover 

550 x = [x] 

551 return x