EM (Expectation-Maximization)

The EM algorithm seems at first to solve a very specific problem but it turns out to be quite useful in general.

Example

Let’s return to the normal mixture model we considered earlier: \[ \begin{aligned} &Y_1 \sim N(\mu_1, \sigma_1) \\ &Y_2 \sim N(\mu_2, \sigma_2) \\ &Z \sim \text{Ber}(p)\\ &X=(1-Z) Y_1 + Z Y_2 \end{aligned} \] Let’s assume for the moment that in addition to \(X\) we also observe \(Z\). Then

to simplify a bit let’s assume \(\sigma_1 = \sigma_2 =1\), then

Notice that \(\hat \mu_1\) is just the mean of the observations from group 1, which we can identify because we know the z’s. It is therefore easy to guess what will happen if we also let the \(\sigma\)’s float: \(\hat \sigma_i\) is the sample standard deviation of the events in group i.

So if we knew the \(z_i\)’s this would be a simple problem. On the other hand,

\[ \begin{aligned} &E[Z|X=x]=\\ &0 \cdot P(Z=0|X=x)+ 1 \cdot P(Z=1|X=x)=\\ &P(Z=1|X=x)=\\ &\frac{P(Z=1,X=x)}{P(X=x)}=\\ &\frac{p\phi(x; \mu_2, \sigma_2)}{(1-p)\phi(x; \mu_1, \sigma_1)+p\phi(x; \mu_2, \sigma_2)} \end{aligned} \] so if we knew the parameters we could estimate each of the \(z_i\)’s. This is then the basic idea of the EM algorithm:

  • in the M step assume you know the \(z_1 ,..,z_n\) and estimate the parameters.

  • in the E step use these parameters to estimate the \(z_1 ,..,z_n\). This is done by setting Zi=1 if E[Zi|Xi]>0.5 and Zi=0 otherwise.

Here is the implementation:

emNormalMix <- function(x,  p=0.5, mu=c(0, 3.5), 
                  sigma=c(1, 1), start=c(p, mu, sigma)) {
    loglike <- function(x, p, mu) 
        sum(log((1-p)*dnorm(x, mu[1])+p*dnorm(x, mu[2])))
    n <- length(x)
    a <- start
    p <- a[1]
    mu <- a[2:3]
    sigma <- a[4:5]
    z <- ifelse(p*dnorm(x, mu[2])/
        ((1-p)*dnorm(x, mu[1])+p*dnorm(x, mu[2]))>0.5, 1, 0)
    print(round(c(a, loglike(x, p, mu)), 3))
    repeat {
        aold <- a
        z <- ifelse(p * dnorm(x, mu[2])/((1 - p)*dnorm(x, mu[1]) + 
            p * dnorm(x, mu[2])) > 0.5, 1, 0)
        p <- sum(z)/n
        mu[1] <- mean(x[z == 0])
        sigma[1] <- sd(x[z == 0])
        mu[2] <- mean(x[z == 1])
        sigma[2] <- sd(x[z == 1])
        a <- c(p, mu, sigma)
        print(round(c(a, loglike(x, p, mu)), 3))
        if (sum(abs(a - aold)) < 1e-04) 
            break
    }
    x.points <- seq(min(x), max(x), length = 100)
    y.points <- (1 - p) * dnorm(x.points, mu[1], sigma[1]) + p * dnorm(x.points, mu[2], sigma[2])
    hist(x, freq=FALSE, main="", ylim=c(0, max(y.points)))
    lines(x.points, y.points, lwd=2)

}
n <- 1000
p <- 0.3
mu <- c(0, 3.5)
sigma <- c(1, 2)
z <- sample(c(0, 1), size=n, replace=TRUE, prob=c(1-p, p))
x <- (1 - z) * rnorm(n, mu[1], sigma[1]) + 
           z * rnorm(n, mu[2], sigma[2])
emNormalMix(x)
## [1]     0.500     0.000     3.500     1.000     1.000 -2190.555
## [1]     0.287    -0.035     3.887     0.879     1.517 -2058.708
## [1]     0.250     0.063     4.175     0.958     1.413 -2039.484
## [1]     0.240     0.092     4.255     0.985     1.386 -2037.173
## [1]     0.236     0.104     4.286     0.996     1.376 -2036.613
## [1]     0.233     0.113     4.309     1.005     1.369 -2036.331
## [1]     0.231     0.119     4.325     1.011     1.365 -2036.210
## [1]     0.230     0.122     4.333     1.014     1.362 -2036.169
## [1]     0.228     0.129     4.349     1.020     1.357 -2036.129
## [1]     0.226     0.135     4.365     1.026     1.352 -2036.146
## [1]     0.225     0.138     4.373     1.029     1.350 -2036.174
## [1]     0.225     0.138     4.373     1.029     1.350 -2036.174

Let’s apply the algorithm to a famous data set, the Old Faithful data, specifically the length of the Waiting.Time:

attach(faithful)
hist(Waiting.Time, main="")

To run the routine we need some starting values. It seems that the two groups are those with data less than and more than 70, so

mu <- c(mean(Waiting.Time[Waiting.Time<70]), mean(Waiting.Time[Waiting.Time>70]))
sigma <- c(sd(Waiting.Time[Waiting.Time<70]), sd(Waiting.Time[Waiting.Time>70]))
print(c(mu,sigma), digits=2)
## [1] 55.2 80.7  6.3  5.3
emNormalMix(Waiting.Time, mu = mu, sigma = sigma)
## [1]     0.500    55.155    80.745     6.267     5.268 -4892.593
## [1]     0.632    54.750    80.285     5.895     5.627 -4856.739
## [1]     0.632    54.750    80.285     5.895     5.627 -4856.739


The EM algorithm was originally invented by Dempster in 1977 to deal with a common problem in Statistics called censoring:

say we are doing a study on survival of patients after cancer surgery. Any such study will have a time-limit after which we will have to start with the data analysis, but hopefully there will still be some patients who are alive, so we don’t know their survival times, but we do know that the survival times are greater than the time that has past sofar. We say the data is censored at time T.

The number of patients with survival times >T is important information and should be used in the analysis. If we order the observations into (x1, .., xn) the uncensored observations (the survival times of those patients that are now dead) and (xn+1, .., xn+m) the censored data, the likelihood function can be written as

because all we know of the censored data is that \[ P(X_i>T)=1-F(T|\theta) \] If we had also observed the survival-times of the censored patients, say z=(zn+1, .., zn+m) we could have written the complete-data likelihood

and again we can use the EM algorithm to estimate \(\theta\):

  • in the M step assume you know the \(z_1,..,z_n\) and estimate \(\theta\).

  • in the E step use \(\theta\) to estimate the \(z_1,..,z_n\)

Example

Say \(X_i \sim \text{Exp}(\theta)\) and we have data \((x_1, .., x_n)\) and we know that m observations were censored at T. Now

so the EM algorithm proceeds as follows:

  • in the M step assume you know the \(z_1,..,z_n\) and estimate \(\theta = 1/mean(x_1, .., x_n, z_{n+1}, .., z_{n+m~})\).

  • in the E step use \(\theta\) to estimate the \(z_1,..,z_n =1/\theta+T\)

emCensExp <- function (n = 1000, T = 1, m = 0, theta = 1, start = theta) 
{
    loglike <- function(x, theta, m, T) {
        -theta * T * m + sum(log(dexp(x, theta)))
    }
    x <- rexp(n, theta)
    u <- seq(theta * 0.75, 1.25 * theta, length = 100)
    ll <- rep(0, 100)
    for (i in 1:100) ll[i] = loglike(x, u[i], m, T)
    plot(u, ll, type = "l", lwd = 2, xlab = expression(theta), 
        ylab = "Log-Likelihood")
    truetheta <- theta
    theta <- start
    print(round(c(theta, loglike(x, theta, m, T)), 3))
    abline(v = theta)
    repeat {
        thetaold <- theta
        z <- rep(1/theta + T, m)
        theta <- 1/mean(c(x, z))
        print(round(c(theta, loglike(x, theta, m, T)), 3))
        abline(v = theta)
        if (abs(theta - thetaold) < 1e-04) 
            break
    }
    theta
}

Let’s first check the case without censoring:

emCensExp()

## [1]     1.000 -1016.039
## [1]     0.984 -1015.912
## [1]     0.984 -1015.912
## [1] 0.984214

And now with 200 censored events:

emCensExp(m=200)

## [1]     1.000 -1194.576
## [1]     0.860 -1178.174
## [1]     0.841 -1177.802
## [1]     0.838 -1177.792
## [1]     0.837 -1177.792
## [1]     0.837 -1177.792
## [1] 0.8371344

Often the EM agorithm is useful in a Catch-22 situation: if we knew A we could do B, and if we knew B we could do A! So the idea is to start pretending to know A, and then iterate.

Example

nonparametric estimation using Bernstein polynomials.

say we have data (X1, .., Xn) from some continuous but unknown f, and we want to estimate f(x) for any x. One idea to do this is to approximate the function f by a polynomial of some degree d, denoted by pd(x), with the coefficients estimated via maximum likelihood. A big problem when doing this is that polynomials are not natural choices for densities because they easily have negative values, and just finding out where we have pd(x)<0 is a nontrivial problem if d>2. One way around this issue is to use polynomials that are naturally non-negative, and a popular choice are so called Bernstein polynomials: \[ x^k(1-x)^{d-k} \] if \(0<x<1\) and \(k=0,..,d\)

of course these are essentially Beta densities, which leads to another nice feature, namely it is easy to normalize the polynomials so they are proper densities: \[ b(k,d,x) = \frac{(d+1)!}{k!(d-k)!} x^k(1-x)^{d-k} \]

It can be shown that any on [0,1] can be approximated uniformly by a linear combination of Bernstein polynomials, that is for any \(\epsilon>0\) there exists a d and numbers a0,..,ad with a0+..+ad=1 such that \[ \max\left\{|f(x)- \sum_{k=0}^d a_k b(k,d,x)|:0<x<1\right\} < \epsilon \] Bernstein polynomials are defined on [0,1], if the f is positive on the interval [A,B] we need to first use the transform \(y=(x-A)/(B-A)\).

If f is defined on \([A, \infty )\) or \((-\infty , \infty)\), other transforms can be used but we won’t discuss that here.

Let’s set \[ p(x;a_0,..,a_d) = \sum_{k=0}^d a_k b(k,d,x) \] so, how can we find a0, .., ad as well as the smallest d for which this is true? Let’s assume for a moment that d is known, then we can estimated a0, .., ad via maximum likelihood, that is we we need to find \[ \max\left\{ \sum_x \log(p(x;a_0,..,a_d)); 0<a_0,..,a_d<1 \text{ and } a_0+..+a_d=1\right\} \] In calculus we have the method of Lagrange multipliers for this type of constraint maximization, but here (if d>1) this leads to a nonlinear system of equations which can not be solved analytically.

Moreover, this is also a difficult problem numerically, because most standard minimization algorithm (such as Newton-Raphson) do not allow for these types of contraints.

Instead we can use the EM algorithm. Even easier, because the Bernstein polynomials do not have parameters we don’t even need the M step!

The algorithm:

use as start value a = rep(1, d+1)/(d+1)

at each iteration set \[ w_k = \text{mean}( a_k b(k,d,x)/p(x) ) \] \(k=0,..,d\)

and stop when (say) \(\sum |a_k-w_k|<0.001\)

dBernstein <- function(x, a, returnMatrix=FALSE) {
  d <- length(a)-1
  n <- length(x)
  Z <- matrix(0, n, d+1)
  for(i in 0:d) Z[, i+1] <- a[i+1]*dbeta(x, i+1, d+1-i)    
  if(returnMatrix) return(Z)
  apply(Z, 1, sum)
}

fitBernstein <- function(x, d) {
  a <- rep(1, d+1)/(d+1)
  k <- 0
  repeat {
    k <- k+1
    Z <- dBernstein(x, a, returnMatrix=TRUE)
    p <- apply(Z, 1, sum)
    for(i in 0:d) Z[, i+1] <- Z[, i+1]/p
    w <- apply(Z,  2, mean)
    if( sum(abs(a-w))<0.01) break
    a <- w
    if(k>100) break
  }
  a
}

Here is an example:

x <- rbeta(1000, 2, 5)
hist(x, 50, freq=FALSE, main="")
t <- seq(0, 1, length=100)
cols <- c("black", "blue", "red", "green")
for(i in 1:4) {
  a <- fitBernstein(x, d=2*i)
  lines(t, dBernstein(t, a=a), col=cols[i])
}  
legend(0.6, 2.5, legend=paste("d=", 2*1:4), lty=rep(1, 4), col = cols)

How can we find a good degree d? We can use the likelihood ratio test:

say we want to compare the fit of \(p_d\) with that of \(p_{d+1}\). Let \(p_d^*\) be \(p_d\) evaluated at the data x using the respective mle’s as coefficients. Then by the large sample theory of the likelihood ratio test \[ (-2) \left( \sum \log p_d^* - \sum \log p_{d+1}^* \right) \sim \chi^2 (1) \] so we will test \(d=1\) vs \(d=0\). If we reject we test \(d=2\) vs \(d=1\) and so on until we fail to reject the null.

a_0 <- fitBernstein(x, d=0)
p_0star <- dBernstein(x, a_0)
a_1 <- fitBernstein(x, d=1)
p_1star <- dBernstein(x, a_1)
chi2 <- (-2)*(sum(log(p_0star))-sum(log(p_1star)))
crit <- qchisq(0.9, 1)
cat("Critical value=", round(crit, 3), "\n")
## Critical value= 2.706
d <- 1
cat("d =", d-1, "Chisquare Statistic =", round(chi2, 3),"\n") 
## d = 0 Chisquare Statistic = 667.212
repeat {
  d <- d+1
  p_0star <- p_1star
  p_1star <- dBernstein(x, fitBernstein(x, d=d))
  chi2 <- (-2)*(sum(log(p_0star))-sum(log(p_1star)))
  cat("d =", d-1, "Chisquare Statistic =", round(chi2, 3),"\n") 
  if(chi2<crit) break
  if(d>20) break
}
## d = 1 Chisquare Statistic = 101.557 
## d = 2 Chisquare Statistic = 32.389 
## d = 3 Chisquare Statistic = 104.702 
## d = 4 Chisquare Statistic = 48.702 
## d = 5 Chisquare Statistic = 4.206 
## d = 6 Chisquare Statistic = -0.319

There is a problem, though: consider this example:

x <- sort(rbeta(1000, 5, 5))
hist(x, 100, freq=FALSE, main="")
a_0 <- fitBernstein(x, d=0)
p_0star <- dBernstein(x, a_0)
lines(x, p_0star, type="l")
a_1 <- fitBernstein(x, d=1)
p_1star <- dBernstein(x, a_1)
lines(x, p_1star, type="l")

chi2 <- (-2)*(sum(log(p_0star))-sum(log(p_1star)))
crit <- qchisq(0.9, 1)
cat("Critical value=", round(crit, 3), "\n")
## Critical value= 2.706
d <- 1
cat("d =", d-1, "Chisquare Statistic =", round(chi2, 3),"\n") 
## d = 0 Chisquare Statistic = 1.998
repeat {
  d <- d+1
  p_0star <- p_1star
  p_1star <- dBernstein(x, fitBernstein(x, d=d))
  chi2 <- (-2)*(sum(log(p_0star))-sum(log(p_1star)))
  cat("d =", d-1, "Chisquare Statistic =", round(chi2, 3),"\n") 
#  if(chi2<crit) break
  if(d>10) break
}
## d = 1 Chisquare Statistic = 602.356 
## d = 2 Chisquare Statistic = 7.588 
## d = 3 Chisquare Statistic = 235.952 
## d = 4 Chisquare Statistic = 7.35 
## d = 5 Chisquare Statistic = 102.172 
## d = 6 Chisquare Statistic = 2.016 
## d = 7 Chisquare Statistic = 36.984 
## d = 8 Chisquare Statistic = 0.858 
## d = 9 Chisquare Statistic = 5.536 
## d = 10 Chisquare Statistic = 2.296

so the routine would stop already at d=1 although although obviously both fits are very bad. That is actually the problem, both are equally bad! In general in addition to the hypothesis test we should also make a visual check to see that the fit looks reasonably ok. The next time the test rejects the null is for d=6:

hist(x, 100, freq=FALSE, main="")
a <- fitBernstein(x, d=6)
p <- dBernstein(x, a)
lines(x, p, type="l")

and this looks quite alright!