diff --git a/src/tinygp/kernels/quasisep.py b/src/tinygp/kernels/quasisep.py index af232811..b380f08a 100644 --- a/src/tinygp/kernels/quasisep.py +++ b/src/tinygp/kernels/quasisep.py @@ -32,10 +32,11 @@ import jax import jax.numpy as jnp import numpy as np +from jax.scipy.linalg import block_diag as jsp_block_diag from tinygp.helpers import JAXArray from tinygp.kernels.base import Kernel -from tinygp.solvers.quasisep.block import Block +from tinygp.solvers.quasisep.block import Block, ensure_dense from tinygp.solvers.quasisep.core import DiagQSM, StrictLowerTriQSM, SymmQSM from tinygp.solvers.quasisep.general import GeneralQSM @@ -220,20 +221,39 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray: class Sum(Quasisep): - """A helper to represent the sum of two quasiseparable kernels""" + """A helper to represent the sum of two quasiseparable kernels + + Args: + kernel1: The first kernel. + kernel2: The second kernel. + use_block: If ``True`` (default), use :class:`Block` diagonal matrices + for the transition matrices, design matrices, and stationary + covariance. If ``False``, use dense ``block_diag`` representations + instead, which avoids compatibility issues with some operations + (e.g. banded noise, product kernels) at a small performance cost + for the state-space matrices. + """ kernel1: Quasisep kernel2: Quasisep + use_block: bool = eqx.field(static=True, default=True) def coord_to_sortable(self, X: JAXArray) -> JAXArray: """We assume that both kernels use the same coordinates""" return self.kernel1.coord_to_sortable(X) + def _block_or_dense(self, m1: JAXArray, m2: JAXArray) -> JAXArray: + if self.use_block: + return Block(m1, m2) + return jsp_block_diag(m1, m2) + def design_matrix(self) -> JAXArray: - return Block(self.kernel1.design_matrix(), self.kernel2.design_matrix()) + return self._block_or_dense( + self.kernel1.design_matrix(), self.kernel2.design_matrix() + ) def stationary_covariance(self) -> JAXArray: - return Block( + return self._block_or_dense( self.kernel1.stationary_covariance(), self.kernel2.stationary_covariance(), ) @@ -247,7 +267,7 @@ def observation_model(self, X: JAXArray) -> JAXArray: ) def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray: - return Block( + return self._block_or_dense( self.kernel1.transition_matrix(X1, X2), self.kernel2.transition_matrix(X1, X2), ) @@ -632,6 +652,8 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray: def _prod_helper(a1: JAXArray, a2: JAXArray) -> JAXArray: + a1 = ensure_dense(a1) + a2 = ensure_dense(a2) i, j = np.meshgrid(np.arange(a1.shape[0]), np.arange(a2.shape[0])) i = i.flatten() j = j.flatten() diff --git a/src/tinygp/solvers/quasisep/block.py b/src/tinygp/solvers/quasisep/block.py index f5736064..6e102cd8 100644 --- a/src/tinygp/solvers/quasisep/block.py +++ b/src/tinygp/solvers/quasisep/block.py @@ -9,6 +9,13 @@ from tinygp.helpers import JAXArray +def ensure_dense(x: Any) -> Any: + """Convert a Block to a dense array, passing through non-Block inputs.""" + if isinstance(x, Block): + return x.to_dense() + return x + + class Block(eqx.Module): blocks: tuple[Any, ...] __array_priority__ = 1999 diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index ac1fdf18..412e0e35 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -29,6 +29,7 @@ from jax.scipy.linalg import block_diag from tinygp.helpers import JAXArray +from tinygp.solvers.quasisep.block import ensure_dense def handle_matvec_shapes( @@ -213,20 +214,24 @@ def impl( return StrictLowerTriQSM( p=jnp.concatenate((p1, p2)), q=jnp.concatenate((q1, q2)), - a=block_diag(a1, a2), + a=block_diag(ensure_dense(a1), ensure_dense(a2)), ) return impl(self, other) def self_mul(self, other: StrictLowerTriQSM) -> StrictLowerTriQSM: """The elementwise product of two :class:`StrictLowerTriQSM` matrices""" + # vmap is needed because a batched Block has 3D block arrays that + # block_diag (used by to_dense) cannot handle without unbatching. + self_a = jax.vmap(ensure_dense)(self.a) + other_a = jax.vmap(ensure_dense)(other.a) i, j = np.meshgrid(np.arange(self.p.shape[1]), np.arange(other.p.shape[1])) i = i.flatten() j = j.flatten() return StrictLowerTriQSM( p=self.p[:, i] * other.p[:, j], q=self.q[:, i] * other.q[:, j], - a=self.a[:, i[:, None], i[None, :]] * other.a[:, j[:, None], j[None, :]], + a=self_a[:, i[:, None], i[None, :]] * other_a[:, j[:, None], j[None, :]], ) def __neg__(self) -> StrictLowerTriQSM: diff --git a/src/tinygp/solvers/quasisep/ops.py b/src/tinygp/solvers/quasisep/ops.py index ff48fb18..4d88d6dd 100644 --- a/src/tinygp/solvers/quasisep/ops.py +++ b/src/tinygp/solvers/quasisep/ops.py @@ -8,6 +8,7 @@ import jax.numpy as jnp from tinygp.helpers import JAXArray +from tinygp.solvers.quasisep.block import ensure_dense from tinygp.solvers.quasisep.core import ( QSM, DiagQSM, @@ -145,15 +146,15 @@ def impl( u += [upper_b.p] if upper_b is not None else [] if lower_a is not None and lower_b is not None: + la_a = ensure_dense(lower_a.a) + lb_a = ensure_dense(lower_b.a) ell = jnp.concatenate( ( - jnp.concatenate( - (lower_a.a, jnp.outer(lower_a.q, lower_b.p)), axis=-1 - ), + jnp.concatenate((la_a, jnp.outer(lower_a.q, lower_b.p)), axis=-1), jnp.concatenate( ( - jnp.zeros((lower_b.a.shape[0], lower_a.a.shape[0])), - lower_b.a, + jnp.zeros((lb_a.shape[0], la_a.shape[0])), + lb_a, ), axis=-1, ), @@ -162,33 +163,33 @@ def impl( ) else: ell = ( - lower_a.a + ensure_dense(lower_a.a) if lower_a is not None - else lower_b.a if lower_b is not None else None + else ensure_dense(lower_b.a) if lower_b is not None else None ) if upper_a is not None and upper_b is not None: + ua_a = ensure_dense(upper_a.a) + ub_a = ensure_dense(upper_b.a) delta = jnp.concatenate( ( jnp.concatenate( ( - upper_a.a, - jnp.zeros((upper_a.a.shape[0], upper_b.a.shape[0])), + ua_a, + jnp.zeros((ua_a.shape[0], ub_a.shape[0])), ), axis=-1, ), - jnp.concatenate( - (jnp.outer(upper_b.q, upper_a.p), upper_b.a), axis=-1 - ), + jnp.concatenate((jnp.outer(upper_b.q, upper_a.p), ub_a), axis=-1), ), axis=0, ) else: delta = ( - upper_a.a + ensure_dense(upper_a.a) if upper_a is not None - else upper_b.a if upper_b is not None else None + else ensure_dense(upper_b.a) if upper_b is not None else None ) return ( diff --git a/tests/test_kernels/test_quasisep.py b/tests/test_kernels/test_quasisep.py index c426224d..c585f989 100644 --- a/tests/test_kernels/test_quasisep.py +++ b/tests/test_kernels/test_quasisep.py @@ -6,6 +6,7 @@ from tinygp import GaussianProcess from tinygp.kernels import quasisep +from tinygp.noise import Banded from tinygp.test_utils import assert_allclose @@ -157,3 +158,41 @@ def test_carma_quads(): assert_allclose(carma31.arroots, carma31_quads.arroots) assert_allclose(carma31.acf, carma31_quads.acf) assert_allclose(carma31.obsmodel, carma31_quads.obsmodel) + + +def test_sum_kernel_with_banded_noise(data): + x, y, _ = data + N = len(x) + k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0) + banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1))) + gp = GaussianProcess(k, x, noise=banded) + assert jnp.isfinite(gp.log_probability(y)) + lp, cond_gp = gp.condition(y) + assert jnp.isfinite(lp) + + +def test_product_of_sum_kernel(data): + x, y, _ = data + k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0) + gp = GaussianProcess(k, x, diag=jnp.ones(len(x))) + assert jnp.isfinite(gp.log_probability(y)) + assert_allclose(k.to_symm_qsm(x).to_dense(), k(x, x)) + + +def test_sum_times_sum_kernel(data): + x, y, _ = data + k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * ( + quasisep.Exp(0.5) + quasisep.Matern32(1.0) + ) + gp = GaussianProcess(k, x, diag=jnp.ones(len(x))) + assert jnp.isfinite(gp.log_probability(y)) + + +def test_sum_kernel_use_block_false(data): + x, y, _ = data + N = len(x) + k_block = quasisep.Cosine(1.0) + quasisep.Cosine(2.0) + k_dense = quasisep.Sum(quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False) + gp_block = GaussianProcess(k_block, x, diag=0.1 * jnp.ones(N)) + gp_dense = GaussianProcess(k_dense, x, diag=0.1 * jnp.ones(N)) + assert_allclose(gp_block.log_probability(y), gp_dense.log_probability(y))