news 2026/5/12 1:36:41

用PyTorch从零实现REINFORCE算法:一个完整的离散与连续动作空间实战教程

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
用PyTorch从零实现REINFORCE算法:一个完整的离散与连续动作空间实战教程

用PyTorch从零实现REINFORCE算法:一个完整的离散与连续动作空间实战教程

强化学习领域近年来发展迅猛,其中策略梯度方法因其直接优化策略的特性备受关注。REINFORCE作为最基础的策略梯度算法,是理解更复杂方法的基石。本文将带你从零开始,用PyTorch实现REINFORCE算法,覆盖离散和连续动作空间两种场景。

1. 环境准备与基础概念

在开始编码前,我们需要配置开发环境并回顾关键概念。推荐使用Python 3.8+和PyTorch 1.10+版本,可以通过以下命令安装必要依赖:

pip install torch gym numpy matplotlib

REINFORCE算法的核心思想是通过蒙特卡洛采样来估计策略梯度。与基于值函数的方法不同,它直接参数化策略并沿着梯度方向更新参数以最大化期望回报。关键公式如下:

$$ \nabla_\theta J(\theta) = \mathbb{E}{\pi\theta}[\nabla_\theta \log \pi_\theta(a|s) G_t] $$

其中:

  • $\pi_\theta(a|s)$ 是参数化策略
  • $G_t$ 是从时刻t开始的累积回报
  • $\theta$ 是策略参数

提示:REINFORCE属于on-policy算法,意味着它使用当前策略生成的数据来更新该策略本身。

2. 离散动作空间实现:CartPole案例

我们首先以经典的CartPole环境为例,展示离散动作空间的实现。CartPole的状态空间包含4个连续变量,动作空间有2个离散选项(左/右)。

2.1 策略网络设计

策略网络将状态映射到动作概率分布。对于离散动作,通常使用softmax输出层:

class DiscretePolicy(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return F.softmax(x, dim=-1)

2.2 动作选择与轨迹收集

REINFORCE需要完整的episode轨迹来计算回报。我们实现一个采样函数:

def collect_episode(env, policy, max_steps=1000): states, actions, rewards, log_probs = [], [], [], [] state = env.reset() for _ in range(max_steps): state = torch.FloatTensor(state).unsqueeze(0) probs = policy(state) dist = Categorical(probs) action = dist.sample() next_state, reward, done, _ = env.step(action.item()) states.append(state) actions.append(action) rewards.append(reward) log_probs.append(dist.log_prob(action)) state = next_state if done: break return states, actions, rewards, log_probs

2.3 策略更新与训练循环

关键训练步骤包括计算折扣回报和策略梯度更新:

def train(policy, optimizer, episodes, gamma=0.99): for _ in range(episodes): # 收集轨迹 states, actions, rewards, log_probs = collect_episode(env, policy) # 计算折扣回报 returns = [] R = 0 for r in reversed(rewards): R = r + gamma * R returns.insert(0, R) # 归一化回报 returns = torch.tensor(returns) returns = (returns - returns.mean()) / (returns.std() + 1e-9) # 计算策略梯度 policy_loss = [] for log_prob, R in zip(log_probs, returns): policy_loss.append(-log_prob * R) # 参数更新 optimizer.zero_grad() sum(policy_loss).backward() optimizer.step()

3. 连续动作空间实现:Pendulum案例

连续动作空间(如Pendulum环境)的实现与离散情况有显著差异。我们使用高斯分布来表示策略。

3.1 连续策略网络设计

连续策略网络输出动作分布的均值和方差:

class ContinuousPolicy(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc_mean = nn.Linear(hidden_dim, output_dim) self.fc_std = nn.Linear(hidden_dim, output_dim) def forward(self, x): x = F.relu(self.fc1(x)) mean = self.fc_mean(x) std = F.softplus(self.fc_std(x)) + 1e-5 # 确保标准差为正 return torch.distributions.Normal(mean, std)

3.2 连续动作采样

动作采样现在从高斯分布中抽取:

def collect_continuous_episode(env, policy, max_steps=200): states, actions, rewards, log_probs = [], [], [], [] state = env.reset() for _ in range(max_steps): state = torch.FloatTensor(state).unsqueeze(0) dist = policy(state) action = dist.sample() log_prob = dist.log_prob(action).sum(dim=-1) next_state, reward, done, _ = env.step(action.detach().numpy()[0]) states.append(state) actions.append(action) rewards.append(reward) log_probs.append(log_prob) state = next_state if done: break return states, actions, rewards, log_probs

3.3 连续空间训练技巧

连续空间训练需要注意几个关键点:

  • 动作缩放:确保动作在环境允许范围内
  • 探索控制:初始标准差设置影响探索效率
  • 梯度稳定性:使用梯度裁剪防止爆炸
def train_continuous(policy, optimizer, episodes, gamma=0.99, max_grad_norm=0.5): for _ in range(episodes): states, _, rewards, log_probs = collect_continuous_episode(env, policy) # 计算折扣回报 returns = [] R = 0 for r in reversed(rewards): R = r + gamma * R returns.insert(0, R) returns = torch.tensor(returns) returns = (returns - returns.mean()) / (returns.std() + 1e-9) # 计算损失 policy_loss = [] for log_prob, R in zip(log_probs, returns): policy_loss.append(-log_prob * R) # 参数更新 optimizer.zero_grad() sum(policy_loss).backward() nn.utils.clip_grad_norm_(policy.parameters(), max_grad_norm) optimizer.step()

4. 高级技巧与性能优化

基础REINFORCE实现虽然简单,但存在高方差问题。以下是几种实用改进方法:

4.1 基线方法(Baseline)

引入状态相关的基线可以减少梯度估计的方差:

class ValueNetwork(nn.Module): def __init__(self, input_dim, hidden_dim): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, 1) def forward(self, x): x = F.relu(self.fc1(x)) return self.fc2(x) # 在训练中使用基线 advantage = returns - value_network(state).squeeze() policy_loss = -log_prob * advantage.detach()

4.2 熵正则化

添加熵项鼓励探索:

entropy = dist.entropy().mean() policy_loss = -log_prob * advantage.detach() - 0.01 * entropy

4.3 并行环境采样

使用多个环境并行采样加速训练:

from multiprocessing import Process, Queue def worker(env_name, policy, queue, max_steps): env = gym.make(env_name) while True: data = collect_episode(env, policy, max_steps) queue.put(data)

5. 调试与可视化

有效的调试技巧可以大幅提升开发效率:

5.1 关键指标监控

记录以下指标有助于分析训练过程:

指标含义期望趋势
回报单回合总奖励逐渐上升
方差回报波动程度逐渐降低
策略随机性初期高后期低

5.2 可视化工具

使用Matplotlib实时监控训练:

import matplotlib.pyplot as plt def plot_learning_curve(rewards, window=100): plt.figure(figsize=(10,5)) plt.plot(rewards, alpha=0.3, label='Raw') plt.plot(np.convolve(rewards, np.ones(window)/window, mode='valid'), label=f'Moving Avg ({window} eps)') plt.xlabel('Episode') plt.ylabel('Total Reward') plt.legend() plt.show()

5.3 常见问题排查

遇到训练失败时,检查以下方面:

  • 学习率是否合适(尝试1e-4到1e-2)
  • 折扣因子gamma是否合理(0.9-0.99)
  • 梯度是否爆炸/消失(添加裁剪)
  • 探索是否充分(调整初始熵)

在实现过程中,我发现连续动作空间的探索尤其关键。初期适当增大动作方差有助于找到有希望的策略区域,之后可以逐渐降低方差以提高稳定性。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/12 1:35:34

液态硅胶注塑加工供应商推荐

随着液态硅胶(LSR)在医疗、母婴、电子、汽车等多个领域的广泛应用,选择一个可靠的液态硅胶注塑加工供应商变得至关重要。作为天沅智能制造科技有限公司(简称TYM),我们不仅深耕于液态硅胶注射成型机械的设计…

作者头像 李华
网站建设 2026/5/12 1:32:02

从 Fork 到第一个 PR:开源新手最完整的一次实战

从 Fork 到第一个 PR:开源新手最完整的一次实战 很多开源新手第一次真正想“做点贡献”,通常会卡在两个地方。 第一是不会找入口,不知道从哪里改起;第二是不会走流程,不知道 Fork、Branch、Commit、Push、PR 之间到底…

作者头像 李华
网站建设 2026/5/12 1:31:33

AI代理工具化新范式:基于MCP协议的模块化连接器实践

1. 项目概述:一个面向AI代理的模块化连接器最近在折腾AI应用开发,特别是围绕AI Agent(智能体)的生态构建时,发现一个挺普遍的问题:如何让这些Agent高效、安全地连接和使用外部工具与服务?无论是…

作者头像 李华
网站建设 2026/5/12 1:28:51

游戏开服即“炸服“?CC攻击成游戏行业隐形杀手

2026年3月,一款备受期待的国产3A大作开启全球公测。开服当天,玩家热情高涨,官方预计同时在线将突破50万。然而就在开服后第47分钟,游戏服务器突然响应迟缓,紧接着大批玩家被强制下线。官方紧急排查后发现,不…

作者头像 李华
网站建设 2026/5/12 1:25:38

51单片机的独立按键和矩阵键盘

引言在嵌入式系统的人机交互中,按键是最基础、最直接的输入设备。从简单的功能切换,到复杂的设备控制,按键的稳定可靠检测是系统功能的基石。本教程将深入剖析8051单片机平台上两种核心的键盘输入方案:独立按键与矩阵键盘。教程概…

作者头像 李华