カルマンフィルター(その3)

このページでは、線形カルマンフィルターでは扱えない問題にも適用可能な粒子フィルターについて考えてみる。

観測情報を用いた推定

トロッコ問題において、観測データ$z(t)$を用いて、状態$x(t),v(t)$の推定を行う方法について、改めて吟味してみたい。

ある時点でのトロッコの位置が$z(t)$であった際に、トロッコの内部状態が $x(t), v(t)$ である確率は、ベイズの定理から $$ P(x(t),v(t)|z(t)) = \frac{P(z(t)|x(t),v(t)) P(x(t),v(t))}{P(z(t))} $$ で与えられる。 $P(z(t)|x(t),v(t))$ は、観測モデルから計算でき、トロッコ問題の場合は観測誤差はガウス分布すると仮定していたが、ベイズの定理自体は、どのような確率分布の場合でも成り立つ。 他方、$P(x(t),v(t))$ のほうは「未知」ではあるが、初期状態を与えれば、内部状態モデルによって$\Delta$先が予測できるので、逐次的に計算する方法がとれる。

もう少し丁寧な考察

時刻$t$での状態を確率変数$X_t$で、観測量を$Z_t$で表すことにすると、 観測から内部状態を推定するということは、条件付き確率 $$ P(X_t | Z_t,Z_{t-1},\cdots,Z_0) $$ を見積もることに他ならない。

時刻$t$で我々が利用可能な情報は、$t-1$での状態の推定分布 $$ P(X_{t-1} | Z_{t-1},\cdots,Z_0) $$ および、状態モデルの遷移確率 $$ P(X_t | X_{t-1}) $$ そして、観測モデルに基づいて予測される確率 $$ P(Z_t|X_t) $$ である。 さらに、具体的な観測値$z$が与えられている。

一方で、一般的な関係 $$ P(X|Y,Z) = \frac{P(X,Y|Z)}{P(Y|Z)} $$ から $$ P(X_t | Z_t=z , Z_{t-1},\cdots, Z_0) =\frac{P(Z_t=z,X_t | Z_{t-1},\cdots,Z_0)} {\sum_x P(Z_t=z,X_t=x| Z_{t-1}, \cdots, Z_0)} $$ となるが、分子のところを周辺化すると $$ \begin{eqnarray} P(Z_t=z,X_t | Z_{t-1},\cdots,Z_0) = \sum_{x'} P(Z_t,X_t, X_{t-1}=x'|Z_{t-1},\cdots,Z_0) \\ = \sum_{x'} P(Z_t=z | X_t) P(X_t, X_{t-1}=x'|Z_{t-1},\cdots,Z_0) \\ = \sum_{x'} P(Z_t=z | X_t) P(X_t | X_{t-1}=x') P(X_{t-1}=x' | Z_{t-1},\cdots,Z_0) \end{eqnarray} $$ となるので、規格化定数を除けば $$ P(X_t| Z_t, Z_{t-1},\cdots,Z_0) \propto \sum_{x'} P(Z_t=z| X_t) P(X_t|X_{t-1}=x') P(X_{t-1}=x' | Z_{t-1},\cdots,Z_0) $$ が得られる。

すなわち、$t-1$での状態$x'$についての分布 $P(X_{t-1}=x' | Z_{t-1},\cdots,Z_0)$、 状態モデルによる遷移確率 $P(X_t|X_{t-1}=x')$、 そして、時刻$t$で$z$を観測する確率 $P(Z_t=z| X_t)$ を乗じて規格化することで、 $t$での状態の分布が推定できる。

多点サンプリングによる状態の推定

多数のサンプル点を状態モデルに従って確率的に「動かし」ながら分布の変化を推定する手法は、サンプル点を「粒子」になぞらえ、粒子フィルターと呼ばれている。 粒子フィルターを用いたモンテカルロ法は、状態モデルや観測モデルが非線形な場合や、揺らぎや誤差が非ガウス的である場合にも適用可能であり、効率を別とすれば、汎用性に優れている。

トロッコ問題のケースでは、以下の手順によってサンプリングすることができる。

状態モデルによる更新ステップ

時刻 $t-\Delta$ の状態の推定値を $\hat{x}_i(t-\Delta), \hat{v}_i(t-\Delta)$ としたとき、$t$での推定値は、状態モデルを使って $$ \begin{eqnarray} \hat{x}_i(t) & = & \hat{x}_i(t-\Delta) + \hat{v}(t-\Delta) \Delta \\ \hat{v}_i(t) & = & \hat{v}_i(t-\Delta) + a_i(t) \sqrt{\Delta} \end{eqnarray} $$ で与えられる。 ここで$i$はサンプルの番号で、$a(t)$について仮定している分布に従う乱数 $a_i(t)$ を生成することで、 多数の標本 $\left\{ (\hat{x}_i(t), \hat{v}_i(t)) | i=1,\cdots,N \right\}$ が得られる。 ここではサンプルの総数を$N$としよう。

条件付き確率に応じた重み付け

次に、状態 $\hat{x}_i(t), \hat{v}_i(t)$ が与えられている条件の下で、時刻$t$で観測値 $z(t)$ が得られる条件付き確率 $P(z(t)| \hat{x}_i(t), \hat{v}_i(t) )$ は、観測モデル $$ z(t) = x(t) + b(t) $$ から求めることができる。 すなわち、$b(t)$は平均0で分散が$Q$のガウス分布に従うと仮定しているのだから、 $$ w_i(t) = P(z(t)| \hat{x}_i(t), \hat{v}_i(t) ) = \frac{1}{\sqrt{2 \pi Q}} \exp\left[ - \frac{(z(t) - \hat{x}_i(t))^2}{2Q} \right] $$ である。 つまり、観測値$z(t)$に照らして状態が$(\hat{x}_i(t), \hat{v}_i(t))$であるような「重み」が $w_i(t)$ ということになる。

以上を踏まえると、事後確率分布 $P(x(t),v(t)|z(t))$ に従ったサンプルとして、$(\hat{x}_i(t), \hat{v}_i(t))$ に $w_i(t)$ の重みを考慮すれば良いことがわかり、例えば、時刻$t$での(補正された)状態の平均は $$ \begin{eqnarray} \tilde{x}(t) & = & \frac{\sum_{i=1}^N w_i(t) \, \hat{x}_i(t)}{\sum_{j=1}^N w_j(t)} \\ \tilde{v}(t) & = & \frac{\sum_{i=1}^N w_i(t) \, \hat{v}_i(t)}{\sum_{j=1}^N w_j(t)} \end{eqnarray} $$ で与えられる。

事後確率に応じたサンプル点集合の再調整

最後に、$N$個のサンプル点が$P(x(t),v(t)|z(t))$ の重みに従って分布するように、再調整する。 具体的には、$N$個の点から、重複を許しながら$i$番目の点を$w_i(t)$の重みで選ぶ操作を$N$回繰り返せばよい。 こうして、時刻$t$でのサンプリング点の集合が得られるので、1ステップ時間を進めて、上の手順を繰り返す。

粒子フィルターによる「トロッコ」問題の計算例

トロッコ問題について、$N=1000$のサンプル点を使って状態を推定するコードの例を以下に示す。

# coding: utf-8

import numpy as np
import math
import random
import matplotlib.pyplot as plt

N=1000
dt=0.1

xsamp = np.random.normal(0,1,N)
vsamp = np.random.normal(0,1,N)
wt = np.zeros((N,))

XT=[ ]
ZT=[ ]
XEST=[ ]
T=[ ]

R=1
Q=1

x=0
v=1

t=0
while t<10:
    xsamp = xsamp + vsamp*dt
    vsamp = vsamp + np.random.normal(0,math.sqrt(R), N) * math.sqrt(dt)

    z = x + random.gauss(0,math.sqrt(Q)) 

    wt = np.exp(-np.square(z-xsamp)/(2*Q))
    wsum = np.sum(wt)
    wt = wt/wsum
    
    xest = wt.dot(xsamp)
    vest = wt.dot(vsamp)

    XT.append(x)
    ZT.append(z)
    XEST.append(xest)
    T.append(t)

    xv=[[x,v] for x,v in zip(xsamp,vsamp)]
    new_sample = np.array(random.choices(xv, k=N, weights = wt) )
    xsamp = new_sample[:,0]
    vsamp = new_sample[:,1]
    
    x = x + v * dt
    v = v + random.gauss(0,math.sqrt(R))*math.sqrt(dt)
    t = t + dt

plt.plot(T,XT, color='blue', linewidth=1.0, label='State Model')
plt.plot(T,XEST, 'o', color='red',label='Estimate')
plt.plot(T,ZT, '*', color='green', linewidth=1.0, label='Observation')
plt.xlabel('T')
plt.ylabel('X')
plt.grid(True)
plt.legend()
plt.show()