An Introduction to Recurrent Neural Networks & LSTMs
In this introductory guide to deep learning, we'll discuss two important concepts: recurrent neural networks (RNNs) and long short-term memory (LSTM) networks.
A recurrent neural network attempts to model time-based or sequence-based data.
A few applications of recurrent neural networks include natural language processing (NLP), predicting stock prices, predicting energy demand, and more.
Recurrent neural networks are designed to learn from sequences of data by passing the hidden state from one step in the sequence, to the next step, and combining this with the input.
This bring us to Long Short-Term Memory (LSTM )networks:
An LSTM network is a type of RNN that uses special units as well as standard units.
What are these special units?
LSTM units include a 'memory cell' that can keep information in memory for longer periods of time.
LSTMs are particularly useful when the neural network needs to switch between using recent information and making use of older data in order to make predictions.
Stay up to date with AI
We're an independent group of machine learning engineers, quantitative analysts, and quantum computing enthusiasts. Subscribe to our newsletter and never miss our articles, latest news, etc.
RNNs vs. LSTMs
Let's say we have a regular neural network that is used for image recognition.
If we pass as image of a dog to the neural network—it will ideally output a high probability of being a dog, and maybe a small percentage chance of being a wolf, and a smaller chance of being a cat.
But what if the image was actually a wolf? How would the neural network know?
Now let's say that we have a sequence of images, say on a nature TV show, and the previous images were a wolf, a bear, and a fox.
In this case, we analyze the images with the same copy of the neural network, but instead use the output of the neural network as part of the the input of the next one.
This sequential process can improve our results.
To do this mathematically, we just combine the vectors in a linear function which will then be combined with an activation function, which can be either sigmoid or hyperbolic tan.
By doing it this way, the final neural network will know that the TV show is about wild animals that live in forests, and can use this information to predict the image is a wolf as opposed to a dog.
This example illustrates how recurrent neural networks work.
The problem with recurrent neural networks is that the memory of the network is generally short-term memory.
For example, if in between the bear and the wolf images we had a flower and a tree (which could be thought of as either domestic or wild), the network would have a hard time remembering the significance of the bear and fox.
In short, RNNs are not particularly well suited to store and make use of long term memory. To solve, LSTM networks that use both long and short term memory are better suited than RNNs.
To recap, below is an overview of how RNNs work:
- Memory comes in and merges with a current event
- The output comes out as a prediction of what the input is
- The output is also part of the input of the next iteration of the neural network
In a similar way, an LSTM works as follows:
- The neural network keeps track not just of short term memory, but also of long term memory
- In every step of the sequence, the long and short term memory in the step get merged
- From this, we get a new long term memory, short term memory, and prediction
By doing it this way, the network can remember information from a long time ago.
Now Let's look at the architecture of LSTMs.
LSTM Network Architecture
In our previous example of the nature TV show, we have:
- Long term memory about nature and forest animals
- Short term memory about flowers and trees
- An event, which is the new image that could be a dog or wolf
We want to combine all this information to create a prediction about what our image is.
The long term memory gives us a hint that the image should favour the wolf prediction over a dog.
We also want all three variables to help us update the long term and short term memory of the network.
To accomplish this, the architecture of the LSTM contains several gates:
- A forget gate
- A learn gate
- A remember gate
- And a use gate
Here's an overview of how each of these gates work together:
- The long term memory goes into the forget gate, where it forgets everything that's not useful
- The short term memory and the event are combined in the learn gate
- The long term memory that hasn't been forgotten plus the new information we learned get joined in the remember gate, which outputs an updated long term memory
- The use gate decides what information we use from what was already know plus what we just learned to make a prediction
- The output becomes both the prediction and the new short term memory
Now let's dive a bit deeper into the different gates.
The Forget Gate
The forget gate takes the long term memory and decides what part to keep and what to forget.
How does this work mathematically?
- The long-term memory (LTM) from time $t-1$ is multiplied by a forget factor $f_t$.
- The forget factor is calculated with the short-term memory (STM) and the event information $E_t$
- We run a small one layer neural network with a linear function combined with the sigmoid function to calculate the forget factor.
Remember Gate
The remember gate takes the output from the forget gate and from the learned gate adds them together.
Use Gate
Also called the output gate, this uses the long-term memory that came from the forget gate and the short-term memory from the learned gate to come up with a new short-term memory and an output (these are the same thing).
Here's how this is done mathematically:
- It applies a small neural network using the
tanh
activation function on the long-term memory - It applies another small neural network on short-term memory and the events using the
sigmoid
activation function - As a final step is multiplies these two together to get the new output
Character-wise RNNs
Character-wise RNNs are networks that learn text one character at a time, and generate new text one character at a time.
Sequence Batching
One of the hardest parts of building recurrent neural networks can be getting the batches right.
Below is an overview of how batching works for RNNs:
- With RNNs we're training on sequences of data, such as text, audio, or stock prices
- By splitting the sequences into smaller sequences we can use matrix operations to improve the efficiency of training
- For example, if we have a sequence of numbers from 0-10, we can either pass it in as one sequence, or we could split it into 2 sequences i.e. [0-4] and [5-10]
- The batch size corresponds to the number of sequences, so here the batch size would be 2
We also choose the length of the sequences we feed to the network, for example we could use the first 3 numbers of the sequence.
We can retain the hidden state from one batch and use it for the next one, thus the sequence information is transferred across batches for each mini-sequence.
Summary: Recurrent Neural Networks & LSTMs
- RNNs are designed to learn from sequences of data by passing the hidden state from one step in the sequence to the next step and combining this with the input.
- LSTM networks are a type of RNN that use special units as well as standard units.
- LSTM units include a memory cell that can keep information in memory for long periods of time.
- LSTMs are particularly useful when our neural network needs to switch between remembering recent features, and features from a long time ago.