PyTorch DQN Agent Walkthrough¶
If you have not read the blank version yet, start with:
This creates a clean learning path.
This page shows how the blank RL skeleton—act, remember, learn—maps onto a minimal DQN-style agent using PyTorch.
Core picture:
state → neural network → Q-values → action
The example below is educational and intentionally minimal. It is a readable DQN-style teaching example, not a production RL system.
What this agent does¶
A small
DQNnetwork predicts one scalar per action for a given state (Q-values).actmixes epsilon-greedy exploration with argmax on predicted Q-values.rememberstores transitions (here, a plain list).learnsamples a mini-batch, builds Bellman-style targets for the taken action, runs MSE loss, then backprop and an optimizer step.
You can read this file top to bottom once, then map each block to the section headings below.
Full minimal file¶
States are wrapped with an extra batch dimension so nn.Linear receives shape (batch, features)—a detail full projects handle in data loaders.
import random
import torch
import torch.nn as nn
import torch.optim as optim
class DQN(nn.Module):
def __init__(self, state_size, action_size):
super().__init__()
self.fc1 = nn.Linear(state_size, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, action_size)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
class Agent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.model = DQN(state_size, action_size)
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
self.criterion = nn.MSELoss()
self.memory = []
self.gamma = 0.99
self.epsilon = 1.0
self.epsilon_decay = 0.995
self.epsilon_min = 0.01
def act(self, state):
if random.random() < self.epsilon:
return random.randrange(self.action_size)
state_t = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
q_values = self.model(state_t).squeeze(0)
return torch.argmax(q_values).item()
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def learn(self, batch_size=32):
if len(self.memory) < batch_size:
return
batch = random.sample(self.memory, batch_size)
for state, action, reward, next_state, done in batch:
state_t = torch.FloatTensor(state).unsqueeze(0)
next_state_t = torch.FloatTensor(next_state).unsqueeze(0)
q_values = self.model(state_t).squeeze(0)
with torch.no_grad():
next_q = self.model(next_state_t).squeeze(0)
target = q_values.clone()
if done:
target[action] = reward
else:
target[action] = reward + self.gamma * torch.max(next_q)
loss = self.criterion(q_values, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
If your environment returns NumPy arrays, wrap them with torch.as_tensor(..., dtype=torch.float32) before the network sees them.
Imports¶
random— epsilon-greedy and replay sampling.torch/nn/optim— network, loss, optimizer.
DQN network¶
DQN subclasses nn.Module. Two linear layers with ReLU map state_size → 64 → action_size.
forward returns a vector of length action_size: one predicted value per discrete action.
Agent class¶
The agent holds:
Member |
Role |
|---|---|
|
Predicts Q-values |
|
Adam on |
|
|
|
List of transitions |
|
Discount factor for future reward |
|
Exploration schedule |
Action selection¶
With probability epsilon, pick a uniform random action. Otherwise:
Build a batch-of-one tensor, run
model, take argmax over Q-values.
torch.no_grad() avoids building a graph for inference-only forward passes.
Memory¶
remember appends (state, action, reward, next_state, done).
This toy uses an unbounded list. Production-style replay uses a fixed-size buffer and often overwrites old data.
Learning¶
If the learn() step feels abstract, the staged progression in Blank RL Agent Template shows how it grows—from simple score tallying in memory, through tabular Q-style updates, to neural batches with loss and an optimizer—in the same conceptual slot before you dive into specifics here.
learn requires at least batch_size entries, then uniformly samples transitions.
The loop updates per transition (not fully vectorized across the batch on GPU). That is easier to read; serious code batches tensors for throughput.
After each learn call, epsilon is decayed toward epsilon_min.
Bellman update¶
For each sample, the code clones the predicted Q-vector and overwrites one index—the action that was taken—with a target:
If
done: target is the immediatereward.Else:
reward + gamma * max_a' Q(s', a')using the same network fornext_statein this minimal snippet.
That one-step bootstrap is the core Bellman idea in Q-learning style updates. Stabilized DQN variants often use a separate target network; this example does not.
Backpropagation¶
loss = criterion(q_values, target) compares the full Q-vector to the clone with one edited component—effectively supervising the selected action’s value.
zero_grad → backward → optimizer.step applies gradients to fc1 and fc2 weights.
Important caveats¶
This is not a production RL system. The design is for readability, not competitive scores or research-grade runs.
Intentionally missing or simplified:
Target network (and periodic sync)
Bounded replay buffer and prioritized replay
Batched, vectorized GPU updates across the whole mini-batch
Seed control and reproducibility tooling
Reward curves, logging, and checkpointing
A full Gymnasium (or other) training loop in this listing
Systematic evaluation scripts and benchmark reporting
Double DQN, n-step returns, and other stabilizers
Solving a toy task or getting non-trivial reward on a small environment does not indicate general intelligence—it means the training loop and updates are functioning in that narrow setting.
Where this connects next¶
Blank RL Agent Template — same slots without PyTorch, to see the scaffold first.
RL Agent Skeleton — short hub that links these articles if you want a single place to start.
From here, a natural direction is integrating Agent into a Gymnasium loop, adding proper evaluation and logging, then layering improvements (target net, replay cap, batched losses) once the readable baseline is clear.