Coverage for src / chebpy / trigtech.py: 93%

324 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 07:22 +0000

1"""Trigonometric (Fourier) technology for periodic function approximation. 

2 

3This module provides the Trigtech class, which represents smooth periodic functions 

4on [-1, 1] using truncated Fourier series. It is the trigonometric analogue of 

5Chebtech and sits in the same class hierarchy: 

6 

7 Onefun → Smoothfun → Trigtech 

8 

9Coefficient storage convention (NumPy-native / FFT order) 

10---------------------------------------------------------- 

11Given n equispaced sample points x_j = -1 + 2j/n (j = 0, …, n-1), the stored 

12coefficients are 

13 

14 coeffs[k] = (1/n) * sum_j f(x_j) * exp(-2*pi*i*j*k/n) 

15 = (numpy.fft.fft(values) / n)[k] 

16 

17This is exactly the output of ``numpy.fft.fft(values) / n``, i.e. NumPy-native 

18(FFT) ordering: DC at index 0, positive frequencies 1 … n//2, then negative 

19frequencies -(n//2)+1 … -1. 

20 

21Use ``_coeffs_to_plotorder()`` to obtain the human-readable DC-centred ordering 

22(equivalent to ``numpy.fft.fftshift``). 

23 

24Evaluation 

25---------- 

26Any point x ∈ [-1, 1] is evaluated via the DFT summation formula: 

27 

28 f(x) = Σ_k coeffs[k] * exp(i*π*ω_k*(x+1)) 

29 

30where ω_k = numpy.fft.fftfreq(n)*n gives the integer frequencies in FFT order. 

31 

32References: 

33---------- 

34* Trefethen, "Spectral Methods in MATLAB" (SIAM 2000) 

35* Chebfun @trigtech (github.com/chebfun/chebfun) 

36""" 

37 

38import warnings 

39from abc import ABC 

40from typing import Any 

41 

42import matplotlib.pyplot as plt 

43import numpy as np 

44 

45from .decorators import self_empty 

46from .plotting import plotfun, plotfuncoeffs 

47from .settings import _preferences as prefs 

48from .smoothfun import Smoothfun 

49from .utilities import Interval, coerce_list 

50 

51 

52def _trig_adaptive( 

53 cls: Any, 

54 fun: Any, 

55 hscale: float = 1, 

56 maxpow2: int | None = None, 

57) -> np.ndarray: 

58 """Adaptively determine the Fourier coefficients needed to represent *fun*. 

59 

60 Uses successively finer equispaced grids (sizes 2**k) until the 

61 high-frequency Fourier modes decay below tolerance. Convergence is 

62 assessed via the one-sided symmetric maximum of the DC-centred coefficient 

63 magnitudes: ``abs_sym[k] = max(|c_k|, |c_{-k}|) / vscale``. The series 

64 is considered converged when the Nyquist/highest-frequency mode 

65 ``abs_sym[-1]`` falls below *tol*. 

66 

67 Args: 

68 cls: Trigtech class (provides ``_trigpts`` and ``_vals2coeffs``). 

69 fun: Callable to approximate. 

70 hscale: Horizontal scale for tolerance adjustment. 

71 maxpow2: Maximum power of 2 to try (defaults to ``prefs.maxpow2``). 

72 

73 Returns: 

74 numpy.ndarray: Fourier coefficients in NumPy FFT order. 

75 """ 

76 minpow2 = 3 # start at n = 8 

77 maxpow2 = maxpow2 if maxpow2 is not None else prefs.maxpow2 

78 tol = prefs.eps * max(hscale, 1) 

79 coeffs: np.ndarray = np.array([]) 

80 for k in range(minpow2, max(minpow2, maxpow2) + 1): 

81 n = 2**k 

82 points = cls._trigpts(n) 

83 values = fun(points) 

84 coeffs = cls._vals2coeffs(values) 

85 vscale = float(np.max(np.abs(values))) 

86 if vscale <= tol: 

87 return np.array([0.0]) 

88 

89 # Build one-sided symmetric maximum: 

90 # abs_sym[ki] = max(|c_{ki}|, |c_{-ki}|) / vscale for ki = 0…n//2 

91 centered = np.fft.fftshift(coeffs) 

92 dc_idx = n // 2 

93 abs_sym = np.zeros(dc_idx + 1) 

94 for ki in range(dc_idx + 1): 

95 p = centered[dc_idx + ki] if dc_idx + ki < n else 0.0 

96 q = centered[dc_idx - ki] 

97 abs_sym[ki] = max(abs(p), abs(q)) / vscale 

98 

99 # Convergence: the Nyquist/highest-frequency mode is negligible. 

100 if abs_sym[-1] <= tol: 

101 above = np.where(abs_sym > tol)[0] 

102 if len(above) == 0: 

103 return np.array([0.0]) 

104 max_k = int(above[-1]) # highest significant frequency index 

105 start = dc_idx - max_k 

106 end = dc_idx + max_k + 1 

107 return np.fft.ifftshift(centered[start:end]) 

108 

109 if k == maxpow2: 

110 warnings.warn( 

111 f"The {cls.__name__} constructor did not converge: using {n} points", 

112 stacklevel=3, 

113 ) 

114 break 

115 return coeffs 

116 

117 

118class Trigtech(Smoothfun, ABC): 

119 """Trigonometric (Fourier) function approximation on [-1, 1]. 

120 

121 Represents a smooth periodic function f: [-1, 1] -> R (or C) as a 

122 truncated Fourier series. Coefficients are stored in NumPy FFT order; 

123 see module docstring for the precise convention. 

124 

125 This class is ``ABC`` so that it cannot be instantiated directly—exactly 

126 mirroring Chebtech, which is also abstract (concrete only through the 

127 ``Chebtech`` name used everywhere). In practice ``Trigtech`` is both the 

128 abstract base and the concrete class: it is not further subclassed, but 

129 the ABC marker prevents accidental bare construction without going through 

130 a named constructor. 

131 """ 

132 

133 # ------------------------------------------------------------------ 

134 # alternative constructors 

135 # ------------------------------------------------------------------ 

136 

137 @classmethod 

138 def initconst(cls, c: Any = None, *, interval: Any = None) -> "Trigtech": 

139 """Initialise a Trigtech from a constant *c*.""" 

140 if not np.isscalar(c): 

141 raise ValueError(c) 

142 if isinstance(c, int): 

143 c = float(c) 

144 return cls(np.array([c]), interval=interval) 

145 

146 @classmethod 

147 def initempty(cls, *, interval: Any = None) -> "Trigtech": 

148 """Initialise an empty Trigtech.""" 

149 return cls(np.array([]), interval=interval) 

150 

151 @classmethod 

152 def initidentity(cls, *, interval: Any = None) -> "Trigtech": 

153 """Trigtech approximation of the identity f(x) = x on [-1, 1]. 

154 

155 Note: f(x) = x is *not* periodic on [-1, 1], so this will not converge 

156 to machine precision. It is provided for interface compatibility with 

157 Chebtech; in practice ``Classicfun.initidentity`` is used instead. 

158 """ 

159 interval = interval if interval is not None else prefs.domain 

160 return cls.initfun_adaptive(lambda x: x, interval=interval) 

161 

162 @classmethod 

163 def initfun(cls, fun: Any = None, n: Any = None, *, interval: Any = None) -> "Trigtech": 

164 """Convenience constructor: adaptive if *n* is None, fixed-length otherwise.""" 

165 if n is None: 

166 return cls.initfun_adaptive(fun, interval=interval) 

167 return cls.initfun_fixedlen(fun, n, interval=interval) 

168 

169 @classmethod 

170 def initfun_fixedlen(cls, fun: Any = None, n: Any = None, *, interval: Any = None) -> "Trigtech": 

171 """Initialise a Trigtech from callable *fun* using *n* equispaced points.""" 

172 if n is None: 

173 raise ValueError("initfun_fixedlen requires the n parameter to be specified") # noqa: TRY003 

174 points = cls._trigpts(int(n)) 

175 values = fun(points) 

176 coeffs = cls._vals2coeffs(values) 

177 return cls(coeffs, interval=interval) 

178 

179 @classmethod 

180 def initfun_adaptive(cls, fun: Any = None, *, interval: Any = None) -> "Trigtech": 

181 """Initialise a Trigtech from callable *fun* using the adaptive constructor.""" 

182 interval = interval if interval is not None else prefs.domain 

183 interval = Interval(*interval) 

184 coeffs = _trig_adaptive(cls, fun, hscale=interval.hscale) 

185 return cls(coeffs, interval=interval) 

186 

187 @classmethod 

188 def initvalues(cls, values: Any = None, *, interval: Any = None) -> "Trigtech": 

189 """Initialise a Trigtech from function values at equispaced points.""" 

190 return cls(cls._vals2coeffs(np.asarray(values)), interval=interval) 

191 

192 # ------------------------------------------------------------------ 

193 # core dunder methods 

194 # ------------------------------------------------------------------ 

195 

196 def __init__(self, coeffs: Any, interval: Any = None) -> None: 

197 """Initialise a Trigtech with FFT-order *coeffs* on *interval*. 

198 

199 Coefficients are always stored as complex128. The :attr:`iscomplex` 

200 property returns True only when the function *values* are complex 

201 (i.e., the coefficients do **not** satisfy the conjugate-symmetry 

202 condition C_{n-k} ≈ conj(C_k)). 

203 

204 Args: 

205 coeffs: 1-D array of Fourier coefficients in NumPy FFT order. 

206 interval: Two-element interval [a, b]. Defaults to ``prefs.domain``. 

207 """ 

208 interval = interval if interval is not None else prefs.domain 

209 self._coeffs = np.array(coeffs, dtype=complex) 

210 self._interval = Interval(*interval) 

211 

212 def __call__(self, x: Any, how: str = "fft") -> Any: 

213 """Evaluate the Trigtech at points *x* via the DFT summation formula. 

214 

215 f(x) = Σ_k coeffs[k] * exp(i*π*ω_k*(x+1)) 

216 

217 where ω_k = fftfreq(n)*n gives integer frequencies in FFT order. 

218 For real-valued functions the imaginary part of the result is discarded. 

219 

220 Args: 

221 x: Evaluation points in [-1, 1]. 

222 how: Ignored; present for interface compatibility with Chebtech. 

223 """ 

224 if self.isempty: 

225 return np.array([]) 

226 scalar = np.isscalar(x) 

227 x = np.atleast_1d(np.asarray(x, dtype=float)).ravel() 

228 

229 if self.isconst: 

230 c0 = self._coeffs[0].real if not self.iscomplex else self._coeffs[0] 

231 out = c0 * np.ones(x.size) 

232 return float(out[0]) if scalar else out 

233 

234 n = self.size 

235 freqs = np.fft.fftfreq(n) * n # [0, 1, …, n//2, -(n//2)+1, …, -1] 

236 # shape: (len(x), n) @ (n,) → (len(x),) 

237 phases = np.exp(1j * np.pi * np.outer(x + 1.0, freqs)) 

238 result = phases @ self._coeffs 

239 if not self.iscomplex: 

240 result = result.real 

241 return float(result[0]) if scalar else result 

242 

243 def __repr__(self) -> str: # pragma: no cover 

244 """Return a concise string representation.""" 

245 return f"<{self.__class__.__name__}{{{self.size}}}>" 

246 

247 # ------------------------------------------------------------------ 

248 # properties 

249 # ------------------------------------------------------------------ 

250 

251 @property 

252 def coeffs(self) -> np.ndarray: 

253 """Fourier coefficients in NumPy FFT order (always complex128).""" 

254 return self._coeffs 

255 

256 @property 

257 def interval(self) -> Interval: 

258 """Interval that the Trigtech is mapped to.""" 

259 return self._interval 

260 

261 @property 

262 def size(self) -> int: 

263 """Number of stored Fourier coefficients.""" 

264 return self._coeffs.size 

265 

266 @property 

267 def isempty(self) -> bool: 

268 """True if the Trigtech has no coefficients.""" 

269 return self.size == 0 

270 

271 @property 

272 def iscomplex(self) -> bool: 

273 """True if the function is complex-valued (values have a non-negligible imaginary part). 

274 

275 This is determined by checking whether the Fourier coefficients violate 

276 the conjugate-symmetry condition C_{n-k} ≈ conj(C_k) that holds for 

277 every real-valued periodic function. 

278 """ 

279 n = self.size 

280 if n <= 1: 

281 return bool(np.any(np.abs(np.imag(self._coeffs)) > 0)) 

282 abs_max = float(np.max(np.abs(self._coeffs))) 

283 if abs_max == 0.0: 

284 return False 

285 tol = 1e-8 * abs_max 

286 # mirror[k-1] = conj(C_{n-k}) for k = 1,...,n-1 

287 mirror = np.conj(self._coeffs[-1:0:-1]) 

288 return bool(np.any(np.abs(self._coeffs[1:] - mirror) > tol)) 

289 

290 @property 

291 def isconst(self) -> bool: 

292 """True if the Trigtech represents a constant (single coefficient).""" 

293 return self.size == 1 

294 

295 @property 

296 def isperiodic(self) -> bool: 

297 """Always True: Trigtech always represents a periodic function.""" 

298 return True 

299 

300 @property 

301 @self_empty(0.0) 

302 def vscale(self) -> float: 

303 """Estimate the vertical scale (max |f|).""" 

304 return float(np.abs(np.asarray(coerce_list(self.values()))).max()) 

305 

306 # ------------------------------------------------------------------ 

307 # utilities 

308 # ------------------------------------------------------------------ 

309 

310 def copy(self) -> "Trigtech": 

311 """Return a deep copy.""" 

312 return self.__class__(self._coeffs.copy(), interval=self._interval.copy()) 

313 

314 def imag(self) -> "Trigtech": 

315 """Return the imaginary part of the function as a real-valued Trigtech. 

316 

317 For a complex function f(x) = g(x) + i·h(x), the Fourier coefficients 

318 of h(x) are H[k] = (D[k] - conj(D[n-k])) / (2i) for k ≥ 1, 

319 and H[0] = Im(D[0]). 

320 """ 

321 if not self.iscomplex: 

322 return self.initconst(0.0, interval=self._interval) 

323 n = self.size 

324 c = self._coeffs 

325 imag_c = np.zeros(n, dtype=complex) 

326 imag_c[0] = np.imag(c[0]) 

327 if n > 1: 

328 mirror = np.conj(c[-1:0:-1]) # conj(c[n-1]), ..., conj(c[1]) 

329 imag_c[1:] = (c[1:] - mirror) / (2j) 

330 return self.__class__(imag_c, self._interval) 

331 

332 def prolong(self, n: int) -> "Trigtech": 

333 """Return a Trigtech of length *n* (truncate or zero-pad in frequency space). 

334 

335 The operation aligns DC components of the source and target DC-centred 

336 representations, then either pads with zeros (n > m) or slices (n < m). 

337 This correctly handles the asymmetry between even- and odd-length arrays. 

338 """ 

339 m = self.size 

340 if n == m: 

341 return self.copy() 

342 

343 centered = np.fft.fftshift(self._coeffs) 

344 dc_src = m // 2 

345 dc_tgt = n // 2 

346 

347 if n > m: 

348 padded = np.zeros(n, dtype=centered.dtype) 

349 start = dc_tgt - dc_src 

350 padded[start : start + m] = centered 

351 return self.__class__(np.fft.ifftshift(padded), interval=self._interval) 

352 else: 

353 start = dc_src - dc_tgt 

354 truncated = centered[start : start + n] 

355 return self.__class__(np.fft.ifftshift(truncated), interval=self._interval) 

356 

357 def real(self) -> "Trigtech": 

358 """Return the real part of the function as a real-valued Trigtech. 

359 

360 For a complex function f(x) = g(x) + i·h(x), the Fourier coefficients 

361 of g(x) are G[k] = (D[k] + conj(D[n-k])) / 2 for k ≥ 1, 

362 and G[0] = Re(D[0]). 

363 """ 

364 if not self.iscomplex: 

365 return self 

366 n = self.size 

367 c = self._coeffs 

368 real_c = np.zeros(n, dtype=complex) 

369 real_c[0] = np.real(c[0]) 

370 if n > 1: 

371 mirror = np.conj(c[-1:0:-1]) # conj(c[n-1]), ..., conj(c[1]) 

372 real_c[1:] = (c[1:] + mirror) / 2 

373 return self.__class__(real_c, self._interval) 

374 

375 def simplify(self) -> "Trigtech": 

376 """Truncate high-frequency Fourier coefficients that are below tolerance. 

377 

378 Uses the same one-sided symmetric-maximum criterion as the adaptive 

379 constructor: the highest-frequency mode retained is the one where 

380 ``max(|c_k|, |c_{-k}|) / vscale > tol``. 

381 """ 

382 oldlen = len(self._coeffs) 

383 longself = self.prolong(max(17, oldlen)) 

384 n = longself.size 

385 tol = prefs.eps * max(self._interval.hscale, 1) 

386 

387 centered = np.fft.fftshift(longself._coeffs) 

388 dc_idx = n // 2 

389 abs_max = float(np.max(np.abs(centered))) 

390 if abs_max == 0.0: 

391 return self.initconst(0.0, interval=self._interval) 

392 

393 abs_sym = np.zeros(dc_idx + 1) 

394 for ki in range(dc_idx + 1): 

395 p = centered[dc_idx + ki] if dc_idx + ki < n else 0.0 

396 q = centered[dc_idx - ki] 

397 abs_sym[ki] = max(abs(p), abs(q)) / abs_max 

398 

399 above = np.where(abs_sym > tol)[0] 

400 if len(above) == 0: 

401 return self.initconst(0.0, interval=self._interval) 

402 max_k = int(above[-1]) 

403 max_k = min(max_k, oldlen // 2) # don't exceed original size 

404 

405 start = dc_idx - max_k 

406 end = dc_idx + max_k + 1 

407 return self.__class__(np.fft.ifftshift(centered[start:end]), interval=self._interval) 

408 

409 def values(self) -> np.ndarray: 

410 """Function values at the n equispaced points x_j = -1 + 2j/n.""" 

411 return self._coeffs2vals(self._coeffs) 

412 

413 def _coeffs_to_plotorder(self) -> np.ndarray: 

414 """Return coefficients in DC-centred (human-readable) order. 

415 

416 Equivalent to ``numpy.fft.fftshift(self.coeffs)``: 

417 ordering is [c_{-n//2}, …, c_{-1}, c_0, c_1, …, c_{n//2-1}]. 

418 """ 

419 return np.fft.fftshift(self._coeffs) 

420 

421 # ------------------------------------------------------------------ 

422 # algebra 

423 # ------------------------------------------------------------------ 

424 

425 @self_empty() 

426 def __add__(self, f: Any) -> "Trigtech": 

427 """Add a scalar or another Trigtech.""" 

428 cls = self.__class__ 

429 if np.isscalar(f): 

430 dtype: Any = complex if np.iscomplexobj(f) else self._coeffs.dtype 

431 cfs = np.array(self._coeffs, dtype=dtype) 

432 cfs[0] += f # add to DC component 

433 return cls(cfs, interval=self._interval) 

434 if f.isempty: 

435 return f.copy() 

436 g = self 

437 n, m = g.size, f.size 

438 if n < m: 

439 g = g.prolong(m) 

440 elif m < n: 

441 f = f.prolong(n) 

442 cfs = f.coeffs + g.coeffs 

443 eps = prefs.eps 

444 tol = 0.5 * eps * max(f.vscale, g.vscale) 

445 if np.all(np.abs(cfs) < tol): 

446 return cls.initconst(0.0, interval=self._interval) 

447 return cls(cfs, interval=self._interval) 

448 

449 @self_empty() 

450 def __div__(self, f: Any) -> "Trigtech": 

451 """Divide this Trigtech by a scalar or another Trigtech.""" 

452 cls = self.__class__ 

453 if np.isscalar(f): 

454 return cls(self._coeffs / np.asarray(f), interval=self._interval) 

455 if f.isempty: 

456 return f.copy() 

457 return cls.initfun_adaptive(lambda x: self(x) / f(x), interval=self._interval) 

458 

459 __truediv__ = __div__ 

460 

461 @self_empty() 

462 def __mul__(self, g: Any) -> "Trigtech": 

463 """Multiply this Trigtech by a scalar or another Trigtech. 

464 

465 Trig-polynomial multiplication is circular convolution in frequency 

466 space. We implement this cleanly by evaluating both on a grid of 

467 size n1 + n2 (sufficient to avoid aliasing), multiplying pointwise, 

468 and taking the FFT. 

469 """ 

470 cls = self.__class__ 

471 if np.isscalar(g): 

472 return cls(g * self._coeffs, interval=self._interval) 

473 if g.isempty: 

474 return g.copy() 

475 n = self.size + g.size 

476 f_vals = self.prolong(n).values() 

477 g_vals = g.prolong(n).values() 

478 return cls(cls._vals2coeffs(f_vals * g_vals), interval=self._interval) 

479 

480 def __neg__(self) -> "Trigtech": 

481 """Return the negation.""" 

482 return self.__class__(-self._coeffs, interval=self._interval) 

483 

484 def __pos__(self) -> "Trigtech": 

485 """Return self (unary plus).""" 

486 return self 

487 

488 @self_empty() 

489 def __pow__(self, f: Any) -> "Trigtech": 

490 """Raise this Trigtech to a power *f* (scalar or Trigtech).""" 

491 

492 def powfun(fn: Any, x: Any) -> Any: 

493 return fn if np.isscalar(fn) else fn(x) 

494 

495 return self.__class__.initfun_adaptive( 

496 lambda x: np.power(self(x), powfun(f, x)), 

497 interval=self._interval, 

498 ) 

499 

500 def __rdiv__(self, f: Any) -> "Trigtech": 

501 """Compute f / self where *f* is a scalar.""" 

502 return self.__class__.initfun_adaptive( 

503 lambda x: (0.0 * x + f) / self(x), 

504 interval=self._interval, 

505 ) 

506 

507 __radd__ = __add__ 

508 __rmul__ = __mul__ 

509 __rtruediv__ = __rdiv__ 

510 

511 def __rsub__(self, f: Any) -> "Trigtech": 

512 """Compute f - self.""" 

513 return -(self - f) 

514 

515 @self_empty() 

516 def __rpow__(self, f: Any) -> "Trigtech": 

517 """Compute f ** self.""" 

518 return self.__class__.initfun_adaptive( 

519 lambda x: np.power(f, self(x)), 

520 interval=self._interval, 

521 ) 

522 

523 def __sub__(self, f: Any) -> "Trigtech": 

524 """Subtract *f* (scalar or Trigtech) from this Trigtech.""" 

525 return self + (-f) 

526 

527 # ------------------------------------------------------------------ 

528 # rootfinding 

529 # ------------------------------------------------------------------ 

530 

531 def roots(self, sort: bool | None = None) -> np.ndarray: 

532 """Find the roots of this Trigtech on [-1, 1]. 

533 

534 Converts to a Chebyshev representation via re-sampling on Chebyshev 

535 points and delegates to the Chebtech colleague-matrix root-finder. 

536 

537 Args: 

538 sort: If True, sort the roots in ascending order. Defaults to 

539 ``prefs.sortroots``. 

540 """ 

541 from .algorithms import newtonroots, rootsunit 

542 from .chebtech import Chebtech 

543 

544 sort = sort if sort is not None else prefs.sortroots 

545 

546 if self.isempty: 

547 return np.array([]) 

548 

549 # Sample on a Chebyshev grid and fit a Chebtech of the same resolution 

550 n = max(2 * self.size + 1, 33) 

551 cheb_pts = Chebtech._chebpts(n) 

552 vals = self(cheb_pts) 

553 ct = Chebtech(Chebtech._vals2coeffs(vals)) 

554 rts = rootsunit(ct.coeffs) 

555 rts = newtonroots(ct, rts) 

556 rts = np.clip(rts, -1.0, 1.0) 

557 return np.sort(rts) if sort else rts 

558 

559 # ------------------------------------------------------------------ 

560 # calculus 

561 # ------------------------------------------------------------------ 

562 

563 @self_empty(resultif=0.0) 

564 def sum(self) -> Any: 

565 """Definite integral of the Trigtech over [-1, 1]. 

566 

567 Only the DC coefficient contributes: 

568 ∫_{-1}^{1} exp(i*π*k*(x+1)) dx = 0 for k ≠ 0 

569 ∫_{-1}^{1} 1 dx = 2 for k = 0 

570 """ 

571 return 2.0 * float(np.real(self._coeffs[0])) 

572 

573 @self_empty() 

574 def cumsum(self) -> "Trigtech": 

575 """Indefinite integral, zero at x = -1, in Fourier coefficient space. 

576 

577 For mode k ≠ 0: antiderivative coefficient = c_k / (i*π*ω_k) 

578 For mode k = 0: set to the constant needed so that F(-1) = 0. 

579 

580 Note: if the DC component (self.coeffs[0]) is non-zero the true 

581 antiderivative contains a linear trend and is not periodic. We still 

582 return a Trigtech representing the *periodic* part, adjusted so that 

583 the result evaluates to 0 at x = -1. 

584 """ 

585 n = self.size 

586 c = self._coeffs.copy() 

587 freqs = np.fft.fftfreq(n) * n # FFT-order integer frequencies 

588 

589 int_c = np.zeros(n, dtype=complex) 

590 mask = freqs != 0 

591 int_c[mask] = c[mask] / (1j * np.pi * freqs[mask]) 

592 

593 # Enforce F(-1) = 0. 

594 # F(x) = Σ_k int_c[k] * exp(i*π*ω_k*(x+1)) 

595 # At x = -1: exp(i*π*ω_k*0) = 1 for all k, so F(-1) = Σ int_c 

596 # Set int_c[0] so that sum(int_c) = 0. 

597 int_c[0] = -np.sum(int_c[1:]) 

598 return self.__class__(int_c, interval=self._interval) 

599 

600 @self_empty() 

601 def diff(self) -> "Trigtech": 

602 """Derivative via the Fourier multiplier i*π*ω_k. 

603 

604 d/dx [c_k * exp(i*π*ω_k*(x+1))] = i*π*ω_k * c_k * exp(i*π*ω_k*(x+1)) 

605 """ 

606 if self.isconst: 

607 return self.__class__(np.array([0.0 + 0.0j]), interval=self._interval) 

608 n = self.size 

609 freqs = np.fft.fftfreq(n) * n 

610 d_coeffs = (1j * np.pi * freqs) * self._coeffs 

611 return self.__class__(d_coeffs, interval=self._interval) 

612 

613 # ------------------------------------------------------------------ 

614 # static helpers (FFT ↔ values) 

615 # ------------------------------------------------------------------ 

616 

617 @staticmethod 

618 def _trigpts(n: int) -> np.ndarray: 

619 """Return *n* equispaced points on [-1, 1).""" 

620 if n == 0: 

621 return np.array([]) 

622 return -1.0 + 2.0 * np.arange(n) / n 

623 

624 @staticmethod 

625 def _vals2coeffs(vals: Any) -> np.ndarray: 

626 """Convert values at equispaced points to FFT coefficients (divided by n). 

627 

628 Always returns complex128, even for real-valued inputs, because Fourier 

629 coefficients for functions such as sin are purely imaginary and would be 

630 discarded if forced to real. 

631 

632 Inverse of ``_coeffs2vals``. 

633 """ 

634 vals = np.asarray(vals) 

635 n = vals.size 

636 if n == 0: 

637 return np.array([], dtype=complex) 

638 return np.fft.fft(vals) / n 

639 

640 @staticmethod 

641 def _coeffs2vals(coeffs: Any) -> np.ndarray: 

642 """Convert FFT coefficients (divided by n) to values at equispaced points. 

643 

644 Inverse of ``_vals2coeffs``. 

645 """ 

646 coeffs = np.asarray(coeffs, dtype=complex) 

647 n = coeffs.size 

648 if n == 0: 

649 return np.array([], dtype=float) 

650 vals = n * np.fft.ifft(coeffs) 

651 # Discard negligible imaginary parts for conjugate-symmetric coefficients 

652 max_real = float(np.max(np.abs(np.real(vals)))) 

653 if float(np.max(np.abs(np.imag(vals)))) < 1e-10 * max(max_real, 1.0): 

654 return np.real(vals) 

655 return vals 

656 

657 # ------------------------------------------------------------------ 

658 # plotting 

659 # ------------------------------------------------------------------ 

660 

661 def plot(self, ax: Any = None, **kwargs: Any) -> Any: 

662 """Plot the Trigtech over [-1, 1]. 

663 

664 Args: 

665 ax: Matplotlib axes. If None, uses the current axes. 

666 **kwargs: Forwarded to matplotlib. 

667 

668 Returns: 

669 The axes on which the plot was drawn. 

670 """ 

671 return plotfun(self, (-1, 1), ax=ax, **kwargs) 

672 

673 def plotcoeffs(self, ax: Any = None, **kwargs: Any) -> Any: 

674 """Plot the absolute Fourier coefficient magnitudes in DC-centred order. 

675 

676 Uses ``_coeffs_to_plotorder()`` so the horizontal axis runs from 

677 the most-negative frequency on the left to the most-positive on 

678 the right, with DC in the centre. 

679 

680 Args: 

681 ax: Matplotlib axes. If None, uses the current axes. 

682 **kwargs: Forwarded to matplotlib. 

683 

684 Returns: 

685 The axes on which the plot was drawn. 

686 """ 

687 ax = ax or plt.gca() 

688 return plotfuncoeffs(np.abs(self._coeffs_to_plotorder()), ax=ax, **kwargs)