Learn how backpropagation works in LSTM network.
In this blog post, I’ll be discussing how backpropagation through time works in LSTM cell and how the current implementation of it has been done in TMVA for LSTM layer design. We call it ‘backpropagation through time’ because we’re repeating this process at each timestep. If observed carefully, this is similar to the normal backpropagation process.
Here is the pictorial representation of the LSTM cell:
Figure-1: The pictorial representation of the LSTM cell with four interacting gates.
The LSTM Cell backpropagation is different from RNN because in each timestep during backpropagation, the values of the input gate, forget gate and output gate should also be updated.
Here are the mathematical equations for reference:
Figure-2: Updating each gate value during the backpropagation process in the LSTM cell.
Figure-3: Final parameters update during the backward pass in the LSTM cell. dW represents input weights, dU represents state weights and db as biases.
The current implementation follows the final update of internal parameters: state weights, input weights, and biases. The implementation of each gate value update is still in progress since it requires many parameters to be passed from the Backward() method. I will be sharing the updates and performance regarding it in the next blog post.
- Above ⨀ is the element-wise product or Hadamard product.
- Inner products will be represented as ⋅
- Outer products will be represented as ⨂
The backward pass process in the TBasicLSTMLayer class has been implemented using the Backward() method to initialize tensor variables to store gradient values related to the hidden state, input gate, forget gate, output gate, and candidate gradient value. These values are passed to CellBackward() method to perform the final update of such internal parameters in the network.
I’ve mentioned a good reference of the LSTM network overall which gives a good numerical example for LSTM to verify layer design.
In my next story, I’ll be sharing the results of the LSTM layer, my final work and future work related to the TMVA-DNN project.