news 2026/6/25 16:34:16

强化学习REINFORCE求最优策略的代码实现

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
强化学习REINFORCE求最优策略的代码实现

理论基础:

注意:

1. 策略的输出要加对数,因此net输出必须softmax,将输出限制为正数。

2. 这里选择action不是greedy地选择最优action,而是按照概率分布选择action->exploration。

3. 策略更新使用的是梯度上升,因此loss取负。

4. 这里使用step一步步收集episode,而不是像之前一样直接使用generate_episode函数生成完成的path,是因为在generate_episode中是使用greedy的方法选择action的(见2)。

5. num_episodes大一些。

代码可运行:

import numpy as np import torch from torch import nn from env import GridWorldEnv from utils import drow_policy ''' policy gradient by Monte Carlo ''' class Reinforce(object): def __init__(self, env: GridWorldEnv, gamma=0.9, lr=1e-2): ''' :param env: :param gamma: discount rate :param lr: learning rate of optimizer ''' self.env = env self.action_space_size = self.env.num_actions self.state_space_size = self.env.num_states self.gamma = gamma self.net = nn.Sequential( nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, self.action_space_size) ) self.policy = np.zeros((self.state_space_size, self.action_space_size)) self.q_value = np.zeros((self.state_space_size, self.action_space_size)) self.optimizer = torch.optim.Adam(self.net.parameters(), lr=lr) def decode_state(self, state): ''' :param state: int :return: 归一化后的元组 ''' i = state // self.env.size j = state % self.env.size return torch.tensor((i / (self.env.size - 1), j / (self.env.size - 1)), dtype=torch.float32) def solve(self, num_episodes): for _ in range(num_episodes): state_int = self.env.reset() state = self.decode_state(state_int) done = False episode = [] # [[state_tensor,reward,done]...[...]] while not done: logits = self.net(state) action_probs = torch.softmax(logits, dim=0) action_dist = torch.distributions.Categorical(action_probs) # 按分布采样 action = action_dist.sample().item() next_state, reward, done = self.env.step(state_int, action) episode.append((state, action, reward)) state_int = next_state state = self.decode_state(next_state) # value update returns = [] G = 0 for _, _, reward in reversed(episode): G = reward + self.gamma * G returns.insert(0, G) # policy update self.optimizer.zero_grad() loss = 0 for (state, action, _), G in zip(episode, returns): logits = self.net(state) action_probs = torch.softmax(logits, dim=0) action_dist = torch.distributions.Categorical(action_probs) log_prob = action_dist.log_prob(torch.tensor(action)) # In Π(a_t|s_t, θ) loss -= log_prob * G # 负号是因为最小化 loss->最大化 J(θ),梯度上升更新参数 loss.backward() self.optimizer.step() def get_policy(self): for s in range(self.state_space_size): a = np.argmax(self.q_value[s]) self.policy[s, a] = 1 return self.policy def get_qvalues(self): for s in range(self.state_space_size): s_t = self.decode_state(s) logits = self.net(s_t) action_probs = torch.softmax(logits, dim=0) self.q_value[s,:] = action_probs.detach().numpy() # q_value是numpy类型,action_probs是tensor,必须转换 return self.q_value if __name__ == '__main__': env = GridWorldEnv( size=5, forbidden=[(1, 2), (3, 3)], terminal=[(4, 4)], r_boundary=-1, r_other=-0.04, r_terminal=1, r_forbidden=-1, r_stay=-0.1 ) vi = Reinforce(env=env) vi.solve(num_episodes=200) print("\n state value: ") print(vi.get_qvalues()) drow_policy(vi.get_policy(), env)

运行结果:

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/23 2:01:50

NVIDIA Project DIGITS:技术架构解析与行业解决方案全景

一、革命性技术架构深度解析 1. GB10超级芯片的异构创新 Project DIGITS的核心是NVIDIA GB10 Grace Blackwell超级芯片,这是一款真正的片上系统(SoC),通过三大突破性设计实现了桌面级Petaflop算力:NVLink-C2C芯片级互连:传统CPU与…

作者头像 李华
网站建设 2026/6/21 20:45:42

从适配到共建:密瓜智能 HAMi × 沐曦 GPU 完成兼容互认

作为一个活跃的开源项目,HAMi 由来自 15 国家、350 贡献者共同维护,已被 200 企业与机构在实际生产环境中采纳,具备良好的可扩展性与支持保障。产品兼容互认 近日,密瓜智能 与 沐曦集成电路(上海)股份有限公…

作者头像 李华
网站建设 2026/6/25 13:34:01

基于springboot二手车交易市场管理系统

基于Spring Boot的二手车交易市场管理系统是一个功能全面、用户友好、安全可靠的在线二手车交易平台。以下是对该系统的详细介绍: 一、系统架构与技术栈 后端:采用Spring Boot框架作为后端开发工具,负责处理业务逻辑,如车辆信息…

作者头像 李华
网站建设 2026/6/24 16:53:20

Python和PHP学哪个比较好?

Python和PHP的选择,核心取决于你的学习目标和应用场景。PHP是老牌Web开发语言,轻量高效,适配中小型网站快速搭建;Python则是全能型语言,覆盖Web、数据分析、AI 等多领域,那么Python和PHP学哪个比较好?详细内容请看下文…

作者头像 李华
网站建设 2026/6/23 14:27:35

Python大数据使用Vue.js构建的大数据分析与可视化系统_m1sf2x1m_c008

文章目录系统截图项目简介大数据系统开发流程主要运用技术介绍爬虫核心代码展示结论源码文档获取定制开发/同行可拿货,招校园代理 :文章底部获取博主联系方式!系统截图 Python大数据使用Vue.js构建的大数据分析与可视化系统_m1sf2x1m_c008 项目简…

作者头像 李华