Coverage for chebpy/core/chebtech.py: 99%

248 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-07 10:30 +0000

1"""Implementation of Chebyshev polynomial technology for function approximation. 

2 

3This module provides the Chebtech class, which is an abstract base class for 

4representing functions using Chebyshev polynomial expansions. It serves as the 

5foundation for the Chebtech class, which uses Chebyshev points of the second kind. 

6 

7The Chebtech classes implement core functionality for working with Chebyshev 

8expansions, including: 

9- Function evaluation using Clenshaw's algorithm or barycentric interpolation 

10- Algebraic operations (addition, multiplication, etc.) 

11- Calculus operations (differentiation, integration, etc.) 

12- Rootfinding 

13- Plotting 

14 

15These classes are primarily used internally by higher-level classes like Bndfun 

16and Chebfun, rather than being used directly by end users. 

17""" 

18 

19from abc import ABC 

20 

21import matplotlib.pyplot as plt 

22import numpy as np 

23 

24from .algorithms import ( 

25 adaptive, 

26 bary, 

27 barywts2, 

28 chebpts2, 

29 clenshaw, 

30 coeffmult, 

31 coeffs2vals2, 

32 newtonroots, 

33 rootsunit, 

34 standard_chop, 

35 vals2coeffs2, 

36) 

37from .decorators import self_empty 

38from .plotting import plotfun, plotfuncoeffs 

39from .settings import _preferences as prefs 

40from .smoothfun import Smoothfun 

41from .utilities import Interval, coerce_list 

42 

43 

44class Chebtech(Smoothfun, ABC): 

45 """Abstract base class serving as the template for Chebtech1 and Chebtech subclasses. 

46 

47 Chebtech objects always work with first-kind coefficients, so much 

48 of the core operational functionality is defined this level. 

49 

50 The user will rarely work with these classes directly so we make 

51 several assumptions regarding input data types. 

52 """ 

53 

54 @classmethod 

55 def initconst(cls, c, *, interval=None): 

56 """Initialise a Chebtech from a constant c.""" 

57 if not np.isscalar(c): 

58 raise ValueError(c) 

59 if isinstance(c, int): 

60 c = float(c) 

61 return cls(np.array([c]), interval=interval) 

62 

63 @classmethod 

64 def initempty(cls, *, interval=None): 

65 """Initialise an empty Chebtech.""" 

66 return cls(np.array([]), interval=interval) 

67 

68 @classmethod 

69 def initidentity(cls, *, interval=None): 

70 """Chebtech representation of f(x) = x on [-1,1].""" 

71 return cls(np.array([0, 1]), interval=interval) 

72 

73 @classmethod 

74 def initfun(cls, fun, n=None, *, interval=None): 

75 """Convenience constructor to automatically select the adaptive or fixedlen constructor. 

76 

77 This constructor automatically selects between the adaptive or fixed-length 

78 constructor based on the input arguments passed. 

79 """ 

80 if n is None: 

81 return cls.initfun_adaptive(fun, interval=interval) 

82 else: 

83 return cls.initfun_fixedlen(fun, n, interval=interval) 

84 

85 @classmethod 

86 def initfun_fixedlen(cls, fun, n, *, interval=None): 

87 """Initialise a Chebtech from the callable fun using n degrees of freedom. 

88 

89 This constructor creates a Chebtech representation of the function using 

90 a fixed number of degrees of freedom specified by n. 

91 """ 

92 points = cls._chebpts(n) 

93 values = fun(points) 

94 coeffs = vals2coeffs2(values) 

95 return cls(coeffs, interval=interval) 

96 

97 @classmethod 

98 def initfun_adaptive(cls, fun, *, interval=None): 

99 """Initialise a Chebtech from the callable fun utilising the adaptive constructor. 

100 

101 This constructor uses an adaptive algorithm to determine the appropriate 

102 number of degrees of freedom needed to represent the function. 

103 """ 

104 interval = interval if interval is not None else prefs.domain 

105 interval = Interval(*interval) 

106 coeffs = adaptive(cls, fun, hscale=interval.hscale) 

107 return cls(coeffs, interval=interval) 

108 

109 @classmethod 

110 def initvalues(cls, values, *, interval=None): 

111 """Initialise a Chebtech from an array of values at Chebyshev points.""" 

112 return cls(cls._vals2coeffs(values), interval=interval) 

113 

114 def __init__(self, coeffs, interval=None): 

115 """Initialize a Chebtech object. 

116 

117 This method initializes a new Chebtech object with the given coefficients 

118 and interval. If no interval is provided, the default interval from 

119 preferences is used. 

120 

121 Args: 

122 coeffs (array-like): The coefficients of the Chebyshev series. 

123 interval (array-like, optional): The interval on which the function 

124 is defined. Defaults to None, which uses the default interval 

125 from preferences. 

126 """ 

127 interval = interval if interval is not None else prefs.domain 

128 self._coeffs = np.array(coeffs) 

129 self._interval = Interval(*interval) 

130 

131 def __call__(self, x, how="clenshaw"): 

132 """Evaluate the Chebtech at the given points. 

133 

134 Args: 

135 x: Points at which to evaluate the Chebtech. 

136 how (str, optional): Method to use for evaluation. Either "clenshaw" or "bary". 

137 Defaults to "clenshaw". 

138 

139 Returns: 

140 The values of the Chebtech at the given points. 

141 

142 Raises: 

143 ValueError: If the specified method is not supported. 

144 """ 

145 method = { 

146 "clenshaw": self.__call__clenshaw, 

147 "bary": self.__call__bary, 

148 } 

149 try: 

150 return method[how](x) 

151 except KeyError: 

152 raise ValueError(how) 

153 

154 def __call__clenshaw(self, x): 

155 return clenshaw(x, self.coeffs) 

156 

157 def __call__bary(self, x): 

158 fk = self.values() 

159 xk = self._chebpts(fk.size) 

160 vk = self._barywts(fk.size) 

161 return bary(x, fk, xk, vk) 

162 

163 def __repr__(self): # pragma: no cover 

164 """Return a string representation of the Chebtech. 

165 

166 Returns: 

167 str: A string representation of the Chebtech. 

168 """ 

169 out = f"<{self.__class__.__name__}{{{self.size}}}>" 

170 return out 

171 

172 # ------------ 

173 # properties 

174 # ------------ 

175 @property 

176 def coeffs(self): 

177 """Chebyshev expansion coefficients in the T_k basis.""" 

178 return self._coeffs 

179 

180 @property 

181 def interval(self): 

182 """Interval that Chebtech is mapped to.""" 

183 return self._interval 

184 

185 @property 

186 def size(self): 

187 """Return the size of the object.""" 

188 return self.coeffs.size 

189 

190 @property 

191 def isempty(self): 

192 """Return True if the Chebtech is empty.""" 

193 return self.size == 0 

194 

195 @property 

196 def iscomplex(self): 

197 """Determine whether the underlying onefun is complex or real valued.""" 

198 return self._coeffs.dtype == complex 

199 

200 @property 

201 def isconst(self): 

202 """Return True if the Chebtech represents a constant.""" 

203 return self.size == 1 

204 

205 @property 

206 @self_empty(0.0) 

207 def vscale(self): 

208 """Estimate the vertical scale of a Chebtech.""" 

209 return np.abs(coerce_list(self.values())).max() 

210 

211 # ----------- 

212 # utilities 

213 # ----------- 

214 def copy(self): 

215 """Return a deep copy of the Chebtech.""" 

216 return self.__class__(self.coeffs.copy(), interval=self.interval.copy()) 

217 

218 def imag(self): 

219 """Return the imaginary part of the Chebtech. 

220 

221 Returns: 

222 Chebtech: A new Chebtech representing the imaginary part of this Chebtech. 

223 If this Chebtech is real-valued, returns a zero Chebtech. 

224 """ 

225 if self.iscomplex: 

226 return self.__class__(np.imag(self.coeffs), self.interval) 

227 else: 

228 return self.initconst(0, interval=self.interval) 

229 

230 def prolong(self, n): 

231 """Return a Chebtech of length n. 

232 

233 Obtained either by truncating if n < self.size or zero-padding if n > self.size. 

234 In all cases a deep copy is returned. 

235 """ 

236 m = self.size 

237 ak = self.coeffs 

238 cls = self.__class__ 

239 if n - m < 0: 

240 out = cls(ak[:n].copy(), interval=self.interval) 

241 elif n - m > 0: 

242 out = cls(np.append(ak, np.zeros(n - m)), interval=self.interval) 

243 else: 

244 out = self.copy() 

245 return out 

246 

247 def real(self): 

248 """Return the real part of the Chebtech. 

249 

250 Returns: 

251 Chebtech: A new Chebtech representing the real part of this Chebtech. 

252 If this Chebtech is already real-valued, returns self. 

253 """ 

254 if self.iscomplex: 

255 return self.__class__(np.real(self.coeffs), self.interval) 

256 else: 

257 return self 

258 

259 def simplify(self): 

260 """Call standard_chop on the coefficients of self. 

261 

262 Returns a Chebtech comprised of a copy of the truncated coefficients. 

263 """ 

264 # coefficients 

265 oldlen = len(self.coeffs) 

266 longself = self.prolong(max(17, oldlen)) 

267 cfs = longself.coeffs 

268 # scale (decrease) tolerance by hscale 

269 tol = prefs.eps * max(self.interval.hscale, 1) 

270 # chop 

271 npts = standard_chop(cfs, tol=tol) 

272 npts = min(oldlen, npts) 

273 # construct 

274 return self.__class__(cfs[:npts].copy(), interval=self.interval) 

275 

276 def values(self): 

277 """Function values at Chebyshev points.""" 

278 return coeffs2vals2(self.coeffs) 

279 

280 # --------- 

281 # algebra 

282 # --------- 

283 @self_empty() 

284 def __add__(self, f): 

285 """Add a scalar or another Chebtech to this Chebtech. 

286 

287 Args: 

288 f: A scalar or another Chebtech to add to this Chebtech. 

289 

290 Returns: 

291 Chebtech: A new Chebtech representing the sum. 

292 """ 

293 cls = self.__class__ 

294 if np.isscalar(f): 

295 if np.iscomplexobj(f): 

296 dtype = complex 

297 else: 

298 dtype = self.coeffs.dtype 

299 cfs = np.array(self.coeffs, dtype=dtype) 

300 cfs[0] += f 

301 return cls(cfs, interval=self.interval) 

302 else: 

303 # TODO: is a more general decorator approach better here? 

304 # TODO: for constant Chebtech, convert to constant and call __add__ again 

305 if f.isempty: 

306 return f.copy() 

307 g = self 

308 n, m = g.size, f.size 

309 if n < m: 

310 g = g.prolong(m) 

311 elif m < n: 

312 f = f.prolong(n) 

313 cfs = f.coeffs + g.coeffs 

314 

315 # check for zero output 

316 eps = prefs.eps 

317 tol = 0.5 * eps * max([f.vscale, g.vscale]) 

318 if all(abs(cfs) < tol): 

319 return cls.initconst(0.0, interval=self.interval) 

320 else: 

321 return cls(cfs, interval=self.interval) 

322 

323 @self_empty() 

324 def __div__(self, f): 

325 """Divide this Chebtech by a scalar or another Chebtech. 

326 

327 Args: 

328 f: A scalar or another Chebtech to divide this Chebtech by. 

329 

330 Returns: 

331 Chebtech: A new Chebtech representing the quotient. 

332 """ 

333 cls = self.__class__ 

334 if np.isscalar(f): 

335 cfs = 1.0 / f * self.coeffs 

336 return cls(cfs, interval=self.interval) 

337 else: 

338 # TODO: review with reference to __add__ 

339 if f.isempty: 

340 return f.copy() 

341 return cls.initfun_adaptive(lambda x: self(x) / f(x), interval=self.interval) 

342 

343 __truediv__ = __div__ 

344 

345 @self_empty() 

346 def __mul__(self, g): 

347 """Multiply this Chebtech by a scalar or another Chebtech. 

348 

349 Args: 

350 g: A scalar or another Chebtech to multiply this Chebtech by. 

351 

352 Returns: 

353 Chebtech: A new Chebtech representing the product. 

354 """ 

355 cls = self.__class__ 

356 if np.isscalar(g): 

357 cfs = g * self.coeffs 

358 return cls(cfs, interval=self.interval) 

359 else: 

360 # TODO: review with reference to __add__ 

361 if g.isempty: 

362 return g.copy() 

363 f = self 

364 n = f.size + g.size - 1 

365 f = f.prolong(n) 

366 g = g.prolong(n) 

367 cfs = coeffmult(f.coeffs, g.coeffs) 

368 out = cls(cfs, interval=self.interval) 

369 return out 

370 

371 def __neg__(self): 

372 """Return the negative of this Chebtech. 

373 

374 Returns: 

375 Chebtech: A new Chebtech representing the negative of this Chebtech. 

376 """ 

377 coeffs = -self.coeffs 

378 return self.__class__(coeffs, interval=self.interval) 

379 

380 def __pos__(self): 

381 """Return this Chebtech (unary positive). 

382 

383 Returns: 

384 Chebtech: This Chebtech (self). 

385 """ 

386 return self 

387 

388 @self_empty() 

389 def __pow__(self, f): 

390 """Raise this Chebtech to a power. 

391 

392 Args: 

393 f: The exponent, which can be a scalar or another Chebtech. 

394 

395 Returns: 

396 Chebtech: A new Chebtech representing this Chebtech raised to the power f. 

397 """ 

398 

399 def powfun(fn, x): 

400 if np.isscalar(fn): 

401 return fn 

402 else: 

403 return fn(x) 

404 

405 return self.__class__.initfun_adaptive(lambda x: np.power(self(x), powfun(f, x)), interval=self.interval) 

406 

407 def __rdiv__(self, f): 

408 """Divide a scalar by this Chebtech. 

409 

410 This is called when f / self is executed and f is not a Chebtech. 

411 

412 Args: 

413 f: A scalar to be divided by this Chebtech. 

414 

415 Returns: 

416 Chebtech: A new Chebtech representing f divided by this Chebtech. 

417 """ 

418 

419 # Executed when __div__(f, self) fails, which is to say whenever f 

420 # is not a Chebtech. We proceeed on the assumption f is a scalar. 

421 def constfun(x): 

422 return 0.0 * x + f 

423 

424 return self.__class__.initfun_adaptive(lambda x: constfun(x) / self(x), interval=self.interval) 

425 

426 __radd__ = __add__ 

427 

428 def __rsub__(self, f): 

429 """Subtract this Chebtech from a scalar. 

430 

431 This is called when f - self is executed and f is not a Chebtech. 

432 

433 Args: 

434 f: A scalar from which to subtract this Chebtech. 

435 

436 Returns: 

437 Chebtech: A new Chebtech representing f minus this Chebtech. 

438 """ 

439 return -(self - f) 

440 

441 @self_empty() 

442 def __rpow__(self, f): 

443 """Raise a scalar to the power of this Chebtech. 

444 

445 This is called when f ** self is executed and f is not a Chebtech. 

446 

447 Args: 

448 f: A scalar to be raised to the power of this Chebtech. 

449 

450 Returns: 

451 Chebtech: A new Chebtech representing f raised to the power of this Chebtech. 

452 """ 

453 return self.__class__.initfun_adaptive(lambda x: np.power(f, self(x)), interval=self.interval) 

454 

455 __rtruediv__ = __rdiv__ 

456 __rmul__ = __mul__ 

457 

458 def __sub__(self, f): 

459 """Subtract a scalar or another Chebtech from this Chebtech. 

460 

461 Args: 

462 f: A scalar or another Chebtech to subtract from this Chebtech. 

463 

464 Returns: 

465 Chebtech: A new Chebtech representing the difference. 

466 """ 

467 return self + (-f) 

468 

469 # ------- 

470 # roots 

471 # ------- 

472 def roots(self, sort=None): 

473 """Compute the roots of the Chebtech on [-1,1]. 

474 

475 Uses the coefficients in the associated Chebyshev series approximation. 

476 """ 

477 sort = sort if sort is not None else prefs.sortroots 

478 rts = rootsunit(self.coeffs) 

479 rts = newtonroots(self, rts) 

480 # fix problems with newton for roots that are numerically very close 

481 rts = np.clip(rts, -1, 1) # if newton roots are just outside [-1,1] 

482 rts = rts if not sort else np.sort(rts) 

483 return rts 

484 

485 # ---------- 

486 # calculus 

487 # ---------- 

488 # Note that function returns 0 for an empty Chebtech object; this is 

489 # consistent with numpy, which returns zero for the sum of an empty array 

490 @self_empty(resultif=0.0) 

491 def sum(self): 

492 """Definite integral of a Chebtech on the interval [-1,1].""" 

493 if self.isconst: 

494 out = 2.0 * self(0.0) 

495 else: 

496 ak = self.coeffs.copy() 

497 ak[1::2] = 0 

498 kk = np.arange(2, ak.size) 

499 ii = np.append([2, 0], 2 / (1 - kk**2)) 

500 out = (ak * ii).sum() 

501 return out 

502 

503 @self_empty() 

504 def cumsum(self): 

505 """Return a Chebtech object representing the indefinite integral. 

506 

507 Computes the indefinite integral of a Chebtech on the interval [-1,1]. 

508 The constant term is chosen such that F(-1) = 0. 

509 """ 

510 n = self.size 

511 ak = np.append(self.coeffs, [0, 0]) 

512 bk = np.zeros(n + 1, dtype=self.coeffs.dtype) 

513 rk = np.arange(2, n + 1) 

514 bk[2:] = 0.5 * (ak[1:n] - ak[3:]) / rk 

515 bk[1] = ak[0] - 0.5 * ak[2] 

516 vk = np.ones(n) 

517 vk[1::2] = -1 

518 bk[0] = (vk * bk[1:]).sum() 

519 out = self.__class__(bk, interval=self.interval) 

520 return out 

521 

522 @self_empty() 

523 def diff(self): 

524 """Return a Chebtech object representing the derivative. 

525 

526 Computes the derivative of a Chebtech on the interval [-1,1]. 

527 """ 

528 if self.isconst: 

529 out = self.__class__(np.array([0.0]), interval=self.interval) 

530 else: 

531 n = self.size 

532 ak = self.coeffs 

533 zk = np.zeros(n - 1, dtype=self.coeffs.dtype) 

534 wk = 2 * np.arange(1, n) 

535 vk = wk * ak[1:] 

536 zk[-1::-2] = vk[-1::-2].cumsum() 

537 zk[-2::-2] = vk[-2::-2].cumsum() 

538 zk[0] = 0.5 * zk[0] 

539 out = self.__class__(zk, interval=self.interval) 

540 return out 

541 

542 @staticmethod 

543 def _chebpts(n): 

544 """Return n Chebyshev points of the second-kind.""" 

545 return chebpts2(n) 

546 

547 @staticmethod 

548 def _barywts(n): 

549 """Barycentric weights for Chebyshev points of 2nd kind.""" 

550 return barywts2(n) 

551 

552 @staticmethod 

553 def _vals2coeffs(vals): 

554 """Map function values at Chebyshev points of 2nd kind. 

555 

556 Converts values at Chebyshev points of 2nd kind to first-kind Chebyshev polynomial coefficients. 

557 """ 

558 return vals2coeffs2(vals) 

559 

560 @staticmethod 

561 def _coeffs2vals(coeffs): 

562 """Map first-kind Chebyshev polynomial coefficients. 

563 

564 Converts first-kind Chebyshev polynomial coefficients to function values at Chebyshev points of 2nd kind. 

565 """ 

566 return coeffs2vals2(coeffs) 

567 

568 # ---------- 

569 # plotting 

570 # ---------- 

571 def plot(self, ax=None, **kwargs): 

572 """Plot the Chebtech on the interval [-1, 1]. 

573 

574 Args: 

575 ax (matplotlib.axes.Axes, optional): The axes on which to plot. Defaults to None. 

576 **kwargs: Additional keyword arguments to pass to the plot function. 

577 

578 Returns: 

579 matplotlib.lines.Line2D: The line object created by the plot. 

580 """ 

581 return plotfun(self, (-1, 1), ax=ax, **kwargs) 

582 

583 def plotcoeffs(self, ax=None, **kwargs): 

584 """Plot the absolute values of the Chebyshev coefficients. 

585 

586 Args: 

587 ax (matplotlib.axes.Axes, optional): The axes on which to plot. Defaults to None. 

588 **kwargs: Additional keyword arguments to pass to the plot function. 

589 

590 Returns: 

591 matplotlib.lines.Line2D: The line object created by the plot. 

592 """ 

593 ax = ax or plt.gca() 

594 return plotfuncoeffs(abs(self.coeffs), ax=ax, **kwargs)