Coverage for src / chebpy / quasimatrix.py: 100%

232 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-22 21:33 +0000

1"""Quasimatrix: a matrix with one continuous dimension. 

2 

3A quasimatrix is an inf x n matrix whose columns are Chebfun objects defined on 

4the same domain. This enables continuous analogues of linear algebra operations 

5such as QR factorization, SVD, least-squares, and more. 

6 

7Reference: Trefethen, "Householder triangularization of a quasimatrix," 

8IMA Journal of Numerical Analysis, 30 (2010), 887-897. 

9""" 

10 

11from __future__ import annotations 

12 

13from typing import Any 

14 

15import matplotlib.pyplot as plt 

16import numpy as np 

17from matplotlib.axes import Axes 

18 

19from .chebfun import Chebfun 

20 

21 

22class Quasimatrix: 

23 """An inf x n column quasimatrix whose columns are Chebfun objects. 

24 

25 A quasimatrix generalises the idea of a matrix so that one of its 

26 dimensions is continuous. Here the rows are indexed by points in an 

27 interval and the columns are Chebfun objects. 

28 

29 Attributes: 

30 columns: list of Chebfun objects forming the columns. 

31 """ 

32 

33 # ------------------------------------------------------------------ 

34 # Construction 

35 # ------------------------------------------------------------------ 

36 def __init__(self, columns: list[Any]) -> None: 

37 """Initialise from a list of Chebfun objects, callables, or scalars.""" 

38 cols: list[Chebfun] = [] 

39 for c in columns: 

40 if isinstance(c, Chebfun): 

41 cols.append(c) 

42 elif callable(c): 

43 cols.append(Chebfun.initfun_adaptive(c, cols[0].domain if cols else None)) 

44 else: 

45 # scalar → constant chebfun on the domain of the first column 

46 if cols: 

47 cols.append(Chebfun.initconst(float(c), cols[0].domain)) 

48 else: 

49 from .settings import _preferences as prefs 

50 

51 cols.append(Chebfun.initconst(float(c), prefs.domain)) 

52 if len(cols) > 1: 

53 # Verify all columns share the same support 

54 ref = cols[0].support 

55 for k, col in enumerate(cols[1:], 1): 

56 if col.support != ref: 

57 msg = f"Column {k} support {col.support} does not match column 0 support {ref}" 

58 raise ValueError(msg) 

59 self.columns: list[Chebfun] = cols 

60 

61 # ------------------------------------------------------------------ 

62 # Properties 

63 # ------------------------------------------------------------------ 

64 @property 

65 def shape(self) -> tuple[float, int]: 

66 """Return (∞, n) where n is the number of columns.""" 

67 return (np.inf, len(self.columns)) 

68 

69 @property 

70 def T(self) -> _TransposedQuasimatrix: 

71 """Return the transpose (an n x inf row quasimatrix).""" 

72 return _TransposedQuasimatrix(self) 

73 

74 @property 

75 def domain(self) -> Any: 

76 """Domain of the quasimatrix columns.""" 

77 if not self.columns: 

78 return None 

79 return self.columns[0].domain 

80 

81 @property 

82 def support(self) -> tuple[float, float]: 

83 """Support interval of the quasimatrix.""" 

84 if not self.columns: 

85 return (0.0, 0.0) 

86 return self.columns[0].support 

87 

88 @property 

89 def isempty(self) -> bool: 

90 """Return True if the quasimatrix has no columns.""" 

91 return len(self.columns) == 0 

92 

93 # ------------------------------------------------------------------ 

94 # Indexing A[:, k], A(x, k) 

95 # ------------------------------------------------------------------ 

96 def __getitem__(self, key: Any) -> Any: 

97 """Column indexing: ``A[:, k]`` returns column k as a Chebfun.""" 

98 if isinstance(key, tuple): 

99 row, col = key 

100 if isinstance(col, slice): 

101 return Quasimatrix(self.columns[col]) 

102 # A[x, k] - evaluate column k at point x 

103 if isinstance(row, slice) and row == slice(None): 

104 return self.columns[col] 

105 return self.columns[col](row) 

106 # A[k] - return column k 

107 if isinstance(key, (int, np.integer)): 

108 return self.columns[key] 

109 if isinstance(key, slice): 

110 return Quasimatrix(self.columns[key]) 

111 raise TypeError(key) 

112 

113 def __len__(self) -> int: 

114 """Return the number of columns.""" 

115 return len(self.columns) 

116 

117 def __iter__(self): 

118 """Iterate over the columns.""" 

119 return iter(self.columns) 

120 

121 # ------------------------------------------------------------------ 

122 # Calling A(x) - evaluate all columns at x, return array 

123 # ------------------------------------------------------------------ 

124 def __call__(self, x: Any) -> np.ndarray: 

125 """Evaluate every column at *x* and return the results as an array. 

126 

127 If *x* is a scalar the result has shape ``(n,)``. 

128 If *x* is an array of length *m* the result has shape ``(m, n)``. 

129 """ 

130 vals = [col(x) for col in self.columns] 

131 return np.column_stack(vals) if np.ndim(x) else np.array(vals) 

132 

133 # ------------------------------------------------------------------ 

134 # Arithmetic 

135 # ------------------------------------------------------------------ 

136 def __matmul__(self, other: Any) -> Any: 

137 """Matrix-vector product: ``A @ c`` returns a Chebfun. 

138 

139 *other* must be a 1-D array-like of length n. 

140 """ 

141 c = np.asarray(other, dtype=float) 

142 if c.ndim != 1 or len(c) != len(self.columns): 

143 msg = f"Cannot multiply {self.shape} quasimatrix by vector of length {len(c)}" 

144 raise ValueError(msg) 

145 result = c[0] * self.columns[0] 

146 for coeff, col in zip(c[1:], self.columns[1:], strict=True): 

147 result = result + coeff * col 

148 return result 

149 

150 def __mul__(self, other: Any) -> Quasimatrix: 

151 """Element-wise scalar multiplication.""" 

152 return Quasimatrix([c * other for c in self.columns]) 

153 

154 def __rmul__(self, other: Any) -> Quasimatrix: 

155 """Right scalar multiplication.""" 

156 return self.__mul__(other) 

157 

158 # ------------------------------------------------------------------ 

159 # Integrals and inner products 

160 # ------------------------------------------------------------------ 

161 def sum(self) -> np.ndarray: 

162 """Definite integral of each column (column sums).""" 

163 return np.array([col.sum() for col in self.columns]) 

164 

165 def inner(self, other: Quasimatrix | None = None) -> np.ndarray: 

166 """Gram matrix ``self.T @ other`` (or ``self.T @ self``).""" 

167 other = other if other is not None else self 

168 m = len(self.columns) 

169 n = len(other.columns) 

170 G = np.empty((m, n)) 

171 for i in range(m): 

172 for j in range(n): 

173 G[i, j] = self.columns[i].dot(other.columns[j]) 

174 return G 

175 

176 # ------------------------------------------------------------------ 

177 # QR factorization (modified Gram-Schmidt) 

178 # ------------------------------------------------------------------ 

179 def qr(self) -> tuple[Quasimatrix, np.ndarray]: 

180 """Compute the reduced QR factorization ``A = Q R``. 

181 

182 Uses modified Gram-Schmidt orthogonalisation in function space. 

183 

184 Returns: 

185 Q: Quasimatrix with orthonormal columns. 

186 R: Upper-triangular n x n NumPy array. 

187 """ 

188 n = len(self.columns) 

189 Q = [col.copy() for col in self.columns] 

190 R = np.zeros((n, n)) 

191 for k in range(n): 

192 for j in range(k): 

193 R[j, k] = Q[j].dot(Q[k]) 

194 Q[k] = Q[k] - R[j, k] * Q[j] 

195 R[k, k] = Q[k].norm(2) 

196 if R[k, k] == 0: 

197 msg = "Rank-deficient quasimatrix: QR factorization failed" 

198 raise np.linalg.LinAlgError(msg) 

199 Q[k] = (1.0 / R[k, k]) * Q[k] 

200 return Quasimatrix(Q), R 

201 

202 # ------------------------------------------------------------------ 

203 # SVD 

204 # ------------------------------------------------------------------ 

205 def svd(self) -> tuple[Quasimatrix, np.ndarray, np.ndarray]: 

206 """Compute the reduced SVD ``A = U S V^T``. 

207 

208 Returns: 

209 U: inf x n quasimatrix with orthonormal columns. 

210 S: 1-D array of singular values (length n). 

211 V: n x n orthogonal NumPy matrix. 

212 """ 

213 Q, R = self.qr() 

214 # Economy SVD of the n x n matrix R 

215 U_r, S, Vt = np.linalg.svd(R, full_matrices=False) 

216 # U = Q @ U_r (linear combinations of orthonormal columns) 

217 U_cols = [] 

218 for j in range(U_r.shape[1]): 

219 U_cols.append(Q @ U_r[:, j]) 

220 return Quasimatrix(U_cols), S, Vt.T # V = Vt.T 

221 

222 # ------------------------------------------------------------------ 

223 # Least-squares (backslash) 

224 # ------------------------------------------------------------------ 

225 def solve(self, f: Chebfun) -> np.ndarray: 

226 r"""Least-squares solution ``c`` to ``A c ~ f``. 

227 

228 Equivalent to MATLAB ``A\f``. Computed via QR factorisation. 

229 """ 

230 Q, R = self.qr() 

231 # b = Q' * f (inner products) 

232 b = np.array([col.dot(f) for col in Q.columns]) 

233 # Solve R c = b (back-substitution) 

234 return np.linalg.solve(R, b) 

235 

236 # ------------------------------------------------------------------ 

237 # Norms 

238 # ------------------------------------------------------------------ 

239 def norm(self, p: Any = "fro") -> float: 

240 """Compute the norm of the quasimatrix. 

241 

242 Args: 

243 p: Norm type. 

244 - 2: the 2-norm (largest singular value). 

245 - 1: max column 1-norm. 

246 - np.inf: max row-sum := max_x sum_j |A_j(x)|. 

247 - 'fro': Frobenius norm (default). 

248 """ 

249 if p == 2: 

250 _, S, _ = self.svd() 

251 return float(S[0]) 

252 if p == 1: 

253 return float(max(col.norm(1) for col in self.columns)) 

254 if p == np.inf: 

255 abssum = self.columns[0].absolute() 

256 for col in self.columns[1:]: 

257 abssum = abssum + col.absolute() 

258 return float(abssum.norm(np.inf)) 

259 if p == "fro": 

260 _, S, _ = self.svd() 

261 return float(np.sqrt(np.sum(S**2))) 

262 raise ValueError(f"Unsupported norm type: {p}") # noqa: TRY003 

263 

264 # ------------------------------------------------------------------ 

265 # Condition number 

266 # ------------------------------------------------------------------ 

267 def cond(self) -> float: 

268 """2-norm condition number (ratio of largest to smallest singular value).""" 

269 _, S, _ = self.svd() 

270 return float(S[0] / S[-1]) 

271 

272 # ------------------------------------------------------------------ 

273 # Rank 

274 # ------------------------------------------------------------------ 

275 def rank(self, tol: float | None = None) -> int: 

276 """Numerical rank (number of significant singular values).""" 

277 _, S, _ = self.svd() 

278 if tol is None: 

279 tol = max(self.shape[1], 20) * np.finfo(float).eps * S[0] 

280 return int(np.sum(tol < S)) 

281 

282 # ------------------------------------------------------------------ 

283 # Null space 

284 # ------------------------------------------------------------------ 

285 def null(self, tol: float | None = None) -> np.ndarray: 

286 """Orthonormal basis for the null space of the quasimatrix. 

287 

288 Returns an n x k NumPy array whose columns span ``null(A)``. 

289 """ 

290 _, S, V = self.svd() 

291 if tol is None: 

292 tol = max(self.shape[1], 20) * np.finfo(float).eps * S[0] 

293 mask = tol >= S 

294 return V[:, mask] 

295 

296 # ------------------------------------------------------------------ 

297 # Orth (orthonormal basis for range) 

298 # ------------------------------------------------------------------ 

299 def orth(self, tol: float | None = None) -> Quasimatrix: 

300 """Orthonormal basis for the column space (range) of the quasimatrix.""" 

301 U, S, _ = self.svd() 

302 if tol is None: 

303 tol = max(self.shape[1], 20) * np.finfo(float).eps * S[0] 

304 mask = tol < S 

305 return Quasimatrix([U.columns[j] for j in range(len(S)) if mask[j]]) 

306 

307 # ------------------------------------------------------------------ 

308 # Pseudoinverse 

309 # ------------------------------------------------------------------ 

310 def pinv(self) -> _TransposedQuasimatrix: 

311 """Moore-Penrose pseudoinverse (returned as an n x inf row quasimatrix). 

312 

313 ``pinv(A) @ f`` gives the same result as ``A.solve(f)``. 

314 """ 

315 U, S, V = self.svd() 

316 # pinv(A) = V S^{-1} U^T 

317 # The rows of pinv(A) are: sum_k V[i,k] / S[k] * U_k 

318 n = len(S) 

319 pinv_cols: list[Chebfun] = [] 

320 for i in range(n): 

321 col = (V[i, 0] / S[0]) * U.columns[0] 

322 for k in range(1, n): 

323 col = col + (V[i, k] / S[k]) * U.columns[k] 

324 pinv_cols.append(col) 

325 return _TransposedQuasimatrix(Quasimatrix(pinv_cols)) 

326 

327 # ------------------------------------------------------------------ 

328 # Plotting 

329 # ------------------------------------------------------------------ 

330 def plot(self, ax: Axes | None = None, **kwds: Any) -> Axes: 

331 """Plot all columns on the same axes.""" 

332 ax = ax or plt.gca() 

333 for col in self.columns: 

334 col.plot(ax=ax, **kwds) 

335 return ax 

336 

337 def spy(self, ax: Axes | None = None, **kwds: Any) -> Axes: 

338 """Visualise the shape of the quasimatrix. 

339 

340 Draws a rectangle representing the inf x n structure, with a dot for 

341 each column to indicate nonzero content. 

342 """ 

343 ax = ax or plt.gca() 

344 n = len(self.columns) 

345 # Draw the bounding rectangle 

346 rect = plt.Rectangle((0.5, 0.5), n, 10, fill=False, edgecolor="black", linewidth=1.5) 

347 ax.add_patch(rect) 

348 # A dot for each column 

349 for j in range(n): 

350 ax.plot(j + 1, 5.5, "bs", markersize=8, **kwds) 

351 ax.set_xlim(0, n + 1) 

352 ax.set_ylim(0, 11) 

353 ax.set_aspect("equal") 

354 ax.set_xlabel(f"n = {n}") 

355 ax.set_ylabel("∞") 

356 ax.set_xticks(range(1, n + 1)) 

357 ax.set_yticks([]) 

358 return ax 

359 

360 # ------------------------------------------------------------------ 

361 # Representation 

362 # ------------------------------------------------------------------ 

363 def __repr__(self) -> str: 

364 """Return a string representation.""" 

365 n = len(self.columns) 

366 if n == 0: 

367 return "Quasimatrix(empty)" 

368 sup = self.support 

369 return f"Quasimatrix(inf x {n} on [{sup[0]}, {sup[1]}])" 

370 

371 def __str__(self) -> str: 

372 """Return a string representation.""" 

373 return self.__repr__() 

374 

375 

376class _TransposedQuasimatrix: 

377 """An n x inf row quasimatrix (transpose of a column quasimatrix). 

378 

379 This is a thin wrapper that enables ``A.T @ f`` and ``A.T @ B`` 

380 with the correct semantics. 

381 """ 

382 

383 def __init__(self, qm: Quasimatrix) -> None: 

384 """Wrap a column quasimatrix as its transpose.""" 

385 self._qm = qm 

386 

387 @property 

388 def shape(self) -> tuple[int, float]: 

389 """Return (n, inf).""" 

390 return (len(self._qm.columns), np.inf) 

391 

392 @property 

393 def T(self) -> Quasimatrix: 

394 """Return the original column quasimatrix.""" 

395 return self._qm 

396 

397 def __matmul__(self, other: Any) -> Any: 

398 """Compute inner products: ``A.T @ f`` or ``A.T @ B``.""" 

399 if isinstance(other, Quasimatrix): 

400 return self._qm.inner(other) 

401 if isinstance(other, Chebfun): 

402 return np.array([col.dot(other) for col in self._qm.columns]) 

403 raise TypeError(f"Cannot multiply _TransposedQuasimatrix by {type(other)}") # noqa: TRY003 

404 

405 def spy(self, ax: Axes | None = None, **kwds: Any) -> Axes: 

406 """Visualise the shape of the transposed quasimatrix.""" 

407 ax = ax or plt.gca() 

408 n = len(self._qm.columns) 

409 rect = plt.Rectangle((0.5, 0.5), 10, n, fill=False, edgecolor="black", linewidth=1.5) 

410 ax.add_patch(rect) 

411 for j in range(n): 

412 ax.plot(5.5, j + 1, "bs", markersize=8, **kwds) 

413 ax.set_xlim(0, 11) 

414 ax.set_ylim(0, n + 1) 

415 ax.set_aspect("equal") 

416 ax.set_ylabel(f"n = {n}") 

417 ax.set_xlabel("∞") 

418 ax.set_yticks(range(1, n + 1)) 

419 ax.set_xticks([]) 

420 return ax 

421 

422 def __repr__(self) -> str: 

423 """Return a string representation.""" 

424 n = len(self._qm.columns) 

425 if n == 0: 

426 return "_TransposedQuasimatrix(empty)" 

427 sup = self._qm.support 

428 return f"_TransposedQuasimatrix({n}x inf on [{sup[0]}, {sup[1]}])" 

429 

430 

431# ------------------------------------------------------------------ 

432# Module-level convenience functions 

433# ------------------------------------------------------------------ 

434def polyfit(f: Chebfun, n: int) -> Chebfun: 

435 """Least-squares polynomial fit of degree *n* to a Chebfun *f*. 

436 

437 Returns a Chebfun representing the best degree-*n* polynomial 

438 approximation to *f* in the L²-norm. 

439 """ 

440 x = Chebfun.initidentity(f.domain) 

441 cols: list[Chebfun] = [Chebfun.initconst(1.0, f.domain)] 

442 xk = cols[0] 

443 for _ in range(n): 

444 xk = xk * x 

445 cols.append(xk) 

446 A = Quasimatrix(cols) 

447 c = A.solve(f) 

448 return A @ c