損失関数とは
損失関数は一言で言うと「目標」と「実際」の結果の差を表したもの。
損失関数は一般的には$\mathcal L$で表される。
損失関数の種類にはいろいろあり平均二乗誤差(MSE:Mean Squared Error)、平均絶対誤差(MAE)、交差クロスエントロピー誤差 (Cross-entropy Loss)などがある。
では損失関数を用いてどのように学習するのかを見ていく
損失関数を用いたモデルの学習
モデルの学習は以下のように表される
$$ \mathcal w_{t+1} = \mathcal w_t - \alpha \frac{\partial L}{\partial \mathcal w} $$
この式の意味するところは、損失関数を小さくする方向に重み$w$を学習するということである。
交差クロスエントロピー
ここで交差クロスエントロピーの定義と学習する方法について具体的に見ていく
交差クロスエントロピーは以下の式で定義される
$$ \mathcal L(p, q) = - \Sigma_i p(i) logq(i) $$
- pは真の確率分布、qは推定した確率分布
- pは正解データの確率分布、qはモデルが出力した確率分布
例1
仮に真の確率分布p=(1,0,0)、推定した確率分布q=(1,0,0)の場合を考える。
この時交差クロスエントロピーは $$ \begin{aligned} \mathcal L(p, q) &= - \Sigma_i p(i) logq(i) \ &= - 1 \cdot log 1 - 0 \cdot log 0 - 0 \cdot log 0 \ &= 0 \end{aligned} $$
このように真の確率分布と推定した確率分布が同じ場合は損失関数は0に収束する
この値を下記の式を用いて考える。
$$ \mathcal w_{t+1} = \mathcal w_t - \alpha \frac{\partial L}{\partial \mathcal w} $$
上記の損失関数が0に収束する場合は上記の$\alpha \frac{\partial L}{\partial \mathcal w}$は0になるため、パラーメーターを更新する必要はない
例2
仮に真の確率分布p=(1,0,0)、推定した確率分布q=(0.5,0,5)の場合を考える。
この時交差クロスエントロピーは $$ \begin{aligned} \mathcal L(p, q) &= - \Sigma_i p(i) logq(i) \ &= - 1 \cdot log 0.5 - 0 \cdot log 0.5 - 0 \cdot log 0 \ &= 0.693147… \end{aligned} $$
このように真の確率分布と推定した確率分布が異なる場合は0.69..となる。
この値を下記の式を用いて考える。
$$ \mathcal w_{t+1} = \mathcal w_t - \alpha \frac{\partial \mathcal L}{\partial \mathcal w} $$
q(x)は、より詳しく見ると入力xと重みw,誤差bに依存する。
具体的には、$q(x)=w\cdot x+b$となる
$$ \frac{\partial \mathcal L}{\partial w} = \frac{\partial (- \Sigma_i p(i) logq(i) )}{\partial w}= \frac{\partial (- \Sigma_i p(i) logq(i) )}{\partial w} $$