跳至主要內容
Fredrick's Lab

【李宏毅老師2021系列】類神經網路訓練不起來怎麼辦(二):Momentum

medium

【李宏毅老師2021系列】類神經網路訓練不起來怎麼辦(二):Momentum

這個系列是在觀看李宏毅老師 2021系列的筆記,希望能用更濃縮的方式將內容整理下來。

延續上篇【李宏毅老師2021系列】類神經網路訓練不起來怎麼辦(一):Local Minima and Saddle Point 裡面所介紹的 saddle point,這一堂課會介紹的 Momentum 是一種結合 Gradient Descent 的更新參數的方式,也是可以逃離 saddle point 的方式。

此篇筆記來自課程:【機器學習2021】類神經網路訓練不起來怎麼辦 (二): 批次 (batch) 與動量 (momentum) — YouTube,Momentum 是在這個影片的最後才介紹的,但為了接續上一篇課程的 saddle point,我拆成兩篇文章來紀錄。

同樣因為 notation 的關係,如果想看比較美觀的版本也可以直接看 notion 的筆記版本:https://quixotic-revolve-92a.notion.site/Tips-for-training-Batch-and-Momentum-cfcbead0a7a040dab13a5802f3ed2e05


Momentum

Momentum 是除了 Gradient Descent 之外,另一種更新並優化 function 參數的方式,也是除了透過 Hessian 以外,能夠逃離 Saddle Point 的另一種方式。

圖表解釋概念

  • 在物理世界中,如圖1假設 error surface 就是二維的斜坡(實際上不是,因為 feature 的維度非常多)。
  • 而參數是一個球。
  • 如果是用 gradient descent 的方式優化 loss 並更新參數,那麼球會卡在平坦的 saddle point(中間的球),到不了 local minima(最右邊的球)。
  • 而在物理世界中,球因為斜坡往下的速度,並不會卡在 saddle point,反而會繼續往前到 local minima,甚至斜坡夠陡峭,會翻過 local minima 繼續往右走。
  • 這種模擬物理世界中球體的前進方式,就是 momentum 的概念。

圖1 - 截圖自李宏毅老師課程

公式解

  • Vanilla Gradient Descent (一般的 gradient descent)

* 初始參數 theta⁰。

* 計算 g⁰。

* 往 g⁰ 的反方向更新參數 theta¹(公式: theta¹ = theta⁰ - eta g⁰)。

* 有了 theta¹ 再計算 g¹。

* …

圖2— 截圖自李宏毅老師課程

Gradient Descent + Momentum

  • 這次更新的方向與距離 = 上一個步驟的 m (momentum) 加上這次的 g (gradient)。

* 初始參數 theta⁰、初始 m⁰ = 0。

* 計算 g⁰。

* 計算 m¹ = lambda * m⁰ - eta * g⁰。
* theta⁰ 根據 g⁰ 的反方向加上 m¹ 移動到 theta¹(公式: theta¹ = theta⁰ + m¹)。

* 有了 theta¹ 之後再計算 g¹。

* 計算 m² = lambda * m¹ - eta * g¹。

* theta² = theta¹ + m², theta² 不只會往 g¹ 的反方向前進(紅色虛線),也會考慮 $m¹$ 的方向(藍色虛線)。

* …

圖3— 截圖自李宏毅老師課程

  • m 其實可以被寫作是 g⁰, g¹,…,g^i-1 的 weighted sum。

* 因為 m⁰ = 0,因此 m¹ = — eta g⁰。

* 接著再繼續帶入 m²。

* m² = -lambda * eta * g⁰ — eta * g¹。

* …

  • 因此對於 momentum 的解讀:
  1. 一種是 g 的負的反方向 + 前一次的方向。
  2. 另一種是,參數更新的方向,是對過去所有 g 的總和。

圖4— 截圖自李宏毅老師課程

圖解逃離 saddle point

  • 跟上一張 Small Gradient 的圖一樣,這次用 Gradient Descent + Momentum 的方式來更新參數。
  • 參數更新的方向(藍色實線) = 負的 g (紅線)+ 上一次參數更新的方向(藍色虛線)
  • 因此可以清楚地看見,參數並不會卡在 saddle point,儘管 g = 0,但因為還考慮了上一次參數更新的方向,因次還有辦法繼續往右走。
  • 而到了最右邊的高峰時,儘管負的 g 要參數往左更新,但若上一次的更新方向的力道大於往左的力道,那麼參數甚至會往右更新。

圖5— 截圖自李宏毅老師課程


結語

除了 Momentum 以外,訓練的 Optimizer 還有很多其他種類,常聽到的還有 AdaGrad、RMSProp、Adam(實務上常用),這篇筆記就是紀錄 Momentum 這種 Optimizer 概念與實際上是如何更新參數的。

以上就是這堂課程關於 Momentum 這個部分的筆記了,如果喜歡我的筆記,歡迎給個clap或留下留言!