This post is aimed at giving a quick and dirty intro to RNNs, with an emphasis naturally on mathematical formulations, so one has just enough working knowledge to implement a basic model. Traditional neural networks treat each input as independent. But in many real-world scenarios, the order and context of data matter. A sequence of words in a sentence, a time series of stock prices, or a series of sensor readings all have inherent temporal dependencies.
RNNs solve this by introducing a "hidden state" - think of it as the network's working memory. This hidden state acts like a messenger, carrying important information from previous time steps to the current one.
Recall the traditional feedforward neural network architecture consists of an input layer, some hidden middle layers, and an output layer with connections between each of the layers. The RNN architecture allows memory by adding self-loops at the nodes in the hidden layers, which means that information from previous steps can be fed back into the network.
In particular this makes them much better at sequential classification tasks than feedforward networks, which are better at independent classification tasks (e.g. image classification).
To implement this idea, we have the following steup. Suppose our labeled training data is sequential, i.e. . We refer to datapoint as the point at timestep .
At each , we have an input given by , a hidden vector , and an output vector (the hat denotes prediction). The hidden layer represents the information that we want to pass from one timestep to the next, and is a function of the input at time as well as the information from the previous timestep . We can summarize this as follows.
- = input vector
- = hidden vector
- = output vector
We will describe a simple recurrent network (SRN) to illustrate the basic idea. We model the relationship as follows.
The functions are activation functions (e.g. ), are the weight matrices, and are the biases. Collectively, the weights and biases are referred to as the parameters . Intuitively, we are just updating our internal memory as a function of the input at time and our memory at the previous timestep, and our prediction for at time is just a function of our memory at time .
Simple Cell RNN Forward Pass
At each timestep , we update our parameters in the following fashion. We first randomly initialize . For the hidden layer, we use the activation function to get an output in the range , although other activation functions can also be used. For a vector we just take coordinate-wise, and we have the following dimensions of the weights and biases (although you can just assume we make everything the correct dimension).
- = input weights for hidden layer
- = hidden layer weights for hidden layer
- = bias for hidden layer
- = hidden layer weights for prediction
- = bias for prediction
Then just as before,
We will assume that we are performing binary classification, i.e. . Since we are doing binary, and we compute our prediction for using this hidden layer and the sigmoid function .
Some examples of applications include language modeling problems like sentiment analysis: given an embedding of a sequence of words, determine whether it is positive or negative. It's possible to use other activation functions for different tasks, like the identity for continuous modeling or softmax for multiclass classification.
We can compute cross-entropy loss at time with the following formula.
But which timestep should we use to compute loss and update the weights? The answer is to compute loss at each timestep according to a sliding context window of a fixed number of previous timesteps, and perform backpropogation. This is referred to as backpropogation through time.
Backpropogration Through Time (BPTT)
The idea for backpropogation here is to unfold the RNN through some timesteps and run backpropogation in the usual sense. That is, we don't just want to run backprop on all the data from time , but also the memory data from times .
Fix some which we will determine later. Starting at , set and foward propogate through the unfolded network according to the previous section. That is,
Then we compute gradient loss and perform usual backpropogation (using vectorized notations for the chain rule).
which implies by a symmetric calculation that
For , we have
Note so for
we have
We apply entry-wise so when we take the chain rule, we let denote the entry-wise derivatives which we just showed how to compute, and take the entry-wise (Hadamard) product . We compute the transpose of the partial to save a lot of ugly transpose notation.
Similar computations give and . Then we just update our parameters, where is our learning rate.
Nice! We have successfully performed backprop through time. Note we only did this for , but the same idea applies for . That is, start with and input . Perform the above forward and backward pass to update the parameters , and set (with our newly updated parameters)
Then perform the passes again with input to update , etc... until we reach the last iteration on .
A question you might have is what a reasonable choice for might be? If we pick large, notice that the contribution of the previous information decays geometrically over time in the above formula for ; notice the factor where ranges from to . Thus the gradient becomes too small and the updated parameters aren't able to train effectively. This is known as the vanishing gradient problem. The solution to this problem, known as long short-term memory (LSTM), is the subject of the next post.
The other possibility is that we implement LSTM, but now the gradients become unreasably large. This is the exploding gradients problem and can be solved by defining some threshold such that when the gradient blows up past this threshold, we normalize the gradient to scale it back down. Another approach that could solve both issues at once is to initialize the weight matrix to be orthogonal since products of orthogonal matrices don't explode or vanish, but this limits our weights and leads to more restrictive models.
In practice, usually or works well, although one is better off using more advanced RNNs to mitigate this issue.
Conclusion
RNNs are a very powerful framework for training neural networks on sequential data and represent an intuitive adaptation of the ideas used to build feedforward neural networks to solve these tasks. In the next post we will discuss some more complex cells that we can choose to solve the vanishing gradients problem.
Namely, instead of using a simple activation on input and the previous , we keep track of five different gates and states: internal memory of the cell, the hidden state that we output across time, the input gate to determine how much input flows to cell memory, the forget gate to determine how much input and previous cell memory flows to cell memory, and the output gate to determine how much input and prevoius cell memory flows into the output hidden state.