Deep Q-Networks

March 29, 2025

Recall a QQ-network is an off-policy value-based reinforcement learning algorithm that works by defining a Q (quality) function

Q(s,a):=E[r1+γr2+γ2r3+s,a,π]Q(s,a) := \bbe[r_1 + \gamma r_2 + \gamma^2 r_3 + \ldots \mid s, a, \pi^*]

where π\pi^* is an optimal policy for the Markov Decision Process, r1r_1 is the reward for taking aciton aa in state ss, rir_i is the reward for transitioning from sis_i to si+1s_{i+1} following π\pi^* thereafter, and γ(0,1)\gamma \in (0,1) is the discount factor which weighs future reward lower than immediate reward.

Essentially, Q(s,a)Q(s,a) represents the quality of taking action aa while in state ss. Then QQ-learning works by iteratively updating state-action QQ-values using a QQ-table to learn this function, and then having an agent pick the action from a given state that maximizes quality Q(s,)Q(s,-).

The way we update this table is by the Bellman equation

Q(s,a)Q(s,a)+α[R(s,a)+γmaxaQ(s,a)Q(s,a)]Q(s,a) \leftarrow Q(s,a) + \alpha [R(s,a) + \gamma \max_{a'} Q(s',a') - Q(s,a)]

where R(s,a)R(s,a) is the immediate reward for taking action aa in state ss, and maxaQ(s,a)\max_{a'} Q(s',a') is the optimal action to take in the next step ss' according to our current QQ function.

The issue with tabular QQ-learning, where we have a table data structure the maps state-action pairs to QQ-values which we iteratively update, is that for large state-action spaces this isn't practical. The solution is to instead use a neural network to functionally approximate QQ.

According to the Bellman equation, notice that optimizing QQ is equivalent to minimizing the TD error R(s,a)+γmaxaQ(s,a)Q(s,a)R(s,a) + \gamma \max_{a'} Q(s',a') - Q(s,a). Thus we can use this to define the loss function for our neural network. Let's take a stab at defining our main training loop for updating the parameters at every step.

Training Loop (Take 1):

  • Compute Q(s,a;θ)Q(s,a;\theta) for every action aa using the network and pick an action aa using the ε\eps-greedy strategy.
  • Compute the target QQ-value for the transition (s,a,r,s)(s,a,r,s'), where r=R(s,a)r = R(s,a) and ss' is the state given by the world for taking action aa at state ss (might be deterministic or stochastic).
target=r+γmaxaQ(s,a;θ)\text{target} = r + \gamma \max_{a'} Q(s',a'; \theta)
  • Update θ\theta with gradient descent θθηJ(θ)\theta \leftarrow \theta - \eta\nabla J(\theta) on the 2\ell_2 temporal-difference (TD) loss J(θ)=12δ(θ)2J(\theta) = \frac{1}{2} \delta(\theta)^2 where δ(θ)\delta(\theta) is the temporal difference error.
δ(θ):=r+γmaxaQ(s,a;θ)Q(s,a;θ)\delta(\theta) := r + \gamma \max_{a'} Q(s',a'; \theta) - Q(s,a;\theta)

Recall the ε\eps-greedy strategy, which is a way to balance exploration/exploitation. We choose the optimal action with probability 1ε1-\eps, and a random action with probability ε\eps. As training progresses, ε\eps exponentially decays so we learn the optimal policy.

The above training loop suffers from the moving target problem, which is that the parameters θ\theta being trained are used to calculate both the prediction and the target value for the next pass. The solution is to instead keep track of two networks.

Experience Replay

We will have two neural networks, an online network and a target network. The target network will keep track of the target QQ-values, and the online network is the one that will be trained at every learning step. Every NN steps, we will copy the online network weights over to the target network. At every learning step for the online network, we use the target QQ-values computed using the target network.

We will denote the target QQ-function by QtargetQ_{\text{target}} with parameters θtarget\theta_{\text{target}}, and the online QQ-function by QonlineQ_{\text{online}} with parameters θonline\theta_{\text{online}}.

Training Loop, Online Network (Take 2):

  • Compute Qonline(s,a;θ)Q_{\text{online}}(s,a; \theta) for every action aa using the online network and pick aa using the ε\eps-greedy strategy.
  • For the transition (s,a,r,s)(s,a,r,s'), compute
target=r+γmaxaQtarget(s,a;θtarget)\text{target} = r + \gamma \max_{a'}Q_{\text{target}}(s', a'; \theta_{\text{target}})
  • Update θonline\theta_{\text{online}} with gradient descent θonlineθonlineηJ(θ)\theta_{\text{online}} \leftarrow \theta_{\text{online}} - \eta \nabla J(\theta) on the temporal-difference (TD) loss J(θ)=12δ(θ)2J(\theta) = \frac{1}{2} \delta(\theta)^2.
δ(θ):=r+γmaxaQtarget(s,a;θtarget)Qonline(s,a;θonline)\delta(\theta) := r + \gamma \max_{a'} Q_{\text{target}}(s',a'; \theta_{\text{target}}) - Q_{\text{online}}(s,a;\theta_{\text{online}})

Training Loop, Target Network (Take 2):

  • For every NN learning steps the online network takes, update the weights from the online network to target network by
θtargetτθonline+(1τ)θtarget\theta_{\text{target}} \leftarrow \tau \theta_{\text{online}} + (1-\tau) \theta_{\text{target}}

The reason we introduce τ\tau in the target network update is to smooth it and make sure the target weights are more stable, since we might update target weights during an online episode training the online network.

Another issue that comes up is that early training updates in the online network can typically have large variance, hence large TD error δ\delta. This could potentially destabilize training by taking massive steps, so we want to limit gradient magnitude for large errors.

We define Huber loss by the function

H(δ):={12δ2δ1δ12δ>1\calh(\delta) := \begin{cases} \frac{1}{2} \delta^2 & |\delta| \leq 1 \\ |\delta| - \frac{1}{2} & |\delta| > 1\end{cases}

where the 1/2-1/2 factor is so that loss is continuous at δ=1\delta = 1. We use this to update the online network instead of naive TD 2\ell_2 loss.

There is yet another issue, which is that consecutive actions in reinforcement learning are highly correlated. It would be better if we could remember previous transitions (s,a,r,s)(s,a,r,s') and randomly sample from them to choose the next transition to train on. Thus, we don't train episodically but rather randomly jump across episode branches to try and learn a general policy.

The action we choose aa in (s,a,r,s)(s,a,r,s') depends on our QonlineQ_{\text{online}} function which we continuously update, so instead of randomly sampling from all memory we should instead keep a replay buffer M\calm which stores only recent transition experiences.

Since our online agent is no longer training along an episode path, instead of updating the weights at every step we can now learn in batches of transitions B\calb randomly sampled from replay M\calm. These measures prevent the network from overfitting to recent experiences and contribute to more stable convergence. The diversity in training samples from the replay buffer also ensures that the network learns a more generalized policy, rather than a policy that is overly tuned to recent experiences.

Training Loop, Online Network (Final Take):

  • Given our current state sis_i, compute Qonline(si,a;θ)Q_{\text{online}}(s_i,a; \theta) for every action aa using the online network and pick aia_i using the ε\eps-greedy strategy. Add the transition (si,ai,ri,si+1)(s_i,a_i,r_i,s_{i+1}) to replay memory buffer M\calm.
  • Randomly sample a batch B\calb from replay memory M\calm, and for each (s,a,r,s)B(s,a,r,s') \in \calb compute
target=r+γmaxaQtarget(s,a;θtarget)\text{target} = r + \gamma \max_{a'}Q_{\text{target}}(s', a'; \theta_{\text{target}})
  • Update θonline\theta_{\text{online}} with gradient descent θonlineθonlineηH(θ;B)\theta_{\text{online}} \leftarrow \theta_{\text{online}} - \eta \nabla \calh(\theta; \calb) on the Huber batch loss H(θ;B)\calh(\theta; \calb) where
H(θ;B):=1B(s,a,r,s)BH(δ(θ))δ(θ):=r+γmaxaQtarget(s,a;θtarget)Qonline(s,a;θonline)\begin{align*} \calh(\theta; \calb) &:= \frac{1}{|\calb|}\sum_{(s,a,r,s') \in \calb} \calh(\delta(\theta)) \\ \delta(\theta) &:= r + \gamma \max_{a'} Q_{\text{target}}(s',a'; \theta_{\text{target}}) - Q_{\text{online}}(s,a;\theta_{\text{online}}) \end{align*}

Training Loop, Target Network (Final Take):

  • For every NN learning steps the online network takes, update the weights from the online network to target network by
θtargetτθonline+(1τ)θtarget\theta_{\text{target}} \leftarrow \tau \theta_{\text{online}} + (1-\tau) \theta_{\text{target}}

PyTorch Implementation

We will do a nice implementation of these ideas using OpenAI's CartPole gym environment. This is the same tutorial PyTorch gives on DQN, except I have added the RecordVideo library to see the agent training and also updated the terminology to match the math described above.

import gymnasium as gym from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo import math import random import matplotlib import matplotlib.pyplot as plt from collections import namedtuple, deque from itertools import count import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.nn.init as init // if GPU is to be used device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) if torch.cuda.is_available() or torch.backends.mps.is_available(): num_episodes = 600 training_period = 100 else: num_episodes = 300 training_period = 25 env = gym.make("CartPole-v1", render_mode="rgb_array") env = RecordVideo(env, video_folder="cartpole-agent", name_prefix="training", episode_trigger=lambda x: x % training_period == 0) env = RecordEpisodeStatistics(env) // set up matplotlib is_ipython = 'inline' in matplotlib.get_backend() if is_ipython: from IPython import display // type: ignore plt.ion()

We have the replay memory class to make things easier during our training update functions.

Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state')) class ReplayMemory(object): def __init__(self, capacity): self.memory = deque([], maxlen=capacity) // Add transition to memory def push(self, *args): self.memory.append(Transition(*args)) // Sample batch from memory def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory)

Recall the model class we defined in the first section. We start by initializing all of our constants, our online network, and our target network.

class model(nn.Module): def __init__(self, input_dim, output_dim): super(model, self).__init__() self.layer1 = nn.Linear(input_dim, 128) self.layer2 = nn.Linear(128, 128) self.layer3 = nn.Linear(128, output_dim) def init_weights(self): if isinstance(self, nn.Linear): init.kaiming_normal_(self.weight) // Compute either one element to determine action // or batch during training. // Returns tensor([[Qleft,Qright]...]). def forward(self, x): x = F.relu(self.layer1(x)) x = F.relu(self.layer2(x)) return self.layer3(x) // BATCH_SIZE is the number of transitions sampled from the replay buffer // GAMMA is the discount factor as mentioned in the previous section // EPS_START is the starting value of epsilon // EPS_END is the final value of epsilon // EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay // TAU is the update rate of the target network // ETA is the learning rate of the ``AdamW`` optimizer BATCH_SIZE = 128 GAMMA = 0.99 TAU = 0.005 ETA = 1e-4 EPS_START = 0.9 EPS_END = 0.05 EPS_DECAY = 1000 // Get number of actions from gym action space n_actions = env.action_space.n // Get the number of state observations state, info = env.reset() n_observations = len(state) online_net = model(n_observations, n_actions).apply(model.init_weights).to(device) target_net = model(n_observations, n_actions).apply(model.init_weights).to(device) // Copy parameters from online to target target_net.load_state_dict(online_net.state_dict()) // Optimize with Adam/AMSGrad optimizer = optim.AdamW(online_net.parameters(), lr=ETA, amsgrad=True) memory = ReplayMemory(10000) steps_done = 0

Adam and AMSGrad are adaptive learning optimizers that adjust learning rates on a per-parameter basis that differ in how they handle second-moment esimates. Adam uses an exponential moving average of past squared gradients, while AMSGrad modifies this by tracking the maximum of these averages to ensure the effective learning rate doesn’t increase unexpectedly.

There might be convergence issues with Adam in some RL applications due to exploding learning rates so the upside of AMSGrad is that we get certain convergence guarantees.

Now we implement the ε\eps-greedy strategy, where we decay our epsilon threshold for choosing a random action according to

ε-threshold=A+(BA)es/D\eps\text{-threshold} = A + (B - A) e^{-s/D}

where BB is the starting threshold, AA is the lowest threshold, DD is the decay rate, and ss is the number of learning steps we have done so far.

// Select an action based on epsilon-greedy policy def select_action(state): global steps_done eps = random.random() eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1 * steps_done / EPS_DECAY) steps_done += 1 if eps > eps_threshold: with torch.no_grad(): return online_net(state).max(1).indices.view(1,1) else: return torch.tensor([[env.action_space.sample()]], dtype=torch.long, device=device)

We'll also set up some plotting so we can view the duration the agent is able to hold up the cart, as well as the 100 episode moving average.

episode_durations = [] def plot_durations(show_result=False): plt.figure(1) durations_t = torch.tensor(episode_durations, dtype=torch.float, device=device) if show_result: plt.title('Result') else: plt.clf() plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Duration') plt.plot(durations_t.numpy()) // Take 100 episode averages and plot them too if len(durations_t) >= 100: means = durations_t.unfold(0, 100, 1).mean(1).view(-1) means = torch.cat((torch.zeros(99), means)) plt.plot(means.numpy()) plt.pause(0.001) // pause a bit so that plots are updated if is_ipython: if not show_result: display.display(plt.gcf()) display.clear_output(wait=True) else: display.display(plt.gcf())

Now we implement the main batch training step for our online model. For terminal states ss in each episode we set Q(s,a)=0Q(s,a) = 0 for all actions aa.

// Perform one step of the optimization (on the online network) def optimize_model(): if len(memory) < BATCH_SIZE: return transitions = memory.sample(BATCH_SIZE) // Transpose array of transitions to Transition of arrays // E.g. batch.state, batch.action, ... are arrays of Tensors representing Transitions. batch = Transition(*zip(*transitions)) // Tensor of next_states in batch which are not final non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]) state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) // Compute Q(s,a) for each state-action pair in the batch state_action_values = online_net(state_batch).gather(1, action_batch) // Compute max Q(s',a') for all next states s', or 0 if s' is terminal next_state_values = torch.zeros(BATCH_SIZE, device=device) // Bitmask of booleans for next_state's s' in batch that are not final non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool, device=device) with torch.no_grad(): next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values // target = r + gamma * max Q(s', a') target_state_action_values = reward_batch + GAMMA * next_state_values // Compute Huber loss huber = nn.HuberLoss() J = huber(state_action_values, target_state_action_values.unsqueeze(1)) // Optimize the model optimizer.zero_grad() J.backward() // In-place gradient clipping torch.nn.utils.clip_grad_value_(online_net.parameters(), 100) optimizer.step()

Gradient clipping was briefly discussed in one of the previous posts, but basically it stops the gradients of the parameters from exploding by normalizing if they get too big.

Lastly, we have the training loop. If GPUs are available we can utilize them to speed up training and complete a lot more episodes. Training RL agents is noisy so a small number of episodes is likely not enough to see good convergence for CartPole.

// Main episodic training loop for i_episode in range(1, num_episodes+1): if i_episode % training_period == 0: print(f"Episode {i_episode} of {num_episodes}") // Initialize the environment and get its state state, info = env.reset() state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) for t in count(): action = select_action(state) observation, reward, terminated, truncated, _ = env.step(action.item()) reward = torch.tensor([reward], device=device) done = terminated or truncated if terminated: next_state = None else: next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0) // Push the reward to memory memory.push(state, action, reward, next_state) // Move to the next state state = next_state // Perform one step of the optimization (on the online network) optimize_model() // Soft update of the target network's weights target_net_state_dict = target_net.state_dict() online_net_state_dict = online_net.state_dict() for key in online_net_state_dict: target_net_state_dict[key] = online_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU) target_net.load_state_dict(target_net_state_dict) if done: episode_durations.append(t + 1) plot_durations() break print('Complete') plot_durations(show_result=True) plt.ioff() plt.show()