ml_recsys 부스트캠프 AITech recsys 통계 EM Algorithm MLE

EM Algorithm

Kwangjin Park

Aug 30, 2024 · 3 min read

Follow

1. EM Algorithm

1.1 개요

  • MLE를 구하는 기법
  • 확률분포의 Latent Variable을 추정하여, 이를 기반으로 모델 파라미터를 반복적으로 개선하는 알고리즘
  • 두 단계를 번갈아 반복적으로 수행하며 모델 파라미터를 개선해 나감
    1. E-Algorithm
      • 변수의 정보를 업데이트
      • 기댓값에 이전 파라미터를 이용하여 업데이트
    2. M-Algorithm
      • 변수들의 분포 형태에 대한 가설을 업데이트
      • 업데이트된 기댓값으로 모델 파라미터를 최적화
  • 역시나, posterior를 구하기 어렵기 때문에 사용되는 방법 중 하나

1.2 유도과정

  • Mixture of Bernoulli Distribution
    • N개의 Binary Data points($x_1, x_2, \ldots ,x_N$)
    • K개의 Bernoulli Distribution
    • Likelihood
      • $p(x_n \vert q_k) = q_k^{x_n}(1-q_k)^{1-x_n}$
      • “$K$개의 데이터에서 $N$개의 Data 만듦”
    • 전체 데이터에 대한 모델링
      • $p(x) = \prod_{n=1}^N \sum_{k=1}^K \pi_k p (x_n \vert q_k)$
        • $\pi_k : k$번째 클러스터의 비중
        • $q_k$ : 확률값
          • 추정해야 하는 파라미터
          • = 확률분포의 파라미터(여기서는 Bernoulli 분포의 파라미터)
  • MLE를 구하는 것이 목표(늘 그래왔듯)
    • Log-likelihood
      • $\log p(x) = \sum_{n=1}^N \log \sum_{k=1}^K \pi_k p (x_n \vert q_k)$
    • Log-likelihood 최대화(MLE)를 위한 방법
      • 추정해야하는 파라미터에 대한 편미분 값이 0인 값 구하기
      • $q_k, \pi_k$ 에 대한 편미분 값이 0이 되는 $q, k$값 구하기

        \[\frac{\partial \log p(x)}{\partial q_k} = 0\] \[\frac{\partial \log p(x)}{\partial \pi_k} = 0\]
  • 유도과정 상세
    1. $\frac{\partial \log p(x)}{\partial q_k} = 0$ : $q_k$에 대한 편미분값
      • 편미분 해보면

        \[\frac{\partial \log p(x)}{\partial q_k} = \sum_{n=1}^{N} \frac{\pi_k}{\sum_{j=1}^{K} \pi_j p(x_n \vert q_j)}\cdot\frac{\partial p(x_n \vert q_k)}{\partial q_k}= 0\]
      • $p(x_n \vert q_k)$에 대한 편미분 값도 구해보면

        \[\frac{\partial p(x_n \vert q_k)}{\partial q_k} = (\frac{x_n}{q_k}- \frac{1-x_n}{1-q_k})p(x_n \vert q_k) = f(q_k;x_n)p(x_n \vert q_k)\]
      • $\frac{\partial p(x_n \vert q_k)}{\partial q_k}$ 결과를 $\frac{\partial \log p(x)}{\partial q_k}$에 대입하여 정리하면

        \[\sum_{n=1}^N \frac {\pi_k p(x_n \vert q_k)}{\sum_{j=1}^{K}\pi_j p(x_n \vert q_j)} \cdot f(q_k;x_n)=0\]
      • Posterior 정의
        • $p(z_{nk}=1 \vert x_n, \theta)$ → n번째 데이터가 주어졌을 때, 그것이 k번째 클러스터에서 나왔을 확률
          • $z_{nk}$ : 우리가 모르는 데이터
          • $x_n, \theta$ : 이미 관측되어, 우리가 알고 있는 데이터들
      • Posterior 정의로 식을 유도해보면

        \[\begin{align*} p(z_{nk} = 1 \vert x_n, \theta) &= \frac{p(z_{nk} = 1)\, p(x_n \vert z_{nk} = 1)}{\sum_{j=1}^{K} p(z_{nj} = 1)\, p(x_n \vert z_{nj} = 1)} \\ &= \frac{\pi_k\, p(x_n \vert q_k)}{\sum_{j=1}^{K} \pi_j\, p(x_n \vert q_j)} \end{align*}\]
      • 유도한 식을 보면, 여전이 q에 관한 복잡한 식으로 구성되어있음 → 편의 상 상수로 대체
        • $z_{nk}$가 binary이기 때문에 posterior를 기댓값($E[z_{nk}]$)으로 대체 가능
        \[\sum_{n=1}^N \frac {\pi_k p(x_n \vert q_k)}{\sum_{j=1}^{K}\pi_j p(x_n \vert q_j)} \cdot f(q_k;x_n)=0\] \[\sum_{n=1}^N \gamma(z_{nk}) \cdot f(q_k;x_n)=0\]
    2. $\frac{\partial \log p(x)}{\partial \pi_k} = 0$ : $\pi_k$에 대한 편미분값
      • $\pi_k$는 분포 별 비중값임 → 합해서 1이 되어야 한다는 제약조건이 필요 → Lagrangian Multiplier를 활용하여 제약조건 해결

        \[\begin{align*} \log p(x) &= \sum_{n=1}^{N} \log \sum_{k=1}^{K} \pi_k p(x_n \vert q_k) \\ L &= \sum_{n=1}^{N} \log \sum_{k=1}^{K} \pi_k p(x_n \vert q_k) + \lambda \left( \sum_{k=1}^{K} \pi_k - 1 \right) \end{align*}\]
      • 편미분 하면

        \[\frac{\partial{L}}{\partial \pi_k} = \sum_{n=1}^N \frac{p(x_n \vert q_k)}{\sum_{j=1}^K \pi_j p(x_n \vert q_j)} + \lambda = 0\]
      • 양변에 $\pi_k \cdot \sum_{k=1}^{K}$곱해주면

        \[\sum_{n=1}^{N} \frac{\sum_{k=1}^{K} \pi_k p(x_n \vert q_k)}{ \sum_{j=1}^{K} \pi_j p(x_n \vert q_j)}+ \sum_{k=1}^{K} \pi_k \lambda = 0\]
        • 첫번째 항에서, 분자 및 분모는 소거 가능(변수명만 다르고 같은 수식)
        • $\sum_{k=1}^{K}\pi_k = 1$ 이므로($\because$ 확률분포들의 비중의 합은 1)

          \[\sum_{n=1}^N \frac{1}{1} + 1 \cdot \lambda = 0\] \[N+\lambda=0 \rightarrow \lambda=-N\]
      • 이를 $\sum_{n=1}^{N} \frac{ \sum_{k=1}^{K} \pi_k p(x_n \vert q_k) }{ \sum_{j=1}^{K} \pi_j p(x_n \vert q_j) } + \sum_{k=1}^{K} \pi_k \lambda = 0$ 에 넣어서 정리해 주면

        \[\sum_{n=1}^N \frac{\pi_kp(x_n \vert q_k)}{\sum_{j=1}^{K}}=\pi_k N\]
      • posterior는 상수로 대체하면

        \[\sum_{n=1}^{N} \gamma(z_{nk})=\pi_k N\]
      • 이를 정리하면

        \[\pi_k = \frac{\sum_{n=1}^N \gamma(z_{nk})}{N}\]

      ⇒ 위 식의 $\gamma (z_{nk})$가 $\pi_{k}$를 포함하고 있는 식이기에 $\pi_k$를 정확히 구할 수는 없으나, $\gamma(z_{nk})$를 이미 알려져 있는 값으로 생각

1.3 EM-Algorithm 정리

  • 개념
    • E-step, M-step 지속적으로 반복하며 업데이트
      • M-step : $q, \pi_k$에 대한 미분값 찾기 (Step 1, 2)
      • E-step : M-step에서 찾은 값을 posterior에 넣어 $\gamma(z_{nk})$ 업데이트 (Step 3)

  • 상세
    1. For q ($\gamma(z_{nk})$는 상수) : $\sum_{n=1}^{N} \gamma(z_{nk}) \cdot f(q_k;x_n) = 0$ → 정리하면 q값 나옴
    2. For $\pi$ (($\gamma(z_{nk})$는 상수)) : $\pi_k = \frac{\sum_{n=1}^{N}\gamma(z_{nk})}{N}$ → p값 나옴
    3. 1, 2번에서 구한 $q, \pi_k$를 $\gamma(z_{nk})$(posterior)에 넣어 posterior 업데이트 → $q, \pi_k$ 값 넣어서 $\gamma(z_{nk})$ 업데이트

      \[\begin{align*} \gamma(z_{nk}) = p(z_{nk} = 1 \vert x_n, \theta) &= \frac{\pi_k\, p(x_n \vert q_k)}{\sum_{j=1}^{K} \pi_j\, p(x_n \vert q_j)} \end{align*}\]
chat_bubble 0

chat_bubble 댓글남기기

댓글남기기