Coverage for src / chebpy / utilities.py: 100%

154 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-20 05:55 +0000

1"""Utility functions and classes for the ChebPy package. 

2 

3This module provides various utility functions and classes used throughout the ChebPy 

4package, including interval operations, domain manipulations, and tolerance functions. 

5It defines the core data structures for representing and manipulating intervals and domains. 

6""" 

7 

8from collections import OrderedDict 

9from collections.abc import Iterable 

10 

11import numpy as np 

12 

13from .decorators import cast_other 

14from .exceptions import ( 

15 IntervalGap, 

16 IntervalOverlap, 

17 IntervalValues, 

18 InvalidDomain, 

19 NotSubdomain, 

20 SupportMismatch, 

21) 

22from .settings import _preferences as prefs 

23 

24 

25def htol() -> float: 

26 """Return the horizontal tolerance used for interval comparisons. 

27 

28 Returns: 

29 float: 5 times the machine epsilon from preferences. 

30 """ 

31 return 5 * prefs.eps 

32 

33 

34class Interval(np.ndarray): 

35 """Utility class to implement Interval logic. 

36 

37 The purpose of this class is to both enforce certain properties of domain 

38 components such as having exactly two monotonically increasing elements and 

39 also to implement the functionality of mapping to and from the unit interval. 

40 

41 Attributes: 

42 formap: Maps y in [-1,1] to x in [a,b] 

43 invmap: Maps x in [a,b] to y in [-1,1] 

44 drvmap: Derivative mapping from y in [-1,1] to x in [a,b] 

45 

46 Note: 

47 Currently only implemented for finite a and b. 

48 The __call__ method evaluates self.formap since this is the most 

49 frequently used mapping operation. 

50 """ 

51 

52 def __new__(cls, a: float = -1.0, b: float = 1.0) -> "Interval": 

53 """Create a new Interval instance. 

54 

55 Args: 

56 a (float, optional): Left endpoint of the interval. Defaults to -1.0. 

57 b (float, optional): Right endpoint of the interval. Defaults to 1.0. 

58 

59 Raises: 

60 IntervalValues: If a >= b. 

61 

62 Returns: 

63 Interval: A new Interval instance. 

64 

65 Examples: 

66 >>> import numpy as np 

67 >>> interval = Interval(-1, 1) 

68 >>> interval.tolist() 

69 [-1.0, 1.0] 

70 >>> float(interval.formap(0)) 

71 0.0 

72 """ 

73 if a >= b: 

74 raise IntervalValues 

75 return np.asarray((a, b), dtype=float).view(cls) 

76 

77 def formap(self, y: float | np.ndarray) -> float | np.ndarray: 

78 """Map from the reference interval [-1,1] to this interval [a,b]. 

79 

80 Args: 

81 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

82 

83 Returns: 

84 float or numpy.ndarray: Corresponding points in the interval [a,b]. 

85 """ 

86 a, b = self 

87 return 0.5 * b * (y + 1.0) + 0.5 * a * (1.0 - y) 

88 

89 def invmap(self, x: float | np.ndarray) -> float | np.ndarray: 

90 """Map from this interval [a,b] to the reference interval [-1,1]. 

91 

92 Args: 

93 x (float or numpy.ndarray): Points in the interval [a,b]. 

94 

95 Returns: 

96 float or numpy.ndarray: Corresponding points in the reference interval [-1,1]. 

97 """ 

98 a, b = self 

99 return (2.0 * x - a - b) / (b - a) 

100 

101 def drvmap(self, y: float | np.ndarray) -> float | np.ndarray: 

102 """Compute the derivative of the forward map. 

103 

104 Args: 

105 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

106 

107 Returns: 

108 float or numpy.ndarray: Derivative values at the corresponding points. 

109 """ 

110 a, b = self # pragma: no cover 

111 return 0.0 * y + 0.5 * (b - a) # pragma: no cover 

112 

113 def __eq__(self, other: "Interval") -> bool: 

114 """Check if two intervals are equal. 

115 

116 Args: 

117 other (Interval): Another interval to compare with. 

118 

119 Returns: 

120 bool: True if the intervals have the same endpoints, False otherwise. 

121 """ 

122 (a, b), (x, y) = self, other 

123 return (a == x) & (y == b) 

124 

125 def __ne__(self, other: "Interval") -> bool: 

126 """Check if two intervals are not equal. 

127 

128 Args: 

129 other (Interval): Another interval to compare with. 

130 

131 Returns: 

132 bool: True if the intervals have different endpoints, False otherwise. 

133 """ 

134 return not self == other 

135 

136 def __call__(self, y: float | np.ndarray) -> float | np.ndarray: 

137 """Map points from [-1,1] to this interval (shorthand for formap). 

138 

139 Args: 

140 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

141 

142 Returns: 

143 float or numpy.ndarray: Corresponding points in the interval [a,b]. 

144 """ 

145 return self.formap(y) 

146 

147 def __contains__(self, other: "Interval") -> bool: 

148 """Check if another interval is contained within this interval. 

149 

150 Args: 

151 other (Interval): Another interval to check. 

152 

153 Returns: 

154 bool: True if other is contained within this interval, False otherwise. 

155 """ 

156 (a, b), (x, y) = self, other 

157 return (a <= x) & (y <= b) 

158 

159 def isinterior(self, x: float | np.ndarray) -> bool | np.ndarray: 

160 """Check if points are strictly in the interior of the interval. 

161 

162 Args: 

163 x (float or numpy.ndarray): Points to check. 

164 

165 Returns: 

166 bool or numpy.ndarray: Boolean array indicating which points are in the interior. 

167 """ 

168 a, b = self 

169 return np.logical_and(a < x, x < b) 

170 

171 @property 

172 def hscale(self) -> float: 

173 """Calculate the horizontal scale factor of the interval. 

174 

175 Returns: 

176 float: The horizontal scale factor. 

177 """ 

178 a, b = self 

179 h = max(infnorm(self), 1) 

180 h_factor = b - a # if interval == domain: scale hscale back to 1 

181 hscale = max(h / h_factor, 1) # else: hscale < 1 

182 return hscale 

183 

184 

185def _merge_duplicates(arr: np.ndarray, tols: np.ndarray) -> np.ndarray: 

186 """Remove duplicate entries from an input array within specified tolerances. 

187 

188 This function works from left to right, keeping the first occurrence of 

189 values that are within tolerance of each other. 

190 

191 Args: 

192 arr (numpy.ndarray): Input array to remove duplicates from. 

193 tols (numpy.ndarray): Array of tolerance values for each pair of adjacent elements. 

194 Should have length one less than arr. 

195 

196 Returns: 

197 numpy.ndarray: Array with duplicates removed. 

198 

199 Note: 

200 Pathological cases may cause issues since this method works by using 

201 consecutive differences. It might be better to take an average (median?), 

202 rather than the left-hand value. 

203 """ 

204 idx = np.append(np.abs(np.diff(arr)) > tols[:-1], True) 

205 return arr[idx] 

206 

207 

208class Domain(np.ndarray): 

209 """Numpy ndarray with additional Chebfun-specific domain logic. 

210 

211 A Domain represents a collection of breakpoints that define a piecewise domain. 

212 It provides methods for manipulating and comparing domains, as well as 

213 generating intervals between adjacent breakpoints. 

214 

215 Attributes: 

216 intervals: Generator yielding Interval objects between adjacent breakpoints. 

217 support: First and last breakpoints of the domain. 

218 """ 

219 

220 def __new__(cls, breakpoints): 

221 """Create a new Domain instance. 

222 

223 Args: 

224 breakpoints (array-like): Collection of monotonically increasing breakpoints. 

225 Must have at least 2 elements. 

226 

227 Raises: 

228 InvalidDomain: If breakpoints has fewer than 2 elements or is not monotonically increasing. 

229 

230 Returns: 

231 Domain: A new Domain instance. 

232 """ 

233 bpts = np.asarray(breakpoints, dtype=float) 

234 if bpts.size == 0: 

235 return bpts.view(cls) 

236 elif bpts.size < 2 or np.any(np.diff(bpts) <= 0): 

237 raise InvalidDomain 

238 else: 

239 return bpts.view(cls) 

240 

241 def __contains__(self, other: "Domain") -> bool: 

242 """Check whether one domain object is a subdomain of another (within tolerance). 

243 

244 Args: 

245 other (Domain): Another domain to check. 

246 

247 Returns: 

248 bool: True if other is contained within this domain (within tolerance), False otherwise. 

249 """ 

250 a, b = self.support 

251 x, y = other.support 

252 bounds = np.array([1 - htol(), 1 + htol()]) 

253 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds) 

254 return (lbnd <= x) & (y <= rbnd) 

255 

256 @classmethod 

257 def from_chebfun(cls, chebfun): 

258 """Initialize a Domain object from a Chebfun. 

259 

260 Args: 

261 chebfun: A Chebfun object with breakpoints. 

262 

263 Returns: 

264 Domain: A new Domain instance with the same breakpoints as the Chebfun. 

265 """ 

266 return cls(chebfun.breakpoints) 

267 

268 @property 

269 def intervals(self) -> Iterable[Interval]: 

270 """Generate Interval objects between adjacent breakpoints. 

271 

272 Yields: 

273 Interval: Interval objects for each pair of adjacent breakpoints. 

274 """ 

275 for a, b in zip(self[:-1], self[1:]): 

276 yield Interval(a, b) 

277 

278 @property 

279 def support(self) -> Interval: 

280 """Get the first and last breakpoints of the domain. 

281 

282 Returns: 

283 numpy.ndarray: Array containing the first and last breakpoints. 

284 """ 

285 return self[[0, -1]] 

286 

287 @cast_other 

288 def union(self, other: "Domain") -> "Domain": 

289 """Create a union of two domain objects with matching support. 

290 

291 Args: 

292 other (Domain): Another domain to union with. 

293 

294 Raises: 

295 SupportMismatch: If the supports of the two domains don't match within tolerance. 

296 

297 Returns: 

298 Domain: A new Domain containing all breakpoints from both domains. 

299 """ 

300 dspt = np.abs(self.support - other.support) 

301 tolerance = np.maximum(htol(), htol() * np.abs(self.support)) 

302 if np.any(dspt > tolerance): 

303 raise SupportMismatch 

304 return self.merge(other) 

305 

306 def merge(self, other: "Domain") -> "Domain": 

307 """Merge two domain objects without checking if they have the same support. 

308 

309 Args: 

310 other (Domain): Another domain to merge with. 

311 

312 Returns: 

313 Domain: A new Domain containing all breakpoints from both domains. 

314 """ 

315 all_bpts = np.append(self, other) 

316 new_bpts = np.unique(all_bpts) 

317 mergetol = np.maximum(htol(), htol() * np.abs(new_bpts)) 

318 mgd_bpts = _merge_duplicates(new_bpts, mergetol) 

319 return self.__class__(mgd_bpts) 

320 

321 @cast_other 

322 def restrict(self, other: "Domain") -> "Domain": 

323 """Truncate self to the support of other, retaining any interior breakpoints. 

324 

325 Args: 

326 other (Domain): Domain to restrict to. 

327 

328 Raises: 

329 NotSubdomain: If other is not a subdomain of self. 

330 

331 Returns: 

332 Domain: A new Domain with breakpoints from self restricted to other's support. 

333 """ 

334 if other not in self: 

335 raise NotSubdomain 

336 dom = self.merge(other) 

337 a, b = other.support 

338 bounds = np.array([1 - htol(), 1 + htol()]) 

339 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds) 

340 new = dom[(lbnd <= dom) & (dom <= rbnd)] 

341 return self.__class__(new) 

342 

343 def breakpoints_in(self, other: "Domain") -> np.ndarray: 

344 """Check which breakpoints are in another domain within tolerance. 

345 

346 Args: 

347 other (Domain): Domain to check against. 

348 

349 Returns: 

350 numpy.ndarray: Boolean array of size equal to self where True indicates 

351 that the breakpoint is in other within the specified tolerance. 

352 """ 

353 out = np.empty(self.size, dtype=bool) 

354 window = np.array([1 - htol(), 1 + htol()]) 

355 # TODO: is there way to vectorise this? 

356 for idx, bpt in enumerate(self): 

357 lbnd, rbnd = np.sort(bpt * window) 

358 lbnd = -htol() if np.abs(lbnd) < htol() else lbnd 

359 rbnd = +htol() if np.abs(rbnd) < htol() else rbnd 

360 isin = (lbnd <= other) & (other <= rbnd) 

361 out[idx] = np.any(isin) 

362 return out 

363 

364 def __eq__(self, other: "Domain") -> bool: 

365 """Test for pointwise equality (within a tolerance) of two Domain objects. 

366 

367 Args: 

368 other (Domain): Another domain to compare with. 

369 

370 Returns: 

371 bool: True if domains have the same size and all breakpoints match within tolerance. 

372 """ 

373 if self.size != other.size: 

374 return False 

375 else: 

376 dbpt = np.abs(self - other) 

377 tolerance = np.maximum(htol(), htol() * np.abs(self)) 

378 return bool(np.all(dbpt <= tolerance)) # cast back to bool 

379 

380 def __ne__(self, other: "Domain") -> bool: 

381 """Test for inequality of two Domain objects. 

382 

383 Args: 

384 other (Domain): Another domain to compare with. 

385 

386 Returns: 

387 bool: True if domains differ in size or any breakpoints don't match within tolerance. 

388 """ 

389 return not self == other 

390 

391 

392def _sortindex(intervals: list[Interval]) -> np.ndarray: 

393 """Return an index determining the ordering of interval objects. 

394 

395 This helper function checks that the intervals: 

396 1. Do not overlap 

397 2. Represent a complete partition of the broader approximation domain 

398 

399 Args: 

400 intervals (array-like): Array of Interval objects to sort. 

401 

402 Returns: 

403 numpy.ndarray: Index array for sorting the intervals. 

404 

405 Raises: 

406 IntervalOverlap: If any intervals overlap. 

407 IntervalGap: If there are gaps between intervals. 

408 """ 

409 # sort by the left endpoint Interval values 

410 subintervals = np.array([x for x in intervals]) 

411 leftbreakpts = np.array([s[0] for s in subintervals]) 

412 idx = leftbreakpts.argsort() 

413 

414 # check domain consistency 

415 srt = subintervals[idx] 

416 x = srt.flatten()[1:-1] 

417 d = x[1::2] - x[::2] 

418 if (d < 0).any(): 

419 raise IntervalOverlap 

420 if (d > 0).any(): 

421 raise IntervalGap 

422 

423 return idx 

424 

425 

426def check_funs(funs: list) -> np.ndarray: 

427 """Return an array of sorted funs with validation checks. 

428 

429 This function checks that the provided funs do not overlap or have gaps 

430 between their intervals. The actual checks are performed in _sortindex. 

431 

432 Args: 

433 funs (array-like): Array of function objects with interval attributes. 

434 

435 Returns: 

436 numpy.ndarray: Sorted array of funs. 

437 

438 Raises: 

439 IntervalOverlap: If any function intervals overlap. 

440 IntervalGap: If there are gaps between function intervals. 

441 """ 

442 funs = np.array(funs) 

443 if funs.size == 0: 

444 sortedfuns = np.array([]) 

445 else: 

446 intervals = (fun.interval for fun in funs) 

447 idx = _sortindex(intervals) 

448 sortedfuns = funs[idx] 

449 return sortedfuns 

450 

451 

452def compute_breakdata(funs: np.ndarray) -> OrderedDict: 

453 """Define function values at breakpoints by averaging left and right limits. 

454 

455 This function computes values at breakpoints by averaging the left and right 

456 limits of adjacent functions. It is typically called after check_funs(), 

457 which ensures that the domain is fully partitioned and non-overlapping. 

458 

459 Args: 

460 funs (numpy.ndarray): Array of function objects with support and endvalues attributes. 

461 

462 Returns: 

463 OrderedDict: Dictionary mapping breakpoints to function values. 

464 """ 

465 if funs.size == 0: 

466 return OrderedDict() 

467 else: 

468 points = np.array([fun.support for fun in funs]) 

469 values = np.array([fun.endvalues for fun in funs]) 

470 points = points.flatten() 

471 values = values.flatten() 

472 xl, xr = points[0], points[-1] 

473 yl, yr = values[0], values[-1] 

474 xx, yy = points[1:-1], values[1:-1] 

475 x = 0.5 * (xx[::2] + xx[1::2]) 

476 y = 0.5 * (yy[::2] + yy[1::2]) 

477 xout = np.append(np.append(xl, x), xr) 

478 yout = np.append(np.append(yl, y), yr) 

479 return OrderedDict(zip(xout, yout)) 

480 

481 

482def generate_funs(domain: Domain | list | None, bndfun_constructor: callable, kwds: dict = {}) -> list: 

483 """Generate a collection of function objects over a domain. 

484 

485 This method is used by several of the Chebfun classmethod constructors to 

486 generate a collection of function objects over the specified domain. 

487 

488 Args: 

489 domain (array-like or None): Domain breakpoints. If None, uses default domain from preferences. 

490 bndfun_constructor (callable): Constructor function for creating function objects. 

491 kwds (dict, optional): Additional keyword arguments to pass to the constructor. Defaults to {}. 

492 

493 Returns: 

494 list: List of function objects covering the domain. 

495 """ 

496 domain = Domain(domain if domain is not None else prefs.domain) 

497 funs = [] 

498 for interval in domain.intervals: 

499 kwds = {**kwds, **{"interval": interval}} 

500 funs.append(bndfun_constructor(**kwds)) 

501 return funs 

502 

503 

504def infnorm(vals: np.ndarray) -> float: 

505 """Calculate the infinity norm of an array. 

506 

507 Args: 

508 vals (array-like): Input array. 

509 

510 Returns: 

511 float: The infinity norm (maximum absolute value) of the input. 

512 """ 

513 return np.linalg.norm(vals, np.inf) 

514 

515 

516def coerce_list(x: object) -> list | Iterable: 

517 """Convert a non-iterable object to a list containing that object. 

518 

519 If the input is already an iterable (except strings), it is returned unchanged. 

520 Strings are treated as non-iterables and wrapped in a list. 

521 

522 Args: 

523 x: Input object to coerce to a list if necessary. 

524 

525 Returns: 

526 list or iterable: The input wrapped in a list if it was not an iterable, 

527 or the original input if it was already an iterable (except strings). 

528 """ 

529 if not isinstance(x, Iterable) or isinstance(x, str): # pragma: no cover 

530 x = [x] 

531 return x