EM Algorithm

At the end of the last section we found the mle’s for the Old Faithful data using Newton’s algorithm. There is another way to find the mle’s in a problem of this kind called the EM or Expectation-Maximization algorithm. The idea is as follows.

Say there were a second variable Zi which is 0 if the next waiting time is a short one and 1 otherwise. Now if we knew those Zi’s it would be easy to estimate the \(\mu\)’s:

\[ \hat{\mu_i}=\text{mean}(X|Z=z_i)\\ \hat{\sigma_i}=\text{sd}(X|Z=z_i) \]

Notice that this is just what we did in the previous section for our first guess: we divided the data in those less than 65 or higher, so we used \(z_i=I_{[0, 65]}(x_i)\).

On the other hand if we knew all the means and standard deviations it would also be easy to estimate \(\alpha\):

\[ \psi_i= \alpha f_{i,1}+(1-\alpha) f_{i,2}\\ w_i=\frac{\alpha f_{i,1}}{\psi_i}\\ \hat{\alpha}=\text{mean}(w_i) \] here the\(\psi_i\) is (kind of) the probability of \(x_i\) under the full model. If \(w_i\) is close to 1, the probability under just model 1 is almost as high as the under the mixture model, which likely means this is an observation from group 1.

The \(w_i\) are called the weights. These formulas can be verified easily using probability theory.

This suggests the following algorithm:

  • choose a starting point for the parameters
  • find the weights
  • find the next estimates of the parameters
  • iterate until convergence

Here is an implementation of the algorithm for the Old Faithful data:

alpha <- 0.355
mu <- c(50, 80)
sigma <- c(5.4, 5.9)
w <- rep(0, 40)
k <- 0
x <- faithful$Waiting.Time
repeat {
  k <- k+1
  psi <- (alpha*dnorm(x, mu[1], sigma[1]) + 
        (1-alpha)*dnorm(x, mu[2], sigma[2]))
  w <- alpha*dnorm(x, mu[1], sigma[1])/psi
  alpha <- mean(w)
  mu[1] <- sum(w*x)/sum(w)
  mu[2] <- sum((1-w)*x)/sum(1-w)
  sigma[1] <- sqrt(sum(w*(x-mu[1])^2)/sum(w)) 
  sigma[2] <- sqrt(sum((1-w)*(x-mu[2])^2)/sum(1-w))
  psi1 <- (alpha*dnorm(x, mu[1], sigma[1]) + 
        (1-alpha)*dnorm(x, mu[2], sigma[2]))
  cat(round(alpha,4), " ",
      round(mu, 1), " ",
      round(sigma, 2), " ",
      round(sum(log(psi1)), 5), "\n")
  if(sum(abs(psi-psi1))<0.001) break
  if(k>100) break
}
## 0.3345   53.8 79.5   5.21 6.46   -1035.686 
## 0.3422   54 79.7   5.42 6.31   -1034.87 
## 0.348   54.2 79.8   5.55 6.17   -1034.425 
## 0.3522   54.3 79.9   5.65 6.07   -1034.199 
## 0.3551   54.4 80   5.72 6   -1034.091 
## 0.357   54.5 80   5.77 5.95   -1034.041 
## 0.3583   54.5 80   5.8 5.92   -1034.019 
## 0.3592   54.6 80.1   5.82 5.9   -1034.009 
## 0.3598   54.6 80.1   5.84 5.89   -1034.005 
## 0.3602   54.6 80.1   5.85 5.88   -1034.003 
## 0.3604   54.6 80.1   5.86 5.88   -1034.002 
## 0.3606   54.6 80.1   5.86 5.87   -1034.002 
## 0.3607   54.6 80.1   5.87 5.87   -1034.002 
## 0.3607   54.6 80.1   5.87 5.87   -1034.002 
## 0.3608   54.6 80.1   5.87 5.87   -1034.002

One big advantage of the EM algorithm is that it let’s us deal with each component of the mixture separately. That turns out to be much easier than working on the full model.

Notice one feature of the EM algorithm: it guarantees that each iteration moves the parameters closer to the mle.

Here is the fitted line plot:

x <- seq(min(faithful$Waiting.Time),
         max(faithful$Waiting.Time),
         length=250)
y <- alpha*dnorm(x, mu[1], sigma[1]) + 
     (1-alpha)*dnorm(x, mu[2], sigma[2])
df <- data.frame(x=x, y=y)
bw <- diff(range(faithful$Waiting.Time))/50 
ggplot(faithful, aes(Waiting.Time)) +
  geom_histogram(aes(y = ..density..),
                 color = "black", 
                 fill = "white", 
                 binwidth = bw) + 
  labs(x = "Waiting Times", y = "Counts") +
  geom_line(aes(x, y), data=df,
            size=1, colour="blue")

Example

Consider the following data set:

bw <- diff(range(df$x))/50 
ggplot(df, aes(x)) +
  geom_histogram(aes(y = ..density..),
    color = "black", 
    fill = "white", 
    binwidth = bw) + 
    labs(x = "x", y = "Density") 

Let’s assume we know that this is a mixture of a normal and a truncated exponential. That is we have the density

\[f(x;\alpha, \lambda, \mu, \sigma)=\alpha\frac{\lambda \exp \{-\lambda x\}}{1-\exp \{-\lambda\}}+(1-\alpha)\frac1{\sqrt{2\pi \sigma^2}}\exp\{-\frac{(x-\mu)^2}{2\sigma^2}\}\]

If we want to use the EM algorithm we can reuse the code from above, except for the part were we need to estimate \(\lambda\). Now the log likelihood function becomes

\[ \begin{aligned} &l(\lambda) = \log \prod (\frac{\lambda \exp \{-\lambda x_i\}}{1-\exp \{-\lambda\}})^{w_i} = \\ &\sum w_i \left[\log \lambda-\lambda x_i-\log(1-\exp \{-\lambda\})\right] = \\ &(\sum w_i) \log \lambda-\lambda \sum w_i x_i-\log(1-e^{-\lambda})(\sum w_i) \\ \end{aligned} \]

and we can maximize this using R:

lambda.hat <- function(x, w) {
  fn <- function(l) {
     -(log(l)*sum(w)-l*sum(x*w)-log(1-exp(-l))*sum(w))
  }
  nlm(fn, 1/mean(x))$estimate
}

Let’s do a little test:

x <- rexp(1e5, 1)
x <- x[x<1]
length(x)
## [1] 63192
w <- rep(1, length(x))
lambda.hat(x, w)
## [1] 0.989678

so that looks good. Now:

dtexp <- function(x, lambda)
  lambda*exp(-lambda*x)/(1-exp(-lambda))
x <- df$x
alpha <- 0.5
mu <- 0.6
sigma <- 0.05
lambda <- 1/mean(x)
w <- rep(0, 1000)
k <- 0
repeat {
  k <- k+1
  psi <- alpha*dtexp(x, lambda) + 
        (1-alpha)*dnorm(x, mu, sigma)
  w <- alpha*dtexp(x, lambda)/psi
  alpha <- mean(w)
  lambda <- lambda.hat(x, w)
  mu <- sum((1-w)*x)/sum(1-w)
  sigma <- sqrt(sum((1-w)*(x-mu)^2)/sum(1-w))
  psi1 <- alpha*dtexp(x, lambda) + 
        (1-alpha)*dnorm(x, mu, sigma)
  cat(round(alpha, 2), " ",
      round(lambda, 2), " ",
      round(mu, 3), " ",
      round(sigma, 4), " ",
      round(sum(log(psi1)), 5), "\n")
  if(sum(abs(psi-psi1))<0.001) break
  if(k>1000) break
}
## 0.75   1.43   0.6   0.0613   29.95453 
## 0.81   1.23   0.601   0.063   47.4751 
## 0.83   1.14   0.601   0.0618   52.40959 
## 0.85   1.1   0.601   0.0592   54.87027 
## 0.86   1.07   0.601   0.0559   56.67691 
## 0.87   1.05   0.601   0.0524   58.29363 
## 0.87   1.04   0.601   0.0489   59.83651 
## 0.88   1.03   0.602   0.0456   61.30018 
## 0.88   1.02   0.603   0.0425   62.63315 
## 0.88   1.01   0.603   0.0397   63.78085 
## 0.88   1.01   0.604   0.0374   64.71343 
## 0.89   1   0.604   0.0354   65.43338 
## 0.89   1   0.605   0.0338   65.96668 
## 0.89   0.99   0.605   0.0325   66.3493 
## 0.89   0.99   0.606   0.0315   66.61717 
## 0.89   0.99   0.606   0.0306   66.80121 
## 0.89   0.98   0.606   0.03   66.92584 
## 0.9   0.98   0.607   0.0294   67.00929 
## 0.9   0.98   0.607   0.029   67.06471 
## 0.9   0.98   0.607   0.0287   67.10127 
## 0.9   0.98   0.607   0.0284   67.12527 
## 0.9   0.97   0.607   0.0282   67.14096 
## 0.9   0.97   0.607   0.028   67.15119 
## 0.9   0.97   0.607   0.0279   67.15785 
## 0.9   0.97   0.607   0.0278   67.16217 
## 0.9   0.97   0.607   0.0277   67.16498 
## 0.9   0.97   0.607   0.0276   67.16679 
## 0.9   0.97   0.608   0.0275   67.16797 
## 0.9   0.97   0.608   0.0275   67.16873 
## 0.9   0.97   0.608   0.0275   67.16922 
## 0.9   0.97   0.608   0.0274   67.16954 
## 0.9   0.97   0.608   0.0274   67.16974 
## 0.9   0.97   0.608   0.0274   67.16988 
## 0.9   0.97   0.608   0.0274   67.16996 
## 0.9   0.97   0.608   0.0274   67.17002 
## 0.9   0.97   0.608   0.0274   67.17005 
## 0.9   0.97   0.608   0.0273   67.17008 
## 0.9   0.97   0.608   0.0273   67.17009 
## 0.9   0.97   0.608   0.0273   67.1701 
## 0.9   0.97   0.608   0.0273   67.17011 
## 0.9   0.97   0.608   0.0273   67.17011 
## 0.9   0.97   0.608   0.0273   67.17011 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012 
## 0.9   0.97   0.608   0.0273   67.17012

and here is the graph:

x <- seq(0, 1, length=250)
y <- alpha*dtexp(x, lambda) + 
     (1-alpha)*dnorm(x, mu, sigma)
df1 <- data.frame(x=x, y=y)
bw <- 1/50 
ggplot(df, aes(x)) +
  geom_histogram(aes(y = ..density..),
                 color = "black", 
                 fill = "white", 
                 binwidth = bw) + 
  labs(x = "x", y = "") +
  geom_line(aes(x, y), data=df1,
            size=1, colour="blue")

Example

Consider the following example:

x <- c(rnorm(1000), rnorm(1000, 0.5) )

so we have a normal mixture, and we want to estimate the parameters again. So we can use the EM routine from above:

em.mix <- function(alpha, mu, sigma, x) {
  w <- rep(0, length(x))
  repeat {
    psi <- (alpha*dnorm(x, mu[1], sigma[1]) + 
        (1-alpha)*dnorm(x, mu[2], sigma[2]))
    w <- alpha*dnorm(x, mu[1], sigma[1])/psi
    alpha <- mean(w)
    mu[1] <- sum(w*x)/sum(w)
    mu[2] <- sum((1-w)*x)/sum(1-w)
    sigma[1] <- sqrt(sum(w*(x-mu[1])^2)/sum(w)) 
    sigma[2] <- sqrt(sum((1-w)*(x-mu[2])^2)/sum(1-w))
    psi1 <- (alpha*dnorm(x, mu[1], sigma[1]) + 
        (1-alpha)*dnorm(x, mu[2], sigma[2]))
    if(sum(abs(psi-psi1))<0.001) break
  }
  round(c(alpha, mu, sigma), 4)
}
em.mix(alpha=0.5, mu=c(0, 0.5), sigma=c(1, 1), x=x)
## [1]  0.5002 -0.0077  0.5052  0.9962  1.0094

and that looks fine. However:

em.mix(alpha=0.5, mu=c(0, 1), sigma=c(1, 2), x=x)
## [1] 0.5792 0.1401 0.3980 1.0192 1.0380

and that fails badly. The reason is the following. Here is a graph of the data:

df <- data.frame(x=x)
bw <- diff(range(x))/50 
ggplot(df, aes(x)) +
  geom_histogram(aes(y = ..density..),
                 color = "black", 
                 fill = "white", 
                 binwidth = bw) + 
  labs(x = "x", y = "") 

and we see that here the two peaks are very close together, relative to the standard deviations. In this the likelihood surface has ridges, that is there are many points the maximize the likelihood function. This problem is called non-identifyablity.

mixtools

There is a nice package for dealing with mixture distribution that uses the EM algorithm called mixtools:

library(mixtools)
res <- normalmixEM(faithful$Waiting.Time, 
            lambda = .5, 
            mu = c(55, 80), 
            sigma = c(5, 55))
## number of iterations= 24
round(c(res$lambda[1], res$mu, res$sigma), 3)
## [1]  0.361 54.615 80.091  5.871  5.868

The EM algorithm was originally invented by Dempster and Laird in 1977 to deal with a common problem in Statistics called censoring: say we are doing a study on survival of patients after cancer surgery. Any such study will have a time limit after which we will have to start with the data analysis, but hopefully there will still be some patients who are alive, so we don’t know their survival times, but we do know that the survival times are greater than the time that has past so far. We say the data is censored at time T. The number of patients with survival times >T is important information and should be used in the analysis. If we order the observations into (x1, .., xn) the uncensored observations (the survival times of those patients that are now dead) and (xn+1, .., xn+m) the censored data, the likelihood function can be written as

\[ L(\theta|x)=\left[ 1-F(T|\theta\right)]^m\prod_{i=1}^{n}f(x_i|\theta) \]

where \(F\) is the distribution function of \(f\).

Of course if we knew the survival times of those m censored patients was (zn+1, .., zn+m) we could write the complete data likelihood:

\[ L(\theta|x, z)=\prod_{i=1}^{n}f(x_i|\theta)\prod_{i=n+1}^{n+m}f(z_i|\theta) \]

This suggests the EM algorithm:

  • in the M step assume you know the z’s and estimate \(\theta\)
  • in the E step assume you know \(\theta\) and estimate the z’s

Example Censored exponential survival times

Say \(X_i \sim \text{Exp}(\theta)\), we have data X1, .., Xn and we know that m observations were censored at T. Now one can find that

\[ \hat{\theta}=\frac{n+m}{\sum x_i + \sum z_i}\\ z_i=\frac1{\theta}+T \]

em.exp <- function(x, m, T) {
  theta.old <- 1/mean(x)
  repeat {
    z <- rep(1/theta.old+T, m)
    theta.new <- 1/mean(c(x, z))
    print(theta.new, 5)
    if(abs(theta.new-theta.old)<0.0001) break
    theta.old <- theta.new
  }
}
x <- rexp(1000, 0.1)
1/mean(x)
## [1] 0.09747232
em.exp(x, 0, 0)
## [1] 0.097472
x <- x[x<20]
m <- 1000 - length(x)
m
## [1] 151
1/mean(x)
## [1] 0.1470817
em.exp(x, m, 20)
## [1] 0.10184
## [1] 0.097324
## [1] 0.096676
## [1] 0.096579
x <- x[x<10]
m <- 1000 - length(x)
m
## [1] 395
1/mean(x)
## [1] 0.2492553
em.exp(x, m, 10)
## [1] 0.1256
## [1] 0.10502
## [1] 0.098634
## [1] 0.096321
## [1] 0.095437
## [1] 0.095092
## [1] 0.094957
## [1] 0.094904

We have concentrated on maximum likelihood as a general method for parameters estimation. There are however many other methods as well.