In this post, I write about different ways of updating
neural network weights.
I’m writing this post out of the notes I took for the class
“Improving Deep
Neural Networks: Hyperparameter tuning, Regularization and Optimization”.
Specifically, I discuss about optimizing the neural network
weights using non-traditional gradient descent algorithms.
Applying proper optimization techniques have direct effect
on the performance of the neural network.
Required pre-requisites to understand this content:
- Neural Networks
- Gradient Descent
- Back Propagation
- Cost functions
To explain the following, I’m assuming neural network
weights θ
to be 1D/2D vector.
The different ways to update the neural network weights are
described as follows:
Traditional gradient Descent:
- It is prone to larger oscillations during its descent to global/local optimum while optimizing the cost function. This affects performance.
- To reduce the oscillations, we use (exponentially weighted) moving average like approaches to dampen out the effects of these oscillations.
- This ensures smoother transition to the global/local optimum while optimizing the cost function.
Exponentially weighted Moving averages:
Assuming we have a 1D vector V, that changes
over time (like temperature). It is prone to large local variations.
To correct
for that (quick and abrupt) change, we perform the following operations
Vcorrected = (β* Vt-1) + (1 - β) * Voriginal
In other words, Vcorrected ≈ Approximately
average over (1.0 / (1.0 - β)) observations
For example,
If β = 0.9 ≈ approximately last 10
observations average.
If β = 0.98 ≈ approximately last 50
observations average.
- Larger value of β gives smoother curves (as opposed to zig-zag/abrupt movement as observed in pure gradient descent).
- Bias correction in exponential weighted moving averages applied to gradient descent like algorithms don't affect them significantly. They can be implemented if needed. This is required in ADAptive momentum algorithm.
Momentum
Since, we understand basics of exponentially weighted moving
averages, we can apply that in weight update step in neural network
optimization.
We update the weights by performing the following steps:
Vdw = (β * Vdw) + ((1- β)* Vdw)
Vdb = (β * Vdb) + ((1- β)* Vdb)
Weight update step:
W = W - (α * Vdw) instead of W = W -
(α
* dw)
b = b - (α * Vdb) instead of b = b -
(α
* dw)
Usually Beta = 0.9
- This weight update procedure gives us smoother convergence to global/local minimum.
- The Vdw, Vdb terms are derived from the exponentially weighted moving average equations.
RMSProp
Root Mean Squared Propagation. Interestingly this was proposed in the Coursera course by Geoffrey Hinton back in 2011-2012.While applying gradient descent, we update the weights by performing the following:
Sdw = (β * Sdw) + ((1- β)* (dw)2
)
Sdb = (β * Sdb) + ((1- β)* (db)2
)
Weight update step:
W = W - (α * (dw/ (sqrt(Sdw) +ε))) instead of W = W - (α *
dw)
b = b - (α * (db/ sqrt(Sdb)+ε))) instead of W = W - (α *
db).
The Sdw, Sdb terms are derived from
the exponentially weighted moving average equations.
Adam: ADAptive Momentum
This is the most commonly used/popular optimization
algorithm in the computer vision community.
We have slightly changed update
algorithm compared to Momentum.
While applying gradient descent, we update the weights by
performing the following:
From Momentum we have:
Vdw = (β1 * Vdw) + ((1.0 - β1) * Vdw)
Vdb = (β1 * Vdb) + ((1.0 - β1) * Vdb)
You have to perform bias correction:
Vdw_corrected = Vdw /(1.0 - ( β1)t)
Vdb_corrected = Vdb /(1.0 - ( β1)t)
From RMSProp we have
Sdw = (β2 * Sdw) + ((1.0 - β2)* (dw)2
)
Sdb = (β2 * Sdb) + ((1.0 - β2)* (db)2
)
You have to perform bias correction:
Sdw_corrected = Sdw /(1.0 - ( β2)t)
Sdb_corrected = Sdb /(1.0 - ( β2)t)
Finally, we perform weight updates:
W = W - (α * (Vdw/ (sqrt(Sdw)+ε))) instead of W = W - (α *
dw).
b = b - (α * (Vdb/ (sqrt(Sdb)+ε))) instead of b = b - (α *
db).
All of this reminds me of Kalman filters, where the observed
signal value is not exactly correct, we observe some process noise and
observation noise and we try to incorporate that to have a updated/corrected
signal value from the sensor.
No comments:
Post a Comment