diff --git a/prob_dim_red/linear_gaussian_ppca.py b/prob_dim_red/linear_gaussian_ppca.py index b3e69b35453327d903d679add376dd38089022c7..077aa8e952ab8f9fe82e0f99d064cf4bf0f2fe23 100644 --- a/prob_dim_red/linear_gaussian_ppca.py +++ b/prob_dim_red/linear_gaussian_ppca.py @@ -26,6 +26,7 @@ from typing import Optional, Generator from dataclasses import dataclass from functools import cached_property, reduce import operator +from collections import deque import numpy as np import numpy.typing as npt @@ -292,7 +293,7 @@ class _DataDescFixed: @cached_property def tr_sample_cov(self): """ - Cached property to compute the trace of the sample covariance. + Cached property to compute the trace of the sample covariance: Nâ»Â¹Xáµ€X. """ return utils.trace_matmul(self.centered_data.T, self.centered_data) / self.N @@ -520,13 +521,18 @@ class LinearGaussianPPCAEstimEM(_LinearGaussianPPCAEstim): """ Initialize the EM algorithm. """ + alpha = 0.5 + scaling_factor_w = ( + alpha * self.fixed.tr_sample_cov / (self.fixed.X_dim * self.fixed.Z_dim) + ) ** 0.5 + # pylint:disable=invalid-name - W = np.random.normal(size=(self.X_dim, self.Z_dim)) + W = scaling_factor_w * np.random.normal(size=(self.X_dim, self.Z_dim)) # U, S, _ = np.linalg.svd(W, full_matrices=False) # W = U * S[None,:] - noise_var = 1e-05 + noise_var = (1 - alpha) * self.fixed.tr_sample_cov / self.fixed.X_dim # the convergence is slower if the init value is greater than true noise var. self.state = _EM_state(W=W, noise_var=noise_var, fixed=self.fixed) @@ -557,14 +563,16 @@ class LinearGaussianPPCAEstimEM(_LinearGaussianPPCAEstim): self.state = _EM_state(W=W_new, noise_var=noise_var_new, fixed=self.fixed) def _em_states( - self, max_iterations, error_tolerance + self, max_iterations, error_tolerance, tolerance_window ) -> Generator[_EM_state, None, None]: """ early stopping """ self._em_init() yield self.state - likelihood = self.state.log_likelihood + que_crit = deque() + crit = self.state.log_likelihood + que_crit.append(crit) for _ in ( range(max_iterations) if max_iterations is not None else itertools.count() ): @@ -575,9 +583,12 @@ class LinearGaussianPPCAEstimEM(_LinearGaussianPPCAEstim): # self.state.W = U * S[None,:] yield self.state - likelihood, old_likelihood = self.state.log_likelihood, likelihood - if (likelihood - old_likelihood) < error_tolerance: - break + crit = self.state.log_likelihood + que_crit.append(crit) + + if len(que_crit) > tolerance_window: + if (crit - que_crit.popleft()) / tolerance_window < error_tolerance: + break U, S, _ = np.linalg.svd(self.state.W, full_matrices=False) self.state.W = U * S[None, :] @@ -592,10 +603,11 @@ class LinearGaussianPPCAEstimEM(_LinearGaussianPPCAEstim): log_det_model_var=self.state.log_det_model_var, ) - def fit( + def fit( # pylint:disable=too-many-positional-arguments, too-many-arguments self, max_iterations=None, error_tolerance=1e-6, + tolerance_window=10, trace=False, progress=True, ) -> Optional[list[_EM_state]]: @@ -608,6 +620,369 @@ class LinearGaussianPPCAEstimEM(_LinearGaussianPPCAEstim): Maximum number of iterations for the EM algorithm. error_tolerance : float Tolerance for the convergence of the EM algorithm. + tolerance_window : int + Window size for moving average over criterion. + trace : bool + Whether to store the intermediate states of the EM algorithm. + progress : bool + Show progress with tqdm (not compatible with trace). + + Returns + ------- + Generator[_EM_state] | None + intermediate states of the EM algorithm if trace is True, otherwise None. + """ + states = self._em_states(max_iterations, error_tolerance, tolerance_window) + if trace: + return states + + if progress: + last_comp_lik = 0 + scale = 0 + for state in (progress_bar := tqdm(states, unit_scale=True)): + gap, last_comp_lik, last_lik = ( + state.complete_log_likelihood - last_comp_lik, + state.complete_log_likelihood, + state.log_likelihood, + ) + scale += 0.1 * (np.log10(np.abs(gap)) - scale) + dscale = max(int(-np.floor(scale - 0.5)), 0) + fmt_comp_lik = "comp-log-lik={" + f":.{dscale:d}f" + "}" + fmt_marg_lik = "marg-log-lik={" + f":.{dscale:d}f" + "}" + progress_bar.set_postfix_str( + fmt_comp_lik.format(last_comp_lik) + + " - " + + fmt_marg_lik.format(last_lik), + refresh=False, + ) + else: + for _ in states: + pass + + return self + + +@dataclass +# pylint:disable=invalid-name +class no_cov_EM_state: + """ + Dataclass for storing the state of the EM algorithm. + + Attributes + ---------- + W : npt.NDArray[np.float64] + Factor-loading matrix. + noise_var : np.float64 + Noise variance σ². + fixed : _DataDescFixed + Fixed data description. + """ + + # pylint:disable=invalid-name + W: npt.NDArray[np.float64] + noise_var: np.float64 + fixed: _DataDescFixed + + @cached_property + def noiseless_posterior_precision(self) -> npt.NDArray[np.float64]: + """Noiseless posterior precision matrix: Wáµ€ W""" + return self.W.T @ self.W + + @cached_property + def posterior_precision(self) -> npt.NDArray[np.float64]: + """Posterior precision matrix: M = Wáµ€ W + σ² I""" + return self.noiseless_posterior_precision + self.noise_var * np.eye( + self.W.shape[1] + ) + + @cached_property + def posterior_variance(self) -> npt.NDArray[np.float64]: + """Posterior variance matrix: (Wáµ€ W + σ²I)â»Â¹ = Mâ»Â¹""" + return np.linalg.inv(self.posterior_precision) + + @cached_property + def centered_data_factor_load_mat_prod(self) -> npt.NDArray[np.float64]: + """Matrix product of centered data and factor loading matrix: (Wáµ€ Xáµ€)áµ€""" + return (self.W.T @ self.fixed.centered_data.T).T + + @cached_property + def posterior_mean(self) -> npt.NDArray[np.float64]: + """Posterior: X W Mâ»Â¹""" + return self.centered_data_factor_load_mat_prod @ self.posterior_variance + + @cached_property + def tr_inter_calc_comp_lik(self) -> npt.NDArray[np.float64]: + """Trace of intermediary calculation in complete likelihood: Tr(Wáµ€ Xáµ€ X W Mâ»Â¹)""" + return np.sum( + ( + (self.centered_data_factor_load_mat_prod).T + @ self.centered_data_factor_load_mat_prod + ) + * self.posterior_variance + ) + + @cached_property + def inter_calc_uninv(self) -> npt.NDArray[np.float64]: + """Intermediary calculation not inverted: N σ² Mâ»Â¹ + Mâ»Â¹ Wáµ€ Xáµ€ X W Mâ»Â¹""" + return (self.posterior_mean.T @ self.posterior_mean) + ( + self.noise_var * self.posterior_variance * self.fixed.N + ) + + @cached_property + def tr_inter_calc_uninv(self) -> npt.NDArray[np.float64]: + """Trace of intermediary calculation not inverted: Tr(N σ² Mâ»Â¹ + Mâ»Â¹ Nâ»Â¹ Wáµ€ Xáµ€ X W Mâ»Â¹)""" + return np.trace(self.inter_calc_uninv) + + @cached_property + def tr_inter_calc_uninv_noiseless_posterior_precision_mat_prod( + self, + ) -> npt.NDArray[np.float64]: + """Trace of matrix prod of intermediary calculation + not inverted and noiseless posterior precision: + Tr((N σ² Mâ»Â¹ + Mâ»Â¹ Nâ»Â¹ Wáµ€ Xáµ€ X W Mâ»Â¹) Wáµ€ W) + """ + return utils.trace_matmul( + self.inter_calc_uninv, self.noiseless_posterior_precision, sym=True + ) + + @cached_property + def inter_calc_inv(self) -> npt.NDArray[np.float64]: + """intermediary calculation not inverted: (N σ² Mâ»Â¹ + Mâ»Â¹ Nâ»Â¹ Wáµ€ Xáµ€ X W Mâ»Â¹)â»Â¹""" + return np.linalg.inv(self.inter_calc_uninv) + + @cached_property + def model_var(self) -> npt.NDArray[np.float64]: + """Model covariance matrix: W Wáµ€ + σ²I = C""" + return self.W @ self.W.T + self.noise_var * np.eye(self.fixed.X_dim) + + @cached_property + def log_det_model_var(self) -> npt.NDArray[np.float64]: + """Log-determinant of the model covariance matrix: log |C|""" + return np.linalg.slogdet(self.model_var)[1] + + @cached_property + def model_var_inv(self) -> npt.NDArray[np.float64]: + """Inverse of the model covariance matrix: Câ»Â¹ = (W Wáµ€ + σ²I)â»Â¹""" + return np.linalg.inv(self.model_var) + + @cached_property + def complete_log_likelihood(self) -> np.float64: + """ + Complete log-likelihood of the model. + """ + # pylint:disable=line-too-long + if self.noise_var == 0: + return -np.inf + + return ( + -( + self.fixed.N + / 2 + * (self.fixed.X_dim * np.log(2 * np.pi * self.noise_var)) + ) + - (self.fixed.N / 2 * (+self.fixed.Z_dim * np.log(2 * np.pi))) + - self.tr_inter_calc_uninv / 2 + - ( + self.tr_inter_calc_uninv_noiseless_posterior_precision_mat_prod + / (2 * self.noise_var) + ) + - self.fixed.N * (self.fixed.tr_sample_cov / (2 * self.noise_var)) + + self.tr_inter_calc_comp_lik / self.noise_var + ) + + @cached_property + def log_likelihood(self) -> np.float64: + """ + Log-likelihood of the model. + """ + if self.noise_var == 0: + return -np.inf + return ( + -self.fixed.N + / 2 + * ( + self.fixed.X_dim * np.log(2 * np.pi) + + self.log_det_model_var + + 1 + / self.noise_var + * ( + self.fixed.tr_sample_cov + - self.tr_inter_calc_comp_lik / self.fixed.N + ) + ) + ) + + +@dataclass(kw_only=True) +class LinearGaussianPpcaNoCovEMResult(_LinearGaussianPpcaResult): + """ + Dataclass for storing the results of Linear Gaussian PPCA EM estimation. + + Attributes + ---------- + mean : npt.NDArray[np.float64] + Mean of the data. + noise_var : np.float64 + Noise variance. + complete_log_likelihood : np.float64 + Complete log-likelihood of the model. + log_likelihood : np.float64 + Marginal log-likelihood of the model. + model_var_inv : npt.NDArray[np.float64] + Inverse of the model covariance. + log_det_model_var : npt.NDArray[np.float64] + Log-determinant of the model covariance. + """ + + # pylint:disable=invalid-name + mean: npt.NDArray[np.float64] + noise_var: np.float64 + complete_log_likelihood: np.float64 + log_likelihood: np.float64 + model_var_inv: npt.NDArray[np.float64] + log_det_model_var: npt.NDArray[np.float64] + + +class LinearGaussianPPCAEstimNoCovEM(_LinearGaussianPPCAEstim): + """ + Class for Linear Gaussian PPCA estimation using the Expectation-Maximization (EM) algorithm. + + Attributes + ---------- + fixed : _DataDescFixed + Fixed data descriptors. + state : no_cov_EM_state | None + State of the EM algorithm. + _result : LinearGaussianPpcaEMResult | None + Result of the EM estimation. + """ + + fixed: _DataDescFixed + state: Optional[no_cov_EM_state] + _result: Optional[LinearGaussianPpcaNoCovEMResult] + + def __init__(self, *args, **kwargs): + """ + Initialize the Linear Gaussian PPCA EM estimation object. + """ + super().__init__(*args, **kwargs) + self.fixed = _DataDescFixed( + N=self.N, + X_dim=self.X_dim, + Z_dim=self.Z_dim, + centered_data=self.centered_data, + ) + self.state = None + + def _em_init(self): + """ + Initialize the EM algorithm. + """ + alpha = 0.5 + scaling_factor_w = ( + alpha * self.fixed.tr_sample_cov / (self.fixed.X_dim * self.fixed.Z_dim) + ) ** 0.5 + + # pylint:disable=invalid-name + W = scaling_factor_w * np.random.normal(size=(self.X_dim, self.Z_dim)) + + # U, S, _ = np.linalg.svd(W, full_matrices=False) + # W = U * S[None,:] + + noise_var = (1 - alpha) * self.fixed.tr_sample_cov / self.fixed.X_dim + # the convergence is slower if the init value is greater than true noise var. + + self.state = no_cov_EM_state(W=W, noise_var=noise_var, fixed=self.fixed) + + def _em_step(self): + """ + Perform one step of the EM algorithm. + """ + # pylint:disable=invalid-name + # pylint:disable=line-too-long + W_new = self.fixed.centered_data.T @ ( + self.state.posterior_mean @ self.state.inter_calc_inv + ) + + noise_var_new = ( + 1 + / (self.X_dim * self.fixed.N) + * ( + self.fixed.tr_sample_cov * self.fixed.N + - 2 + * utils.trace_matmul( + self.fixed.centered_data.T @ self.state.posterior_mean, + W_new.T, + sym=False, + ) + + utils.trace_matmul( + self.state.inter_calc_uninv, W_new.T @ W_new, sym=True + ) + ) + ) + + self.state = no_cov_EM_state(W=W_new, noise_var=noise_var_new, fixed=self.fixed) + + def _em_states( + self, max_iterations, error_tolerance, tolerance_window + ) -> Generator[no_cov_EM_state, None, None]: + """ + early stopping + """ + self._em_init() + yield self.state + que_crit = deque() + crit = self.state.log_likelihood + que_crit.append(crit) + for _ in ( + range(max_iterations) if max_iterations is not None else itertools.count() + ): + self._em_step() + + # if count % 120 == 0: + # U, S, _ = np.linalg.svd(self.state.W, full_matrices=False) + # self.state.W = U * S[None,:] + + yield self.state + crit = self.state.log_likelihood + que_crit.append(crit) + + if len(que_crit) > tolerance_window: + if (crit - que_crit.popleft()) / tolerance_window < error_tolerance: + break + + U, S, _ = np.linalg.svd(self.state.W, full_matrices=False) + self.state.W = U * S[None, :] + + self._result = LinearGaussianPpcaNoCovEMResult( + mean=self.mean, + W=self.state.W, + noise_var=self.state.noise_var, + complete_log_likelihood=self.state.complete_log_likelihood, + log_likelihood=self.state.log_likelihood, + model_var_inv=self.state.model_var_inv, + log_det_model_var=self.state.log_det_model_var, + ) + + def fit( # pylint:disable=too-many-positional-arguments, too-many-arguments + self, + max_iterations=None, + error_tolerance=1e-6, + tolerance_window=10, + trace=False, + progress=True, + ) -> Optional[list[no_cov_EM_state]]: + """ + Fit the Linear Gaussian PPCA model using the EM algorithm. + + Parameters + ---------- + max_iterations : int | None + Maximum number of iterations for the EM algorithm. + error_tolerance : float + Tolerance for the convergence of the EM algorithm. + tolerance_window : int + Window size for moving average over criterion. trace : bool Whether to store the intermediate states of the EM algorithm. progress : bool @@ -618,7 +993,7 @@ class LinearGaussianPPCAEstimEM(_LinearGaussianPPCAEstim): Generator[_EM_state] | None intermediate states of the EM algorithm if trace is True, otherwise None. """ - states = self._em_states(max_iterations, error_tolerance) + states = self._em_states(max_iterations, error_tolerance, tolerance_window) if trace: return states