人类反馈强化学习(RLHF) 从架构到监督微调
关于强化学习(reinforcement learning from human feedback)架构、演员-评论家架构、近端策略优化(PPO)及DeepSpeed Chat的RLHF三阶段训练流程,并附代码实操与详细注释。
一、强化学习架构
强化学习是机器学习技术,旨在训练智能体(agent)在特定环境(environment)中通过交互决策,最大化奖励(reward)信号。
核心要素
- 智能体:在环境中执行行动的个体/软件。
- 环境:智能体所处的场景,对行动做出反应并给出奖惩。
- 状态(state):环境在某一时刻的描述。
- 动作(action):智能体在特定状态下执行的操作。
- 奖赏(reward):环境给智能体的反馈(正/负数值)。
- 策略(policy):智能体的行为函数,决定状态下的行动选择。
- 值函数(value function):预测状态/行动后可获得的预期回报。
常见算法
- 基础算法:Q-Learning、SARSA、演员-评论家(Actor-Critic)。
- 深度强化学习算法:Deep Q-Network(DQN)、深度确定性策略梯度(DDPG)。
比如下图展现了人类强化学习的核心交互逻辑:
其中:
人类反馈训练奖励模型,对智能体的动作进行评分,排序等偏好标注,让奖励模型学会人类的评价标准
每次智能体从环境获得预测,做出动作,环境把动作传递给智能奖励模型,结合人类的反馈训练输出奖励,智能体再结合奖励调整策略,再次与环境交互,这样周而复始。
很好理解吧,你会学了喵?
二、演员-评论家架构
演员-评论家(Actor-Critic)是结合值函数估计与策略梯度优化的强化学习框架,包含:
1、核心模块
- 演员(Actor)
- 策略模型,根据当前状态输出动作的概率分布。
- 通过策略梯度调整参数,最大化预期累积奖励。
- 评论家(Critic)
- 价值函数估计器,评估演员动作的价值。
- 计算优势函数(Advantage Function),指导演员更新策略。
- 结合与交互
- 演员执行动作→环境反馈奖励与新状态→评论家更新价值估计→演员根据优势函数更新策略。
- 优势:值函数辅助理解环境,策略梯度高效探索,二者独立训练提升效率。
如图,Actor-Critic原本是两个独立网络,分别负责策略输出(Actor)和价值评估(Critic),而图中通过共享特征提取网络(绿色模块),让两者复用底层的状态特征,仅在顶层分支做差异化处理,是该架构的关键优化手段。
2. 各模块作用
输入 s :代表智能体所处的环境状态(state),是整个网络的输入源。
绿色共享网络:对输入状态 s 进行特征提取,输出统一的高维特征表示,供Actor和Critic分支使用。
蓝色Actor网络:基于共享特征,输出动作概率分布(如图中的 left / right / fire ),决定智能体该执行的动作。
橙色Critic网络:基于共享特征,输出标量值(scalar),代表当前状态/动作的价值评估(即预期奖励),用于指导Actor更新策略。
3. 参数共享的优势
减少参数量:避免重复提取状态特征,降低模型存储和计算成本。
提升特征一致性:Actor和Critic使用相同的底层特征,让策略决策与价值评估的依据保持统一,训练更稳定。
加快收敛速度:共享特征能让两个网络的学习过程相互促进,减少训练迭代次数。
三、近端策略优化架构
近端策略优化(PPO)是演员-评论家方法的增强版,解决策略更新不稳定、高方差等问题,核心改进点:
- 策略更新机制:引入 clip 机制,限制策略更新幅度,避免剧烈变化。
- 重要性采样:重用旧策略数据,减少数据需求并加快学习速度。
- 优化目标:加入基于KL散度(Kullback-Leibler Divergence)的惩罚项,保持策略更新平稳。
图中是一个ppo,演员评论家等混合架构
让我们看看这是如何运作的喵:
准备“基础原料”——SFT模型与初始回答
用户输入问题 x 后,策略模型先像“初稿写手”一样生成回答,接着系统会把“问题+回答”的问答对切成“状态、动作、奖励”这些小模块(就像把一篇文章拆成段落分析)。同时,SFT模型是提前用人类优质回答微调好的“范本”,后续会用KL散度盯着策略模型,防止它生成的内容和范本差太远(避免模型乱回答)。打分与评估——奖励模型+评论模型
奖励模型像“人类评委”,根据人类反馈的标准给策略模型的回答打分,打出的分数就是“奖励信号”;
评论模型像“专业分析师”,评估当前回答的“价值”(比如这个回答在当前语境下好不好),给个量化的价值估计。算清“好坏差距”——GAE的作用
把奖励模型的分数、评论模型的价值估计都交给GAE(广义优势估计),它就像“数据分析师”,会算出两个关键结果:一是“优势值”(这个回答比平均水平好多少),二是“总回报”(长期来看这个回答能带来多少收益),这些数据都会存进经验缓冲(像个数据库)里。优化“初稿写手”——策略模型的升级
策略模型从经验缓冲里拿数据,结合预训练数据(保证模型不忘掉原有知识)开始优化:
用PPO-clip损失限制修改幅度,避免改得太猛导致模型“学歪”;用语言模型损失保证回答符合语言逻辑,不出现病句、胡话。
同时,评论模型也会用均方差误差损失自我修正,让后续的价值评估更准。
整个过程的核心就是让策略模型在“人类评委(奖励模型)”的打分、“分析师(评论模型)”的评估、“范本(SFT模型)”的约束下,不断修改回答,最终生成既符合人类偏好,又有知识准确性的内容。
你学会了喵~~
四、DeepSpeed Chat
DeepSpeed Chat是微软开发的深度学习框架,复刻OpenAI InstructGPT的RLHF三阶段训练策略,实现ChatGPT风格模型的高效训练。
RLHF三阶段训练流程
- 监督微调(SFT)
- 用人类问答数据微调预训练语言模型(如GPT、Llama),训练演员模型,使其能根据提示生成文本响应。
- 目标:让模型理解并响应各类查询,适配对话场景。
- 奖励模型训练(Reward Model Training)
- 用小型模型学习评估生成文本的质量,从“好/坏”示例中区分高低质量回答,输出奖励信号。
- 数据集:人工标注的 chosen dataset (高质量)和 reject dataset (低质量)。
- 策略优化(Policy Optimization)
- 奖励模型对演员模型输出评分,得到优势函数近似值(奖励信号代理)。
- 演员模型通过策略梯度更新调整权重,生成更易获得高奖励的响应。
- 引入参考模型与KL散度约束,避免演员模型输出偏离参考模型过远。
参考资源
- DeepSpeed Chat源码:https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeedChat
- 推荐基础模型:OPT-1.3B(参考/演员模型)、OPT-350m(奖励/评论家模型);中文适配:Chinese-Llama-2-1.3b(https://huggingface.co/hfl/chinese-llama-2-1.3b)
- 国内镜像源:https://hf-mirror.com/hfl/chinese-llama-2-1.3b
- 开源RLHF数据集
从DeepSpeed Chat源码中可获取的开源RLHF数据集:
Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gpt-pairwise yingxie/rlhf-reward-datasets openai/webgpt_comparisons stanfordnlp/SHPshu pdvuy/sharegpt_alpaca_oa_vicuna_format wangrui6/Zhihu-KOL Cohere/miracl-zh-queries-22-12 Hello-SimpleAI/HC3-Chinese Cohere/miracl-ja-queries-22-12 lmqg/qag_jaquad lmqg/qag_jaquad- Dahoas/rm-static:train 76.3k行,test 5.1k行;
- Dahoas/full-hh-rlhf:train 112k行,test 12.5k行。
- …
五、训练数据读取(含代码实操)
1、定义统一参数
通过 argparse 定义训练所需的基础参数,用于指定数据集路径、训练配置等。
下面代码的模型输出目录等路径可以写你自己的喵
importargparse# 创建参数解析器args=argparse.ArgumentParser()# 分布式训练的本地rank,单机单卡设为0args.local_rank=0# 数据集路径args.data_path=["dahoas/rm-static"]# 数据划分比例(训练、验证、测试)args.data_split="2,4,8"# 数据输出路径args.data_output_path='/tmp/data_files'# 模型输出目录args.output_dir="./output_step3_llama"# 随机种子args.seed=1234# 最大序列长度args.max_seq_len=2562、读取Dahoas/rm-static数据集
从 dschat 工具库导入数据集读取方法,加载指定数据集并查看结构。
fromdschat.utils.dataimportraw_datasets# 使用deepspeedchat的方法# 数据集名称dataset_name="dahoas/rm-static"# 加载原始数据集raw_dataset=raw_datasets.DahoasRmstaticDataset(args.output_dir,args.seed,args.local_rank,dataset_name)# DahoasRmstaticDataset :针对Dahoas/rm-static数据集的专用读取类,处理数据加载与划分。# 获取训练集数据train_dataset=raw_dataset.get_train_data()# get_train_data() :提取训练子集,数据集包含 prompt (提示)、 response (响应)、 chosen (优选回复)、 rejected (劣选回复)四个核心字段。# 打印训练集基本信息print(train_dataset)输出结果大概是这样喵
Dataset({ features: ['prompt', 'response', 'chosen', 'rejected'], num_rows: 76256 })试着打印第一条数据
# 打印第一条数据的prompt+chosen拼接结果fori,tmp_datainenumerate(train_dataset):print(raw_dataset.get_prompt_and_chosen(tmp_data))# get_prompt_and_chosen() :将 prompt 与 chosen 字段拼接,用于后续SFT训练break# 只打印第一条3、SFT监督微调数据
SFT(监督微调)需将 prompt 与 chosen 字段合并,生成模型训练的标注数据,并通过分词器转为张量。
加载分词器
使用LlamaTokenizer对文本进行分词,设置pad_token为eos_token(Llama模型无默认pad_token)。
fromtransformersimportLlamaTokenizer# 模型路径/名称model_name_or_path="ci/llama2/llama-2-7b-hf"# 加载分词器tokenizer=AutoTokenizer.from_pretrained(model_name_or_path,trust_remote_code=True)# 设置pad_token为eos_tokentokenizer.pad_token=tokenizer.eos_token生成SFT训练数据
将 prompt+chosen 拼接后的文本分词,生成模型可接收的 input_ids 和 attention_mask 。
# 对话结束标记end_of_conversation_token="<<|endoftext|>"# 初始化数据集列表prompt_dataset=[]chosen_dataset=[]reject_dataset=[]# 遍历训练集fori,tmp_datainenumerate(train_dataset):# 拼接prompt+chosen生成优选回复文本chosen_sentence=raw_dataset.get_prompt_and_chosen(tmp_data)ifchosen_sentenceisnotNone:# 添加对话结束标记chosen_sentence+=end_of_conversation_token# 分词并转为PyTorch张量chosen_token=tokenizer(chosen_sentence,max_length=args.max_seq_len,# 最大序列长度padding="max_length",# 不足长度时补pad_tokentruncation=True,# 超过长度时截断return_tensors="pt"# 返回PyTorch张量)# 去除batch维度(squeeze(0))chosen_token["input_ids"]=chosen_token["input_ids"].squeeze(0)chosen_token["attention_mask"]=chosen_token["attention_mask"].squeeze(0)# 添加到数据集chosen_dataset.append(chosen_token)# 打印第一条数据的input_idsprint(chosen_dataset[0]["input_ids"])4、奖励模型微调数据
奖励模型训练需要正样本(chosen)和负样本(rejected),分别由 prompt+chosen 和 prompt+rejected 拼接生成。
# 对话结束标记end_of_conversation_token="<<|endoftext|>"# 初始化数据集列表prompt_dataset=[]chosen_dataset=[]reject_dataset=[]# 遍历训练集fori,tmp_datainenumerate(train_dataset):# 获取正样本文本(prompt+chosen)chosen_sentence=raw_dataset.get_prompt_and_chosen(tmp_data)# 获取负样本文本(prompt+rejected)reject_sentence=raw_dataset.get_prompt_and_rejected(tmp_data)ifchosen_sentenceisnotNoneandreject_sentenceisnotNone:# 添加对话结束标记chosen_sentence+=end_of_conversation_token reject_sentence+=end_of_conversation_token# 正样本分词chosen_token=tokenizer(chosen_sentence,max_length=args.max_seq_len,padding="max_length",truncation=True,return_tensors="pt")# 负样本分词reject_token=tokenizer(reject_sentence,max_length=args.max_seq_len,padding="max_length",truncation=True,return_tensors="pt")# 去除batch维度并添加到数据集chosen_token["input_ids"]=chosen_token["input_ids"].squeeze(0)chosen_token["attention_mask"]=chosen_token["attention_mask"].squeeze(0)chosen_dataset.append(chosen_token)reject_token["input_ids"]=reject_token["input_ids"].squeeze(0)reject_token["attention_mask"]=reject_token["attention_mask"].squeeze(0)reject_dataset.append(reject_token)break# 仅处理第一条数据用于演示# 打印正/负样本文本和张量print(f"chosen_sentence={chosen_sentence}")print(f"chosen_dataset[0]={chosen_dataset[0]}")print(f"reject_sentence={reject_sentence}")print(f"reject_dataset[0]={reject_dataset[0]}")代码注释:
- 奖励模型通过对比正/负样本的输出得分,学习对优质回复的偏好。
- 正/负样本的分词逻辑与SFT一致,保证输入格式统一。
5、RLHF微调数据
RLHF微调阶段仅使用 prompt 字段作为输入,对分词后的 input_ids 进行反转,并过滤超长序列。
filtered=0# 统计过滤的超长数据量prompt_dataset=[]fori,tmp_datainenumerate(train_dataset):# 获取prompt文本prompt=raw_dataset.get_prompt(tmp_data)ifpromptisnotNone:# 对prompt分词prompt_token=tokenizer(prompt,max_length=args.max_seq_len,padding="max_length",truncation=True,return_tensors="pt")# 打印原始input_ids形状print(f"prompt_token['input_ids'].size()={prompt_token['input_ids'].size()}")# 若序列长度超过最大限制,反转并截断ifprompt_token["input_ids"].size()[-1]>args.max_seq_len:# 反转input_ids和attention_maskforkey_wordin["input_ids","attention_mask"]:prompt_token[key_word]=prompt_token[key_word].squeeze(0).flip(0)# 反转张量# 截断至max_seq_lenprompt_token[key_word]=prompt_token[key_word][:args.max_seq_len].unsqueeze(0)# 添加到数据集prompt_dataset.append(prompt_token)else:filtered+=1# 过滤空数据break# 仅处理第一条数据用于演示# 打印处理后的第一条数据input_idsprint(f"prompt_dataset[0]['input_ids']={prompt_dataset[0]['input_ids']}")六、监督微调(SFT):RLHF训练的核心环节
RLHF中的监督微调(SFT)与普通监督微调的核心差异在于训练数据集,仅采用RLHF数据集的 prompt 和 chosen 字段构建训练样本。本文基于DeepSpeed Chat框架,整理SFT的参数配置、数据加载与模型训练全流程,并添加详细注释。
1、设置通用训练参数
通过 argparse 定义SFT训练的基础配置,包括数据集路径、训练超参、硬件配置等。
路径一样可以换成自己的喵
importargparse# 创建参数解析器,标注为RLHF训练第三步(SFT)args=argparse.ArgumentParser(description="(Step 3) RLHF training arguments")# 分布式训练参数:单机单卡设为0args.local_rank=0# 数据集路径:Dahoas/rm-static数据集的本地挂载路径args.data_path=["/mnt/Dahoas/rm-static"]# 数据划分比例:训练/验证/测试 = 2:4:4args.data_split="2,4,4"# 数据预处理后的输出路径args.data_output_path='/tmp/data_files'# 最大序列长度:统一文本输入的长度为256args.max_seq_len=256# 训练轮数:仅训练1个epoch用于演示args.num_train_epochs=1# 随机种子:保证实验可复现args.seed=1234# 单设备训练批次大小:设为1便于观察数据格式args.per_device_train_batch_size=1# 单设备验证批次大小args.per_device_eval_batch_size=1# 是否打印损失值args.print_loss=True2、SFT模型训练流程
加载预训练大模型与分词器,构建SFT数据集并执行单轮训练,重点验证数据格式与模型输出。
- 导入依赖与加载模型/分词器
# 导入因果语言模型和分词器fromtransformersimportAutoModelForCausalLM,AutoTokenizer# 导入DeepSpeed Chat的数据加载工具fromdschat.utils.dataimportDataLoaderfromdschat.utils.data.data_utilsimportcreate_prompt_dataset# 预训练模型路径:中文Llama-2-7B模型的本地路径model_name_or_path='/mnt/chinese-llama-2-7b'# 加载分词器tokenizer=AutoTokenizer.from_pretrained(model_name_or_path,trust_remote_code=True# 信任远程代码(针对自定义模型))# 加载因果语言模型,自动分配设备并启用模型卸载(节省显存)model=AutoModelForCausalLM.from_pretrained(model_name_or_path,device_map="auto",# 自动将模型层分配到GPU/CPUtrust_remote_code=True,offload_folder='/tmp'# 模型卸载的临时文件夹).eval()# 先设为评估模式,训练时再切换为train()3、构建SFT训练/验证数据集
# 训练阶段标记:1代表SFT阶段train_phase=1# 创建SFT数据集(仅使用prompt和chosen字段)train_dataset,eval_dataset=create_prompt_dataset(args.local_rank,# 分布式rankargs.data_path,# 数据集路径args.data_split,# 数据划分比例args.data_output_path,# 数据输出路径train_phase,# 训练阶段args.seed,# 随机种子tokenizer,# 分词器args.max_seq_len,# 最大序列长度sft_only_data_path=[]# 仅SFT的数据集路径(此处为空))# 创建训练数据加载器train_dataloader=DataLoader(train_dataset,batch_size=args.per_device_train_batch_size# 批次大小)4、执行SFT训练(单轮演示)
# 遍历训练轮数(仅1轮)forepochinrange(args.num_train_epochs):print(f"Beginning of Epoch{epoch+1}/{args.num_train_epochs}, Total Micro Batches:{len(train_dataloader)}")model.train()# 切换为训练模式# 遍历训练批次forstep,batchinenumerate(train_dataloader):# 模型前向传播:输入批次数据,关闭缓存以节省显存outputs=model(**batch,use_cache=False)loss=outputs.loss# 获取训练损失# 打印损失值(若开启)ifargs.print_loss:print(f"Epoch:{epoch}, Step:{step}, Rank: loss ={loss}")# 反向传播计算梯度loss.backward(loss)break# 仅训练1个批次用于演示,实际训练需删除训练过程中模型会输出损失值和logits(模型输出的原始预测分数),输出大概如下喵:
# 损失值输出outputs.loss# tensor(5.4209, grad_fn=<ToCopyBackward0>)# logits输出(模型对每个token的预测分数)outputs.logits# tensor([[[-1.4668, -3.6502, -0.1360, ..., -1.3301, -0.2463, -1.1404],# [-1.4668, -3.6468, -0.1329, ..., -1.3306, -0.2472, -1.1411],# ...,# [-1.2088, ..., -1.9516, -1.8949]]])其中:
- loss :当前批次的训练损失,数值越低表示模型预测越接近真实标签( chosen 字段)。
- logits :模型对每个位置token的预测概率分布(未经过SoftMax),维度为 [batch_size, seq_len, vocab_size] 。