Coverage for src / chebpy / chebtech.py: 100%

252 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-22 21:33 +0000

1"""Implementation of Chebyshev polynomial technology for function approximation. 

2 

3This module provides the Chebtech class, which is an abstract base class for 

4representing functions using Chebyshev polynomial expansions. It serves as the 

5foundation for the Chebtech class, which uses Chebyshev points of the second kind. 

6 

7The Chebtech classes implement core functionality for working with Chebyshev 

8expansions, including: 

9- Function evaluation using Clenshaw's algorithm or barycentric interpolation 

10- Algebraic operations (addition, multiplication, etc.) 

11- Calculus operations (differentiation, integration, etc.) 

12- Rootfinding 

13- Plotting 

14 

15These classes are primarily used internally by higher-level classes like Bndfun 

16and Chebfun, rather than being used directly by end users. 

17""" 

18 

19from abc import ABC 

20from collections.abc import Callable 

21from typing import Any 

22 

23import matplotlib.pyplot as plt 

24import numpy as np 

25 

26from .algorithms import ( 

27 adaptive, 

28 bary, 

29 barywts2, 

30 chebpts2, 

31 clenshaw, 

32 coeffmult, 

33 coeffs2vals2, 

34 newtonroots, 

35 rootsunit, 

36 standard_chop, 

37 vals2coeffs2, 

38) 

39from .decorators import self_empty 

40from .plotting import plotfun, plotfuncoeffs 

41from .settings import _preferences as prefs 

42from .smoothfun import Smoothfun 

43from .utilities import Interval, coerce_list 

44 

45 

46class Chebtech(Smoothfun, ABC): 

47 """Abstract base class serving as the template for Chebtech1 and Chebtech subclasses. 

48 

49 Chebtech objects always work with first-kind coefficients, so much 

50 of the core operational functionality is defined this level. 

51 

52 The user will rarely work with these classes directly so we make 

53 several assumptions regarding input data types. 

54 """ 

55 

56 @classmethod 

57 def initconst(cls, c: Any = None, *, interval: Any = None) -> Any: 

58 """Initialise a Chebtech from a constant c.""" 

59 if not np.isscalar(c): 

60 raise ValueError(c) 

61 if isinstance(c, int): 

62 c = float(c) 

63 return cls(np.array([c]), interval=interval) 

64 

65 @classmethod 

66 def initempty(cls, *, interval: Any = None) -> "Chebtech": 

67 """Initialise an empty Chebtech.""" 

68 return cls(np.array([]), interval=interval) 

69 

70 @classmethod 

71 def initidentity(cls, *, interval: Any = None) -> "Chebtech": 

72 """Chebtech representation of f(x) = x on [-1,1].""" 

73 return cls(np.array([0, 1]), interval=interval) 

74 

75 @classmethod 

76 def initfun(cls, fun: Any = None, n: Any = None, *, interval: Any = None) -> Any: 

77 """Convenience constructor to automatically select the adaptive or fixedlen constructor. 

78 

79 This constructor automatically selects between the adaptive or fixed-length 

80 constructor based on the input arguments passed. 

81 """ 

82 if n is None: 

83 return cls.initfun_adaptive(fun, interval=interval) 

84 else: 

85 return cls.initfun_fixedlen(fun, n, interval=interval) 

86 

87 @classmethod 

88 def initfun_fixedlen(cls, fun: Any = None, n: Any = None, *, interval: Any = None) -> Any: 

89 """Initialise a Chebtech from the callable fun using n degrees of freedom. 

90 

91 This constructor creates a Chebtech representation of the function using 

92 a fixed number of degrees of freedom specified by n. 

93 """ 

94 if n is None: 

95 raise ValueError("n must be specified for fixed-length initialization") # noqa: TRY003 

96 points = cls._chebpts(int(n)) 

97 values = fun(points) 

98 coeffs = vals2coeffs2(values) 

99 return cls(coeffs, interval=interval) 

100 

101 @classmethod 

102 def initfun_adaptive(cls, fun: Any = None, *, interval: Any = None) -> Any: 

103 """Initialise a Chebtech from the callable fun utilising the adaptive constructor. 

104 

105 This constructor uses an adaptive algorithm to determine the appropriate 

106 number of degrees of freedom needed to represent the function. 

107 """ 

108 interval = interval if interval is not None else prefs.domain 

109 interval = Interval(*interval) 

110 coeffs = adaptive(cls, fun, hscale=interval.hscale) 

111 return cls(coeffs, interval=interval) 

112 

113 @classmethod 

114 def initvalues(cls, values: Any = None, *, interval: Any = None) -> Any: 

115 """Initialise a Chebtech from an array of values at Chebyshev points.""" 

116 return cls(cls._vals2coeffs(values), interval=interval) 

117 

118 def __init__(self, coeffs: Any, interval: Any = None) -> None: 

119 """Initialize a Chebtech object. 

120 

121 This method initializes a new Chebtech object with the given coefficients 

122 and interval. If no interval is provided, the default interval from 

123 preferences is used. 

124 

125 Args: 

126 coeffs (array-like): The coefficients of the Chebyshev series. 

127 interval (array-like, optional): The interval on which the function 

128 is defined. Defaults to None, which uses the default interval 

129 from preferences. 

130 """ 

131 interval = interval if interval is not None else prefs.domain 

132 self._coeffs = np.array(coeffs) 

133 self._interval = Interval(*interval) 

134 

135 def __call__(self, x: Any, how: str = "clenshaw") -> Any: 

136 """Evaluate the Chebtech at the given points. 

137 

138 Args: 

139 x: Points at which to evaluate the Chebtech. 

140 how (str, optional): Method to use for evaluation. Either "clenshaw" or "bary". 

141 Defaults to "clenshaw". 

142 

143 Returns: 

144 The values of the Chebtech at the given points. 

145 

146 Raises: 

147 ValueError: If the specified method is not supported. 

148 """ 

149 method: dict[str, Callable[[Any], Any]] = { 

150 "clenshaw": self.__call__clenshaw, 

151 "bary": self.__call__bary, 

152 } 

153 try: 

154 return method[how](x) 

155 except KeyError as err: 

156 raise ValueError(how) from err 

157 

158 def __call__clenshaw(self, x: Any) -> Any: 

159 return clenshaw(x, self.coeffs) 

160 

161 def __call__bary(self, x: Any) -> Any: 

162 fk = self.values() 

163 xk = self._chebpts(fk.size) 

164 vk = self._barywts(fk.size) 

165 return bary(x, fk, xk, vk) 

166 

167 def __repr__(self) -> str: # pragma: no cover 

168 """Return a string representation of the Chebtech. 

169 

170 Returns: 

171 str: A string representation of the Chebtech. 

172 """ 

173 out = f"<{self.__class__.__name__}{{{self.size}}}>" 

174 return out 

175 

176 # ------------ 

177 # properties 

178 # ------------ 

179 @property 

180 def coeffs(self) -> np.ndarray: 

181 """Chebyshev expansion coefficients in the T_k basis.""" 

182 return self._coeffs 

183 

184 @property 

185 def interval(self) -> Interval: 

186 """Interval that Chebtech is mapped to.""" 

187 return self._interval 

188 

189 @property 

190 def size(self) -> int: 

191 """Return the size of the object.""" 

192 return self.coeffs.size 

193 

194 @property 

195 def isempty(self) -> bool: 

196 """Return True if the Chebtech is empty.""" 

197 return self.size == 0 

198 

199 @property 

200 def iscomplex(self) -> bool: 

201 """Determine whether the underlying onefun is complex or real valued.""" 

202 return self._coeffs.dtype == complex 

203 

204 @property 

205 def isconst(self) -> bool: 

206 """Return True if the Chebtech represents a constant.""" 

207 return self.size == 1 

208 

209 @property 

210 @self_empty(0.0) 

211 def vscale(self) -> float: 

212 """Estimate the vertical scale of a Chebtech.""" 

213 return float(np.abs(np.asarray(coerce_list(self.values()))).max()) 

214 

215 # ----------- 

216 # utilities 

217 # ----------- 

218 def copy(self) -> "Chebtech": 

219 """Return a deep copy of the Chebtech.""" 

220 return self.__class__(self.coeffs.copy(), interval=self.interval.copy()) 

221 

222 def imag(self) -> Any: 

223 """Return the imaginary part of the Chebtech. 

224 

225 Returns: 

226 Chebtech: A new Chebtech representing the imaginary part of this Chebtech. 

227 If this Chebtech is real-valued, returns a zero Chebtech. 

228 """ 

229 if self.iscomplex: 

230 return self.__class__(np.imag(self.coeffs), self.interval) 

231 else: 

232 return self.initconst(0, interval=self.interval) 

233 

234 def prolong(self, n: int) -> "Chebtech": 

235 """Return a Chebtech of length n. 

236 

237 Obtained either by truncating if n < self.size or zero-padding if n > self.size. 

238 In all cases a deep copy is returned. 

239 """ 

240 m = self.size 

241 ak = self.coeffs 

242 cls = self.__class__ 

243 if n - m < 0: 

244 out = cls(ak[:n].copy(), interval=self.interval) 

245 elif n - m > 0: 

246 out = cls(np.append(ak, np.zeros(n - m)), interval=self.interval) 

247 else: 

248 out = self.copy() 

249 return out 

250 

251 def real(self) -> "Chebtech": 

252 """Return the real part of the Chebtech. 

253 

254 Returns: 

255 Chebtech: A new Chebtech representing the real part of this Chebtech. 

256 If this Chebtech is already real-valued, returns self. 

257 """ 

258 if self.iscomplex: 

259 return self.__class__(np.real(self.coeffs), self.interval) 

260 else: 

261 return self 

262 

263 def simplify(self) -> "Chebtech": 

264 """Call standard_chop on the coefficients of self. 

265 

266 Returns a Chebtech comprised of a copy of the truncated coefficients. 

267 """ 

268 # coefficients 

269 oldlen = len(self.coeffs) 

270 longself = self.prolong(max(17, oldlen)) 

271 cfs = longself.coeffs 

272 # scale (decrease) tolerance by hscale 

273 tol = prefs.eps * max(self.interval.hscale, 1) 

274 # chop 

275 npts = standard_chop(cfs, tol=tol) 

276 npts = min(oldlen, npts) 

277 # construct 

278 return self.__class__(cfs[:npts].copy(), interval=self.interval) 

279 

280 def values(self) -> np.ndarray: 

281 """Function values at Chebyshev points.""" 

282 return coeffs2vals2(self.coeffs) 

283 

284 # --------- 

285 # algebra 

286 # --------- 

287 @self_empty() 

288 def __add__(self, f: Any) -> Any: 

289 """Add a scalar or another Chebtech to this Chebtech. 

290 

291 Args: 

292 f: A scalar or another Chebtech to add to this Chebtech. 

293 

294 Returns: 

295 Chebtech: A new Chebtech representing the sum. 

296 """ 

297 cls = self.__class__ 

298 if np.isscalar(f): 

299 if np.iscomplexobj(f): 

300 dtype: Any = complex 

301 else: 

302 dtype = self.coeffs.dtype 

303 cfs = np.array(self.coeffs, dtype=dtype) 

304 cfs[0] += f 

305 return cls(cfs, interval=self.interval) 

306 else: 

307 # TODO: is a more general decorator approach better here? 

308 # TODO: for constant Chebtech, convert to constant and call __add__ again 

309 if f.isempty: 

310 return f.copy() 

311 g = self 

312 n, m = g.size, f.size 

313 if n < m: 

314 g = g.prolong(m) 

315 elif m < n: 

316 f = f.prolong(n) 

317 cfs = f.coeffs + g.coeffs 

318 

319 # check for zero output 

320 eps = prefs.eps 

321 tol = 0.5 * eps * max([f.vscale, g.vscale]) 

322 if all(abs(cfs) < tol): 

323 return cls.initconst(0.0, interval=self.interval) 

324 else: 

325 return cls(cfs, interval=self.interval) 

326 

327 @self_empty() 

328 def __div__(self, f: Any) -> Any: 

329 """Divide this Chebtech by a scalar or another Chebtech. 

330 

331 Args: 

332 f: A scalar or another Chebtech to divide this Chebtech by. 

333 

334 Returns: 

335 Chebtech: A new Chebtech representing the quotient. 

336 """ 

337 cls = self.__class__ 

338 if np.isscalar(f): 

339 cfs = 1.0 / np.asarray(f) * self.coeffs 

340 return cls(cfs, interval=self.interval) 

341 else: 

342 # TODO: review with reference to __add__ 

343 if f.isempty: 

344 return f.copy() 

345 return cls.initfun_adaptive(lambda x: self(x) / f(x), interval=self.interval) 

346 

347 __truediv__ = __div__ 

348 

349 @self_empty() 

350 def __mul__(self, g: Any) -> Any: 

351 """Multiply this Chebtech by a scalar or another Chebtech. 

352 

353 Args: 

354 g: A scalar or another Chebtech to multiply this Chebtech by. 

355 

356 Returns: 

357 Chebtech: A new Chebtech representing the product. 

358 """ 

359 cls = self.__class__ 

360 if np.isscalar(g): 

361 cfs = g * self.coeffs 

362 return cls(cfs, interval=self.interval) 

363 else: 

364 # TODO: review with reference to __add__ 

365 if g.isempty: 

366 return g.copy() 

367 f = self 

368 n = f.size + g.size - 1 

369 f = f.prolong(n) 

370 g = g.prolong(n) 

371 cfs = coeffmult(f.coeffs, g.coeffs) 

372 out = cls(cfs, interval=self.interval) 

373 return out 

374 

375 def __neg__(self) -> "Chebtech": 

376 """Return the negative of this Chebtech. 

377 

378 Returns: 

379 Chebtech: A new Chebtech representing the negative of this Chebtech. 

380 """ 

381 coeffs = -self.coeffs 

382 return self.__class__(coeffs, interval=self.interval) 

383 

384 def __pos__(self) -> "Chebtech": 

385 """Return this Chebtech (unary positive). 

386 

387 Returns: 

388 Chebtech: This Chebtech (self). 

389 """ 

390 return self 

391 

392 @self_empty() 

393 def __pow__(self, f: Any) -> Any: 

394 """Raise this Chebtech to a power. 

395 

396 Args: 

397 f: The exponent, which can be a scalar or another Chebtech. 

398 

399 Returns: 

400 Chebtech: A new Chebtech representing this Chebtech raised to the power f. 

401 """ 

402 

403 def powfun(fn: Any, x: Any) -> Any: 

404 if np.isscalar(fn): 

405 return fn 

406 else: 

407 return fn(x) 

408 

409 return self.__class__.initfun_adaptive(lambda x: np.power(self(x), powfun(f, x)), interval=self.interval) 

410 

411 def __rdiv__(self, f: Any) -> Any: 

412 """Divide a scalar by this Chebtech. 

413 

414 This is called when f / self is executed and f is not a Chebtech. 

415 

416 Args: 

417 f: A scalar to be divided by this Chebtech. 

418 

419 Returns: 

420 Chebtech: A new Chebtech representing f divided by this Chebtech. 

421 """ 

422 

423 # Executed when __div__(f, self) fails, which is to say whenever f 

424 # is not a Chebtech. We proceeed on the assumption f is a scalar. 

425 def constfun(x: Any) -> Any: 

426 return 0.0 * x + f 

427 

428 return self.__class__.initfun_adaptive(lambda x: constfun(x) / self(x), interval=self.interval) 

429 

430 __radd__ = __add__ 

431 

432 def __rsub__(self, f: Any) -> Any: 

433 """Subtract this Chebtech from a scalar. 

434 

435 This is called when f - self is executed and f is not a Chebtech. 

436 

437 Args: 

438 f: A scalar from which to subtract this Chebtech. 

439 

440 Returns: 

441 Chebtech: A new Chebtech representing f minus this Chebtech. 

442 """ 

443 return -(self - f) 

444 

445 @self_empty() 

446 def __rpow__(self, f: Any) -> Any: 

447 """Raise a scalar to the power of this Chebtech. 

448 

449 This is called when f ** self is executed and f is not a Chebtech. 

450 

451 Args: 

452 f: A scalar to be raised to the power of this Chebtech. 

453 

454 Returns: 

455 Chebtech: A new Chebtech representing f raised to the power of this Chebtech. 

456 """ 

457 return self.__class__.initfun_adaptive(lambda x: np.power(f, self(x)), interval=self.interval) 

458 

459 __rtruediv__ = __rdiv__ 

460 __rmul__ = __mul__ 

461 

462 def __sub__(self, f: Any) -> Any: 

463 """Subtract a scalar or another Chebtech from this Chebtech. 

464 

465 Args: 

466 f: A scalar or another Chebtech to subtract from this Chebtech. 

467 

468 Returns: 

469 Chebtech: A new Chebtech representing the difference. 

470 """ 

471 return self + (-f) 

472 

473 # ------- 

474 # roots 

475 # ------- 

476 def roots(self, sort: bool | None = None) -> np.ndarray: 

477 """Compute the roots of the Chebtech on [-1,1]. 

478 

479 Uses the coefficients in the associated Chebyshev series approximation. 

480 """ 

481 sort = sort if sort is not None else prefs.sortroots 

482 rts = rootsunit(self.coeffs) 

483 rts = newtonroots(self, rts) 

484 # fix problems with newton for roots that are numerically very close 

485 rts = np.clip(rts, -1, 1) # if newton roots are just outside [-1,1] 

486 rts = rts if not sort else np.sort(rts) 

487 return rts 

488 

489 # ---------- 

490 # calculus 

491 # ---------- 

492 # Note that function returns 0 for an empty Chebtech object; this is 

493 # consistent with numpy, which returns zero for the sum of an empty array 

494 @self_empty(resultif=0.0) 

495 def sum(self) -> Any: 

496 """Definite integral of a Chebtech on the interval [-1,1].""" 

497 if self.isconst: 

498 out = 2.0 * self(0.0) 

499 else: 

500 ak = self.coeffs.copy() 

501 ak[1::2] = 0 

502 kk = np.arange(2, ak.size) 

503 ii = np.append([2, 0], 2 / (1 - kk**2)) 

504 out = (ak * ii).sum() 

505 return out 

506 

507 @self_empty() 

508 def cumsum(self) -> "Chebtech": 

509 """Return a Chebtech object representing the indefinite integral. 

510 

511 Computes the indefinite integral of a Chebtech on the interval [-1,1]. 

512 The constant term is chosen such that F(-1) = 0. 

513 """ 

514 n = self.size 

515 ak = np.append(self.coeffs, [0, 0]) 

516 bk = np.zeros(n + 1, dtype=self.coeffs.dtype) 

517 rk = np.arange(2, n + 1) 

518 bk[2:] = 0.5 * (ak[1:n] - ak[3:]) / rk 

519 bk[1] = ak[0] - 0.5 * ak[2] 

520 vk = np.ones(n) 

521 vk[1::2] = -1 

522 bk[0] = (vk * bk[1:]).sum() 

523 out = self.__class__(bk, interval=self.interval) 

524 return out 

525 

526 @self_empty() 

527 def diff(self) -> "Chebtech": 

528 """Return a Chebtech object representing the derivative. 

529 

530 Computes the derivative of a Chebtech on the interval [-1,1]. 

531 """ 

532 if self.isconst: 

533 out = self.__class__(np.array([0.0]), interval=self.interval) 

534 else: 

535 n = self.size 

536 ak = self.coeffs 

537 zk = np.zeros(n - 1, dtype=self.coeffs.dtype) 

538 wk = 2 * np.arange(1, n) 

539 vk = wk * ak[1:] 

540 zk[-1::-2] = vk[-1::-2].cumsum() 

541 zk[-2::-2] = vk[-2::-2].cumsum() 

542 zk[0] = 0.5 * zk[0] 

543 out = self.__class__(zk, interval=self.interval) 

544 return out 

545 

546 @staticmethod 

547 def _chebpts(n: int) -> np.ndarray: 

548 """Return n Chebyshev points of the second-kind.""" 

549 return chebpts2(n) 

550 

551 @staticmethod 

552 def _barywts(n: int) -> np.ndarray: 

553 """Barycentric weights for Chebyshev points of 2nd kind.""" 

554 return barywts2(n) 

555 

556 @staticmethod 

557 def _vals2coeffs(vals: Any) -> np.ndarray: 

558 """Map function values at Chebyshev points of 2nd kind. 

559 

560 Converts values at Chebyshev points of 2nd kind to first-kind Chebyshev polynomial coefficients. 

561 """ 

562 return vals2coeffs2(vals) 

563 

564 @staticmethod 

565 def _coeffs2vals(coeffs: Any) -> np.ndarray: 

566 """Map first-kind Chebyshev polynomial coefficients. 

567 

568 Converts first-kind Chebyshev polynomial coefficients to function values at Chebyshev points of 2nd kind. 

569 """ 

570 return coeffs2vals2(coeffs) 

571 

572 # ---------- 

573 # plotting 

574 # ---------- 

575 def plot(self, ax: Any = None, **kwargs: Any) -> Any: 

576 """Plot the Chebtech on the interval [-1, 1]. 

577 

578 Args: 

579 ax (matplotlib.axes.Axes, optional): The axes on which to plot. Defaults to None. 

580 **kwargs: Additional keyword arguments to pass to the plot function. 

581 

582 Returns: 

583 matplotlib.lines.Line2D: The line object created by the plot. 

584 """ 

585 return plotfun(self, (-1, 1), ax=ax, **kwargs) 

586 

587 def plotcoeffs(self, ax: Any = None, **kwargs: Any) -> Any: 

588 """Plot the absolute values of the Chebyshev coefficients. 

589 

590 Args: 

591 ax (matplotlib.axes.Axes, optional): The axes on which to plot. Defaults to None. 

592 **kwargs: Additional keyword arguments to pass to the plot function. 

593 

594 Returns: 

595 matplotlib.lines.Line2D: The line object created by the plot. 

596 """ 

597 ax = ax or plt.gca() 

598 return plotfuncoeffs(abs(self.coeffs), ax=ax, **kwargs)