ソフトマックス関数を用いた混合ガウスモデルの計算

機械学習
信号処理
Python
Published

May 30, 2021

はじめに

多峰性のある分布表現する方法として,ガウス分布の線形結合を用いる混合ガウスモデルが広く知られています. しかしがなら,観測したサンプルから混合ガウスモデルのパラメータを解析的の求めることはできません. そのため,EMアルゴリズム等を用いて数値計算的にパラメータを推定します.

EMアルゴリズムを用いた混合ガウスモデルのパラメータ推定は多くの資料で解説されています. これらの資料を読み解こうと何度か挑戦しましたが,どうもしっくりと来ませんでした. ですが,参考[1]述べられている,線形結合の重みとしてソフトマックス関数を用いる方法がとても理解しやすかったので,自分なりの理解をまとめました.

やったこと

  • 線形結合の重みとしてソフトマックス関数を用いた場合における各パラメータの導関数を計算
  • 勾配法によって混合ガウスモデルのパラメータを推定
  • EMアルゴリズムによって混合ガウスモデルのパラメータを推定

混合ガウスモデルについて

平均を\(\mu\),標準偏差を\(\sigma\)としたとき,ガウス分布\(\phi\left(x; \mu, \sigma\right)\)は以下の様に表されます.

\[ \begin{aligned} \phi\left(x; \mu, \sigma\right) & = \frac{1}{\sqrt{2 \pi \sigma^{2}}} \exp{\left(- \frac{\left(x - \mu\right)^{2}}{2\sigma^{2}}\right)} \end{aligned} \]

この時,ガウス分布の線形結合からなる混合ガウスモデル\(q\left(x; \vec{\theta}\right)\)は以下の様に表すことができます.

\[ \begin{aligned} q\left(x; \vec{\theta}\right) &= \sum_{l=1}^{m} w_{l}\phi\left(x; \mu_{l}, \sigma_{l}\right) \end{aligned} \]

ここで,\(\vec{\theta} = \left(w_{1}, \dots, w_{m},\mu_{1}, \dots, \mu_{m}, \sigma_{1}, \dots, \sigma_{m}\right)\)であり,\(w_{l}\)は線形結合の重みを,\(\mu_{l}\)は各ガウス分布の平均を,\(\sigma_{l}\)は各ガウス分布の標準偏差をそれぞれ表します.

\(q\left(x; \vec{\theta}\right)\) が分布であるためには,\(x\)に対して

\[ \begin{aligned} q\left(x; \vec{\theta}\right) &\geq 0 \\ \int q\left(x; \vec{\theta}\right) &= 1 \end{aligned} \]

である必要があります.従って,\(w_{1}, \dots, w_{m}\)は以下の条件を満たす必要があります.

\[ \begin{aligned} w_{1} \dots w_{m} &\geq 0 \\ \sum_{l=1}^{m} w_{l} &= 1 \end{aligned} \]

次に最尤推定によってパラメータ\(\vec{\theta}\)を推定します.ここでは,尤度関数\(L\left(\vec{\theta}\right)\)を以下の様に定義します.

\[ \begin{aligned} L\left(\vec{\theta}\right) &= \prod _{i=1}^{n} q\left(x_{i}; \vec{\theta}\right) \end{aligned} \]

また,\(w_{l}\)に対する拘束条件を考慮しなくてはなりません.従って,最尤推定量\(\vec{\theta}\)は以下の様に表されます.

\[ \begin{aligned} \hat{\vec{\theta}} &= \argmax_{\vec{\theta}} L\left(\vec{\theta}\right) & \text{subject to } \begin{cases} w_{1} \dots w_{m} &\geq 0\\ \sum_{l=1}^{m} w_{l} &= 1 \end{cases} \end{aligned} \]

拘束条件付きの最適化にはラグランジュの未定乗数法が多く用いられていると思います.今回のケースにおいても同様です.

しかしながら,ラグランジュの未定乗数法を用いる方法はあまりしっくり来きませんでした.(理解できなかった) そこで,ここでは参考[1]にある\(w_{l}\)としてソフトマック関数を用いる方法を採用します.従って\(w_{l}\)\(\gamma_{l}\)を用いて

\[ \begin{aligned} w_{l} &= \frac{\exp{\left(\gamma_{l}\right)}}{\sum_{l' = 1}^{m}\exp{\left(\gamma_{l’}\right)}} \end{aligned} \]

と表現することにします.ソフトマックス関数は上述の拘束条件を満たすため,拘束条件を考えることなく最適化を行うことができます.また,ここでは1次元の混合ガウスモデルについて述べましたが,多次元の場合も同様です.

勾配法によるパラメータの推定

それでは,対数尤度関数\(\log L\left(\vec{\theta}\right)\)の各パラメータに対する導関数を計算していきます.まずは,\(\gamma_{l}\)についてです.

\[ \begin{aligned} \frac{\partial w_{k}}{\partial \gamma_{l}} &= \frac{\partial}{\partial \gamma_{l}} \frac{\exp{\left(\gamma_{k}\right)}}{\sum_{k'=1}^{m}\exp{\left(\gamma_{k'}\right)}} \\ &= \frac{\exp{\left(\gamma_{k}\right)}}{\sum_{k'=1}^{m}\exp{\left(\gamma_{k'}\right)}}\delta_{k,l} - \frac{\exp{\left(\gamma_{k}\right)}\exp{\left(\gamma_{l}\right)}}{\left(\sum_{k'=1}^{m}\exp{\left(\gamma_{k'}\right)}\right)^{2}}\\ &= \frac{\exp{\left(\gamma_{k}\right)}}{\sum_{k'=1}^{m}\exp{\left(\gamma_{k'}\right)}}\left(\delta_{k,l} - \frac{\exp{\left(\gamma_{l}\right)}}{\sum_{k'=1}^{m}\exp{\left(\gamma_{k'}\right)}}\right)\\ &= w_{k}\left(\delta_{k,l} - w_{l}\right)\\ \frac{\partial}{\partial \gamma_{l}} \log L\left(\vec{\theta}\right) &= \frac{\partial}{\partial \gamma_{l}} \sum_{i=1}^{n} \log \sum_{k}^{m} w_{k}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right) \\ &= \sum_{i=1}^{n}\frac{1}{\sum_{k'=1}^{m} w_{k'}\phi\left(x_{i}; \mu_{k'}, \sigma_{k'}\right)} \sum_{k=1}^{m} \phi\left(x_{i}; \mu_{k}, \sigma_{k}\right) \frac{\partial w_{k}}{\partial \gamma_{l}}\\ &= \sum_{i=1}^{n} \frac{\sum_{k=1}^{m} w_{k}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)\delta_{k,l} -w_{l}\sum_{k=1}^{m} w_{k}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)}{\sum_{k'=1}^{m} w_{k'}\phi\left(x_{i}; \mu_{k'}, \sigma_{k'}\right)} \\ &= \sum_{i=1}^{n} \frac{w_{l}\phi\left(x_{i}; \mu_{l}, \sigma_{l}\right)}{\sum_{k'=1}^{m} w_{k'}\phi\left(x_{i}; \mu_{k'}, \sigma_{k'}\right)} - w_{l} \\ \end{aligned} \]

次に,\(\mu_{l}\)についてです.

\[ \begin{aligned} \frac{\partial}{\partial \mu_{l}}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right) &= \frac{1}{\sqrt{2 \pi \sigma_{k}^{2}}}\frac{\partial}{\partial \mu_{l}} \exp{\left(- \frac{\left(x_{i} - \mu_{k}\right)^{2}}{2\sigma_{k}^{2}}\right)} \\ &= \frac{1}{\sqrt{2 \pi \sigma_{k}^{2}}} \exp{\left(- \frac{\left(x_{i} - \mu_{k}\right)^{2}}{2\sigma_{k}^{2}}\right)} \frac{x_{i} - \mu_{k}}{\sigma_{k}^{2}} \delta_{k,l} \\ &= \frac{x_{i} - \mu_{k}}{\sigma_{k}^{2}} \phi\left(x_{i}; \mu_{k}, \sigma_{k}\right) \delta_{k,l}\\ \frac{\partial}{\partial \mu_{l}} \log L\left(\vec{\theta}\right) &= \frac{\partial}{\partial \mu_{l}} \sum_{i=1}^{n} \log \sum_{k}^{m} w_{k}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right) \\ &= \sum_{i=1}^{n}\frac{1}{\sum_{k'=1}^{m} w_{k'}\phi\left(x_{i}; \mu_{k'}, \sigma_{k'}\right)} \sum_{k=1}^{m} w_{k} \frac{\partial}{\partial \mu_{l}} \phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)\\ &= \sum_{i=1}^{n}\frac{1}{\sum_{k'=1}^{m} w_{k'}\phi\left(x_{i}; \mu_{k'}, \sigma_{k'}\right)} \sum_{k=1}^{m} w_{k} \frac{x_{i} - \mu_{k}}{\sigma_{k}^{2}} \phi\left(x_{i}; \mu_{k}, \sigma_{k}\right) \delta_{k,l}\\ &= \sum_{i=1}^{n}\frac{1}{\sum_{k'=1}^{m} w_{k'}\phi\left(x_{i}; \mu_{k'}, \sigma_{k'}\right)} w_{l} \frac{x_{i} - \mu_{l}}{\sigma_{l}^{2}} \phi\left(x_{i}; \mu_{l}, \sigma_{l}\right)\\ &= \frac{1}{\sigma_{l}^{2}} \sum_{i=1}^{n} \left(x_{i} - \mu_{l}\right) \frac{w_{l} \phi\left(x_{i}; \mu_{l}, \sigma_{l}\right) }{\sum_{k'=1}^{m} w_{k'}\phi\left(x_{i}; \mu_{k'}, \sigma_{k'}\right)} \\ \end{aligned} \]

最後に,\(\sigma_{l}\)についてです.

\[ \begin{aligned} \frac{\partial}{\partial \sigma_{l}}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right) &= \frac{1}{\sqrt{2 \pi}} \frac{\partial}{\partial \sigma_{l}} \sigma_{k}^{-1} \exp{\left(- \frac{\left(x_{i} - \mu_{k}\right)^{2}}{2\sigma_{k}^{2}}\right)} \\ &= -\frac{1}{\sqrt{2\pi}\sigma_{k}^{2}} \exp{\left(- \frac{\left(x - \mu_{k}\right)^{2}}{2\sigma_{k}^{2}}\right)} \delta_{k,l} + \frac{\left(x_{i} - \mu_{k}\right)^{2}}{\sqrt{2\pi}\sigma_{k}^{4}} \exp{\left(- \frac{\left(x - \mu_{k}\right)^{2}}{2\sigma_{k}^{2}}\right)} \delta_{k,l}\\ &= \frac{\left(x_{i} - \mu_{k}\right)^{2} - \sigma_{k}^{2}}{\sigma_{k}^{3}} \phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)\delta_{k,l}\\ \frac{\partial}{\partial \sigma_{l}} \log L\left(\vec{\theta}\right) &= \frac{\partial}{\partial \sigma_{l}} \sum_{i=1}^{n} \log \sum_{k}^{m} w_{k}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right) \\ &= \sum_{i=1}^{n}\frac{1}{\sum_{k'=1}^{m} w_{k'}\phi\left(x_{i}; \mu_{k'}, \sigma_{k'}\right)} \sum_{k=1}^{m} w_{k} \frac{\partial}{\partial \sigma_{l}} \phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)\\ &= \sum_{i=1}^{n}\frac{1}{\sum_{k'=1}^{m} w_{k'}\phi\left(x_{i}; \mu_{k'}, \sigma_{k'}\right)} \sum_{k=1}^{m} w_{k} \frac{\left(x_{i} - \mu_{k}\right)^{2} - \sigma_{k}^{2}}{\sigma_{k}^{3}} \phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)\delta_{k,l}\\ &= \frac{1}{\sigma_{l}^{3}}\sum_{i=1}^{n} \left(\left(x_{i} - \mu_{l}\right)^{2} - \sigma_{l}^{2}\right) \frac{w_{l}\phi\left(x_{i}; \mu_{l}, \sigma_{l}\right)}{\sum_{k'=1}^{m} w_{k'}\phi\left(x_{i}; \mu_{k'}, \sigma_{k'}\right)} \end{aligned} \]

ここで,

\[ \begin{aligned} \eta_{i,l} &= \frac{w_{l}\phi\left(x_{i}; \mu_{l}, \sigma_{l}\right)}{\sum_{k'=1}^{m} w_{k'}\phi\left(x_{i}; \mu_{k'}, \sigma_{k'}\right)} \end{aligned} \]

と置きます.すると最終的な導関数は以下の様に表されます.

\[ \begin{aligned} \frac{\partial}{\partial \gamma_{l}} \log L\left(\vec{\theta}\right) &= \sum_{i=1}^{n} \eta_{i,l} - w_{l} \\ \frac{\partial}{\partial \mu_{l}} \log L\left(\vec{\theta}\right) &= \frac{1}{\sigma_{l}^{2}} \sum_{i=1}^{n} \left(x_{i} - \mu_{l}\right) \eta_{i,l}\\ \frac{\partial}{\partial \sigma_{l}} \log L\left(\vec{\theta}\right) &=\frac{1}{\sigma_{l}^{3}}\sum_{i=1}^{n} \left(\left(x_{i} - \mu_{l}\right)^{2} - \sigma_{l}^{2}\right)\eta_{i,l} \end{aligned} \]

次に勾配法を用いてパラメータの推定を行っていきます.具体的な手順は以下の通りです.

  1. \(\hat{\gamma_{l}}\), \(\hat{\mu_{l}}\), \(\hat{\sigma_{l}}\)をランダムな値で初期化
  2. 以下に示す様にパラメータを更新
    \[ \begin{aligned} \hat{\gamma_{l}} &= \hat{\gamma_{l}} + \epsilon \frac{\partial}{\partial \gamma_{l}} \log L\left(\vec{\theta}\right) \\ \hat{\mu_{l}} &= \hat{\mu_{l}} + \epsilon \frac{\partial}{\partial \mu_{l}} \log L\left(\vec{\theta}\right) \\ \hat{\sigma_{l}} &= \hat{\sigma_{l}} + \epsilon \frac{\partial}{\partial \sigma_{l}} \log L\left(\vec{\theta}\right) \\ \end{aligned} \]
  3. 一定回数または収束するまで2.を繰り返す

実際にパラメータ推定に関するコードを記述する前に,推定を行うデータを作成します. 今回,\(\vec{\mu} = \left(3.1, -1.3, 0.5\right)\), \(\vec{\sigma} = \left(0.8, 1.0, 0.5 \right)\), \(\vec{w} = \left(0.25, 0.25, 0.5\right)\)として,2048個の標本をサンプリングすることで検証用のデータを作成しました.

import numpy as np
np.random.seed(99)

mus =  np.array([3.1, -1.3, 0.5])
sigmas =  np.array([0.8, 1.0, 0.5])
ns = np.array([512, 512, 1024])
N = np.sum(ns)
ws = ns / N

data = [np.random.normal(m, s, n) for n, m, s in zip(ns, mus, sigmas)]
all_data = np.hstack(data)

勾配法によるパラメータ推定を行います.train_grad が上述したパラメータ推定の根幹になります. ここでは,\(\epsilon = 0.0001\)とし,パラメータの変化量の絶対値和が一定値以下となる場合を収束と判定しています.

def phi(xs, mus, sigmas):
    num = np.exp(-0.5 * ((xs - mus[:,np.newaxis]) / sigmas[:,np.newaxis]) ** 2.0)
    den = np.sqrt(2 * np.pi) * sigmas[:,np.newaxis]
    return num / den

def eta(xs, ws, mus, sigmas):
    num = ws[:, np.newaxis] * phi(xs, mus, sigmas)
    den = np.sum(num, axis=0)
    return num / den

def softmax(gammas):
    num = np.exp(gammas)
    den = np.sum(num, axis=0)
    return num / den

def grad_gamma(etas, ws):
    return np.sum(etas, axis=1) - etas.shape[1] * ws

def grad_mu(xs, etas, mus, sigmas):
    num = np.sum((xs - mus[:, np.newaxis]) * etas, axis=1)
    den =sigmas ** 2.0
    return num / den

def grad_sigma(xs, etas, mus, sigmas):
    num = np.sum((((xs - mus[:, np.newaxis]) / sigmas[:,np.newaxis]) ** 2.0 - 1.0) * etas, axis=1)
    den = sigmas 
    return num / den

def train_grad(xs):
    mus_hat = 2 * np.max(np.abs(xs)) * (np.random.rand(3) - 0.5)
    sigmas_hat = np.random.rand(3)
    gammas_hat = np.random.rand(3)

    eps = 0.0001
    loss = float('inf')
    criteria = 1e-10
    iter = 0;
    while iter < 30000:
        ws_hat = softmax(gammas_hat)
        etas_hat = eta(xs, ws_hat, mus_hat, sigmas_hat)

        delta_gammas = eps * grad_gamma(etas_hat, ws_hat)
        delta_mus = eps * grad_mu(xs, etas_hat, mus_hat, sigmas_hat)
        delta_sigmas = eps * grad_sigma(xs, etas_hat, mus_hat, sigmas_hat)

        gammas_hat += delta_gammas
        mus_hat += delta_mus
        sigmas_hat += delta_sigmas
        loss = np.sum(np.abs(np.hstack([delta_gammas, delta_mus, delta_sigmas])))
        if loss < criteria:
            break
        iter += 1

    ws_hat = softmax(gammas_hat)
    return loss, iter, (ws_hat, mus_hat, sigmas_hat)

loss, iter, (ws_hat_grad, mus_hat_grad, sigmas_hat_grad) = train_grad(all_data)
print(f"loss: {loss}, iterations: {iter}")
loss: 9.995055674503072e-08, iterations: 16943

最後に推定したパラメータws_hat_grad, mus_hat_grad, sigmas_hat_grad を用いた分布と,訓練データとをプロットして比較します. 勾配法による推定は,パラメータの並びについて曖昧さがあります.これは,\(\hat{w_{l}}\)\(w_{l}\)とが必ずしも対応するとは限らないことを意味します.(他のパラメータも同様)

そこで,プロットに先立ちws_hat_grad, mus_hat_grad, sigmas_hat_gradws, mus, sigmasとを比較して,並び順を入れ変えることでこの問題に対応します. この処理はpermute にて行われます.

import altair as alt
import pandas as pd

def permute(ws, mus, sigmas, ws_hat, mus_hat, sigmas_hat):
    preds = np.vstack([ws_hat, mus_hat, sigmas_hat])
    acts = np.vstack([ws, mus, sigmas])
    corr = preds.T @ acts / (np.linalg.norm(preds, axis=0)[:, np.newaxis] @ np.linalg.norm(acts, axis=0)[np.newaxis,:])
    perm = np.argmax(corr, axis=0)
    ws_hat = ws_hat[perm]
    mus_hat = mus_hat[perm]
    sigmas_hat = sigmas_hat[perm]
    return (ws_hat, mus_hat, sigmas_hat)

def plot(data, ws_hat, mus_hat, sigmas_hat):
    bins = np.linspace(-6, 6, 128)
    hists = {f"Class {i}": (128 / (12 * N)) * np.histogram(d, bins=bins)[0] for i, d in enumerate(data)}

    bin_centers = (bins[:-1] + bins[1:]) / 2
    bars = alt.Chart(pd.DataFrame({
        "Bin": bin_centers,
        **hists,
    })).transform_fold(
        fold=[f"Class {i}" for i in range(len(data))],
        as_=["Class", "Probability"]
    ).mark_bar(opacity=0.5).encode(
        alt.X("Bin:Q"),
        alt.Y("Probability:Q"),
        alt.Color('Class:N')
    )

    envelopes = {f"Class {i}": v for i, v in enumerate(ws_hat[:, np.newaxis] * phi(bin_centers, mus_hat, sigmas_hat))}
    lines = alt.Chart(
    pd.DataFrame({
        "Bin": bin_centers,
        **envelopes,
    })).transform_fold(
        fold=[f"Class {i}" for i in range(len(data))],
        as_=["Class", "Probability"]
    ).mark_line(size=3).encode(
        alt.X("Bin:Q"),
        alt.Y("Probability:Q"),
        alt.Color('Class:N')
    )

    return bars + lines

ws_hat_grad, mus_hat_grad, sigmas_hat_grad = permute(ws, mus, sigmas, ws_hat_grad, mus_hat_grad, sigmas_hat_grad)
plot(data, ws_hat_grad, mus_hat_grad, sigmas_hat_grad)

以上の結果より,適切に推定できていることが確認できました.しかしながら,勾配法による方法は

  • 初期値によっては推定に失敗する場合がある
  • \(\epsilon\)の選択によっては推定に失敗する場合がある
  • 収束が遅い

という問題があります.これらの問題に対応するため,EMアルゴリズムによるパラメータ推定を以下に述べます.

EMアルゴリズムによるパラメータの推定

次にEMアルゴリズムによるパラメータの推定を行います.EMアルゴリズムは混合ガウスモデルのパラメータ推定以外にも利用可能な汎用的なアルゴリズムです.しかしながら,ここではあまり深入りせずに混合ガウスモデルのパラメータ推定方法の一つとして取り扱います.

EMアルゴリズムはEステップとMステップを交互に繰り返すことで実現します.両ステップについての詳細を以下に述べます.

Eステップ

Eステップでは,対数尤度関数\(\log L\left(\vec{\theta}\right)\)に対して,現在の推定値\(\hat{\vec{\theta}}\)で接する下界\(b\left(\vec{\theta}\right)\)を求めます.すなわち,\(b\left(\vec{\theta}\right)\)は以下の関係を満たす必要があります.

\[ \begin{aligned} b\left(\vec{\theta}\right) &\leq \log L\left(\vec{\theta}\right) \\ b\left(\hat{\vec{\theta}}\right) &= \log L\left(\hat{\vec{\theta}}\right) \\ \end{aligned} \]

下界\(b\left(\vec{\theta}\right)\)を求めると聞くと非常に複雑そうに思えます.しかしながら,混合ガウスモデルのパラメータ推定という用途に限定するとそこまで複雑でもありません.まず,対数尤度関数\(\log L\left(\vec{\theta}\right)\)に対してイェンセンの不等式を適用し,下界\(b'\left(\vec{\theta}; \bf{A}\right)\)を求めていきます.ここで,\(\bf{A} = \left\{a_{i,l}\right\}\)であり,\(b\left(\vec{\theta}\right) = b'\left(\vec{\theta}; \hat{\bf{A}}\right)\)であるとします.

\[ \begin{aligned} \log L\left(\vec{\theta}\right) &= \sum_{i=1}^{n} \log \sum_{l}^{m} w_{l}\phi\left(x_{i}; \mu_{l}, \sigma_{l}\right) \\ &= \sum_{i=1}^{n} \log \sum_{l}^{m} a_{i,l}\frac{w_{l}\phi\left(x_{i}; \mu_{l}, \sigma_{l}\right)}{a_{i,l}} \\ &\geq \sum_{i=1}^{n} \sum_{l}^{m} a_{i,l} \log\frac{w_{l}\phi\left(x_{i}; \mu_{l}, \sigma_{l}\right)}{a_{i,l}} \\ &= b'\left(\vec{\theta}; \bf{A}\right) \end{aligned} \]

次に,\(b'\left(\hat{\vec{\theta}}; \hat{\bf{A}}\right) = \log L\left(\hat{\vec{\theta}}\right)\)となる\(\hat{\bf{A}}\)を求めます. \[ \begin{aligned} b'\left(\hat{\vec{\theta}}; \bf{A}\right) &= \sum_{i=1}^{n} \sum_{l}^{m} a_{i,l} \log\frac{\hat{w_{l}}\phi\left(x_{i}; \hat{\mu_{l}}, \hat{\sigma_{l}}\right)}{a_{i,l}} \\ & \text{ここで$\hat{\vec{\theta}}$に対する$\eta_{i,l} $である$\hat{\eta}_{i,l}$を考えます} \\ \hat{\eta}_{i,l} &= \frac{\hat{w}_{l}\phi\left(x_{i}; \hat{\mu}_{l}, \hat{\sigma}_{l}\right)}{\sum_{k'=1}^{m} \hat{w}_{k'}\phi\left(x_{i}; \hat{\mu}_{k'}, \hat{\sigma}_{k'}\right)} \\ & \text{そして$a_{i,l} = \hat{\eta}_{i,l}$とすると} \\ b'\left(\hat{\vec{\theta}}; \bf{A}\right) &= \sum_{i=1}^{n} \sum_{l}^{m} \hat{\eta}_{i,l} \log\frac{\hat{w_{l}}\phi\left(x_{i}; \hat{\mu_{l}}, \hat{\sigma_{l}}\right)}{\hat{\eta}_{i,l}} \\ &= \sum_{i=1}^{n} \sum_{l}^{m} \hat{\eta}_{i,l} \log \sum_{l'=1}^{m} \hat{w}_{l'}\phi\left(x_{i}; \hat{\mu}_{l'}, \hat{\sigma}_{l'}\right) \\ &= \left(\sum_{l}^{m} \hat{\eta}_{i,l}\right)\sum_{i=1}^{n} \log \sum_{l'=1}^{m} \hat{w}_{l'}\phi\left(x_{i}; \hat{\mu}_{l'}, \hat{\sigma}_{l'}\right) \\ &= \sum_{i=1}^{n} \log \sum_{l'=1}^{m} \hat{w}_{l'}\phi\left(x_{i}; \hat{\mu}_{l'}, \hat{\sigma}_{l'}\right) \\ &= \log L\left(\vec{\theta}\right) \end{aligned} \]

従って,\(\hat{\bf{A}} = \left\{\hat{\eta}_{i,l}\right\}\)とすると,

\[ \begin{aligned} b\left(\vec{\theta}\right) &= b'\left(\vec{\theta}; \hat{\bf{A}}\right)\\ &= \sum_{i=1}^{n} \sum_{l}^{m} \hat{\eta}_{i,l} \log\frac{w_{l}\phi\left(x_{i}; \mu_{l}, \sigma_{l}\right)}{\hat{\eta}_{i,l}} \\ \end{aligned} \]

以上の様に,下界\(b\left(\vec{\theta}\right)\)を求めることは,現在の推定値\(\hat{\vec{\theta}}\)を用いて\(\hat{\eta}_{i,l}\)を求めることで実現します.

Mステップ

Mステップでは,\(b\left(\vec{\theta}\right)\)が最大となる\(\hat{\theta}'\)を求めます.\(\hat{\theta}'\)\(b\left(\vec{\theta}\right)\)の各パラメータに対する導関数が\(0\)となる\(\theta\)を採用します.まずは, \(\gamma_{l}\)についてです

\[ \begin{aligned} \frac{\partial}{\partial \gamma_{l}} b\left(\vec{\theta}\right) &= \frac{\partial}{\partial \gamma_{l}} \sum_{i=1}^{n} \sum_{k}^{m} \hat{\eta}_{i,k} \log\frac{w_{k}\phi\left(x_{i}; \mu_{k}, {\sigma_{k}}\right)}{\hat{\eta}_{i,k}} \\ &= \sum_{i=1}^{n} \sum_{k}^{m} \hat{\eta}_{i,k} \frac{\hat{\eta}_{i,k}}{w_{k}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)} \frac{\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)}{\gamma_{l}} \frac{\partial w_{k}}{\partial \gamma_{l}} \\ &= \sum_{i=1}^{n} \sum_{k}^{m} \frac{\hat{\eta}_{i,k}}{w_{k}} w_{k}\left(\delta_{k,l} - w_{l}\right) \\ &= \sum_{i=1}^{n} \sum_{k}^{m} \hat{\eta}_{i,k}\delta_{k,l} - w_{l}\hat{\eta}_{i,k} \\ &= \sum_{i=1}^{n} \hat{\eta}_{i,l} - w_{l}\left(\sum_{k}^{m} \hat{\eta}_{i,k}\right) \\ &= \sum_{i=1}^{n} \hat{\eta}_{i,l} - w_{l} \\ & \text{$\frac{\partial}{\partial \gamma_{l}} b\left(\hat{\vec{\theta}}\right) = 0$より}\\ \sum_{i=1}^{n} \hat{w}_{l} &= \sum_{i=1}^{n} \hat{\eta}_{i,l} \\ n\hat{w}_{l} &= \sum_{i=1}^{n} \hat{\eta}_{i,l} \\ \hat{w}_{l} &= \frac{1}{n} \sum_{i=1}^{n} \hat{\eta}_{i,l} \end{aligned} \]

次に,\(\mu_{l}\)についてです.

\[ \begin{aligned} \frac{\partial}{\partial \mu_{l}} b\left(\vec{\theta}\right) &= \frac{\partial}{\partial \mu_{l}} \sum_{i=1}^{n} \sum_{k}^{m} \hat{\eta}_{i,k} \log\frac{w_{k}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)}{\hat{\eta}_{i,k}} \\ &= \sum_{i=1}^{n} \sum_{k}^{m} \hat{\eta}_{i,k} \frac{\hat{\eta}_{i,k}}{w_{k}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)} \frac{ w_{k}}{\hat{\eta}_{i,k} } \frac{\partial }{\partial \mu_{l}} \phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)\\ &= \sum_{i=1}^{n} \sum_{k}^{m} \frac{\hat{\eta}_{i,k}}{\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)} \frac{x_{i} - \mu_{k}}{\sigma_{k}^{2}} \phi\left(x_{i}; \mu_{k}, \sigma_{k}\right) \delta_{k,l}\\ &= \frac{1}{\sigma_{l}^{2}} \sum_{i=1}^{n} \hat{\eta}_{i,l}x_{i} - \hat{\eta}_{i,l}\mu_{l}\\ & \text{$\frac{\partial}{\partial \mu_{l}} b\left(\hat{\vec{\theta}}\right) = 0$より}\\ \hat{\mu}_{l} \sum_{i=1}^{n}\hat{\eta}_{i,l}&= \sum_{i=1}^{n}\hat{\eta}_{i,l}x_{i} \\ \hat{\mu}_{l} &= \frac{\sum_{i=1}^{n}\hat{\eta}_{i,l}x_{i} }{\sum_{i=1}^{n}\hat{\eta}_{i,l}}\\ \end{aligned} \]

最後に,\(\sigma_{l}\)についてです.

\[ \begin{aligned} \frac{\partial}{\partial \sigma_{l}} b\left(\vec{\theta}\right) &= \frac{\partial}{\partial \sigma_{l}} \sum_{i=1}^{n} \sum_{k}^{m} \hat{\eta}_{i,k} \log\frac{w_{k}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)}{\hat{\eta}_{i,k}} \\ &= \sum_{i=1}^{n} \sum_{k}^{m} \hat{\eta}_{i,k} \frac{\hat{\eta}_{i,k}}{w_{k}\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)} \frac{ w_{k}}{\hat{\eta}_{i,k} } \frac{\partial }{\partial \sigma_{l}} \phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)\\ &= \sum_{i=1}^{n} \sum_{k}^{m} \frac{\hat{\eta}_{i,k}}{\phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)}\frac{\left(x_{i} - \mu_{k}\right)^{2} - \sigma_{k}^{2}}{\sigma_{k}^{3}} \phi\left(x_{i}; \mu_{k}, \sigma_{k}\right)\delta_{k,l}\\ &= \frac{1}{\sigma_{l}^{3}} \sum_{i=1}^{n} \hat{\eta}_{i,l}\left(x_{i} - \mu_{l}\right)^{2} - \hat{\eta}_{i,l}\sigma_{l}^{2} \\ & \text{$\frac{\partial}{\partial \mu_{l}} b\left(\hat{\vec{\theta}}\right) = 0$より}\\ \hat{\sigma}_{l}^{2} \sum_{i=1}^{n}\hat{\eta}_{i,l}&= \sum_{i=1}^{n} \hat{\eta}_{i,l}\left(x_{i} - \mu_{l}\right)^{2} \\ \hat{\sigma}_{l} &= \sqrt{\frac{\sum_{i=1}^{n} \hat{\eta}_{i,l}\left(x_{i} - \mu_{l}\right)^{2}}{\sum_{i=1}^{n}\hat{\eta}_{i,l}}}\\ \end{aligned} \]

まとめると,Mステップでは以下の様にパラメータを更新します.

\[ \begin{aligned} \hat{w}_{l} &= \frac{1}{n} \sum_{i=1}^{n} \hat{\eta}_{i,l}\\ \hat{\mu}_{l} &= \frac{\sum_{i=1}^{n}\hat{\eta}_{i,l}x_{i} }{\sum_{i=1}^{n}\hat{\eta}_{i,l}}\\ \hat{\sigma}_{l} &= \sqrt{\frac{\sum_{i=1}^{n} \hat{\eta}_{i,l}\left(x_{i} - \mu_{l}\right)^{2}}{\sum_{i=1}^{n}\hat{\eta}_{i,l}}}\\ \end{aligned} \]

次にEMアルゴリズムを用いてパラメータの推定を行っていきます.具体的な手順は以下の通りです.

  1. \(\hat{\gamma_{l}}\), \(\hat{\mu_{l}}\), \(\hat{\sigma_{l}}\)をランダムな値で初期化
  2. \(\hat{\eta_{i,l}}\)を更新し,下界\(b\left(\vec{\theta}\right)\)を求める(Eステップ) \[ \begin{aligned} \hat{\eta_{i,l}} &= \frac{\hat{w}_{l}\phi\left(x_{i}; \hat{\mu}_{l}, \hat{\sigma}_{l}\right)}{\sum_{k'=1}^{m} \hat{w}_{k'}\phi\left(x_{i}; \hat{\mu}_{k'}, \hat{\sigma}_{k'}\right)} \end{aligned} \]
  3. 下界\(b\left(\vec{\theta}\right)\)を最大化する\(\hat{\vec{\theta}}'\)を求める(Mステップ) \[ \begin{aligned} \hat{w}_{l} &= \frac{1}{n} \sum_{i=1}^{n} \hat{\eta}_{i,l}\\ \hat{\mu}_{l} &= \frac{\sum_{i=1}^{n}\hat{\eta}_{i,l}x_{i} }{\sum_{i=1}^{n}\hat{\eta}_{i,l}}\\ \hat{\sigma}_{l} &= \sqrt{\frac{\sum_{i=1}^{n} \hat{\eta}_{i,l}\left(x_{i} - \mu_{l}\right)^{2}}{\sum_{i=1}^{n}\hat{\eta}_{i,l}}}\\ \end{aligned} \]
  4. 一定回数または収束するまで2〜3を繰り返す

EMアルゴリズムによるパラメータ推定を行うコードは以下の通りです.勾配法の時と同様にtrain_em が上述したパラメータ推定の根幹になります. そして,全パラメータの変化量の絶対値和が一定値以下となる場合を収束と判定しています.また,入力については勾配法で用いたものを再度使用します.

def opt_w(etas):
    return np.average(etas, axis=1)

def opt_mu(xs, etas):
    num = np.sum(etas * xs, axis=1)
    den = np.sum(etas, axis=1)
    return num / den

def opt_sigma(xs, etas, mus):
    num = np.sum((xs - mus[:, np.newaxis]) ** 2 * etas, axis=1)
    den = np.sum(etas, axis=1)
    return np.sqrt(num / den)

def train_em(xs):
    mus_hat = 2 * np.max(np.abs(xs)) * (np.random.rand(3) - 0.5)
    sigmas_hat = np.random.rand(3)
    ws_hat = np.random.rand(3)

    loss = float('inf')
    criteria = 1e-10
    iter = 0;
    while iter < 10000:
        # E step
        etas_hat = eta(xs, ws_hat, mus_hat, sigmas_hat)

        # M Step
        delta_ws = ws_hat
        delta_mus = mus_hat
        delta_sigmas = sigmas_hat
        ws_hat = opt_w(etas_hat)
        mus_hat = opt_mu(xs, etas_hat)
        sigmas_hat = opt_sigma(xs, etas_hat, mus_hat)
        delta_ws -= ws_hat
        delta_mus -= mus_hat
        delta_sigmas -= sigmas_hat

        loss = np.sum(np.abs(np.hstack([delta_ws, delta_mus, delta_sigmas])))
        if loss < criteria:
            break
        iter += 1

    return loss, iter, (ws_hat, mus_hat, sigmas_hat)

loss, iter, (ws_hat_em, mus_hat_em, sigmas_hat_em) = train_em(all_data)
print(f"loss: {loss}, iterations: {iter}")
print(ws_hat_em, mus_hat_em, sigmas_hat_em)
loss: 9.965148151103165e-11, iterations: 1254
[0.47878854 0.27353509 0.24767637] [ 0.51716133 -1.10900049  3.16175044] [0.51084106 1.06776561 0.76372732]

勾配法と比較して,イテレーション回数が10分の1程度であることが確認できます. 勾配法の時と同様に訓練データと推定したパラメータによる分布とをプロットして確認します.

ws_hat_em, mus_hat_em, sigmas_hat_em = permute(ws, mus, sigmas, ws_hat_em, mus_hat_em, sigmas_hat_em)
plot(data, ws_hat_em, mus_hat_em, sigmas_hat_em)

以上の結果より,適切に推定できていることが確認できました.EMアルゴリズムの勾配法に対する利点としては,

  • \(\epsilon\)を設定する必要がない
  • 勾配法よりも高速に収束する

という点があげられます.しかしながら,

  • 初期値によっては推定に失敗する場合がある

という問題は継続して存在します.したがって,実用を考えると複数の初期値でパラメータを推定するなどの工夫が必要になるかもしれません.

参考文献