Recall a -network is an off-policy value-based reinforcement learning algorithm that works by defining a Q (quality) function
where is an optimal policy for the Markov Decision Process, is the reward for taking aciton in state , is the reward for transitioning from to following thereafter, and is the discount factor which weighs future reward lower than immediate reward.
Essentially, represents the quality of taking action while in state . Then -learning works by iteratively updating state-action -values using a -table to learn this function, and then having an agent pick the action from a given state that maximizes quality .
The way we update this table is by the Bellman equation
where is the immediate reward for taking action in state , and is the optimal action to take in the next step according to our current function.
The issue with tabular -learning, where we have a table data structure the maps state-action pairs to -values which we iteratively update, is that for large state-action spaces this isn't practical. The solution is to instead use a neural network to functionally approximate .
According to the Bellman equation, notice that optimizing is equivalent to minimizing the TD error . Thus we can use this to define the loss function for our neural network. Let's take a stab at defining our main training loop for updating the parameters at every step.
Training Loop (Take 1):
- Compute for every action using the network and pick an action using the -greedy strategy.
- Compute the target -value for the transition , where and is the state given by the world for taking action at state (might be deterministic or stochastic).
- Update with gradient descent on the temporal-difference (TD) loss where is the temporal difference error.
Recall the -greedy strategy, which is a way to balance exploration/exploitation. We choose the optimal action with probability , and a random action with probability . As training progresses, exponentially decays so we learn the optimal policy.
The above training loop suffers from the moving target problem, which is that the parameters being trained are used to calculate both the prediction and the target value for the next pass. The solution is to instead keep track of two networks.
Experience Replay
We will have two neural networks, an online network and a target network. The target network will keep track of the target -values, and the online network is the one that will be trained at every learning step. Every steps, we will copy the online network weights over to the target network. At every learning step for the online network, we use the target -values computed using the target network.
We will denote the target -function by with parameters , and the online -function by with parameters .
Training Loop, Online Network (Take 2):
- Compute for every action using the online network and pick using the -greedy strategy.
- For the transition , compute
- Update with gradient descent on the temporal-difference (TD) loss .
Training Loop, Target Network (Take 2):
- For every learning steps the online network takes, update the weights from the online network to target network by
The reason we introduce in the target network update is to smooth it and make sure the target weights are more stable, since we might update target weights during an online episode training the online network.
Another issue that comes up is that early training updates in the online network can typically have large variance, hence large TD error . This could potentially destabilize training by taking massive steps, so we want to limit gradient magnitude for large errors.
We define Huber loss by the function
where the factor is so that loss is continuous at . We use this to update the online network instead of naive TD loss.
There is yet another issue, which is that consecutive actions in reinforcement learning are highly correlated. It would be better if we could remember previous transitions and randomly sample from them to choose the next transition to train on. Thus, we don't train episodically but rather randomly jump across episode branches to try and learn a general policy.
The action we choose in depends on our function which we continuously update, so instead of randomly sampling from all memory we should instead keep a replay buffer which stores only recent transition experiences.
Since our online agent is no longer training along an episode path, instead of updating the weights at every step we can now learn in batches of transitions randomly sampled from replay . These measures prevent the network from overfitting to recent experiences and contribute to more stable convergence. The diversity in training samples from the replay buffer also ensures that the network learns a more generalized policy, rather than a policy that is overly tuned to recent experiences.
Training Loop, Online Network (Final Take):
- Given our current state , compute for every action using the online network and pick using the -greedy strategy. Add the transition to replay memory buffer .
- Randomly sample a batch from replay memory , and for each compute
- Update with gradient descent on the Huber batch loss where
Training Loop, Target Network (Final Take):
- For every learning steps the online network takes, update the weights from the online network to target network by
PyTorch Implementation
We will do a nice implementation of these ideas using OpenAI's CartPole gym environment. This is the same tutorial PyTorch gives on DQN, except I have added the RecordVideo library to see the agent training and also updated the terminology to match the math described above.
import gymnasium as gym from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo import math import random import matplotlib import matplotlib.pyplot as plt from collections import namedtuple, deque from itertools import count import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.nn.init as init // if GPU is to be used device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) if torch.cuda.is_available() or torch.backends.mps.is_available(): num_episodes = 600 training_period = 100 else: num_episodes = 300 training_period = 25 env = gym.make("CartPole-v1", render_mode="rgb_array") env = RecordVideo(env, video_folder="cartpole-agent", name_prefix="training", episode_trigger=lambda x: x % training_period == 0) env = RecordEpisodeStatistics(env) // set up matplotlib is_ipython = 'inline' in matplotlib.get_backend() if is_ipython: from IPython import display // type: ignore plt.ion()
We have the replay memory class to make things easier during our training update functions.
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state')) class ReplayMemory(object): def __init__(self, capacity): self.memory = deque([], maxlen=capacity) // Add transition to memory def push(self, *args): self.memory.append(Transition(*args)) // Sample batch from memory def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory)
Recall the model class we defined in the first section. We start by initializing all of our constants, our online network, and our target network.
class model(nn.Module): def __init__(self, input_dim, output_dim): super(model, self).__init__() self.layer1 = nn.Linear(input_dim, 128) self.layer2 = nn.Linear(128, 128) self.layer3 = nn.Linear(128, output_dim) def init_weights(self): if isinstance(self, nn.Linear): init.kaiming_normal_(self.weight) // Compute either one element to determine action // or batch during training. // Returns tensor([[Qleft,Qright]...]). def forward(self, x): x = F.relu(self.layer1(x)) x = F.relu(self.layer2(x)) return self.layer3(x) // BATCH_SIZE is the number of transitions sampled from the replay buffer // GAMMA is the discount factor as mentioned in the previous section // EPS_START is the starting value of epsilon // EPS_END is the final value of epsilon // EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay // TAU is the update rate of the target network // ETA is the learning rate of the ``AdamW`` optimizer BATCH_SIZE = 128 GAMMA = 0.99 TAU = 0.005 ETA = 1e-4 EPS_START = 0.9 EPS_END = 0.05 EPS_DECAY = 1000 // Get number of actions from gym action space n_actions = env.action_space.n // Get the number of state observations state, info = env.reset() n_observations = len(state) online_net = model(n_observations, n_actions).apply(model.init_weights).to(device) target_net = model(n_observations, n_actions).apply(model.init_weights).to(device) // Copy parameters from online to target target_net.load_state_dict(online_net.state_dict()) // Optimize with Adam/AMSGrad optimizer = optim.AdamW(online_net.parameters(), lr=ETA, amsgrad=True) memory = ReplayMemory(10000) steps_done = 0
Adam and AMSGrad are adaptive learning optimizers that adjust learning rates on a per-parameter basis that differ in how they handle second-moment esimates. Adam uses an exponential moving average of past squared gradients, while AMSGrad modifies this by tracking the maximum of these averages to ensure the effective learning rate doesn’t increase unexpectedly.
There might be convergence issues with Adam in some RL applications due to exploding learning rates so the upside of AMSGrad is that we get certain convergence guarantees.
Now we implement the -greedy strategy, where we decay our epsilon threshold for choosing a random action according to
where is the starting threshold, is the lowest threshold, is the decay rate, and is the number of learning steps we have done so far.
// Select an action based on epsilon-greedy policy def select_action(state): global steps_done eps = random.random() eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1 * steps_done / EPS_DECAY) steps_done += 1 if eps > eps_threshold: with torch.no_grad(): return online_net(state).max(1).indices.view(1,1) else: return torch.tensor([[env.action_space.sample()]], dtype=torch.long, device=device)
We'll also set up some plotting so we can view the duration the agent is able to hold up the cart, as well as the 100 episode moving average.
episode_durations = [] def plot_durations(show_result=False): plt.figure(1) durations_t = torch.tensor(episode_durations, dtype=torch.float, device=device) if show_result: plt.title('Result') else: plt.clf() plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Duration') plt.plot(durations_t.numpy()) // Take 100 episode averages and plot them too if len(durations_t) >= 100: means = durations_t.unfold(0, 100, 1).mean(1).view(-1) means = torch.cat((torch.zeros(99), means)) plt.plot(means.numpy()) plt.pause(0.001) // pause a bit so that plots are updated if is_ipython: if not show_result: display.display(plt.gcf()) display.clear_output(wait=True) else: display.display(plt.gcf())
Now we implement the main batch training step for our online model. For terminal states in each episode we set for all actions .
// Perform one step of the optimization (on the online network) def optimize_model(): if len(memory) < BATCH_SIZE: return transitions = memory.sample(BATCH_SIZE) // Transpose array of transitions to Transition of arrays // E.g. batch.state, batch.action, ... are arrays of Tensors representing Transitions. batch = Transition(*zip(*transitions)) // Tensor of next_states in batch which are not final non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]) state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) // Compute Q(s,a) for each state-action pair in the batch state_action_values = online_net(state_batch).gather(1, action_batch) // Compute max Q(s',a') for all next states s', or 0 if s' is terminal next_state_values = torch.zeros(BATCH_SIZE, device=device) // Bitmask of booleans for next_state's s' in batch that are not final non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool, device=device) with torch.no_grad(): next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values // target = r + gamma * max Q(s', a') target_state_action_values = reward_batch + GAMMA * next_state_values // Compute Huber loss huber = nn.HuberLoss() J = huber(state_action_values, target_state_action_values.unsqueeze(1)) // Optimize the model optimizer.zero_grad() J.backward() // In-place gradient clipping torch.nn.utils.clip_grad_value_(online_net.parameters(), 100) optimizer.step()
Gradient clipping was briefly discussed in one of the previous posts, but basically it stops the gradients of the parameters from exploding by normalizing if they get too big.
Lastly, we have the training loop. If GPUs are available we can utilize them to speed up training and complete a lot more episodes. Training RL agents is noisy so a small number of episodes is likely not enough to see good convergence for CartPole.
// Main episodic training loop for i_episode in range(1, num_episodes+1): if i_episode % training_period == 0: print(f"Episode {i_episode} of {num_episodes}") // Initialize the environment and get its state state, info = env.reset() state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) for t in count(): action = select_action(state) observation, reward, terminated, truncated, _ = env.step(action.item()) reward = torch.tensor([reward], device=device) done = terminated or truncated if terminated: next_state = None else: next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0) // Push the reward to memory memory.push(state, action, reward, next_state) // Move to the next state state = next_state // Perform one step of the optimization (on the online network) optimize_model() // Soft update of the target network's weights target_net_state_dict = target_net.state_dict() online_net_state_dict = online_net.state_dict() for key in online_net_state_dict: target_net_state_dict[key] = online_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU) target_net.load_state_dict(target_net_state_dict) if done: episode_durations.append(t + 1) plot_durations() break print('Complete') plot_durations(show_result=True) plt.ioff() plt.show()