Coverage for chebpy/core/utilities.py: 100%

154 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-07 10:30 +0000

1"""Utility functions and classes for the ChebPy package. 

2 

3This module provides various utility functions and classes used throughout the ChebPy 

4package, including interval operations, domain manipulations, and tolerance functions. 

5It defines the core data structures for representing and manipulating intervals and domains. 

6""" 

7 

8from collections import OrderedDict 

9from collections.abc import Iterable 

10 

11import numpy as np 

12 

13from .decorators import cast_other 

14from .exceptions import ( 

15 IntervalGap, 

16 IntervalOverlap, 

17 IntervalValues, 

18 InvalidDomain, 

19 NotSubdomain, 

20 SupportMismatch, 

21) 

22from .settings import _preferences as prefs 

23 

24 

25def htol() -> float: 

26 """Return the horizontal tolerance used for interval comparisons. 

27 

28 Returns: 

29 float: 5 times the machine epsilon from preferences. 

30 """ 

31 return 5 * prefs.eps 

32 

33 

34class Interval(np.ndarray): 

35 """Utility class to implement Interval logic. 

36 

37 The purpose of this class is to both enforce certain properties of domain 

38 components such as having exactly two monotonically increasing elements and 

39 also to implement the functionality of mapping to and from the unit interval. 

40 

41 Attributes: 

42 formap: Maps y in [-1,1] to x in [a,b] 

43 invmap: Maps x in [a,b] to y in [-1,1] 

44 drvmap: Derivative mapping from y in [-1,1] to x in [a,b] 

45 

46 Note: 

47 Currently only implemented for finite a and b. 

48 The __call__ method evaluates self.formap since this is the most 

49 frequently used mapping operation. 

50 """ 

51 

52 def __new__(cls, a: float = -1.0, b: float = 1.0) -> "Interval": 

53 """Create a new Interval instance. 

54 

55 Args: 

56 a (float, optional): Left endpoint of the interval. Defaults to -1.0. 

57 b (float, optional): Right endpoint of the interval. Defaults to 1.0. 

58 

59 Raises: 

60 IntervalValues: If a >= b. 

61 

62 Returns: 

63 Interval: A new Interval instance. 

64 """ 

65 if a >= b: 

66 raise IntervalValues 

67 return np.asarray((a, b), dtype=float).view(cls) 

68 

69 def formap(self, y: float | np.ndarray) -> float | np.ndarray: 

70 """Map from the reference interval [-1,1] to this interval [a,b]. 

71 

72 Args: 

73 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

74 

75 Returns: 

76 float or numpy.ndarray: Corresponding points in the interval [a,b]. 

77 """ 

78 a, b = self 

79 return 0.5 * b * (y + 1.0) + 0.5 * a * (1.0 - y) 

80 

81 def invmap(self, x: float | np.ndarray) -> float | np.ndarray: 

82 """Map from this interval [a,b] to the reference interval [-1,1]. 

83 

84 Args: 

85 x (float or numpy.ndarray): Points in the interval [a,b]. 

86 

87 Returns: 

88 float or numpy.ndarray: Corresponding points in the reference interval [-1,1]. 

89 """ 

90 a, b = self 

91 return (2.0 * x - a - b) / (b - a) 

92 

93 def drvmap(self, y: float | np.ndarray) -> float | np.ndarray: 

94 """Compute the derivative of the forward map. 

95 

96 Args: 

97 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

98 

99 Returns: 

100 float or numpy.ndarray: Derivative values at the corresponding points. 

101 """ 

102 a, b = self # pragma: no cover 

103 return 0.0 * y + 0.5 * (b - a) # pragma: no cover 

104 

105 def __eq__(self, other: "Interval") -> bool: 

106 """Check if two intervals are equal. 

107 

108 Args: 

109 other (Interval): Another interval to compare with. 

110 

111 Returns: 

112 bool: True if the intervals have the same endpoints, False otherwise. 

113 """ 

114 (a, b), (x, y) = self, other 

115 return (a == x) & (y == b) 

116 

117 def __ne__(self, other: "Interval") -> bool: 

118 """Check if two intervals are not equal. 

119 

120 Args: 

121 other (Interval): Another interval to compare with. 

122 

123 Returns: 

124 bool: True if the intervals have different endpoints, False otherwise. 

125 """ 

126 return not self == other 

127 

128 def __call__(self, y: float | np.ndarray) -> float | np.ndarray: 

129 """Map points from [-1,1] to this interval (shorthand for formap). 

130 

131 Args: 

132 y (float or numpy.ndarray): Points in the reference interval [-1,1]. 

133 

134 Returns: 

135 float or numpy.ndarray: Corresponding points in the interval [a,b]. 

136 """ 

137 return self.formap(y) 

138 

139 def __contains__(self, other: "Interval") -> bool: 

140 """Check if another interval is contained within this interval. 

141 

142 Args: 

143 other (Interval): Another interval to check. 

144 

145 Returns: 

146 bool: True if other is contained within this interval, False otherwise. 

147 """ 

148 (a, b), (x, y) = self, other 

149 return (a <= x) & (y <= b) 

150 

151 def isinterior(self, x: float | np.ndarray) -> bool | np.ndarray: 

152 """Check if points are strictly in the interior of the interval. 

153 

154 Args: 

155 x (float or numpy.ndarray): Points to check. 

156 

157 Returns: 

158 bool or numpy.ndarray: Boolean array indicating which points are in the interior. 

159 """ 

160 a, b = self 

161 return np.logical_and(a < x, x < b) 

162 

163 @property 

164 def hscale(self) -> float: 

165 """Calculate the horizontal scale factor of the interval. 

166 

167 Returns: 

168 float: The horizontal scale factor. 

169 """ 

170 a, b = self 

171 h = max(infnorm(self), 1) 

172 h_factor = b - a # if interval == domain: scale hscale back to 1 

173 hscale = max(h / h_factor, 1) # else: hscale < 1 

174 return hscale 

175 

176 

177def _merge_duplicates(arr: np.ndarray, tols: np.ndarray) -> np.ndarray: 

178 """Remove duplicate entries from an input array within specified tolerances. 

179 

180 This function works from left to right, keeping the first occurrence of 

181 values that are within tolerance of each other. 

182 

183 Args: 

184 arr (numpy.ndarray): Input array to remove duplicates from. 

185 tols (numpy.ndarray): Array of tolerance values for each pair of adjacent elements. 

186 Should have length one less than arr. 

187 

188 Returns: 

189 numpy.ndarray: Array with duplicates removed. 

190 

191 Note: 

192 Pathological cases may cause issues since this method works by using 

193 consecutive differences. It might be better to take an average (median?), 

194 rather than the left-hand value. 

195 """ 

196 idx = np.append(np.abs(np.diff(arr)) > tols[:-1], True) 

197 return arr[idx] 

198 

199 

200class Domain(np.ndarray): 

201 """Numpy ndarray with additional Chebfun-specific domain logic. 

202 

203 A Domain represents a collection of breakpoints that define a piecewise domain. 

204 It provides methods for manipulating and comparing domains, as well as 

205 generating intervals between adjacent breakpoints. 

206 

207 Attributes: 

208 intervals: Generator yielding Interval objects between adjacent breakpoints. 

209 support: First and last breakpoints of the domain. 

210 """ 

211 

212 def __new__(cls, breakpoints): 

213 """Create a new Domain instance. 

214 

215 Args: 

216 breakpoints (array-like): Collection of monotonically increasing breakpoints. 

217 Must have at least 2 elements. 

218 

219 Raises: 

220 InvalidDomain: If breakpoints has fewer than 2 elements or is not monotonically increasing. 

221 

222 Returns: 

223 Domain: A new Domain instance. 

224 """ 

225 bpts = np.asarray(breakpoints, dtype=float) 

226 if bpts.size == 0: 

227 return bpts.view(cls) 

228 elif bpts.size < 2 or np.any(np.diff(bpts) <= 0): 

229 raise InvalidDomain 

230 else: 

231 return bpts.view(cls) 

232 

233 def __contains__(self, other: "Domain") -> bool: 

234 """Check whether one domain object is a subdomain of another (within tolerance). 

235 

236 Args: 

237 other (Domain): Another domain to check. 

238 

239 Returns: 

240 bool: True if other is contained within this domain (within tolerance), False otherwise. 

241 """ 

242 a, b = self.support 

243 x, y = other.support 

244 bounds = np.array([1 - htol(), 1 + htol()]) 

245 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds) 

246 return (lbnd <= x) & (y <= rbnd) 

247 

248 @classmethod 

249 def from_chebfun(cls, chebfun): 

250 """Initialize a Domain object from a Chebfun. 

251 

252 Args: 

253 chebfun: A Chebfun object with breakpoints. 

254 

255 Returns: 

256 Domain: A new Domain instance with the same breakpoints as the Chebfun. 

257 """ 

258 return cls(chebfun.breakpoints) 

259 

260 @property 

261 def intervals(self) -> Iterable[Interval]: 

262 """Generate Interval objects between adjacent breakpoints. 

263 

264 Yields: 

265 Interval: Interval objects for each pair of adjacent breakpoints. 

266 """ 

267 for a, b in zip(self[:-1], self[1:]): 

268 yield Interval(a, b) 

269 

270 @property 

271 def support(self) -> Interval: 

272 """Get the first and last breakpoints of the domain. 

273 

274 Returns: 

275 numpy.ndarray: Array containing the first and last breakpoints. 

276 """ 

277 return self[[0, -1]] 

278 

279 @cast_other 

280 def union(self, other: "Domain") -> "Domain": 

281 """Create a union of two domain objects with matching support. 

282 

283 Args: 

284 other (Domain): Another domain to union with. 

285 

286 Raises: 

287 SupportMismatch: If the supports of the two domains don't match within tolerance. 

288 

289 Returns: 

290 Domain: A new Domain containing all breakpoints from both domains. 

291 """ 

292 dspt = np.abs(self.support - other.support) 

293 tolerance = np.maximum(htol(), htol() * np.abs(self.support)) 

294 if np.any(dspt > tolerance): 

295 raise SupportMismatch 

296 return self.merge(other) 

297 

298 def merge(self, other: "Domain") -> "Domain": 

299 """Merge two domain objects without checking if they have the same support. 

300 

301 Args: 

302 other (Domain): Another domain to merge with. 

303 

304 Returns: 

305 Domain: A new Domain containing all breakpoints from both domains. 

306 """ 

307 all_bpts = np.append(self, other) 

308 new_bpts = np.unique(all_bpts) 

309 mergetol = np.maximum(htol(), htol() * np.abs(new_bpts)) 

310 mgd_bpts = _merge_duplicates(new_bpts, mergetol) 

311 return self.__class__(mgd_bpts) 

312 

313 @cast_other 

314 def restrict(self, other: "Domain") -> "Domain": 

315 """Truncate self to the support of other, retaining any interior breakpoints. 

316 

317 Args: 

318 other (Domain): Domain to restrict to. 

319 

320 Raises: 

321 NotSubdomain: If other is not a subdomain of self. 

322 

323 Returns: 

324 Domain: A new Domain with breakpoints from self restricted to other's support. 

325 """ 

326 if other not in self: 

327 raise NotSubdomain 

328 dom = self.merge(other) 

329 a, b = other.support 

330 bounds = np.array([1 - htol(), 1 + htol()]) 

331 lbnd, rbnd = np.min(a * bounds), np.max(b * bounds) 

332 new = dom[(lbnd <= dom) & (dom <= rbnd)] 

333 return self.__class__(new) 

334 

335 def breakpoints_in(self, other: "Domain") -> np.ndarray: 

336 """Check which breakpoints are in another domain within tolerance. 

337 

338 Args: 

339 other (Domain): Domain to check against. 

340 

341 Returns: 

342 numpy.ndarray: Boolean array of size equal to self where True indicates 

343 that the breakpoint is in other within the specified tolerance. 

344 """ 

345 out = np.empty(self.size, dtype=bool) 

346 window = np.array([1 - htol(), 1 + htol()]) 

347 # TODO: is there way to vectorise this? 

348 for idx, bpt in enumerate(self): 

349 lbnd, rbnd = np.sort(bpt * window) 

350 lbnd = -htol() if np.abs(lbnd) < htol() else lbnd 

351 rbnd = +htol() if np.abs(rbnd) < htol() else rbnd 

352 isin = (lbnd <= other) & (other <= rbnd) 

353 out[idx] = np.any(isin) 

354 return out 

355 

356 def __eq__(self, other: "Domain") -> bool: 

357 """Test for pointwise equality (within a tolerance) of two Domain objects. 

358 

359 Args: 

360 other (Domain): Another domain to compare with. 

361 

362 Returns: 

363 bool: True if domains have the same size and all breakpoints match within tolerance. 

364 """ 

365 if self.size != other.size: 

366 return False 

367 else: 

368 dbpt = np.abs(self - other) 

369 tolerance = np.maximum(htol(), htol() * np.abs(self)) 

370 return bool(np.all(dbpt <= tolerance)) # cast back to bool 

371 

372 def __ne__(self, other: "Domain") -> bool: 

373 """Test for inequality of two Domain objects. 

374 

375 Args: 

376 other (Domain): Another domain to compare with. 

377 

378 Returns: 

379 bool: True if domains differ in size or any breakpoints don't match within tolerance. 

380 """ 

381 return not self == other 

382 

383 

384def _sortindex(intervals: list[Interval]) -> np.ndarray: 

385 """Return an index determining the ordering of interval objects. 

386 

387 This helper function checks that the intervals: 

388 1. Do not overlap 

389 2. Represent a complete partition of the broader approximation domain 

390 

391 Args: 

392 intervals (array-like): Array of Interval objects to sort. 

393 

394 Returns: 

395 numpy.ndarray: Index array for sorting the intervals. 

396 

397 Raises: 

398 IntervalOverlap: If any intervals overlap. 

399 IntervalGap: If there are gaps between intervals. 

400 """ 

401 # sort by the left endpoint Interval values 

402 subintervals = np.array([x for x in intervals]) 

403 leftbreakpts = np.array([s[0] for s in subintervals]) 

404 idx = leftbreakpts.argsort() 

405 

406 # check domain consistency 

407 srt = subintervals[idx] 

408 x = srt.flatten()[1:-1] 

409 d = x[1::2] - x[::2] 

410 if (d < 0).any(): 

411 raise IntervalOverlap 

412 if (d > 0).any(): 

413 raise IntervalGap 

414 

415 return idx 

416 

417 

418def check_funs(funs: list) -> np.ndarray: 

419 """Return an array of sorted funs with validation checks. 

420 

421 This function checks that the provided funs do not overlap or have gaps 

422 between their intervals. The actual checks are performed in _sortindex. 

423 

424 Args: 

425 funs (array-like): Array of function objects with interval attributes. 

426 

427 Returns: 

428 numpy.ndarray: Sorted array of funs. 

429 

430 Raises: 

431 IntervalOverlap: If any function intervals overlap. 

432 IntervalGap: If there are gaps between function intervals. 

433 """ 

434 funs = np.array(funs) 

435 if funs.size == 0: 

436 sortedfuns = np.array([]) 

437 else: 

438 intervals = (fun.interval for fun in funs) 

439 idx = _sortindex(intervals) 

440 sortedfuns = funs[idx] 

441 return sortedfuns 

442 

443 

444def compute_breakdata(funs: np.ndarray) -> OrderedDict: 

445 """Define function values at breakpoints by averaging left and right limits. 

446 

447 This function computes values at breakpoints by averaging the left and right 

448 limits of adjacent functions. It is typically called after check_funs(), 

449 which ensures that the domain is fully partitioned and non-overlapping. 

450 

451 Args: 

452 funs (numpy.ndarray): Array of function objects with support and endvalues attributes. 

453 

454 Returns: 

455 OrderedDict: Dictionary mapping breakpoints to function values. 

456 """ 

457 if funs.size == 0: 

458 return OrderedDict() 

459 else: 

460 points = np.array([fun.support for fun in funs]) 

461 values = np.array([fun.endvalues for fun in funs]) 

462 points = points.flatten() 

463 values = values.flatten() 

464 xl, xr = points[0], points[-1] 

465 yl, yr = values[0], values[-1] 

466 xx, yy = points[1:-1], values[1:-1] 

467 x = 0.5 * (xx[::2] + xx[1::2]) 

468 y = 0.5 * (yy[::2] + yy[1::2]) 

469 xout = np.append(np.append(xl, x), xr) 

470 yout = np.append(np.append(yl, y), yr) 

471 return OrderedDict(zip(xout, yout)) 

472 

473 

474def generate_funs(domain: Domain | list | None, bndfun_constructor: callable, kwds: dict = {}) -> list: 

475 """Generate a collection of function objects over a domain. 

476 

477 This method is used by several of the Chebfun classmethod constructors to 

478 generate a collection of function objects over the specified domain. 

479 

480 Args: 

481 domain (array-like or None): Domain breakpoints. If None, uses default domain from preferences. 

482 bndfun_constructor (callable): Constructor function for creating function objects. 

483 kwds (dict, optional): Additional keyword arguments to pass to the constructor. Defaults to {}. 

484 

485 Returns: 

486 list: List of function objects covering the domain. 

487 """ 

488 domain = Domain(domain if domain is not None else prefs.domain) 

489 funs = [] 

490 for interval in domain.intervals: 

491 kwds = {**kwds, **{"interval": interval}} 

492 funs.append(bndfun_constructor(**kwds)) 

493 return funs 

494 

495 

496def infnorm(vals: np.ndarray) -> float: 

497 """Calculate the infinity norm of an array. 

498 

499 Args: 

500 vals (array-like): Input array. 

501 

502 Returns: 

503 float: The infinity norm (maximum absolute value) of the input. 

504 """ 

505 return np.linalg.norm(vals, np.inf) 

506 

507 

508def coerce_list(x: object) -> list | Iterable: 

509 """Convert a non-iterable object to a list containing that object. 

510 

511 If the input is already an iterable (except strings), it is returned unchanged. 

512 Strings are treated as non-iterables and wrapped in a list. 

513 

514 Args: 

515 x: Input object to coerce to a list if necessary. 

516 

517 Returns: 

518 list or iterable: The input wrapped in a list if it was not an iterable, 

519 or the original input if it was already an iterable (except strings). 

520 """ 

521 if not isinstance(x, Iterable) or isinstance(x, str): # pragma: no cover 

522 x = [x] 

523 return x