news 2025/12/18 14:37:56

人类反馈强化学习(RLHF) 从强化学习架构到监督微调

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
人类反馈强化学习(RLHF) 从强化学习架构到监督微调

人类反馈强化学习(RLHF) 从架构到监督微调

关于强化学习(reinforcement learning from human feedback)架构、演员-评论家架构、近端策略优化(PPO)及DeepSpeed Chat的RLHF三阶段训练流程,并附代码实操与详细注释。

一、强化学习架构

强化学习是机器学习技术,旨在训练智能体(agent)在特定环境(environment)中通过交互决策,最大化奖励(reward)信号。

核心要素

  1. 智能体:在环境中执行行动的个体/软件。
  2. 环境:智能体所处的场景,对行动做出反应并给出奖惩。
  3. 状态(state):环境在某一时刻的描述。
  4. 动作(action):智能体在特定状态下执行的操作。
  5. 奖赏(reward):环境给智能体的反馈(正/负数值)。
  6. 策略(policy):智能体的行为函数,决定状态下的行动选择。
  7. 值函数(value function):预测状态/行动后可获得的预期回报。

常见算法

  • 基础算法:Q-Learning、SARSA、演员-评论家(Actor-Critic)。
  • 深度强化学习算法:Deep Q-Network(DQN)、深度确定性策略梯度(DDPG)。

比如下图展现了人类强化学习的核心交互逻辑:

其中:
人类反馈训练奖励模型,对智能体的动作进行评分,排序等偏好标注,让奖励模型学会人类的评价标准
每次智能体从环境获得预测,做出动作,环境把动作传递给智能奖励模型,结合人类的反馈训练输出奖励,智能体再结合奖励调整策略,再次与环境交互,这样周而复始。
很好理解吧,你会学了喵?

二、演员-评论家架构

演员-评论家(Actor-Critic)是结合值函数估计与策略梯度优化的强化学习框架,包含:

1、核心模块

  1. 演员(Actor)
  • 策略模型,根据当前状态输出动作的概率分布。
  • 通过策略梯度调整参数,最大化预期累积奖励。
  1. 评论家(Critic)
  • 价值函数估计器,评估演员动作的价值。
  • 计算优势函数(Advantage Function),指导演员更新策略。
  1. 结合与交互
  • 演员执行动作→环境反馈奖励与新状态→评论家更新价值估计→演员根据优势函数更新策略。
  • 优势:值函数辅助理解环境,策略梯度高效探索,二者独立训练提升效率。

    如图,Actor-Critic原本是两个独立网络,分别负责策略输出(Actor)和价值评估(Critic),而图中通过共享特征提取网络(绿色模块),让两者复用底层的状态特征,仅在顶层分支做差异化处理,是该架构的关键优化手段。

2. 各模块作用

  • 输入 s :代表智能体所处的环境状态(state),是整个网络的输入源。

  • 绿色共享网络:对输入状态 s 进行特征提取,输出统一的高维特征表示,供Actor和Critic分支使用。

  • 蓝色Actor网络:基于共享特征,输出动作概率分布(如图中的 left / right / fire ),决定智能体该执行的动作。

  • 橙色Critic网络:基于共享特征,输出标量值(scalar),代表当前状态/动作的价值评估(即预期奖励),用于指导Actor更新策略。

3. 参数共享的优势

  • 减少参数量:避免重复提取状态特征,降低模型存储和计算成本。

  • 提升特征一致性:Actor和Critic使用相同的底层特征,让策略决策与价值评估的依据保持统一,训练更稳定。

  • 加快收敛速度:共享特征能让两个网络的学习过程相互促进,减少训练迭代次数。

三、近端策略优化架构

近端策略优化(PPO)是演员-评论家方法的增强版,解决策略更新不稳定、高方差等问题,核心改进点:

  1. 策略更新机制:引入 clip 机制,限制策略更新幅度,避免剧烈变化。
  2. 重要性采样:重用旧策略数据,减少数据需求并加快学习速度。
  3. 优化目标:加入基于KL散度(Kullback-Leibler Divergence)的惩罚项,保持策略更新平稳。

图中是一个ppo,演员评论家等混合架构
让我们看看这是如何运作的喵:

  1. 准备“基础原料”——SFT模型与初始回答
    用户输入问题 x 后,策略模型先像“初稿写手”一样生成回答,接着系统会把“问题+回答”的问答对切成“状态、动作、奖励”这些小模块(就像把一篇文章拆成段落分析)。同时,SFT模型是提前用人类优质回答微调好的“范本”,后续会用KL散度盯着策略模型,防止它生成的内容和范本差太远(避免模型乱回答)。

  2. 打分与评估——奖励模型+评论模型
    奖励模型像“人类评委”,根据人类反馈的标准给策略模型的回答打分,打出的分数就是“奖励信号”;
    评论模型像“专业分析师”,评估当前回答的“价值”(比如这个回答在当前语境下好不好),给个量化的价值估计。

  3. 算清“好坏差距”——GAE的作用
    把奖励模型的分数、评论模型的价值估计都交给GAE(广义优势估计),它就像“数据分析师”,会算出两个关键结果:一是“优势值”(这个回答比平均水平好多少),二是“总回报”(长期来看这个回答能带来多少收益),这些数据都会存进经验缓冲(像个数据库)里。

  4. 优化“初稿写手”——策略模型的升级
    策略模型从经验缓冲里拿数据,结合预训练数据(保证模型不忘掉原有知识)开始优化:
    用PPO-clip损失限制修改幅度,避免改得太猛导致模型“学歪”;用语言模型损失保证回答符合语言逻辑,不出现病句、胡话。
    同时,评论模型也会用均方差误差损失自我修正,让后续的价值评估更准。

整个过程的核心就是让策略模型在“人类评委(奖励模型)”的打分、“分析师(评论模型)”的评估、“范本(SFT模型)”的约束下,不断修改回答,最终生成既符合人类偏好,又有知识准确性的内容。

你学会了喵~~

四、DeepSpeed Chat

DeepSpeed Chat是微软开发的深度学习框架,复刻OpenAI InstructGPT的RLHF三阶段训练策略,实现ChatGPT风格模型的高效训练。

RLHF三阶段训练流程

  1. 监督微调(SFT)
  • 用人类问答数据微调预训练语言模型(如GPT、Llama),训练演员模型,使其能根据提示生成文本响应。
  • 目标:让模型理解并响应各类查询,适配对话场景。
  1. 奖励模型训练(Reward Model Training)
  • 用小型模型学习评估生成文本的质量,从“好/坏”示例中区分高低质量回答,输出奖励信号。
  • 数据集:人工标注的 chosen dataset (高质量)和 reject dataset (低质量)。
  1. 策略优化(Policy Optimization)
  • 奖励模型对演员模型输出评分,得到优势函数近似值(奖励信号代理)。
  • 演员模型通过策略梯度更新调整权重,生成更易获得高奖励的响应。
  • 引入参考模型与KL散度约束,避免演员模型输出偏离参考模型过远。

参考资源

  • DeepSpeed Chat源码:https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeedChat
  • 推荐基础模型:OPT-1.3B(参考/演员模型)、OPT-350m(奖励/评论家模型);中文适配:Chinese-Llama-2-1.3b(https://huggingface.co/hfl/chinese-llama-2-1.3b)
  • 国内镜像源:https://hf-mirror.com/hfl/chinese-llama-2-1.3b
  1. 开源RLHF数据集

从DeepSpeed Chat源码中可获取的开源RLHF数据集:

Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gpt-pairwise yingxie/rlhf-reward-datasets openai/webgpt_comparisons stanfordnlp/SHPshu pdvuy/sharegpt_alpaca_oa_vicuna_format wangrui6/Zhihu-KOL Cohere/miracl-zh-queries-22-12 Hello-SimpleAI/HC3-Chinese Cohere/miracl-ja-queries-22-12 lmqg/qag_jaquad lmqg/qag_jaquad
  • Dahoas/rm-static:train 76.3k行,test 5.1k行;
  • Dahoas/full-hh-rlhf:train 112k行,test 12.5k行。

五、训练数据读取(含代码实操)

1、定义统一参数

通过 argparse 定义训练所需的基础参数,用于指定数据集路径、训练配置等。

下面代码的模型输出目录等路径可以写你自己的喵

importargparse# 创建参数解析器args=argparse.ArgumentParser()# 分布式训练的本地rank,单机单卡设为0args.local_rank=0# 数据集路径args.data_path=["dahoas/rm-static"]# 数据划分比例(训练、验证、测试)args.data_split="2,4,8"# 数据输出路径args.data_output_path='/tmp/data_files'# 模型输出目录args.output_dir="./output_step3_llama"# 随机种子args.seed=1234# 最大序列长度args.max_seq_len=256

2、读取Dahoas/rm-static数据集

从 dschat 工具库导入数据集读取方法,加载指定数据集并查看结构。

fromdschat.utils.dataimportraw_datasets# 使用deepspeedchat的方法# 数据集名称dataset_name="dahoas/rm-static"# 加载原始数据集raw_dataset=raw_datasets.DahoasRmstaticDataset(args.output_dir,args.seed,args.local_rank,dataset_name)# DahoasRmstaticDataset :针对Dahoas/rm-static数据集的专用读取类,处理数据加载与划分。# 获取训练集数据train_dataset=raw_dataset.get_train_data()# get_train_data() :提取训练子集,数据集包含 prompt (提示)、 response (响应)、 chosen (优选回复)、 rejected (劣选回复)四个核心字段。# 打印训练集基本信息print(train_dataset)

输出结果大概是这样喵

Dataset({ features: ['prompt', 'response', 'chosen', 'rejected'], num_rows: 76256 })

试着打印第一条数据

# 打印第一条数据的prompt+chosen拼接结果fori,tmp_datainenumerate(train_dataset):print(raw_dataset.get_prompt_and_chosen(tmp_data))# get_prompt_and_chosen() :将 prompt 与 chosen 字段拼接,用于后续SFT训练break# 只打印第一条

3、SFT监督微调数据

SFT(监督微调)需将 prompt 与 chosen 字段合并,生成模型训练的标注数据,并通过分词器转为张量。

加载分词器

使用LlamaTokenizer对文本进行分词,设置pad_token为eos_token(Llama模型无默认pad_token)。

fromtransformersimportLlamaTokenizer# 模型路径/名称model_name_or_path="ci/llama2/llama-2-7b-hf"# 加载分词器tokenizer=AutoTokenizer.from_pretrained(model_name_or_path,trust_remote_code=True)# 设置pad_token为eos_tokentokenizer.pad_token=tokenizer.eos_token

生成SFT训练数据

将 prompt+chosen 拼接后的文本分词,生成模型可接收的 input_ids 和 attention_mask 。

# 对话结束标记end_of_conversation_token="<<|endoftext|>"# 初始化数据集列表prompt_dataset=[]chosen_dataset=[]reject_dataset=[]# 遍历训练集fori,tmp_datainenumerate(train_dataset):# 拼接prompt+chosen生成优选回复文本chosen_sentence=raw_dataset.get_prompt_and_chosen(tmp_data)ifchosen_sentenceisnotNone:# 添加对话结束标记chosen_sentence+=end_of_conversation_token# 分词并转为PyTorch张量chosen_token=tokenizer(chosen_sentence,max_length=args.max_seq_len,# 最大序列长度padding="max_length",# 不足长度时补pad_tokentruncation=True,# 超过长度时截断return_tensors="pt"# 返回PyTorch张量)# 去除batch维度(squeeze(0))chosen_token["input_ids"]=chosen_token["input_ids"].squeeze(0)chosen_token["attention_mask"]=chosen_token["attention_mask"].squeeze(0)# 添加到数据集chosen_dataset.append(chosen_token)# 打印第一条数据的input_idsprint(chosen_dataset[0]["input_ids"])

4、奖励模型微调数据

奖励模型训练需要正样本(chosen)和负样本(rejected),分别由 prompt+chosen 和 prompt+rejected 拼接生成。

# 对话结束标记end_of_conversation_token="<<|endoftext|>"# 初始化数据集列表prompt_dataset=[]chosen_dataset=[]reject_dataset=[]# 遍历训练集fori,tmp_datainenumerate(train_dataset):# 获取正样本文本(prompt+chosen)chosen_sentence=raw_dataset.get_prompt_and_chosen(tmp_data)# 获取负样本文本(prompt+rejected)reject_sentence=raw_dataset.get_prompt_and_rejected(tmp_data)ifchosen_sentenceisnotNoneandreject_sentenceisnotNone:# 添加对话结束标记chosen_sentence+=end_of_conversation_token reject_sentence+=end_of_conversation_token# 正样本分词chosen_token=tokenizer(chosen_sentence,max_length=args.max_seq_len,padding="max_length",truncation=True,return_tensors="pt")# 负样本分词reject_token=tokenizer(reject_sentence,max_length=args.max_seq_len,padding="max_length",truncation=True,return_tensors="pt")# 去除batch维度并添加到数据集chosen_token["input_ids"]=chosen_token["input_ids"].squeeze(0)chosen_token["attention_mask"]=chosen_token["attention_mask"].squeeze(0)chosen_dataset.append(chosen_token)reject_token["input_ids"]=reject_token["input_ids"].squeeze(0)reject_token["attention_mask"]=reject_token["attention_mask"].squeeze(0)reject_dataset.append(reject_token)break# 仅处理第一条数据用于演示# 打印正/负样本文本和张量print(f"chosen_sentence={chosen_sentence}")print(f"chosen_dataset[0]={chosen_dataset[0]}")print(f"reject_sentence={reject_sentence}")print(f"reject_dataset[0]={reject_dataset[0]}")

代码注释:

  • 奖励模型通过对比正/负样本的输出得分,学习对优质回复的偏好。
  • 正/负样本的分词逻辑与SFT一致,保证输入格式统一。

5、RLHF微调数据

RLHF微调阶段仅使用 prompt 字段作为输入,对分词后的 input_ids 进行反转,并过滤超长序列。

filtered=0# 统计过滤的超长数据量prompt_dataset=[]fori,tmp_datainenumerate(train_dataset):# 获取prompt文本prompt=raw_dataset.get_prompt(tmp_data)ifpromptisnotNone:# 对prompt分词prompt_token=tokenizer(prompt,max_length=args.max_seq_len,padding="max_length",truncation=True,return_tensors="pt")# 打印原始input_ids形状print(f"prompt_token['input_ids'].size()={prompt_token['input_ids'].size()}")# 若序列长度超过最大限制,反转并截断ifprompt_token["input_ids"].size()[-1]>args.max_seq_len:# 反转input_ids和attention_maskforkey_wordin["input_ids","attention_mask"]:prompt_token[key_word]=prompt_token[key_word].squeeze(0).flip(0)# 反转张量# 截断至max_seq_lenprompt_token[key_word]=prompt_token[key_word][:args.max_seq_len].unsqueeze(0)# 添加到数据集prompt_dataset.append(prompt_token)else:filtered+=1# 过滤空数据break# 仅处理第一条数据用于演示# 打印处理后的第一条数据input_idsprint(f"prompt_dataset[0]['input_ids']={prompt_dataset[0]['input_ids']}")

六、监督微调(SFT):RLHF训练的核心环节

RLHF中的监督微调(SFT)与普通监督微调的核心差异在于训练数据集,仅采用RLHF数据集的 prompt 和 chosen 字段构建训练样本。本文基于DeepSpeed Chat框架,整理SFT的参数配置、数据加载与模型训练全流程,并添加详细注释。

1、设置通用训练参数

通过 argparse 定义SFT训练的基础配置,包括数据集路径、训练超参、硬件配置等。

路径一样可以换成自己的喵

importargparse# 创建参数解析器,标注为RLHF训练第三步(SFT)args=argparse.ArgumentParser(description="(Step 3) RLHF training arguments")# 分布式训练参数:单机单卡设为0args.local_rank=0# 数据集路径:Dahoas/rm-static数据集的本地挂载路径args.data_path=["/mnt/Dahoas/rm-static"]# 数据划分比例:训练/验证/测试 = 2:4:4args.data_split="2,4,4"# 数据预处理后的输出路径args.data_output_path='/tmp/data_files'# 最大序列长度:统一文本输入的长度为256args.max_seq_len=256# 训练轮数:仅训练1个epoch用于演示args.num_train_epochs=1# 随机种子:保证实验可复现args.seed=1234# 单设备训练批次大小:设为1便于观察数据格式args.per_device_train_batch_size=1# 单设备验证批次大小args.per_device_eval_batch_size=1# 是否打印损失值args.print_loss=True

2、SFT模型训练流程

加载预训练大模型与分词器,构建SFT数据集并执行单轮训练,重点验证数据格式与模型输出。

  1. 导入依赖与加载模型/分词器
# 导入因果语言模型和分词器fromtransformersimportAutoModelForCausalLM,AutoTokenizer# 导入DeepSpeed Chat的数据加载工具fromdschat.utils.dataimportDataLoaderfromdschat.utils.data.data_utilsimportcreate_prompt_dataset# 预训练模型路径:中文Llama-2-7B模型的本地路径model_name_or_path='/mnt/chinese-llama-2-7b'# 加载分词器tokenizer=AutoTokenizer.from_pretrained(model_name_or_path,trust_remote_code=True# 信任远程代码(针对自定义模型))# 加载因果语言模型,自动分配设备并启用模型卸载(节省显存)model=AutoModelForCausalLM.from_pretrained(model_name_or_path,device_map="auto",# 自动将模型层分配到GPU/CPUtrust_remote_code=True,offload_folder='/tmp'# 模型卸载的临时文件夹).eval()# 先设为评估模式,训练时再切换为train()

3、构建SFT训练/验证数据集

# 训练阶段标记:1代表SFT阶段train_phase=1# 创建SFT数据集(仅使用prompt和chosen字段)train_dataset,eval_dataset=create_prompt_dataset(args.local_rank,# 分布式rankargs.data_path,# 数据集路径args.data_split,# 数据划分比例args.data_output_path,# 数据输出路径train_phase,# 训练阶段args.seed,# 随机种子tokenizer,# 分词器args.max_seq_len,# 最大序列长度sft_only_data_path=[]# 仅SFT的数据集路径(此处为空))# 创建训练数据加载器train_dataloader=DataLoader(train_dataset,batch_size=args.per_device_train_batch_size# 批次大小)

4、执行SFT训练(单轮演示)

# 遍历训练轮数(仅1轮)forepochinrange(args.num_train_epochs):print(f"Beginning of Epoch{epoch+1}/{args.num_train_epochs}, Total Micro Batches:{len(train_dataloader)}")model.train()# 切换为训练模式# 遍历训练批次forstep,batchinenumerate(train_dataloader):# 模型前向传播:输入批次数据,关闭缓存以节省显存outputs=model(**batch,use_cache=False)loss=outputs.loss# 获取训练损失# 打印损失值(若开启)ifargs.print_loss:print(f"Epoch:{epoch}, Step:{step}, Rank: loss ={loss}")# 反向传播计算梯度loss.backward(loss)break# 仅训练1个批次用于演示,实际训练需删除

训练过程中模型会输出损失值和logits(模型输出的原始预测分数),输出大概如下喵:

# 损失值输出outputs.loss# tensor(5.4209, grad_fn=<ToCopyBackward0>)# logits输出(模型对每个token的预测分数)outputs.logits# tensor([[[-1.4668, -3.6502, -0.1360, ..., -1.3301, -0.2463, -1.1404],# [-1.4668, -3.6468, -0.1329, ..., -1.3306, -0.2472, -1.1411],# ...,# [-1.2088, ..., -1.9516, -1.8949]]])

其中:

  • loss :当前批次的训练损失,数值越低表示模型预测越接近真实标签( chosen 字段)。
  • logits :模型对每个位置token的预测概率分布(未经过SoftMax),维度为 [batch_size, seq_len, vocab_size] 。

如果你觉得文章有意思的话,别忘了赏点小鱼干喵

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2025/12/11 21:57:41

微传记【7】——程序员圣经之父:高德纳(Donald Knuth)

微传记【7】——程序员圣经之父&#xff1a;高德纳&#xff08;Donald Knuth&#xff09; 他花了60年写一本永远写不完的书&#xff0c;顺便发明了TeX和文学编程 1962年&#xff0c;24岁的高德纳接到加州理工学院出版社的电话&#xff1a; “年轻人&#xff0c;你愿不愿意给我们…

作者头像 李华
网站建设 2025/12/17 20:27:14

GPT-5.2:是创作的未来,还是创作者的终结?

创作的“命运”与AI的挑战 随着人工智能的飞速发展&#xff0c;我们已经开始看到AI技术在许多行业中的强大影响力。在内容创作领域&#xff0c;GPT-5.2等高级语言模型的出现&#xff0c;不仅提高了创作效率&#xff0c;还在某种程度上挑战了创作者的“存在意义”。AI可以自动生…

作者头像 李华
网站建设 2025/12/17 15:18:19

AI测试、大模型测试(五)AI测试工具有哪些

目录 一、AI测试工具分类 1.1 智能测试生成工具 1.2 智能测试执行与优化工具 1.3 专项领域AI测试工具 二、AI测试工具展望 一、AI测试工具分类 AI测试工具&#xff0c;可以按功能、应用场景、技术实现等等进行分类。 1.1 智能测试生成工具 (1) 什么是智能测试生成…

作者头像 李华
网站建设 2025/12/11 21:53:41

LightRAG 系列8:最佳实践与避坑指南

图片来源网络&#xff0c;侵权联系删。 LightRAG系列文章 ● LightRAG系列1&#xff1a;为什么 Web 开发者需要关注 RAG&#xff1f; ● LightRAG系列2&#xff1a;什么是 LightRAG&#xff1f;它和 LangChain 有什么区别&#xff1f; ● LightRAG系列3&#xff1a;LightRAG …

作者头像 李华
网站建设 2025/12/11 21:52:19

Wazuh+OpenCTI威胁情报集成教程(二)之OpenCTI 平台基础与规则体系

文章目录 背景 一、OpenCTI 核心认知 1. 什么是 OpenCTI? 2. 为什么要用 OpenCTI? 3. 谁适合用 OpenCTI? 二、OpenCTI 核心功能模块(附实操场景) 三、OpenCTI 安装部署(零基础教程) 1. 环境要求(核心参考) 2. 详细安装步骤(Ubuntu 22.04 示例) 四、OpenCTI 实操:10…

作者头像 李华
网站建设 2025/12/11 21:51:50

吐血整理,性能测试-正确定义性能瓶颈分析,一篇通透...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 性能测试和功能测…

作者头像 李华