Long Short-Term Memory: From Zero to Hero with PyTorch
Just like us, Recurrent Neural Networks (RNNs) can be very forgetful. This struggle with short-term memory causes RNNs to lose their effectiveness in most tasks. However, do not fret, Long Short-Term Memory networks (LSTMs) have great memories and can remember information which the vanilla RNN is unable to!
LSTMs are a particular variant of RNNs, therefore having a grasp of the concepts surrounding RNNs will significantly aid your understanding of LSTMs in this article. I covered the mechanism of RNNs in my previous article here.
A quick recap on RNNs:
RNNs process inputs in a sequential manner, where the context from the previous input is considered when computing the output of the current step. This allows the neural network to carry information over different time steps rather than keeping all the inputs independent of each other.
Process Flow of RNNs:
However, a significant shortcoming that plagues the typical RNN is the problem of vanishing/exploding gradients. This problem arises when back-propagating through the RNN during training, especially for networks with deeper layers. The gradients have to go through continuous matrix multiplications during the back-propagation process due to the chain rule, causing the gradient to either shrink exponentially (vanish) or blow up exponentially (explode). Having a gradient that is too small prevents the weights from updating and learning, whereas extremely large gradients cause the model to be unstable.
Due to these issues, RNNs are unable to work with longer sequences and hold on to long-term dependencies, making them suffer from “short-term memory”.
What are LSTMs:
While LSTMs are a kind of RNN and function similarly to traditional RNNs, its Gating mechanism is what sets it apart. This feature addresses the “short-term memory” problem of RNNs.
Vanilla RNN vs LSTM:
As we can see from the image, the difference lies mainly in the LSTM’s ability to preserve long-term memory. This is especially important in the majority of Natural Language Processing (NLP) or time-series and sequential tasks. For example, let’s say we have a network generating text based on some input given to us. At the start of the text, it is mentioned that the author has a “dog named Cliff”. After a few other sentences where there is no mention of a pet or dog, the author brings up his pet again, and the model has to generate the next word to “However, Cliff, my pet ____”. As the word pet appeared right before the blank, a RNN can deduce that the next word will likely be an animal that can be kept as a pet.
RNNs are unable to remember information from much earlier:
However, due to the short-term memory, the typical RNN will only be able to use the contextual information from the text that appeared in the last few sentences – which is not useful at all. The RNN has no clue as to what animal the pet might be as the relevant information from the start of the text has already been lost.
On the other hand, the LSTM can retain the earlier information that the author has a pet dog, and this will aid the model in choosing “the dog” when it comes to generating the text at that point due to the contextual information from a much earlier time step.
onto features for a long-term
Inner workings of the LSTM:
The secret sauce to the LSTM lies in its gating mechanism within each LSTM cell. In the normal RNN cell, the input at a time-step and the hidden state from the previous time step is passed through a tanh activation function to obtain a new hidden state and output.
Inner workings of an RNN cell:
LSTMs, on the other hand, have a slightly more complex structure. At each time step, the LSTM cell takes in 3 different pieces of information — the current input data, the short-term memory from the previous cell (similar to hidden states in RNNs) and lastly the long-term memory.
The short-term memory is commonly referred to as the hidden state, and the long-term memory is usually known as the cell state.
The cell then uses gates to regulate the information to be kept or discarded at each time step before passing on the long-term and short-term information to the next cell.
These gates can be seen as water filters. Ideally, the role of these gates is supposed to selectively remove any irrelevant information, similar to how water filters prevent impurities from passing through. At the same time, only water and beneficial nutrients can pass through these filters, just like how the gates only hold on to the useful information. Of course, these gates need to be trained to accurately filter what is useful and what is not.
LSTM Gates can be seen as filters:
These gates are called the Input Gate, the Forget Gate, and the Output Gate. There are many variants to the names of these gates; nevertheless, the calculations and workings of these gates are mostly the same.
Workings of the LSTM cell:
Let’s go through the mechanisms of these gates one-by-one.
The input gate decides what new information will be stored in the long-term memory. It only works with the information from the current input and the short-term memory from the previous time step. Therefore, it has to filter out the information from these variables that are not useful.
Input Gate computations:
Mathematically, this is achieved using 2 layers. The first layer can be seen as the filter which selects what information can pass through it and what information to be discarded. To create this layer, we pass the short-term memory and current input into a sigmoid function. The sigmoid function will transform the values to be between 0 and 1, with 0 indicating that part of the information is unimportant, whereas 1 indicates that the information will be used. This helps to decide the values to be kept and used, and also the values to be discarded. As the layer is being trained through back-propagation, the weights in the sigmoid function will be updated such that it learns to only let the useful pass through while discarding the less critical features.
The second layer takes the short term memory and current input as well and passes it through an activation function, usually the $$tanh$$ function, to regulate the network.
The outputs from these 2 layers are then multiplied, and the final outcome represents the information to be kept in the long-term memory and used as the output.
The forget gate decides which information from the long-term memory should be kept or discarded. This is done by multiplying the incoming long-term memory by a forget vector generated by the current input and incoming short-term memory.
Just like the first layer in the Input gate, the forget vector is also a selective filterlayer. To obtain the forget vector, the short-term memory, and current input is passed through a sigmoid function, similar to the first layer in the Input Gate above, but with different weights. The vector will be made up of 0s and 1s and will be multiplied with the long-term memory to choose which parts of the long-term memory to retain.
The outputs from the Input gate and the Forget gate will undergo a pointwise addition to give a new version of the long-term memory, which will be passed on to the next cell. This new long-term memory will also be used in the final gate, the Output gate.
The output gate will take the current input, the previous short-term memory, and the newly computed long-term memory to produce the new short-term memory/hidden state which will be passed on to the cell in the next time step. The output of the current time step can also be drawn from this hidden state.
First, the previous short-term memory and current input will be passed into a sigmoid function (Yes, this is the 3rd time we’re doing this) with different weights yet again to create the third and final filter. Then, we put the new long-term memory through an activation $$tanh$$ function. The output from these 2 processes will be multiplied to produce the new short-term memory.
The short-term and long-term memory produced by these gates will then be carried over to the next cell for the process to be repeated. The output of each time step can be obtained from the short-term memory, also known as the hidden state.
That’s all there is to the mechanisms of the typical LSTM structure. Not all that tough, eh?