Skip to contents

Implements regime (ii) of Hoek and Elliott (2024). Runs the textbook expectation-maximisation algorithm for Gaussian mixtures on the supplied samples, with diagonal ridge regularisation for numerical stability, optional multi-start, and monotone-log-likelihood checking.

Usage

fit_em_samples(
  target,
  N = 2L,
  init = NULL,
  max_iter = 100L,
  tol = 1e-06,
  ridge_eps = 1e-06,
  n_starts = 5L,
  anneal = FALSE,
  temp_schedule = NULL,
  seed = NULL,
  canonicalise = TRUE
)

Arguments

target

A gmm_target carrying an n by p samples matrix.

N

Number of mixture components.

init

A gmm initialisation, or NULL to use init_kmeans().

max_iter

Maximum number of EM iterations.

tol

Relative-log-likelihood convergence tolerance.

ridge_eps

Ridge added to each component covariance at every M-step.

n_starts

Number of multi-start initialisations (only when init is NULL). The best fit by final log-likelihood is returned.

anneal

Logical. If TRUE, a deterministic-annealing warm-start (see gmm_anneal_path()) replaces the multi-start: the components are annealed from a high temperature down to one, and the resulting parameters seed a single final (cold) EM polish. This attacks the local-optima sensitivity of cold EM at the cost of the schedule length. Defaults to FALSE (cold best-of-n_starts).

temp_schedule

Optional numeric vector of descending temperatures for the annealing warm-start. NULL (the default) uses a geometric schedule from 10 down to 1 in covariance-whitened units. Ignored when anneal = FALSE.

seed

Optional integer seed for the annealing perturbations (the warm-start is deterministic given a seed). Ignored when anneal = FALSE.

canonicalise

Logical. If TRUE (the default), the fitted mixture is post-processed by gmm_canonicalise() so that components are sorted by descending weight and (as a tiebreaker) by descending ||mu||.

Value

A gmm_fit with regime = "sample". When anneal = TRUE the diagnostics list also carries annealed = TRUE and the temp_schedule used.

Examples

x <- matrix(stats::rnorm(200), ncol = 2)
tgt <- gmm_target_from_samples(x)
fit <- fit_em_samples(tgt, N = 2L, max_iter = 30L, n_starts = 2L)
fit@diagnostics$loglik_final
#> [1] -284.1687