7 min read

Gaussian Mixture Clustering with RStan

\(\newcommand{\t}[1]{ {#1}^T }\) \(\newcommand{\inv}[1]{ {#1}^{-1} }\) \(\newcommand{\vec}[1]{\mathbf{#1}}\) \(\newcommand{\rv}[1]{\mathbf{#1}}\) \(\newcommand{\follows}{\thicksim}\) \(\newcommand{\normal}[1]{\mathcal{N}\!\left(#1\right)}\) \(\newcommand{\ga}[1]{\mathrm{Gamma}\!\left(#1\right)}\) \(\newcommand{\diff}[1]{\mathop{}\!\mathrm{d}#1}\) \(\newcommand{\fun}[2]{#1\left(#2\right)}\) \(\newcommand{\expect}[1]{\vec{E}\left[#1\right]}\) \(\newcommand{\prob}[1]{\fun{p}{#1}}\) \(\DeclareMathOperator*{\argmax}{arg\,max}\)
pacman::p_load("MASS")
pacman::p_load("tidyverse")
pacman::p_load("rstan")
pacman::p_load("mvtnorm")

Let’s generate some testing data first. For this post, I will generate 1000 data points from 3 bivariate normal distributions using the mvrnorm function from the MASS package.

set.seed(43) # set the seed for random number generation

Define the number of data points and clusters:

N <- 1000 # number of data points
K <- 3    # number of clusters

Now we generate an index vector; the \(i\)-th element in this index vector indicates the cluster from which the \(i\)-th data point will be generated:

# generate the indicator indices:
which.cluster <- sample(1:K, N, replace = T, prob = c(.1, .6, .3))

Make up three mean vectors mus and three covariance matricies sigmas for generating the dummy data:

mus <- list(c(2, 2), c(-2, -2), c(4, -4))
sigmas <- list(diag(c(1, 1)), diag(c(2, 2)), diag(c(2, 2)))

Now actually generate the data by scanning through the indices and generate data points:

data <- which.cluster %>% sapply(FUN = function(i) {mvrnorm(n = 1, mus[[i]], sigmas[[i]], tol = 1e-6, empirical = FALSE, EISPACK = FALSE)}) %>% t()

We can plot the data to gain a visualise understanding of the clusters:

data %>% cbind(which.cluster) %>% as_tibble(.name_repair = "minimal") %>% `colnames<-`(c("X", "Y", "cluster"))%>% mutate(cluster = as.character(cluster)) %>% ggplot() + geom_point(aes(x = X, y = Y, colour = cluster))

Now we move to define our model in Stan that will infer the means deviations of the three clusters from the data.

Modelling Gaussian Mixtures in Stan

Now we specify the mixture model in Stan’s modelling language. A Stan model is usually divided into several “blocks”; each of these blocks specifies different aspects of the model. The simplest model would have three blocks:

  1. data: This block declares the properties of the input (i.e. training) data to be used with the model such as their dimensions and contraints (e.g. “a probability vector whose elements should add up to one”, i.e. a simplex vector).
  2. parameters: This block declares the parameters we want Stan to estimate. During the estimation, Stan will walk through the parameter space defined by the declarations made in this parameters block guided by the likelihood function specified in the model block; inferences are made possible by examining the distributions of the samples drawn for the parameters along the way.
  3. model: This is where we actually tell Stan how to calculate the log-likelihood for each point in the parameter space using the data.
model_code <- "
data {

  // ===============================
  int<lower=0> N;  // number of data points
  int<lower=1> M;  // number of variables
  int<lower=1> K;  // number of clusters
  vector[M] x[N];  // the data
  // ===============================

  // ===============================
  // hyper-parameters for the cluster probabilities:
  vector[K] alpha;
  // ===============================
  
  // ===============================
  // the following two hyperparameters specify the prior distribution
  // (a normal distribution in this case) for the mean vectors for the
  // normal distributions that generated the data points for each of
  // the clusters:

  // the mean for the normal prior that governs 
  // the mean vectors for the clusters:
  vector[M] mu_prior_mu;

  // the covariance matrix for the normal prior that 
  // governs the mean vectors for the clusters:
  cov_matrix[M] mu_prior_sigma;
  // ===============================

  // ===============================
  // the following two hyperparameters specify the prior distribution
  // (an inverse Wishart distribution) for the covariances for the 
  // normal distributions that generated the data points for each of 
  // the clusters:

  real<lower=M> covariance_prior_p;
  matrix[M, M] covariance_prior_V;
  // ===============================
}

parameters {
  simplex[K] p;           // the overall weights between the K clusters
  vector[M] mu[K];        // the mean vectors for each of the K clusters
  cov_matrix[M] sigma[K]; // the covariance matrices for each of the K clusters
}

model {
  p ~ dirichlet(alpha); // p is contrained by a dirichlet prior

  // For each cluster, we set up the pirors for 
  // their mean vectors and covariance matrices:
  for (k in 1:K) {
    mu[k] ~ multi_normal(mu_prior_mu, mu_prior_sigma);
    sigma[k] ~ inv_wishart(covariance_prior_p, covariance_prior_V);
  }

  for (n in 1:N) {
    // For each data point n and for each cluster k, 
    // compute and record the log-likelihoods that
    // cluster k has generated the data point n:
    real lps[K];
    for (k in 1:K) {
      lps[k] = log(p[k]) + multi_normal_lpdf( x[n] | mu[k], sigma[k] );
    }

    // This is how the mixture is specified:
    // the log_sum_exp 'trick' convert the log-likelihoods into linear space, 
    // sum them up to form the overall likelihood for data poin n, and 
    // convert it back into a log-likelihood. 
    // Note that the 'target' variable is a reserved variable for stan 
    // and is used to keep track of the overall log-likelihood for each iteration.
    target += log_sum_exp(lps);
  }
  
}
"
model <- stan_model(model_code = model_code)

The Stan package is originally designed for doing Bayesian analysis. Here, however, I would like to try out its support for point estimation using maximum-likelihood. By default Stan uses the L-BFGS algorithm to perform the optimisation.

fit <- rstan::optimizing(model, 
                  data = list(N = N, 
                              x = data, 
                              M = 2, K = 3, 
                              alpha = c(5, 5, 5), 
                              mu_prior_mu = c(0, 0), 
                              mu_prior_sigma = diag(c(10, 10)), 
                              covariance_prior_p = 2, 
                              covariance_prior_V = diag(c(1, 1))), 
                  seed = 43,
                  as_vector = FALSE)

Now that we have obtained a fit, let’s extract the estimates for the parameters:

# a helper function to extract the parameters estimates from stan's fit object:
extract_params <- function(pars) { 
  1:(dim(pars)[1]) %>%  # the values in the first dimension serve as the indices for the actual parameters
    lapply(FUN = function(k) {  # for each of the indices:
      pars %>% dim() %>% {
        # extract the dimensions, remove the first element and expand the rest into index range:
        c(k, lapply(.[-1], FUN = function(x) { 1:x }) ) 
      } %>% do.call( what = function(...) { `[`(pars, ...) }) }) } 

# extract the parameter estimates
ps  <- fit$par$p %>% extract_params()
mus <- fit$par$mu %>% extract_params()
sigmas <- fit$par$sigma %>% extract_params()

Visualising the Fit

However, the parameter estimates only tell us about the locations and shapes of the normal distributions that correspond to the clusters; to visualise the results of the clustering, we need to figure out which cluster does each data point belongs to. In other words, for each data point \(\vec{x}\) and each cluster \(k\), we calculate the (log-)likelihood that it \(\vec{x}\) belongs to cluster \(k\); we then assign \(\vec{x}\) with the cluster that has the highest (log-)likelihood for having generated \(\vec{x}\).

# Create a function that calculates the log-likelihood
# that a set of data was generated from a specific cluster:
loglikelihood_for_cluster <- function(data, p, mu, sigma) {
  apply(data, MARGIN = 1, FUN = function(x, p, mu, sigma) { log(p) + dmvnorm(x, mean = mu, sigma = sigma, log = TRUE)}, p, mu, sigma)
}

# Apply the above function over all clusters.
# We get a matrix where each row represents a data point and
# each of the k columns is the log-likelihood for the said data point
# when it is generated by the k-th cluster:
loglikelihoods <- mapply(function(p, mu, sigma) {loglikelihood_for_cluster(data, p, mu, sigma )} , ps, mus, sigmas)

# Assign the clusters and plot it.
cbind(data, cluster = loglikelihoods %>% apply(MARGIN = 1, FUN = which.max)) %>% as_tibble(.name_repair = "minimal") %>% `colnames<-`(c("X", "Y", "cluster")) %>% mutate(cluster = as.character(cluster)) %>% ggplot() + geom_point(aes(x = X, y = Y, colour = cluster))

Caveat

Sometimes the optimisation algorithm can get trapped into local maximums. In such cases, the parameter estimates would be very different from the true parameters that were used to generate the data. To combat this, one may avail of Stan’s automatic variational inference facility, or go full Bayesian through sampling.

Test the Model on the IRIS Dataset

It might be interesting to see how our model performs outside the toy dataset we prepared. In this case, I will be using the classic iris dataset commonly used to demonstrate clustering algorithms.

fit <- rstan::optimizing(model, 
                  data = list(N = nrow(iris), 
                              x = iris[, -5], 
                              M = 4, K = 3, 
                              alpha = c(5, 5, 5), 
                              mu_prior_mu = iris[, -5] %>% colMeans(), 
                              mu_prior_sigma = diag(rep(4, 4)), 
                              covariance_prior_p = 4, 
                              covariance_prior_V = diag(rep(1, 4))), 
                  seed = 43,
                  as_vector = FALSE)

ps  <- fit$par$p %>% extract_params()
mus <- fit$par$mu %>% extract_params()
sigmas <- fit$par$sigma %>% extract_params()

loglikelihoods <- mapply(function(p, mu, sigma) {loglikelihood_for_cluster(iris[, -5], p, mu, sigma )}, ps, mus, sigmas)

classified <- cbind(iris, cluster = loglikelihoods %>% apply(MARGIN = 1, FUN = which.max)) %>% as_tibble(.name_repair = "minimal") %>% mutate(cluster = as.character(cluster))

Display the confusion matrix:

classified %>% dplyr::select(cluster, Species) %>% table()
##        Species
## cluster setosa versicolor virginica
##       1      0         48         0
##       2     50          0         0
##       3      0          2        50

We can see that for the most part, the clusterings made by the algorithm can be corresponded to the actual species. Also plot the classification:

classified %>% ggplot() + geom_point(aes(x = Sepal.Length, y = Sepal.Width, colour = cluster, shape = Species))