逆伝播のデモ
$$ \definecolor{input}{RGB}{66, 133, 244} \definecolor{output}{RGB}{219, 68, 55} \definecolor{dinput}{RGB}{244, 180, 0} \definecolor{doutput}{RGB}{15, 157, 88} \definecolor{dweight}{RGB}{102, 0, 255} $$

逆伝播アルゴリズム

バックプロパゲーション アルゴリズムは、大規模なニューラル ネットワークを迅速にトレーニングするために不可欠です。この記事では、アルゴリズムの仕組みについて説明します。

下にスクロールしてください...

シンプルなニューラル ネットワーク

右側には、入力が 1 つ、出力ノードが 1 つ、2 つのノードの非表示レイヤ 2 つを備えたニューラル ネットワークが表示されています。

隣接するレイヤのノードは、ネットワーク パラメータであるウェイト \(w_{ij}\)で接続されます。

活性化関数

各ノードには、合計入力 \(\color{input}x\)、アクティベーション関数 \(f(\color{input}x\color{black})\)、出力 \(\color{output}y\color{black}=f(\color{input}x\color{black})\)があります。 \(f(\color{input}x\color{black})\) は非線形関数でなければなりません。そうでなければ、ニューラル ネットワークは線形モデルのみを学習できます。

よく使用されるアクティベーション関数は Sigmoid 関数です。 \(f(\color{input}x\color{black}) = \frac{1}{1+e^{-\color{input}x}}\)

誤差関数

目標は、 \(\color{output}y_{output}\) すべての入力 \(\color{input}x_{input}\)で予測出力が目標に近づくように、データからネットワークの重みを自動的に学習することです。 \(\color{output}y_{output}\)

目標からどのくらい離れているかを測定するには、エラー関数を使用します \(E\)。よく使用されるエラー関数は \(E(\color{output}y_{output}\color{black},\color{output}y_{target}\color{black}) = \frac{1}{2}(\color{output}y_{output}\color{black} - \color{output}y_{target}\color{black})^2 \)です。

転送の伝播

まず、入力例を受け取り、 \((\color{input}x_{input}\color{black},\color{output}y_{target}\color{black})\) ネットワークの入力レイヤを更新します。

一貫性を保つため、入力は他のノードと同様のものですが、活性化関数がないものと仮定して、出力が入力と等しくなるようにします(例: \( \color{output}y_1 \color{black} = \color{input} x_{input} \))。

転送の伝播

ここでは、最初の隠しレイヤを更新します。前のレイヤのノードの出力を取得し、 \(\color{output}y\) 重みを使用して次のレイヤのノードの入力を計算します。 \(\color{input}x\)
$$ \color{input} x_j \color{black} = $$$$ \sum_{i\in in(j)} w_{ij}\color{output} y_i\color{black} +b_j$$

転送の伝播

次に、最初の隠れ層のノードの出力を更新します。これには、アクティベーション関数 \( f(x) \)を使用します。
$$ \color{output} y \color{black} = f(\color{input} x \color{black})$$

転送の伝播

この 2 つの数式を使用して、ネットワークの残りの部分に伝播し、ネットワークの最終出力を取得します。
$$ \color{output} y \color{black} = f(\color{input} x \color{black})$$
$$ \color{input} x_j \color{black} = $$$$ \sum_{i\in in(j)} w_{ij}\color{output} y_i \color{black} + b_j$$

エラーのデリバティブ

逆伝播アルゴリズムは、予測出力を特定の例の目的の出力と比較した後、ネットワークの各重みを更新する量を決定します。このために、各重みに関してエラーがどのように変化するかを計算する必要があります \(\color{dweight}\frac{dE}{dw_{ij}}\)。
エラーの導関数を取得したら、シンプルな更新ルールを使用して重みを更新できます。
$$w_{ij} = w_{ij} - \alpha \color{dweight}\frac{dE}{dw_{ij}}$$
ここで、 \(\alpha\) は正の定数です。これは学習率と呼ばれ、経験的に微調整する必要があります。

[注] 更新ルールはきわめてシンプルです。重みが増すとエラーが発生する(\(\color{dweight}\frac{dE}{dw_{ij}}\color{black} < 0\))場合は重みを上げ、重みを上げるとエラーが発生する(\(\color{dweight}\frac{dE}{dw_{ij}} \color{black} > 0\))場合は重みを減らします。

その他のデリバティブ

\(\color{dweight}\frac{dE}{dw_{ij}}\)を計算しやすくするため、ノードごとにさらに 2 つのデリバティブ、つまりエラーによる変化を次のように保存します。
  • ノードの合計入力 \(\color{dinput}\frac{dE}{dx}\)
  • ノードの出力 \(\color{doutput}\frac{dE}{dy}\)。

逆伝播

エラーのデリバティブの逆伝播を開始します。この特定の入力例の予測出力があるため、その出力によってエラーがどのように変化するかを計算できます。エラー関数から、 \(E = \frac{1}{2}(\color{output}y_{output}\color{black} - \color{output}y_{target}\color{black})^2\) 次のようになります。
$$ \color{doutput} \frac{\partial E}{\partial y_{output}} \color{black} = \color{output} y_{output} \color{black} - \color{output} y_{target}$$

逆伝播

これで、チェーンルールを \(\color{doutput} \frac{dE}{dy}\) 使用して \(\color{dinput}\frac{dE}{dx}\) 取得できます。
$$\color{dinput} \frac{\partial E}{\partial x} \color{black} = \frac{dy}{dx}\color{doutput}\frac{\partial E}{\partial y} \color{black} = \frac{d}{dx}f(\color{input}x\color{black})\color{doutput}\frac{\partial E}{\partial y}$$
\(\frac{d}{dx}f(\color{input}x\color{black}) = f(\color{input}x\color{black})(1 - f(\color{input}x\color{black}))\) の場合、 \(f(\color{input}x\color{black})\) が Sigmoid アクティベーション関数です。

逆伝播

ノードの合計入力に関する誤差導関数を取得したらすぐに、そのノードに入る重みに関するエラー導関数を取得できます。
$$\color{dweight} \frac{\partial E}{\partial w_{ij}} \color{black} = \frac{\partial x_j}{\partial w_{ij}} \color{dinput}\frac{\partial E}{\partial x_j} \color{black} = \color{output}y_i \color{dinput} \frac{\partial E}{\partial x_j}$$

逆伝播

チェーンルールを使うと、前のレイヤからも \(\frac{dE}{dy}\) 取得できます。丸で囲むように作成しました。
$$ \color{doutput} \frac{\partial E}{\partial y_i} \color{black} = \sum_{j\in out(i)} \frac{\partial x_j}{\partial y_i} \color{dinput} \frac{\partial E}{\partial x_j} \color{black} = \sum_{j\in out(i)} w_{ij} \color{dinput} \frac{\partial E}{\partial x_j}$$

逆伝播

あとは、すべてのエラーの導関数を計算するまで、前の 3 つの数式を繰り返すだけです。

おしまい。

コンピューティング...