Coverage for src / chebpy / quasimatrix.py: 100%
232 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:33 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-22 21:33 +0000
1"""Quasimatrix: a matrix with one continuous dimension.
3A quasimatrix is an inf x n matrix whose columns are Chebfun objects defined on
4the same domain. This enables continuous analogues of linear algebra operations
5such as QR factorization, SVD, least-squares, and more.
7Reference: Trefethen, "Householder triangularization of a quasimatrix,"
8IMA Journal of Numerical Analysis, 30 (2010), 887-897.
9"""
11from __future__ import annotations
13from typing import Any
15import matplotlib.pyplot as plt
16import numpy as np
17from matplotlib.axes import Axes
19from .chebfun import Chebfun
22class Quasimatrix:
23 """An inf x n column quasimatrix whose columns are Chebfun objects.
25 A quasimatrix generalises the idea of a matrix so that one of its
26 dimensions is continuous. Here the rows are indexed by points in an
27 interval and the columns are Chebfun objects.
29 Attributes:
30 columns: list of Chebfun objects forming the columns.
31 """
33 # ------------------------------------------------------------------
34 # Construction
35 # ------------------------------------------------------------------
36 def __init__(self, columns: list[Any]) -> None:
37 """Initialise from a list of Chebfun objects, callables, or scalars."""
38 cols: list[Chebfun] = []
39 for c in columns:
40 if isinstance(c, Chebfun):
41 cols.append(c)
42 elif callable(c):
43 cols.append(Chebfun.initfun_adaptive(c, cols[0].domain if cols else None))
44 else:
45 # scalar → constant chebfun on the domain of the first column
46 if cols:
47 cols.append(Chebfun.initconst(float(c), cols[0].domain))
48 else:
49 from .settings import _preferences as prefs
51 cols.append(Chebfun.initconst(float(c), prefs.domain))
52 if len(cols) > 1:
53 # Verify all columns share the same support
54 ref = cols[0].support
55 for k, col in enumerate(cols[1:], 1):
56 if col.support != ref:
57 msg = f"Column {k} support {col.support} does not match column 0 support {ref}"
58 raise ValueError(msg)
59 self.columns: list[Chebfun] = cols
61 # ------------------------------------------------------------------
62 # Properties
63 # ------------------------------------------------------------------
64 @property
65 def shape(self) -> tuple[float, int]:
66 """Return (∞, n) where n is the number of columns."""
67 return (np.inf, len(self.columns))
69 @property
70 def T(self) -> _TransposedQuasimatrix:
71 """Return the transpose (an n x inf row quasimatrix)."""
72 return _TransposedQuasimatrix(self)
74 @property
75 def domain(self) -> Any:
76 """Domain of the quasimatrix columns."""
77 if not self.columns:
78 return None
79 return self.columns[0].domain
81 @property
82 def support(self) -> tuple[float, float]:
83 """Support interval of the quasimatrix."""
84 if not self.columns:
85 return (0.0, 0.0)
86 return self.columns[0].support
88 @property
89 def isempty(self) -> bool:
90 """Return True if the quasimatrix has no columns."""
91 return len(self.columns) == 0
93 # ------------------------------------------------------------------
94 # Indexing A[:, k], A(x, k)
95 # ------------------------------------------------------------------
96 def __getitem__(self, key: Any) -> Any:
97 """Column indexing: ``A[:, k]`` returns column k as a Chebfun."""
98 if isinstance(key, tuple):
99 row, col = key
100 if isinstance(col, slice):
101 return Quasimatrix(self.columns[col])
102 # A[x, k] - evaluate column k at point x
103 if isinstance(row, slice) and row == slice(None):
104 return self.columns[col]
105 return self.columns[col](row)
106 # A[k] - return column k
107 if isinstance(key, (int, np.integer)):
108 return self.columns[key]
109 if isinstance(key, slice):
110 return Quasimatrix(self.columns[key])
111 raise TypeError(key)
113 def __len__(self) -> int:
114 """Return the number of columns."""
115 return len(self.columns)
117 def __iter__(self):
118 """Iterate over the columns."""
119 return iter(self.columns)
121 # ------------------------------------------------------------------
122 # Calling A(x) - evaluate all columns at x, return array
123 # ------------------------------------------------------------------
124 def __call__(self, x: Any) -> np.ndarray:
125 """Evaluate every column at *x* and return the results as an array.
127 If *x* is a scalar the result has shape ``(n,)``.
128 If *x* is an array of length *m* the result has shape ``(m, n)``.
129 """
130 vals = [col(x) for col in self.columns]
131 return np.column_stack(vals) if np.ndim(x) else np.array(vals)
133 # ------------------------------------------------------------------
134 # Arithmetic
135 # ------------------------------------------------------------------
136 def __matmul__(self, other: Any) -> Any:
137 """Matrix-vector product: ``A @ c`` returns a Chebfun.
139 *other* must be a 1-D array-like of length n.
140 """
141 c = np.asarray(other, dtype=float)
142 if c.ndim != 1 or len(c) != len(self.columns):
143 msg = f"Cannot multiply {self.shape} quasimatrix by vector of length {len(c)}"
144 raise ValueError(msg)
145 result = c[0] * self.columns[0]
146 for coeff, col in zip(c[1:], self.columns[1:], strict=True):
147 result = result + coeff * col
148 return result
150 def __mul__(self, other: Any) -> Quasimatrix:
151 """Element-wise scalar multiplication."""
152 return Quasimatrix([c * other for c in self.columns])
154 def __rmul__(self, other: Any) -> Quasimatrix:
155 """Right scalar multiplication."""
156 return self.__mul__(other)
158 # ------------------------------------------------------------------
159 # Integrals and inner products
160 # ------------------------------------------------------------------
161 def sum(self) -> np.ndarray:
162 """Definite integral of each column (column sums)."""
163 return np.array([col.sum() for col in self.columns])
165 def inner(self, other: Quasimatrix | None = None) -> np.ndarray:
166 """Gram matrix ``self.T @ other`` (or ``self.T @ self``)."""
167 other = other if other is not None else self
168 m = len(self.columns)
169 n = len(other.columns)
170 G = np.empty((m, n))
171 for i in range(m):
172 for j in range(n):
173 G[i, j] = self.columns[i].dot(other.columns[j])
174 return G
176 # ------------------------------------------------------------------
177 # QR factorization (modified Gram-Schmidt)
178 # ------------------------------------------------------------------
179 def qr(self) -> tuple[Quasimatrix, np.ndarray]:
180 """Compute the reduced QR factorization ``A = Q R``.
182 Uses modified Gram-Schmidt orthogonalisation in function space.
184 Returns:
185 Q: Quasimatrix with orthonormal columns.
186 R: Upper-triangular n x n NumPy array.
187 """
188 n = len(self.columns)
189 Q = [col.copy() for col in self.columns]
190 R = np.zeros((n, n))
191 for k in range(n):
192 for j in range(k):
193 R[j, k] = Q[j].dot(Q[k])
194 Q[k] = Q[k] - R[j, k] * Q[j]
195 R[k, k] = Q[k].norm(2)
196 if R[k, k] == 0:
197 msg = "Rank-deficient quasimatrix: QR factorization failed"
198 raise np.linalg.LinAlgError(msg)
199 Q[k] = (1.0 / R[k, k]) * Q[k]
200 return Quasimatrix(Q), R
202 # ------------------------------------------------------------------
203 # SVD
204 # ------------------------------------------------------------------
205 def svd(self) -> tuple[Quasimatrix, np.ndarray, np.ndarray]:
206 """Compute the reduced SVD ``A = U S V^T``.
208 Returns:
209 U: inf x n quasimatrix with orthonormal columns.
210 S: 1-D array of singular values (length n).
211 V: n x n orthogonal NumPy matrix.
212 """
213 Q, R = self.qr()
214 # Economy SVD of the n x n matrix R
215 U_r, S, Vt = np.linalg.svd(R, full_matrices=False)
216 # U = Q @ U_r (linear combinations of orthonormal columns)
217 U_cols = []
218 for j in range(U_r.shape[1]):
219 U_cols.append(Q @ U_r[:, j])
220 return Quasimatrix(U_cols), S, Vt.T # V = Vt.T
222 # ------------------------------------------------------------------
223 # Least-squares (backslash)
224 # ------------------------------------------------------------------
225 def solve(self, f: Chebfun) -> np.ndarray:
226 r"""Least-squares solution ``c`` to ``A c ~ f``.
228 Equivalent to MATLAB ``A\f``. Computed via QR factorisation.
229 """
230 Q, R = self.qr()
231 # b = Q' * f (inner products)
232 b = np.array([col.dot(f) for col in Q.columns])
233 # Solve R c = b (back-substitution)
234 return np.linalg.solve(R, b)
236 # ------------------------------------------------------------------
237 # Norms
238 # ------------------------------------------------------------------
239 def norm(self, p: Any = "fro") -> float:
240 """Compute the norm of the quasimatrix.
242 Args:
243 p: Norm type.
244 - 2: the 2-norm (largest singular value).
245 - 1: max column 1-norm.
246 - np.inf: max row-sum := max_x sum_j |A_j(x)|.
247 - 'fro': Frobenius norm (default).
248 """
249 if p == 2:
250 _, S, _ = self.svd()
251 return float(S[0])
252 if p == 1:
253 return float(max(col.norm(1) for col in self.columns))
254 if p == np.inf:
255 abssum = self.columns[0].absolute()
256 for col in self.columns[1:]:
257 abssum = abssum + col.absolute()
258 return float(abssum.norm(np.inf))
259 if p == "fro":
260 _, S, _ = self.svd()
261 return float(np.sqrt(np.sum(S**2)))
262 raise ValueError(f"Unsupported norm type: {p}") # noqa: TRY003
264 # ------------------------------------------------------------------
265 # Condition number
266 # ------------------------------------------------------------------
267 def cond(self) -> float:
268 """2-norm condition number (ratio of largest to smallest singular value)."""
269 _, S, _ = self.svd()
270 return float(S[0] / S[-1])
272 # ------------------------------------------------------------------
273 # Rank
274 # ------------------------------------------------------------------
275 def rank(self, tol: float | None = None) -> int:
276 """Numerical rank (number of significant singular values)."""
277 _, S, _ = self.svd()
278 if tol is None:
279 tol = max(self.shape[1], 20) * np.finfo(float).eps * S[0]
280 return int(np.sum(tol < S))
282 # ------------------------------------------------------------------
283 # Null space
284 # ------------------------------------------------------------------
285 def null(self, tol: float | None = None) -> np.ndarray:
286 """Orthonormal basis for the null space of the quasimatrix.
288 Returns an n x k NumPy array whose columns span ``null(A)``.
289 """
290 _, S, V = self.svd()
291 if tol is None:
292 tol = max(self.shape[1], 20) * np.finfo(float).eps * S[0]
293 mask = tol >= S
294 return V[:, mask]
296 # ------------------------------------------------------------------
297 # Orth (orthonormal basis for range)
298 # ------------------------------------------------------------------
299 def orth(self, tol: float | None = None) -> Quasimatrix:
300 """Orthonormal basis for the column space (range) of the quasimatrix."""
301 U, S, _ = self.svd()
302 if tol is None:
303 tol = max(self.shape[1], 20) * np.finfo(float).eps * S[0]
304 mask = tol < S
305 return Quasimatrix([U.columns[j] for j in range(len(S)) if mask[j]])
307 # ------------------------------------------------------------------
308 # Pseudoinverse
309 # ------------------------------------------------------------------
310 def pinv(self) -> _TransposedQuasimatrix:
311 """Moore-Penrose pseudoinverse (returned as an n x inf row quasimatrix).
313 ``pinv(A) @ f`` gives the same result as ``A.solve(f)``.
314 """
315 U, S, V = self.svd()
316 # pinv(A) = V S^{-1} U^T
317 # The rows of pinv(A) are: sum_k V[i,k] / S[k] * U_k
318 n = len(S)
319 pinv_cols: list[Chebfun] = []
320 for i in range(n):
321 col = (V[i, 0] / S[0]) * U.columns[0]
322 for k in range(1, n):
323 col = col + (V[i, k] / S[k]) * U.columns[k]
324 pinv_cols.append(col)
325 return _TransposedQuasimatrix(Quasimatrix(pinv_cols))
327 # ------------------------------------------------------------------
328 # Plotting
329 # ------------------------------------------------------------------
330 def plot(self, ax: Axes | None = None, **kwds: Any) -> Axes:
331 """Plot all columns on the same axes."""
332 ax = ax or plt.gca()
333 for col in self.columns:
334 col.plot(ax=ax, **kwds)
335 return ax
337 def spy(self, ax: Axes | None = None, **kwds: Any) -> Axes:
338 """Visualise the shape of the quasimatrix.
340 Draws a rectangle representing the inf x n structure, with a dot for
341 each column to indicate nonzero content.
342 """
343 ax = ax or plt.gca()
344 n = len(self.columns)
345 # Draw the bounding rectangle
346 rect = plt.Rectangle((0.5, 0.5), n, 10, fill=False, edgecolor="black", linewidth=1.5)
347 ax.add_patch(rect)
348 # A dot for each column
349 for j in range(n):
350 ax.plot(j + 1, 5.5, "bs", markersize=8, **kwds)
351 ax.set_xlim(0, n + 1)
352 ax.set_ylim(0, 11)
353 ax.set_aspect("equal")
354 ax.set_xlabel(f"n = {n}")
355 ax.set_ylabel("∞")
356 ax.set_xticks(range(1, n + 1))
357 ax.set_yticks([])
358 return ax
360 # ------------------------------------------------------------------
361 # Representation
362 # ------------------------------------------------------------------
363 def __repr__(self) -> str:
364 """Return a string representation."""
365 n = len(self.columns)
366 if n == 0:
367 return "Quasimatrix(empty)"
368 sup = self.support
369 return f"Quasimatrix(inf x {n} on [{sup[0]}, {sup[1]}])"
371 def __str__(self) -> str:
372 """Return a string representation."""
373 return self.__repr__()
376class _TransposedQuasimatrix:
377 """An n x inf row quasimatrix (transpose of a column quasimatrix).
379 This is a thin wrapper that enables ``A.T @ f`` and ``A.T @ B``
380 with the correct semantics.
381 """
383 def __init__(self, qm: Quasimatrix) -> None:
384 """Wrap a column quasimatrix as its transpose."""
385 self._qm = qm
387 @property
388 def shape(self) -> tuple[int, float]:
389 """Return (n, inf)."""
390 return (len(self._qm.columns), np.inf)
392 @property
393 def T(self) -> Quasimatrix:
394 """Return the original column quasimatrix."""
395 return self._qm
397 def __matmul__(self, other: Any) -> Any:
398 """Compute inner products: ``A.T @ f`` or ``A.T @ B``."""
399 if isinstance(other, Quasimatrix):
400 return self._qm.inner(other)
401 if isinstance(other, Chebfun):
402 return np.array([col.dot(other) for col in self._qm.columns])
403 raise TypeError(f"Cannot multiply _TransposedQuasimatrix by {type(other)}") # noqa: TRY003
405 def spy(self, ax: Axes | None = None, **kwds: Any) -> Axes:
406 """Visualise the shape of the transposed quasimatrix."""
407 ax = ax or plt.gca()
408 n = len(self._qm.columns)
409 rect = plt.Rectangle((0.5, 0.5), 10, n, fill=False, edgecolor="black", linewidth=1.5)
410 ax.add_patch(rect)
411 for j in range(n):
412 ax.plot(5.5, j + 1, "bs", markersize=8, **kwds)
413 ax.set_xlim(0, 11)
414 ax.set_ylim(0, n + 1)
415 ax.set_aspect("equal")
416 ax.set_ylabel(f"n = {n}")
417 ax.set_xlabel("∞")
418 ax.set_yticks(range(1, n + 1))
419 ax.set_xticks([])
420 return ax
422 def __repr__(self) -> str:
423 """Return a string representation."""
424 n = len(self._qm.columns)
425 if n == 0:
426 return "_TransposedQuasimatrix(empty)"
427 sup = self._qm.support
428 return f"_TransposedQuasimatrix({n}x inf on [{sup[0]}, {sup[1]}])"
431# ------------------------------------------------------------------
432# Module-level convenience functions
433# ------------------------------------------------------------------
434def polyfit(f: Chebfun, n: int) -> Chebfun:
435 """Least-squares polynomial fit of degree *n* to a Chebfun *f*.
437 Returns a Chebfun representing the best degree-*n* polynomial
438 approximation to *f* in the L²-norm.
439 """
440 x = Chebfun.initidentity(f.domain)
441 cols: list[Chebfun] = [Chebfun.initconst(1.0, f.domain)]
442 xk = cols[0]
443 for _ in range(n):
444 xk = xk * x
445 cols.append(xk)
446 A = Quasimatrix(cols)
447 c = A.solve(f)
448 return A @ c