摘要
强化学习是机器学习的一个重要分支,它让智能体通过与环境交互来学习策略,以最大化长期奖励。本文将介绍如何使用PyTorch实现两种经典的强化学习算法——Deep Q-Network (DQN) 和 Actor-Critic Algorithm with Asynchronous Advantage (A3C)。我们将从环境搭建开始,逐步实现算法的核心部分,并给出完整的代码示例。
1. 引言
强化学习(Reinforcement Learning, RL)是一种允许智能体通过与环境互动来学习如何做出决策的方法。在这个过程中,智能体会采取行动以尝试最大化累积奖励。近年来,随着深度学习的发展,结合深度神经网络的强化学习算法在许多领域取得了突破性的成果。
2. 环境搭建
为了进行强化学习实验,我们需要一个模拟环境。OpenAI Gym是一个广泛使用的开源库,提供了大量的环境供研究人员使用。
pip install gym
pip install torch
3. DQN 实现
DQN 是 Deep Q-Learning 的一种实现,它使用卷积神经网络近似Q函数,并通过经验回放和目标网络稳定训练过程。
3.1 DQN 算法概述
- 经验回放:存储过去的经验,随机抽取一批次经验进行学习,减少数据相关性。
- 目标网络:定期更新的目标网络用于预测下一个状态的价值,提高训练稳定性。
3.2 环境初始化
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
env = gym.make('CartPole-v1')
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
3.3 网络定义
class DQN(nn.Module):
def __init__(self, input_dim, output_dim):
super(DQN, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, output_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
3.4 训练循环
class DQNAgent:
def __init__(self, input_dim, output_dim, learning_rate=0.001, gamma=0.99, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995):
self.policy_net = DQN(input_dim, output_dim)
self.target_net = DQN(input_dim, output_dim)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
self.gamma = gamma
self.epsilon = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.replay_buffer = deque(maxlen=10000)
def select_action(self, state):
if random.random() > self.epsilon:
with torch.no_grad():
return self.policy_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor([[random.randrange(output_dim)]], dtype=torch.long)
def optimize_model(self):
if len(self.replay_buffer) < 64:
return
transitions = random.sample(self.replay_buffer, 64)
batch = Transition(*zip(*transitions))
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
batch.next_state)), dtype=torch.bool)
non_final_next_states = torch.cat([s for s in batch.next_state
if s is not None])
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
state_action_values = self.policy_net(state_batch).gather(1, action_batch)
next_state_values = torch.zeros(64)
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()
expected_state_action_values = (next_state_values * self.gamma) + reward_batch
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
self.optimizer.step()
def update_epsilon(self):
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
def update_target_network(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
agent = DQNAgent(input_dim, output_dim)
num_episodes = 1000
for i_episode in range(num_episodes):
state = env.reset()
state = torch.from_numpy(state).float().unsqueeze(0)
for t in range(10000):
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action.item())
reward = torch.tensor([reward], dtype=torch.float32)
next_state = torch.from_numpy(next_state).float().unsqueeze(0)
if not done:
next_state = next_state
else:
next_state = None
agent.replay_buffer.append((state, action, next_state, reward))
state = next_state
agent.optimize_model()
agent.update_epsilon()
if done:
break
if i_episode % 10 == 0:
agent.update_target_network()
print(f"Episode {i_episode} completed out of {num_episodes} episodes")
4. A3C 实现
A3C 是一种异步的 Actor-Critic 方法,它使用多个代理同时收集经验,然后更新一个全局的模型。
4.1 A3C 算法概述
- Actor-Critic:使用两个网络,一个用于预测动作(Actor),另一个用于评估状态值(Critic)。
- 异步更新:多个代理并行与环境交互,并周期性地更新全局模型。
4.2 环境初始化
import multiprocessing
from threading import Thread
class Environment:
def __init__(self, env_name, seed):
self.env = gym.make(env_name)
self.env.seed(seed)
def step(self, action):
return self.env.step(action)
def reset(self):
return self.env.reset()
def render(self):
self.env.render()
4.3 全局模型定义
class GlobalModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(GlobalModel, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc_pi = nn.Linear(128, output_dim)
self.fc_v = nn.Linear(128, 1)
def forward(self, x):
x = F.relu(self.fc1(x))
pi = self.fc_pi(x)
v = self.fc_v(x)
return F.softmax(pi, dim=-1), v
4.4 工作线程定义
class Worker(Thread):
def __init__(self, global_model, optimizer, env_name, seed, gamma=0.99, max_steps=20):
super(Worker, self).__init__()
self.global_model = global_model
self.optimizer = optimizer
self.env = Environment(env_name, seed)
self.gamma = gamma
self.max_steps = max_steps
self.local_model = GlobalModel(input_dim, output_dim)
self.local_model.load_state_dict(global_model.state_dict())
def run(self):
state = torch.from_numpy(self.env.reset()).float().unsqueeze(0)
while True:
log_probs = []
values = []
rewards = []
for _ in range(self.max_steps):
policy, value = self.local_model(state)
action = policy.multinomial(num_samples=1).data[0]
next_state, reward, done, _ = self.env.step(action.item())
next_state = torch.from_numpy(next_state).float().unsqueeze(0)
rewards.append(reward)
values.append(value)
log_prob = torch.log(policy[action])
log_probs.append(log_prob)
if done:
break
state = next_state
R = torch.zeros(1, 1)
if not done:
_, R = self.local_model(next_state)
loss = 0
for i in reversed(range(len(rewards))):
R = self.gamma * R + rewards[i]
advantage = R - values[i]
loss = loss + 0.5 * advantage.pow(2) - log_probs[i] * advantage - 0.01 * log_probs[i]
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.local_model.parameters(), 40)
for local_param, global_param in zip(self.local_model.parameters(), self.global_model.parameters()):
global_param._grad = local_param.grad
self.optimizer.step()
4.5 主程序
def main():
global_model = GlobalModel(input_dim, output_dim)
optimizer = optim.Adam(global_model.parameters(), lr=0.0001)
workers = []
for i in range(multiprocessing.cpu_count()):
worker = Worker(global_model, optimizer, 'CartPole-v1', i)
worker.start()
workers.append(worker)
for worker in workers:
worker.join()
if __name__ == '__main__':
main()
5. 结论
本文介绍了如何使用PyTorch实现两种经典的强化学习算法——DQN 和 A3C。通过这些例子,我们可以看到PyTorch的强大之处在于其灵活性和易于实现复杂的神经网络结构。希望这些示例能够帮助你更深入地理解强化学习的基本原理和实践方法。