EM Algorithm
Motivation
More often than not, we cannot model the data distribution \(p(x)\) with a given distribution as the data distribution is more complex than the given one. In this case, we introduce some latent variable \(z\) and assume \(p(x\vert z)\) follows that distribution. In other words, we regards the data distribution \(p(x)\) as a combination of several independent distributions \(p(x\vert z)\) whose weights are given by \(p(z)\), i.e., \(p(x)=\sum_zp(z)p(x\vert z)\). In practice, we usually maximize the log likelihood for numeric stability and parameterizes the distribution by \(\theta\) as follows
\[\begin{align} \mathcal J(\theta)=\log p(x;\theta)=\log\sum_zp(x,z;\theta)\tag {1} \end{align}\]Eq.\( (1) \) per se is intractable since it requires to sum up \( p(x, z;\theta) \) for all \( z \) and \( z \) may be innumerable.
Intuitive Derivation of EM Algorithm
To solve Eq.\( (1) \), we first applies Bayes’ rule
\[\begin{align} \log p(x;\theta)&=\log {p(x, z;\theta)\over p(z|x;\theta)}\\\ &=\log{p(x,z;\theta)\over q(z)}+\log{q(z)\over p(z|x;\theta)}\\\ &=\mathbb E_{z\sim q(z)}\left[\log{p(x,z;\theta)}\right]-\mathbb E_{z\sim q(z)}\left[\log q(z)\right]+\mathbb E_{z\sim q(z)}\left[\log{q(z)\over p(z|x;\theta)}\right]\\\ &=\mathbb E_{z\sim q(z)}\left[\log{p(x,z;\theta)}\right] +H(Z)+ D_{KL}(q(z)\Vert p(z|x;\theta))\tag {2} \end{align}\]where the expectation taken in the third step makes sense since the left hand side of the equation \( \log p(x;\theta) \) is independent of \( z \).
Intuitively, we could see from Eq.\( (2) \) that if the KL term is zero, then maximizing the objective defined in Eq.\( (1) \) is equivalent to maximizing the first term in Eq.\( (2) \) since \( H(Z) \) is constant w.r.t. \( \theta \). In other words, we may solve Eq.\( (1) \) by making \( q(z)=p(z\vert x;\theta) \) and then maximizing \( \mathbb E_{z\sim q(z)}\left[\log{p(x,z;\theta)}\right] \). The second step, however, changes \( \theta \), making \( q(z) \ne p(z\vert x; \theta) \), so we reassign \( q(z) \) and repeat the above process until convergence, which is exaatly what the EM algorithm does.
Summing it up, the EM algorithm iterates the following two steps:
- E-step: Evaluate the conditional distribution over \( z \) given the dataset and the current model parameters \( p(z\vert x;\theta^{(t)}) \), then we have \( Q(\theta, \theta^{(t)}):=\mathbb E_{z\sim p(z\vert x;\theta^{(t)})}[\log p(x,z;\theta)] \)
- M-step: Maximize \( Q(\theta, \theta^{(t)})\), \( \theta^{(t+1)}\leftarrow \arg\max_\theta Q(\theta,\theta^{(t)}) \)
A detailed proof that this process monotonically improves \(p(x;\theta)\) is given in the end.
Pros and Cons of EM Algorithm
The EM algorithm has some very appealing advantages, which make it a good approach to solve maximum likelihood or maximum a posteriori estimates of parameters in statistic models:
- It guarantees to continuously improve the model, i.e., \( p(x; \theta_{t+1})\ge p(x;\theta_{t}) \)
- It works well in practice
On the other hand, it also comes with some cons
- Thanks to its continuously improving nature, it is easy to get stuck in a local optimum. One possible solution to ease this problem is to use different initializations and then select the one with the highest \( \mathcal J(\theta) \) defined in Eq.\( (1) \).
- Convergence can be slow
- It is an algorithm that is specialized to exponential families, such as Gaussian distributions and Bernoulli distributions, which has the form of
EM Algorithm for Gaussian Mixture Models
Gaussian Mixture Models(GMMs)
Before diving into the EM algorithm for GMMs, let’s first briefly recap GMMs. A Gaussian mixture model is a probabilistic model that assume all \( N \) data points are generated from a mixture of \( K \) Gaussian distributions with unknown parameters (\( {\mu_k} \), \( {\Sigma_k} \)), where each Gaussian component \( \mathcal N(x\vert \mu_k, \Sigma_k) \) is weighed by a prior probability \( \pi_k \). Together a Gaussian mixture model is defined as follows
\[\begin{align} p(x;\theta)=\sum_{k=1}^Kp(z_k)p(x|z_k)=\sum_{k=1}^K\pi_k\mathcal N(x;\mu_k, \Sigma_k)\\\ \mathrm{where\ }\theta=\{\pi_k,\mu_k, \Sigma_k\} \end{align}\]Or in the latent form, for each datapoint \( x_i \) we sample a latent variable \( z_i \) from \( \mathrm{Categorical}(\pi) \), indicating which cluster(i.e., Gaussian component) \( x_i \) belongs to, and we have
\[\begin{align} z_i\sim\mathrm{Categorical}(\pi)\\\ p(x_i|z_i=k;\theta)=\mathcal N(x_i;\mu_k, \Sigma_k) \end{align}\]EM Algorithm for GMMs
To ease notation, we start off by defining a matrix \( r \) of shape \( (N, K) \), where \( r_{i,k} \) is the probability (or weight) that \( x_i \) is of cluster \( k \). Furthermore, we define \( r_k \) as the total weights all data points contribute to cluster \( k \). Mathematically, we have
\[\begin{align} r_{i,k}&=p(z_i=k|x_i;\theta^{(t)})\tag {3}\\\ r_k&=\sum_{i=1}^Nr_{i,k} \end{align}\]Notice that
\[\begin{align} \sum_{k=1}^Kr_k&=\sum_{i=1}^N\sum_{k=1}^Kr_{i,k}\\\ &=\sum_{i=1}^N1 \\\ &=N \end{align}\]The detailed derivation of EM algorithm for GMMs is appended in the end, here we just outline the results.
E-step: This is a general step for any mixture model, in which we compute all \( r_{i,k} \) for \( \theta^{(t)} \) according to Eq.\( (3) \)
M-step: Update parameters:
\[\begin{align} \pi_k^{(t+1)}&={r_k\over N}\\\ \mu_k^{(t+1)}&={1\over r_k}\sum_{i=1}^Nr_{i,k}x_i\\\ \Sigma_k^{(t+1)}&={1\over r_k}\sum_{i=1}^Nr_{i,k}(x_i-\mu_k^{(t+1)})(x_i-\mu_k^{(t+1)})^T \end{align}\]These results give some nice intuitions
- The prior probability \( \pi_k \) is the average probability that a random data point in the dataset is of cluster \( k \)
- The mean \( \mu_k \) is the weighted mean of data points in the dataset
- The covariance matrix \( \Sigma_k \) is the weighted sample covariance matrix using the new mean computed in the previous step
Supplementary Materials
Proof that EM Algorithm Actually Works
To argue why the EM algorithm works, all we have to do is to clarify that \( \log p(x;\theta^{(t+1)})\ge\log p(x;\theta^{(t)}) \), where \( \theta^{(t+1)} \) is produced by the M-step. Let us first start with the following observation
\[\begin{align} \log p(x;\theta)&=\log p(x,z;\theta) - \log p(z|x;\theta)\\\ &=\mathbb E_{z\sim q(z)}\left[\log p(x,z;\theta)\right]-\mathbb E_{z\sim q(z)}[\log p(z|x;\theta)] \end{align}\]where \( q(z) \) could be any arbitrary distribution on \( z \).
By setting \( \theta=\theta^{(t+1)} \) and \( q(z)=p(z\vert x;\theta^{(t)}) \), we get
\[\begin{align} \log p(x;\theta^{(t+1)})&=\mathbb E_{z\sim p(z|x;\theta^{(t)})}\left[\log p\left(x,z;\theta^{(t+1)}\right)\right]-\mathbb E_{z\sim p(z|x;\theta^{(t)})}\left[\log p\left(z|x;\theta^{(t+1)}\right)\right]\\\ &=Q(\theta^{(t+1)},\theta^{(t)})-\mathbb E_{z\sim p(z|x;\theta^{(t)})}\left[\log p\left(z|x;\theta^{(t+1)}\right)\right]\\\ &\ge Q(\theta^{(t)},\theta^{(t)})-\mathbb E_{z\sim p(z|x;\theta^{(t)})}\left[\log p\left(z|x;\theta^{(t+1)}\right)\right]\tag {4} \end{align}\]where the last inequality comes by the definition of the M-step. Let \( p(\theta)=p(z\vert x;\theta) \), then we have
\[\begin{align} -\mathbb E_{z\sim p(\theta^{(t)})}\left[\log p(\theta^{(t+1)})\right]&=-\mathbb E_{z\sim p(\theta^{(t)})}\left[\log p(\theta^{(t+1)})-\log p(\theta^{(t)})+\log p(\theta^{(t)})\right]\\\ &=D_{KL}\left(p(\theta^{(t)})\Vert p(\theta^{(t+1)})\right) -\mathbb E_{z\sim p(\theta^{(t)})}\left[\log p(\theta^{(t)})\right]\\\ &\ge -\mathbb E_{z\sim p(\theta^{(t)})}\left[\log p(\theta^{(t)})\right]\tag{5} \end{align}\]stick Eq.\( (5) \) back into Eq.\( (4) \)
\[\begin{align} \log p(x;\theta^{(t+1)})&\ge Q(\theta^{(t)},\theta^{(t)})-\mathbb E_{z\sim p(z|x;\theta^{(t)})}\left[\log p\left(z|x;\theta^{(t)}\right)\right]\\\ &=\log p\left(x;\theta^{(t)}\right) \end{align}\]Until now we have shown that EM Algorithm steadily improves \( p(x;\theta) \).
Derivation of EM Algorithm for GMMs
E-Step
In the E-step, we first compute \( p(z\vert x;\theta^{(t)}) \), i.e., all \( r_{i,k} \) at time step \(t\), from Bayes’ theorem. Then we compute \( Q(\theta, \theta^{(t)}) \) as follows
\[\begin{align} Q(\theta,\theta^{(t)})&=\mathbb E_{z\sim p(z|x;\theta^{(t)})}[\log p(x_1, \dots, x_n,z_1,\dots, z_n;\theta)]\\\ &=\sum_{i=1}^N\mathbb E_{z_i\sim p(z_i|x_i;\theta^{(t)})}[\log p(x_i,z_i;\theta)]\\\ &=\sum_{i=1}^N\mathbb E_{z_i\sim p(z_i|x_i;\theta^{(t)})}[\log p(z_i)]+\mathbb E_{z_i\sim p(z_i|x_i;\theta^{(t)})}[\log p(x_i|z_i;\theta)]\\\ &=\sum_{i=1}^N\sum_{k=1}^Kr_{i,k}\log\pi_k+r_{i,k}\log p(x_i|z_i=k;\theta)\tag {6} \end{align}\]Notice that for now we have not used any knowledge specific to GMMs—Eq.\( (6) \) is applicable to any mixture model. Immediately we will see that the M-step update for \( \pi \) is also the same for any mixture model.
M-Step
In this step, we optimize all three kinds of parameters in turn.
Update \(\pi\)
From Eq.\( (6) \), we could see that updating \( \pi \) is equivalent to solve the following constraint optimization problem:
\[\begin{align} \pi^{(t+1)}&=\arg\max_\pi\sum_{i=1}^N\sum_{k=1}^Kr_{i,k}\log \pi_k\\\ &=\arg\max_\pi\sum_{k=1}^Kr_k\log \pi_k\\\ s.t. \sum_{k=1}^K\pi_k&=1\\\ \pi_k&\ge0,\quad k=1,\dots, K \end{align}\]This is a concave maximization problem, and can be easily solved using the Lagrangian multiplier method. The Lagrangian is
\[\begin{align} \mathcal L(\pi_1,\dots, \pi_k,\lambda)=\sum_{k=1}^Kr_k\log\pi_k-\lambda\left(\sum_{k=1}^K\pi_k-1\right) \end{align}\]By taking the gradient of the Lagrangian and setting it to zero, we get
\[\begin{align} \lambda&=\sum_{k=1}^Kr_k=N\\\ \pi_k&={r_k\over N},\quad k=1,\dots,K \end{align}\]As we mentioned earlier, update of \( \pi \) is independent of any model parameter and therefore the above process is applicable to any mixture model.
Update \( \mu \) and \( \Sigma \)
Both \( \mu \) and \( \Sigma \) are computed by taking the gradient of Eq.\( (6) \) and setting it to zero, respectively. Concretely, we first expand the second term in Eq.\( (6) \) by plugging in the Gaussian PDF
\[\begin{align} &\max_\theta\sum_{i=1}^N\sum_{k=1}^K r_{i,k}\log p(x_i|z_i;\theta)\\\ =&\max_{\mu, \Sigma}\sum_{i=1}^N\sum_{k=1}^K r_{i,k}\left(-{d\over 2}\log 2\pi-{1\over 2}\Big(\log\det(\Sigma_k)+(x_i-\mu_k^T)\Sigma_k^{-1}(x_i-\mu_k^T)\Big)\right)\\\ =&\max_{\mu, \Sigma}\sum_{i=1}^N\sum_{k=1}^K r_{i,k}\left(-{1\over 2}\Big(\log\det(\Sigma_k)+(x_i-\mu_k)^T\Sigma_k^{-1}(x_i-\mu_k)\Big)\right)\tag {7} \end{align}\]To compute the optimal value for \( \mu_k \), we take the gradient of Eq.\( (7) \) w.r.t. \( \mu_k \), setting it to zero, then we have
\[\begin{align} -\sum_{i=1}^Nr_{i,k}\Sigma_k^{-1}(x_i-\mu_k)=0\\\ \mu_k={\sum_{i=1}^N r_{i,k}x_i\over r_{k}},\quad k=1,\dots,K \end{align}\]The optimal value for \( \Sigma_k \) is computed in the same process, but with a bit more machinary from matrix calculus. Assuming \( X \) and \( A \) are both symmetric, we have
\[\begin{align} \nabla_XTr(X^{-1}A)=-X^{-1}AX^{-1}\\\ \nabla_X\log\det(X)=X^{-1} \end{align}\]Then the gradient w.r.t. \( \Sigma_k \) is
\[\begin{align} &\nabla_{\Sigma_k}\left(\sum_{i=1}^N\sum_{k=1}^K r_{i,k}\left(-{1\over 2}\Big(\log\det(\Sigma_k)+(x_i-\mu_k)^T\Sigma_k^{-1}(x_i-\mu_k)\Big)\right)\right)\\\ =&-{1\over 2}\sum_{i=1}^Nr_{i,k}\Big(\nabla_{\Sigma_k}\log\det(\Sigma_k)+Tr(\Sigma_k^{-1}(x_i-\mu_k)(x_i-\mu_k)^T\Big)\\\ =&-{1\over 2}\left(r_k\Sigma_k^{-1}-\sum_{i=1}^Nr_{i,k}\Sigma_k^{-1}(x_i-\mu_k)(x_i-\mu_k)^T\Sigma_k^{-1}\right) \end{align}\]by setting it to zero, it gives us the result
\[\begin{align} \Sigma_k={1\over r_k}{\sum_{i=1}^Nr_{i,k}(x_i-\mu_k)(x_i-\mu_k)^T} \end{align}\]Observe that, when doing this, nothing prevents the next iterate of \( \Sigma \) from becoming singular. One can add \( \epsilon I_d \) to the estimate for some small \( \epsilon >0 \) to ensure non-degeneracy, but this is mostly a hack.
References
https://www.youtube.com/playlist?list=PLD0F06AA0D2E8FFBA
Stephen Tu. Learning Mixture Models