Coverage for src / chebpy / trigtech.py: 93%
324 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 07:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 07:22 +0000
1"""Trigonometric (Fourier) technology for periodic function approximation.
3This module provides the Trigtech class, which represents smooth periodic functions
4on [-1, 1] using truncated Fourier series. It is the trigonometric analogue of
5Chebtech and sits in the same class hierarchy:
7 Onefun → Smoothfun → Trigtech
9Coefficient storage convention (NumPy-native / FFT order)
10----------------------------------------------------------
11Given n equispaced sample points x_j = -1 + 2j/n (j = 0, …, n-1), the stored
12coefficients are
14 coeffs[k] = (1/n) * sum_j f(x_j) * exp(-2*pi*i*j*k/n)
15 = (numpy.fft.fft(values) / n)[k]
17This is exactly the output of ``numpy.fft.fft(values) / n``, i.e. NumPy-native
18(FFT) ordering: DC at index 0, positive frequencies 1 … n//2, then negative
19frequencies -(n//2)+1 … -1.
21Use ``_coeffs_to_plotorder()`` to obtain the human-readable DC-centred ordering
22(equivalent to ``numpy.fft.fftshift``).
24Evaluation
25----------
26Any point x ∈ [-1, 1] is evaluated via the DFT summation formula:
28 f(x) = Σ_k coeffs[k] * exp(i*π*ω_k*(x+1))
30where ω_k = numpy.fft.fftfreq(n)*n gives the integer frequencies in FFT order.
32References:
33----------
34* Trefethen, "Spectral Methods in MATLAB" (SIAM 2000)
35* Chebfun @trigtech (github.com/chebfun/chebfun)
36"""
38import warnings
39from abc import ABC
40from typing import Any
42import matplotlib.pyplot as plt
43import numpy as np
45from .decorators import self_empty
46from .plotting import plotfun, plotfuncoeffs
47from .settings import _preferences as prefs
48from .smoothfun import Smoothfun
49from .utilities import Interval, coerce_list
52def _trig_adaptive(
53 cls: Any,
54 fun: Any,
55 hscale: float = 1,
56 maxpow2: int | None = None,
57) -> np.ndarray:
58 """Adaptively determine the Fourier coefficients needed to represent *fun*.
60 Uses successively finer equispaced grids (sizes 2**k) until the
61 high-frequency Fourier modes decay below tolerance. Convergence is
62 assessed via the one-sided symmetric maximum of the DC-centred coefficient
63 magnitudes: ``abs_sym[k] = max(|c_k|, |c_{-k}|) / vscale``. The series
64 is considered converged when the Nyquist/highest-frequency mode
65 ``abs_sym[-1]`` falls below *tol*.
67 Args:
68 cls: Trigtech class (provides ``_trigpts`` and ``_vals2coeffs``).
69 fun: Callable to approximate.
70 hscale: Horizontal scale for tolerance adjustment.
71 maxpow2: Maximum power of 2 to try (defaults to ``prefs.maxpow2``).
73 Returns:
74 numpy.ndarray: Fourier coefficients in NumPy FFT order.
75 """
76 minpow2 = 3 # start at n = 8
77 maxpow2 = maxpow2 if maxpow2 is not None else prefs.maxpow2
78 tol = prefs.eps * max(hscale, 1)
79 coeffs: np.ndarray = np.array([])
80 for k in range(minpow2, max(minpow2, maxpow2) + 1):
81 n = 2**k
82 points = cls._trigpts(n)
83 values = fun(points)
84 coeffs = cls._vals2coeffs(values)
85 vscale = float(np.max(np.abs(values)))
86 if vscale <= tol:
87 return np.array([0.0])
89 # Build one-sided symmetric maximum:
90 # abs_sym[ki] = max(|c_{ki}|, |c_{-ki}|) / vscale for ki = 0…n//2
91 centered = np.fft.fftshift(coeffs)
92 dc_idx = n // 2
93 abs_sym = np.zeros(dc_idx + 1)
94 for ki in range(dc_idx + 1):
95 p = centered[dc_idx + ki] if dc_idx + ki < n else 0.0
96 q = centered[dc_idx - ki]
97 abs_sym[ki] = max(abs(p), abs(q)) / vscale
99 # Convergence: the Nyquist/highest-frequency mode is negligible.
100 if abs_sym[-1] <= tol:
101 above = np.where(abs_sym > tol)[0]
102 if len(above) == 0:
103 return np.array([0.0])
104 max_k = int(above[-1]) # highest significant frequency index
105 start = dc_idx - max_k
106 end = dc_idx + max_k + 1
107 return np.fft.ifftshift(centered[start:end])
109 if k == maxpow2:
110 warnings.warn(
111 f"The {cls.__name__} constructor did not converge: using {n} points",
112 stacklevel=3,
113 )
114 break
115 return coeffs
118class Trigtech(Smoothfun, ABC):
119 """Trigonometric (Fourier) function approximation on [-1, 1].
121 Represents a smooth periodic function f: [-1, 1] -> R (or C) as a
122 truncated Fourier series. Coefficients are stored in NumPy FFT order;
123 see module docstring for the precise convention.
125 This class is ``ABC`` so that it cannot be instantiated directly—exactly
126 mirroring Chebtech, which is also abstract (concrete only through the
127 ``Chebtech`` name used everywhere). In practice ``Trigtech`` is both the
128 abstract base and the concrete class: it is not further subclassed, but
129 the ABC marker prevents accidental bare construction without going through
130 a named constructor.
131 """
133 # ------------------------------------------------------------------
134 # alternative constructors
135 # ------------------------------------------------------------------
137 @classmethod
138 def initconst(cls, c: Any = None, *, interval: Any = None) -> "Trigtech":
139 """Initialise a Trigtech from a constant *c*."""
140 if not np.isscalar(c):
141 raise ValueError(c)
142 if isinstance(c, int):
143 c = float(c)
144 return cls(np.array([c]), interval=interval)
146 @classmethod
147 def initempty(cls, *, interval: Any = None) -> "Trigtech":
148 """Initialise an empty Trigtech."""
149 return cls(np.array([]), interval=interval)
151 @classmethod
152 def initidentity(cls, *, interval: Any = None) -> "Trigtech":
153 """Trigtech approximation of the identity f(x) = x on [-1, 1].
155 Note: f(x) = x is *not* periodic on [-1, 1], so this will not converge
156 to machine precision. It is provided for interface compatibility with
157 Chebtech; in practice ``Classicfun.initidentity`` is used instead.
158 """
159 interval = interval if interval is not None else prefs.domain
160 return cls.initfun_adaptive(lambda x: x, interval=interval)
162 @classmethod
163 def initfun(cls, fun: Any = None, n: Any = None, *, interval: Any = None) -> "Trigtech":
164 """Convenience constructor: adaptive if *n* is None, fixed-length otherwise."""
165 if n is None:
166 return cls.initfun_adaptive(fun, interval=interval)
167 return cls.initfun_fixedlen(fun, n, interval=interval)
169 @classmethod
170 def initfun_fixedlen(cls, fun: Any = None, n: Any = None, *, interval: Any = None) -> "Trigtech":
171 """Initialise a Trigtech from callable *fun* using *n* equispaced points."""
172 if n is None:
173 raise ValueError("initfun_fixedlen requires the n parameter to be specified") # noqa: TRY003
174 points = cls._trigpts(int(n))
175 values = fun(points)
176 coeffs = cls._vals2coeffs(values)
177 return cls(coeffs, interval=interval)
179 @classmethod
180 def initfun_adaptive(cls, fun: Any = None, *, interval: Any = None) -> "Trigtech":
181 """Initialise a Trigtech from callable *fun* using the adaptive constructor."""
182 interval = interval if interval is not None else prefs.domain
183 interval = Interval(*interval)
184 coeffs = _trig_adaptive(cls, fun, hscale=interval.hscale)
185 return cls(coeffs, interval=interval)
187 @classmethod
188 def initvalues(cls, values: Any = None, *, interval: Any = None) -> "Trigtech":
189 """Initialise a Trigtech from function values at equispaced points."""
190 return cls(cls._vals2coeffs(np.asarray(values)), interval=interval)
192 # ------------------------------------------------------------------
193 # core dunder methods
194 # ------------------------------------------------------------------
196 def __init__(self, coeffs: Any, interval: Any = None) -> None:
197 """Initialise a Trigtech with FFT-order *coeffs* on *interval*.
199 Coefficients are always stored as complex128. The :attr:`iscomplex`
200 property returns True only when the function *values* are complex
201 (i.e., the coefficients do **not** satisfy the conjugate-symmetry
202 condition C_{n-k} ≈ conj(C_k)).
204 Args:
205 coeffs: 1-D array of Fourier coefficients in NumPy FFT order.
206 interval: Two-element interval [a, b]. Defaults to ``prefs.domain``.
207 """
208 interval = interval if interval is not None else prefs.domain
209 self._coeffs = np.array(coeffs, dtype=complex)
210 self._interval = Interval(*interval)
212 def __call__(self, x: Any, how: str = "fft") -> Any:
213 """Evaluate the Trigtech at points *x* via the DFT summation formula.
215 f(x) = Σ_k coeffs[k] * exp(i*π*ω_k*(x+1))
217 where ω_k = fftfreq(n)*n gives integer frequencies in FFT order.
218 For real-valued functions the imaginary part of the result is discarded.
220 Args:
221 x: Evaluation points in [-1, 1].
222 how: Ignored; present for interface compatibility with Chebtech.
223 """
224 if self.isempty:
225 return np.array([])
226 scalar = np.isscalar(x)
227 x = np.atleast_1d(np.asarray(x, dtype=float)).ravel()
229 if self.isconst:
230 c0 = self._coeffs[0].real if not self.iscomplex else self._coeffs[0]
231 out = c0 * np.ones(x.size)
232 return float(out[0]) if scalar else out
234 n = self.size
235 freqs = np.fft.fftfreq(n) * n # [0, 1, …, n//2, -(n//2)+1, …, -1]
236 # shape: (len(x), n) @ (n,) → (len(x),)
237 phases = np.exp(1j * np.pi * np.outer(x + 1.0, freqs))
238 result = phases @ self._coeffs
239 if not self.iscomplex:
240 result = result.real
241 return float(result[0]) if scalar else result
243 def __repr__(self) -> str: # pragma: no cover
244 """Return a concise string representation."""
245 return f"<{self.__class__.__name__}{{{self.size}}}>"
247 # ------------------------------------------------------------------
248 # properties
249 # ------------------------------------------------------------------
251 @property
252 def coeffs(self) -> np.ndarray:
253 """Fourier coefficients in NumPy FFT order (always complex128)."""
254 return self._coeffs
256 @property
257 def interval(self) -> Interval:
258 """Interval that the Trigtech is mapped to."""
259 return self._interval
261 @property
262 def size(self) -> int:
263 """Number of stored Fourier coefficients."""
264 return self._coeffs.size
266 @property
267 def isempty(self) -> bool:
268 """True if the Trigtech has no coefficients."""
269 return self.size == 0
271 @property
272 def iscomplex(self) -> bool:
273 """True if the function is complex-valued (values have a non-negligible imaginary part).
275 This is determined by checking whether the Fourier coefficients violate
276 the conjugate-symmetry condition C_{n-k} ≈ conj(C_k) that holds for
277 every real-valued periodic function.
278 """
279 n = self.size
280 if n <= 1:
281 return bool(np.any(np.abs(np.imag(self._coeffs)) > 0))
282 abs_max = float(np.max(np.abs(self._coeffs)))
283 if abs_max == 0.0:
284 return False
285 tol = 1e-8 * abs_max
286 # mirror[k-1] = conj(C_{n-k}) for k = 1,...,n-1
287 mirror = np.conj(self._coeffs[-1:0:-1])
288 return bool(np.any(np.abs(self._coeffs[1:] - mirror) > tol))
290 @property
291 def isconst(self) -> bool:
292 """True if the Trigtech represents a constant (single coefficient)."""
293 return self.size == 1
295 @property
296 def isperiodic(self) -> bool:
297 """Always True: Trigtech always represents a periodic function."""
298 return True
300 @property
301 @self_empty(0.0)
302 def vscale(self) -> float:
303 """Estimate the vertical scale (max |f|)."""
304 return float(np.abs(np.asarray(coerce_list(self.values()))).max())
306 # ------------------------------------------------------------------
307 # utilities
308 # ------------------------------------------------------------------
310 def copy(self) -> "Trigtech":
311 """Return a deep copy."""
312 return self.__class__(self._coeffs.copy(), interval=self._interval.copy())
314 def imag(self) -> "Trigtech":
315 """Return the imaginary part of the function as a real-valued Trigtech.
317 For a complex function f(x) = g(x) + i·h(x), the Fourier coefficients
318 of h(x) are H[k] = (D[k] - conj(D[n-k])) / (2i) for k ≥ 1,
319 and H[0] = Im(D[0]).
320 """
321 if not self.iscomplex:
322 return self.initconst(0.0, interval=self._interval)
323 n = self.size
324 c = self._coeffs
325 imag_c = np.zeros(n, dtype=complex)
326 imag_c[0] = np.imag(c[0])
327 if n > 1:
328 mirror = np.conj(c[-1:0:-1]) # conj(c[n-1]), ..., conj(c[1])
329 imag_c[1:] = (c[1:] - mirror) / (2j)
330 return self.__class__(imag_c, self._interval)
332 def prolong(self, n: int) -> "Trigtech":
333 """Return a Trigtech of length *n* (truncate or zero-pad in frequency space).
335 The operation aligns DC components of the source and target DC-centred
336 representations, then either pads with zeros (n > m) or slices (n < m).
337 This correctly handles the asymmetry between even- and odd-length arrays.
338 """
339 m = self.size
340 if n == m:
341 return self.copy()
343 centered = np.fft.fftshift(self._coeffs)
344 dc_src = m // 2
345 dc_tgt = n // 2
347 if n > m:
348 padded = np.zeros(n, dtype=centered.dtype)
349 start = dc_tgt - dc_src
350 padded[start : start + m] = centered
351 return self.__class__(np.fft.ifftshift(padded), interval=self._interval)
352 else:
353 start = dc_src - dc_tgt
354 truncated = centered[start : start + n]
355 return self.__class__(np.fft.ifftshift(truncated), interval=self._interval)
357 def real(self) -> "Trigtech":
358 """Return the real part of the function as a real-valued Trigtech.
360 For a complex function f(x) = g(x) + i·h(x), the Fourier coefficients
361 of g(x) are G[k] = (D[k] + conj(D[n-k])) / 2 for k ≥ 1,
362 and G[0] = Re(D[0]).
363 """
364 if not self.iscomplex:
365 return self
366 n = self.size
367 c = self._coeffs
368 real_c = np.zeros(n, dtype=complex)
369 real_c[0] = np.real(c[0])
370 if n > 1:
371 mirror = np.conj(c[-1:0:-1]) # conj(c[n-1]), ..., conj(c[1])
372 real_c[1:] = (c[1:] + mirror) / 2
373 return self.__class__(real_c, self._interval)
375 def simplify(self) -> "Trigtech":
376 """Truncate high-frequency Fourier coefficients that are below tolerance.
378 Uses the same one-sided symmetric-maximum criterion as the adaptive
379 constructor: the highest-frequency mode retained is the one where
380 ``max(|c_k|, |c_{-k}|) / vscale > tol``.
381 """
382 oldlen = len(self._coeffs)
383 longself = self.prolong(max(17, oldlen))
384 n = longself.size
385 tol = prefs.eps * max(self._interval.hscale, 1)
387 centered = np.fft.fftshift(longself._coeffs)
388 dc_idx = n // 2
389 abs_max = float(np.max(np.abs(centered)))
390 if abs_max == 0.0:
391 return self.initconst(0.0, interval=self._interval)
393 abs_sym = np.zeros(dc_idx + 1)
394 for ki in range(dc_idx + 1):
395 p = centered[dc_idx + ki] if dc_idx + ki < n else 0.0
396 q = centered[dc_idx - ki]
397 abs_sym[ki] = max(abs(p), abs(q)) / abs_max
399 above = np.where(abs_sym > tol)[0]
400 if len(above) == 0:
401 return self.initconst(0.0, interval=self._interval)
402 max_k = int(above[-1])
403 max_k = min(max_k, oldlen // 2) # don't exceed original size
405 start = dc_idx - max_k
406 end = dc_idx + max_k + 1
407 return self.__class__(np.fft.ifftshift(centered[start:end]), interval=self._interval)
409 def values(self) -> np.ndarray:
410 """Function values at the n equispaced points x_j = -1 + 2j/n."""
411 return self._coeffs2vals(self._coeffs)
413 def _coeffs_to_plotorder(self) -> np.ndarray:
414 """Return coefficients in DC-centred (human-readable) order.
416 Equivalent to ``numpy.fft.fftshift(self.coeffs)``:
417 ordering is [c_{-n//2}, …, c_{-1}, c_0, c_1, …, c_{n//2-1}].
418 """
419 return np.fft.fftshift(self._coeffs)
421 # ------------------------------------------------------------------
422 # algebra
423 # ------------------------------------------------------------------
425 @self_empty()
426 def __add__(self, f: Any) -> "Trigtech":
427 """Add a scalar or another Trigtech."""
428 cls = self.__class__
429 if np.isscalar(f):
430 dtype: Any = complex if np.iscomplexobj(f) else self._coeffs.dtype
431 cfs = np.array(self._coeffs, dtype=dtype)
432 cfs[0] += f # add to DC component
433 return cls(cfs, interval=self._interval)
434 if f.isempty:
435 return f.copy()
436 g = self
437 n, m = g.size, f.size
438 if n < m:
439 g = g.prolong(m)
440 elif m < n:
441 f = f.prolong(n)
442 cfs = f.coeffs + g.coeffs
443 eps = prefs.eps
444 tol = 0.5 * eps * max(f.vscale, g.vscale)
445 if np.all(np.abs(cfs) < tol):
446 return cls.initconst(0.0, interval=self._interval)
447 return cls(cfs, interval=self._interval)
449 @self_empty()
450 def __div__(self, f: Any) -> "Trigtech":
451 """Divide this Trigtech by a scalar or another Trigtech."""
452 cls = self.__class__
453 if np.isscalar(f):
454 return cls(self._coeffs / np.asarray(f), interval=self._interval)
455 if f.isempty:
456 return f.copy()
457 return cls.initfun_adaptive(lambda x: self(x) / f(x), interval=self._interval)
459 __truediv__ = __div__
461 @self_empty()
462 def __mul__(self, g: Any) -> "Trigtech":
463 """Multiply this Trigtech by a scalar or another Trigtech.
465 Trig-polynomial multiplication is circular convolution in frequency
466 space. We implement this cleanly by evaluating both on a grid of
467 size n1 + n2 (sufficient to avoid aliasing), multiplying pointwise,
468 and taking the FFT.
469 """
470 cls = self.__class__
471 if np.isscalar(g):
472 return cls(g * self._coeffs, interval=self._interval)
473 if g.isempty:
474 return g.copy()
475 n = self.size + g.size
476 f_vals = self.prolong(n).values()
477 g_vals = g.prolong(n).values()
478 return cls(cls._vals2coeffs(f_vals * g_vals), interval=self._interval)
480 def __neg__(self) -> "Trigtech":
481 """Return the negation."""
482 return self.__class__(-self._coeffs, interval=self._interval)
484 def __pos__(self) -> "Trigtech":
485 """Return self (unary plus)."""
486 return self
488 @self_empty()
489 def __pow__(self, f: Any) -> "Trigtech":
490 """Raise this Trigtech to a power *f* (scalar or Trigtech)."""
492 def powfun(fn: Any, x: Any) -> Any:
493 return fn if np.isscalar(fn) else fn(x)
495 return self.__class__.initfun_adaptive(
496 lambda x: np.power(self(x), powfun(f, x)),
497 interval=self._interval,
498 )
500 def __rdiv__(self, f: Any) -> "Trigtech":
501 """Compute f / self where *f* is a scalar."""
502 return self.__class__.initfun_adaptive(
503 lambda x: (0.0 * x + f) / self(x),
504 interval=self._interval,
505 )
507 __radd__ = __add__
508 __rmul__ = __mul__
509 __rtruediv__ = __rdiv__
511 def __rsub__(self, f: Any) -> "Trigtech":
512 """Compute f - self."""
513 return -(self - f)
515 @self_empty()
516 def __rpow__(self, f: Any) -> "Trigtech":
517 """Compute f ** self."""
518 return self.__class__.initfun_adaptive(
519 lambda x: np.power(f, self(x)),
520 interval=self._interval,
521 )
523 def __sub__(self, f: Any) -> "Trigtech":
524 """Subtract *f* (scalar or Trigtech) from this Trigtech."""
525 return self + (-f)
527 # ------------------------------------------------------------------
528 # rootfinding
529 # ------------------------------------------------------------------
531 def roots(self, sort: bool | None = None) -> np.ndarray:
532 """Find the roots of this Trigtech on [-1, 1].
534 Converts to a Chebyshev representation via re-sampling on Chebyshev
535 points and delegates to the Chebtech colleague-matrix root-finder.
537 Args:
538 sort: If True, sort the roots in ascending order. Defaults to
539 ``prefs.sortroots``.
540 """
541 from .algorithms import newtonroots, rootsunit
542 from .chebtech import Chebtech
544 sort = sort if sort is not None else prefs.sortroots
546 if self.isempty:
547 return np.array([])
549 # Sample on a Chebyshev grid and fit a Chebtech of the same resolution
550 n = max(2 * self.size + 1, 33)
551 cheb_pts = Chebtech._chebpts(n)
552 vals = self(cheb_pts)
553 ct = Chebtech(Chebtech._vals2coeffs(vals))
554 rts = rootsunit(ct.coeffs)
555 rts = newtonroots(ct, rts)
556 rts = np.clip(rts, -1.0, 1.0)
557 return np.sort(rts) if sort else rts
559 # ------------------------------------------------------------------
560 # calculus
561 # ------------------------------------------------------------------
563 @self_empty(resultif=0.0)
564 def sum(self) -> Any:
565 """Definite integral of the Trigtech over [-1, 1].
567 Only the DC coefficient contributes:
568 ∫_{-1}^{1} exp(i*π*k*(x+1)) dx = 0 for k ≠ 0
569 ∫_{-1}^{1} 1 dx = 2 for k = 0
570 """
571 return 2.0 * float(np.real(self._coeffs[0]))
573 @self_empty()
574 def cumsum(self) -> "Trigtech":
575 """Indefinite integral, zero at x = -1, in Fourier coefficient space.
577 For mode k ≠ 0: antiderivative coefficient = c_k / (i*π*ω_k)
578 For mode k = 0: set to the constant needed so that F(-1) = 0.
580 Note: if the DC component (self.coeffs[0]) is non-zero the true
581 antiderivative contains a linear trend and is not periodic. We still
582 return a Trigtech representing the *periodic* part, adjusted so that
583 the result evaluates to 0 at x = -1.
584 """
585 n = self.size
586 c = self._coeffs.copy()
587 freqs = np.fft.fftfreq(n) * n # FFT-order integer frequencies
589 int_c = np.zeros(n, dtype=complex)
590 mask = freqs != 0
591 int_c[mask] = c[mask] / (1j * np.pi * freqs[mask])
593 # Enforce F(-1) = 0.
594 # F(x) = Σ_k int_c[k] * exp(i*π*ω_k*(x+1))
595 # At x = -1: exp(i*π*ω_k*0) = 1 for all k, so F(-1) = Σ int_c
596 # Set int_c[0] so that sum(int_c) = 0.
597 int_c[0] = -np.sum(int_c[1:])
598 return self.__class__(int_c, interval=self._interval)
600 @self_empty()
601 def diff(self) -> "Trigtech":
602 """Derivative via the Fourier multiplier i*π*ω_k.
604 d/dx [c_k * exp(i*π*ω_k*(x+1))] = i*π*ω_k * c_k * exp(i*π*ω_k*(x+1))
605 """
606 if self.isconst:
607 return self.__class__(np.array([0.0 + 0.0j]), interval=self._interval)
608 n = self.size
609 freqs = np.fft.fftfreq(n) * n
610 d_coeffs = (1j * np.pi * freqs) * self._coeffs
611 return self.__class__(d_coeffs, interval=self._interval)
613 # ------------------------------------------------------------------
614 # static helpers (FFT ↔ values)
615 # ------------------------------------------------------------------
617 @staticmethod
618 def _trigpts(n: int) -> np.ndarray:
619 """Return *n* equispaced points on [-1, 1)."""
620 if n == 0:
621 return np.array([])
622 return -1.0 + 2.0 * np.arange(n) / n
624 @staticmethod
625 def _vals2coeffs(vals: Any) -> np.ndarray:
626 """Convert values at equispaced points to FFT coefficients (divided by n).
628 Always returns complex128, even for real-valued inputs, because Fourier
629 coefficients for functions such as sin are purely imaginary and would be
630 discarded if forced to real.
632 Inverse of ``_coeffs2vals``.
633 """
634 vals = np.asarray(vals)
635 n = vals.size
636 if n == 0:
637 return np.array([], dtype=complex)
638 return np.fft.fft(vals) / n
640 @staticmethod
641 def _coeffs2vals(coeffs: Any) -> np.ndarray:
642 """Convert FFT coefficients (divided by n) to values at equispaced points.
644 Inverse of ``_vals2coeffs``.
645 """
646 coeffs = np.asarray(coeffs, dtype=complex)
647 n = coeffs.size
648 if n == 0:
649 return np.array([], dtype=float)
650 vals = n * np.fft.ifft(coeffs)
651 # Discard negligible imaginary parts for conjugate-symmetric coefficients
652 max_real = float(np.max(np.abs(np.real(vals))))
653 if float(np.max(np.abs(np.imag(vals)))) < 1e-10 * max(max_real, 1.0):
654 return np.real(vals)
655 return vals
657 # ------------------------------------------------------------------
658 # plotting
659 # ------------------------------------------------------------------
661 def plot(self, ax: Any = None, **kwargs: Any) -> Any:
662 """Plot the Trigtech over [-1, 1].
664 Args:
665 ax: Matplotlib axes. If None, uses the current axes.
666 **kwargs: Forwarded to matplotlib.
668 Returns:
669 The axes on which the plot was drawn.
670 """
671 return plotfun(self, (-1, 1), ax=ax, **kwargs)
673 def plotcoeffs(self, ax: Any = None, **kwargs: Any) -> Any:
674 """Plot the absolute Fourier coefficient magnitudes in DC-centred order.
676 Uses ``_coeffs_to_plotorder()`` so the horizontal axis runs from
677 the most-negative frequency on the left to the most-positive on
678 the right, with DC in the centre.
680 Args:
681 ax: Matplotlib axes. If None, uses the current axes.
682 **kwargs: Forwarded to matplotlib.
684 Returns:
685 The axes on which the plot was drawn.
686 """
687 ax = ax or plt.gca()
688 return plotfuncoeffs(np.abs(self._coeffs_to_plotorder()), ax=ax, **kwargs)