news 2026/4/20 17:50:16

Search-R1论文浅析与代码实现

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Search-R1论文浅析与代码实现

otivation

使用seach engine给reasoning LLM赋能

Methodimage-20251021113633265

在PPO的基础上,基于给定的Search Egine

R

,进行轨迹生成。

J

P

P

O

(

θ

)

=

E

(

q

,

a

)

D

,

o

π

o

l

d

(

|

q

;

R

)

1

|

o

|

t

=

1

I

(

o

t

)

min

[

π

θ

(

o

t

|

q

,

o

<

t

;

R

)

π

o

l

d

(

o

t

|

q

,

o

<

t

;

R

)

A

t

,

c

l

i

p

(

1

ϵ

,

1

+

ϵ

,

π

θ

(

o

t

|

q

,

o

<

t

;

R

)

π

o

l

d

(

o

t

|

q

,

o

<

t

;

R

)

)

A

t

]

其中需要对

R

返回的token进行mask

I

(

o

t

)

=

{

0

,

o

t

i

s

a

r

e

t

r

i

v

e

d

t

o

k

e

n

;

1

,

o

t

h

e

r

w

i

s

e

;

Experimentsimage-20251021114918946

默认使用PPO,整体效果来看search-r1强化是有效的。training dataset来自NQ和Hotpot QA

PPO vs GRPO

认为PPO比GRPO更加稳定,效果更好;GRPO收敛更快

image-20251021115656035

image-20251021115618888

Instruct model vs base model

认为虽然instruct model在最开始的reward要优于base model,但是在step的后期,两者reward是可比的,且base model的效果优于instruct model。

(我认为,这里instruct好于base,可能是因为instruct后,模型的多样性下降了(因为RL的对齐),导致模型在search task的探索能力下降。但是,WebDancer等文章均使用的是Instruct model,我认为是那些工作 并不是一上来就search RL的,而是先做RFT的SFT,想让instruct model适应RL的格式,并注入search task的领域知识(planing能力、工具调用能力、总结能力等等)。如果是对base model做post-training的RFT(数据量可能不大),base model会出现指令不遵循的问题。因此在SFT+RL的后续WebAgent的工作中,一半以Instruct model为基座。)

image-20251021115930524

image-20251021115918404

Response length and valid study

early stage:response length明显下降,同时reward有小幅度提升(更好的理解search 任务,输出更精简)

latter stage:response length回升,reward也提升(可以发现是seach call的次数提升导致)

image-20251021120743669

ablation of retrived token mask

mask是必要的,因为model的预测目标本就不是 预测出retrieved token,而是学会工具调用与计划总结

image-20251021122034355

image-20251021121917794

Number of Retrieved Passages Study in SEARCH-R1 Training

召回的docs不是越多越好(actor model总结时会更容易出现幻觉或是遗漏细节),也不是越少越好(巧妇难为无米之炊)

image-20251021122054986

group size of GRPO

GRPO的size 大的话,效果好收敛快,但是不太稳定(感觉是论文工作设计有问题,我没有遇到过这种reward sharp decrease)

image-20251021122255511

Conclusion

提出了agent下的RL方法,但是没有构建sft的轨迹数据,导致无法学到 planing规划、单一工具调用、多工具关系的能力。

代码实现

Agent-RL的代码实现难点在于以下两方面,我将会对比naive RL和search-r1的在以下两方面的代码进行解析

traj的loop 生成

traj的reward manager

1. loop生成轨迹数据

区别于naive的RL,search-r1需要提取每步的action和tool,并进行retrieve调用。

首先咱们先来看一下verl在verl.trainer.ppo.ray_trainer.py调用的self.actor_rollout_wg.generate_sequences(gen_batch_output)的navie实现。

verl/workers/rollout/naive/naive_rollout.py。值得注意的是,rollout是采样,不需要保存计算图的,使用@torch.no_grad

class NaiveRollout(BaseRollout):

def __init__(self, module: nn.Module, config):

"""A naive rollout. It requires the module to be compatible with huggingface APIs. That is:

The module should define __call__ to receive input_ids, attention_mask and position_ids.

It outputs a structure that contains logits field.

Args:

module: module here follows huggingface APIs

config: DictConfig

"""

super().__init__()

self.config = config

self.module = module

#########################################################################

# rollout 不保存计算图

#########################################################################

@torch.no_grad()

def generate_sequences(self, prompts: DataProto) -> DataProto:

"""Generate sequences"""

#########################################################################

# 值得注意的是 如果是grpo,那么这里batch['input_ids']的shape是(batch_size*rollout.n, prompt_length)的

# 在ray_trainer.py里面有先做repeat操作

#########################################################################

idx = prompts.batch['input_ids'] # (bs, prompt_length)

attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask

position_ids = prompts.batch['position_ids']

# used to construct attention_mask

eos_token_id = prompts.meta_info['eos_token_id']

batch_size = idx.size(0)

prompt_length = idx.size(1)

self.module.eval()

# 这里的pre_attention_mask是记录每一个sequence是否已经rollout完毕

# 即 在当前iter生成的token之前 是否已经出现过 eos_token

prev_attention_mask = torch.ones(size=(batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)

logits_lst = []

#########################################################################

# 这里整体的思路是,每个迭代iter 同步生成所有sequence的同一位置(position_id)的 next_token_id

# 并且循环 response_length次,无论是否遇到eos_id

# 这么做的目的在于,基于矩阵操作并行地生成所有sequence,而不是每个sequence的生成,保证rollout效率

#########################################################################

for _ in range(self.config.response_length):

# if the sequence context is growing too long we must crop it at block_size

# idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]

idx_cond = idx

# forward the model to get the logits for the index in the sequence

# we use huggingface APIs here

output = self.module(input_ids=idx_cond, attention_mask=attention_mask, position_ids=position_ids)

# logits: (bs, hidden_layer_num, vocab_size)

logits = output.logits

#########################################################################

# 下面是一些采样的操作

# temperature: 每个token的所有的vocab的logit/temp

# topk: 把非topk的vocab 的logit 赋值为-inf,不影响后续的softmax,忽略这些低概率的vocab

# do_sample: 是概率采样 或是 选择概率最大的idx

#########################################################################

# pluck the logits at the final step and scale by desired temperature

logits = logits[:, -1, :] / self.config.temperature # (bs, vocab_size)

# optionally crop the logits to only the top k options

if self.config.top_k is not None:

v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1)))

logits[logits < v[:, [-1]]] = -float('Inf')

# apply softmax to convert logits to (normalized) probabilities

probs = F.softmax(logits, dim=-1)

# sample from the distribution

if self.config.do_sample:

idx_next = torch.multinomial(probs, num_samples=1)

else:

idx_next = torch.argmax(probs, dim=-1, keepdim=True)

#########################################################################

# 下面进行拼接

# attention_mask

# position_ids

# idx

#########################################################################

# 将当前token的mask拼接到之前的attention_mask上

# 其实当前token是否被mask主要看 之前的token是否出现 eos_token

attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1)

# 如果当前token是eos_token或之前出现过eos_token,那么之后的所有token都应该是被mask掉的

prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool())

prev_attention_mask.to(attention_mask.dtype)

position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1)

# append sampled index to the running sequence and continue

idx = torch.cat((idx, idx_next), dim=1)

logits_lst.append(logits)

# 将[(bs, vocab_size), ..., (bs, vocab_size)] 一共resp_length个 在1维度上进行堆叠

logits = torch.stack(logits_lst, dim=1) # (bs, response_length, vocab_size)

prompts = idx[:, :prompt_length] # (bs, prompt_length)

response = idx[:, prompt_length:] # (bs, response_length)

# 获取采样的每个token的概率(一般就是softmax一下,再根据response进行检索)

log_probs = logprobs_from_logits(logits=logits, labels=response)

batch = TensorDict(

{

'input_ids': prompts,

'responses': response,

'sequences': idx,

'old_log_probs': log_probs,

'attention_mask': attention_mask,

'position_ids': position_ids,

},

batch_size=batch_size)

self.module.train()

return DataProto(batch=batch)

可以发现的是,batch的response相当于是右填充,因为每个seq首次出现的eos_idx的后面的attnetion_mask都是1,具体是以下代码导致的:

# 将当前token的mask拼接到之前的attention_mask上

# 其实当前token是否被mask主要看 之前的token是否出现 eos_token

attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1)

# 如果当前token是eos_token或之前出现过eos_token,那么之后的所有token都应该是被mask掉的

prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool())

好了,看完naive的一个batch的sequences的generate流程,我们需要进一步看一下agent的traj的生成。

traj可以简单地认为是naive sequence的loop,但是需要对在每个step生成的sequence进行decode,来解析工具,并将工具调用的结果拼接到sequence的后面作为prompt,进行后续step的生成。

search-r1的训练流程为verl.trainer.ppo.ray_trainer.py,与原始verl最大的区别在于使用了新的LLMGenerationManager.run_llm_loop()方法以生成agent traj,因此我们先阅读这个主要模块:search_r1.llm_agent.generation.py

@dataclass

class GenerationConfig:

max_turns: int

# 最大开始prompt长度

max_start_length: int

# 最大累积prompt长度(start+(repsonse+obser)*step)

max_prompt_length: int

# 最大单次生成response的长度

max_response_length: int

# 最大工具返回内容的长度

max_obs_length: int

num_gpus: int

# 是否需要think

no_think_rl: bool=False

# search engine的url

search_url: str = None

# 召回docs的个数

topk: int = 3

class LLMGenerationManager:

...

#################################################################

# 生成agent traj数据,循环config.max_turns轮,每个traj最多是由max_turns*[sequence]拼接得到的

#################################################################

def run_llm_loop(self, gen_batch, initial_input_ids: torch.Tensor) -> Tuple[Dict, Dict]:

"""Run main LLM generation loop."""

#################################################################

# 下面初始化一些全局变量,用于维护 batch中每一个traj在 每个轮次turn的

# prompt response mask status action_stats search_stats

#################################################################

# 左填充

original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}

# 右填充

original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}

# 当前轮次 每个taj是否是active的(是否未完成且无异常):(bsz*rollout.n)

# 若active_mask = 0,那么这个example可能是结果了或是异常了,就不再进行后续turn的生成了

active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)

# 每个traj的active turn的总数(这个traj的turn总数)

turns_stats = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)

# 每个traj的action的总数(不一定等于turns_stats,因为有些turn可能action是错误的,不在(answer, search)中)

valid_action_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)

# 每个traj的search action的总数( 一般是turns_stats - answer_num(一般是1) )

valid_search_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)

# 每个轮次中 活跃的traj的数量

active_num_list = [active_mask.sum().item()]

rollings = gen_batch

#################################################################

# 下面开始 轮次循环,每个轮次需要生成response+提取工具+调用工具+获取obs+拼接prompt

#################################################################

# Main generation loop

for step in range(self.config.max_turns):

if not active_mask.sum():

break

rollings.batch = self.tensor_fn.cut_to_effective_len(

rollings.batch,

keys=['input_ids', 'attention_mask', 'position_ids']

)

# gen_output = self.actor_rollout_wg.generate_sequences(rollings)

# 仅筛选出还是active的traj(根据active_mask)

rollings_active = DataProto.from_dict({

k: v[active_mask] for k, v in rollings.batch.items()

})

# 这里先认为num_gpus是1, 并没有data-paralle,直接就是gen_output = self.actor_rollout_wg.generate_sequences(rollings)生成response

gen_output = self._generate_with_gpu_padding(rollings_active)

meta_info = gen_output.meta_info

# 对responses (bsz*rollout.n, response_length)做后处理

# 先做decode,将token ids解码成字符串,提取<search></search>包裹的动作或是<answer></answer>包裹的答案

# 并对提取出的search或是answer重新进行encode得到ids返回(右填充)

responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])

# 根据active_mask,将不active的exmaple的ids用pad_token填充,其str为""

# 因为active的example数量小于等于 batch的大小,为了填充batch,则需要padding not active example

responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)

# 调用search engine,返回所有exmaple的 【docs、是否成功、action是否合理、是否是search动作】 的列表

# Execute in environment and process observations

next_obs, dones, valid_action, is_search = self.execute_predictions(

responses_str, self.tokenizer.pad_token, active_mask

)

# 如果done的话,那么就要mask掉,因此是0

curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)

# 上个turn是not active,那么这个轮次还是

active_mask = active_mask * curr_active_mask

active_num_list.append(active_mask.sum().item())

turns_stats[curr_active_mask] += 1

valid_action_stats += torch.tensor(valid_action, dtype=torch.int)

valid_search_stats += torch.tensor(is_search, dtype=torch.int)

# 处理observation,先右填充编码,再进行max_obs_length截断,获取靠右左侧的obs

next_obs_ids = self._process_next_obs(next_obs)

# 将每traj的rolling+response_ids+next_obs_ids进行拼接

# 注意:rolling是左填充,response obs均是右填充

# 因此 拼接完成后,还需要将中间的padding ids移动到左侧,保持其他token的原有位置,继续维持rollings的左填充

# Update states

rollings = self._update_rolling_state(

rollings,

responses_ids,

next_obs_ids

)

# 同样是拼接 original_right_side+response+obs

# 但保持右填充

original_right_side = self._update_right_side(

original_right_side,

responses_ids,

next_obs_ids

)

# 可能存在一些example经过max_turns次循环后,还是没有得到answer,导致没有not active

# final LLM rollout

if active_mask.sum():

rollings.batch = self.tensor_fn.cut_to_effective_len(

rollings.batch,

keys=['input_ids', 'attention_mask', 'position_ids']

)

# gen_output = self.actor_rollout_wg.generate_sequences(rollings)

rollings_active = DataProto.from_dict({

k: v[active_mask] for k, v in rollings.batch.items()

})

gen_output = self._generate_with_gpu_padding(rollings_active)

meta_info = gen_output.meta_info

responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])

responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)

# # Execute in environment and process observations

_, dones, valid_action, is_search = self.execute_predictions(

responses_str, self.tokenizer.pad_token, active_mask, do_search=False

)

curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)

active_mask = active_mask * curr_active_mask

active_num_list.append(active_mask.sum().item())

valid_action_stats += torch.tensor(valid_action, dtype=torch.int)

valid_search_stats += torch.tensor(is_search, dtype=torch.int)

original_right_side = self._update_right_side(

original_right_side,

responses_ids,

)

meta_info['turns_stats'] = turns_stats.tolist()

meta_info['active_mask'] = active_mask.tolist()

meta_info['valid_action_stats'] = valid_action_stats.tolist()

meta_info['valid_search_stats'] = valid_search_stats.tolist()

print("ACTIVE_TRAJ_NUM:", active_num_list)

return self._compose_final_output(original_left_side, original_right_side, meta_info)

# 拼接origin_left+累积的模型输出和工具调用

def _compose_final_output(self, left_side: Dict,

right_side: Dict,

meta_info: Dict) -> Tuple[Dict, Dict]:

"""Compose final generation output."""

final_output = right_side.copy()

final_output['prompts'] = left_side['input_ids']

# Combine input IDs

final_output['input_ids'] = torch.cat([

left_side['input_ids'],

right_side['responses']

], dim=1)

# Create attention mask and position ids

final_output['attention_mask'] = torch.cat([

self.tensor_fn.create_attention_mask(left_side['input_ids']),

self.tensor_fn.create_attention_mask(final_output['responses'])

], dim=1)

final_output['info_mask'] = torch.cat([

self.tensor_fn.create_attention_mask(left_side['input_ids']),

self.tensor_fn.create_attention_mask(final_output['responses_with_info_mask'])

], dim=1)

final_output['position_ids'] = self.tensor_fn.create_position_ids(

final_output['attention_mask']

)

final_output = DataProto.from_dict(final_output)

final_output.meta_info.update(meta_info)

return final_output

咱们再回过来看一下search-r1的rl流程 ray_trainer.py

#########################################################################

# search-r1是直接在verl的trainer.ppo.ray_trainer.py的源码上进行扩展

# 添加了新的 generate_mannager用于生成agent traj(将在下一个代码框进行介绍)

# 我们先来看一下search-r1的整体训练流程

#########################################################################

def fit(self):

"""

The training loop of PPO.

The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.

The light-weight advantage computation is done on the driver process.

"""

logger = self.logger

self.global_steps = 0

# perform validation before training

# currently, we only support validation using the reward_function.

if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):

val_metrics = self._validate()

pprint(f'Initial validation metrics: {val_metrics}')

logger.log(data=val_metrics, step=self.global_steps)

if self.config.trainer.get('val_only', False):

return

# we start from step 1

self.global_steps += 1

#########################################################################

# 这里是新添加的agent traj轨迹数据的generate模块

# Agent config preparation

gen_config = GenerationConfig(

max_turns=self.config.max_turns,

max_start_length=self.config.data.max_start_length,

max_prompt_length=self.config.data.max_prompt_length,

max_response_length=self.config.data.max_response_length,

max_obs_length=self.config.data.max_obs_length,

num_gpus=self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes,

no_think_rl=self.config.algorithm.no_think_rl,

search_url = self.config.retriever.url,

topk = self.config.retriever.topk,

)

generation_manager = LLMGenerationManager(

tokenizer=self.tokenizer,

actor_rollout_wg=self.actor_rollout_wg,

config=gen_config,

)

#########################################################################

#########################################################################

# 这里的loop还是verl的源码,循环每一个train epoch

# start training loop

for epoch in range(self.config.trainer.total_epochs):

for batch_dict in self.train_dataloader:

print(f'epoch {epoch}, step {self.global_steps}')

metrics = {}

timing_raw = {}

# 获取一个batch的训练数据 (bsz, prompt_length)

# 并进行repeat(grpo需要repeat)

# 注意:prompt是左填充的

batch: DataProto = DataProto.from_single_dict(batch_dict)

batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n_agent, interleave=True)

# pop those keys for generation

gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])

####################

# original code here

with _timer('step', timing_raw):

if not self.config.do_search:

gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],

dtype=object)

# repeat to align with repeated responses in rollout

batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)

batch = batch.union(gen_batch_output)

#########################################################################

# 这里就是新的search-r1的训练流程了

#########################################################################

####################

# Below is aLL about agents - the "LLM + forloop"

####################

# with _timer('step', timing_raw):

else:

# 这里先做了一个左截断,仅保留靠右的max_start_length的prompt ids

first_input_ids = gen_batch.batch['input_ids'][:, -gen_config.max_start_length:].clone().long()

with _timer('gen', timing_raw):

generation_manager.timing_raw = timing_raw

# 这里生成数据 (bsz*rollout.n, prompt_length+response_length)

final_gen_batch_output = generation_manager.run_llm_loop(

gen_batch=gen_batch,

initial_input_ids=first_input_ids,

)

# final_gen_batch_output.batch.apply(lambda x: x.long(), inplace=True)

for key in final_gen_batch_output.batch.keys():

final_gen_batch_output.batch[key] = final_gen_batch_output.batch[key].long()

with torch.no_grad():

output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)

final_gen_batch_output = final_gen_batch_output.union(output)

# batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],

# dtype=object)

# 看来是输入的时候记录了每个q的index在non_tensor中

batch.non_tensor_batch['uid'] = batch.non_tensor_batch['index'].copy()

# repeat to align with repeated responses in rollout

batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)

batch = batch.union(final_gen_batch_output)

####################

####################

# balance the number of valid tokens on each dp rank.

# Note that this breaks the order of data inside the batch.

# Please take care when you implement group based adv computation such as GRPO and rloo

self._balance_batch(batch, metrics=metrics)

# compute global_valid tokens

batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()

# batch.batch.apply(lambda x, key: x.long() if key != "old_log_probs" else x, inplace=True, key=True)

for key in batch.batch.keys():

if key != 'old_log_probs':

batch.batch[key] = batch.batch[key].long()

if self.use_reference_policy:

# compute reference log_prob

with _timer('ref', timing_raw):

ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)

batch = batch.union(ref_log_prob)

# compute values

if self.use_critic:

with _timer('values', timing_raw):

values = self.critic_wg.compute_values(batch)

batch = batch.union(values)

with _timer('adv', timing_raw):

# compute scores. Support both model and function-based.

# We first compute the scores using reward model. Then, we call reward_fn to combine

# the results from reward model and rule-based results.

if self.use_rm:

# we first compute reward model score

reward_tensor = self.rm_wg.compute_rm_score(batch)

batch = batch.union(reward_tensor)

# we combine with rule-based rm

reward_tensor = self.reward_fn(batch)

batch.batch['token_level_scores'] = reward_tensor

# compute rewards. apply_kl_penalty if available

if not self.config.actor_rollout_ref.actor.use_kl_loss:

batch, kl_metrics = apply_kl_penalty(batch,

kl_ctrl=self.kl_ctrl,

kl_penalty=self.config.algorithm.kl_penalty)

metrics.update(kl_metrics)

else:

batch.batch['token_level_rewards'] = batch.batch['token_level_scores']

# compute advantages, executed on the driver process

batch = compute_advantage(batch,

adv_estimator=self.config.algorithm.adv_estimator,

gamma=self.config.algorithm.gamma,

lam=self.config.algorithm.lam,

num_repeat=self.config.actor_rollout_ref.rollout.n)

# update critic

if self.use_critic:

with _timer('update_critic', timing_raw):

critic_output = self.critic_wg.update_critic(batch)

critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])

metrics.update(critic_output_metrics)

# implement critic warmup

if self.config.trainer.critic_warmup <= self.global_steps:

# update actor

with _timer('update_actor', timing_raw):

if self.config.do_search and self.config.actor_rollout_ref.actor.state_masking:

batch, metrics = self._create_loss_mask(batch, metrics)

actor_output = self.actor_rollout_wg.update_actor(batch)

actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])

metrics.update(actor_output_metrics)

# validate

if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \

self.global_steps % self.config.trainer.test_freq == 0:

with _timer('testing', timing_raw):

val_metrics: dict = self._validate()

metrics.update(val_metrics)

if self.config.trainer.save_freq > 0 and \

self.global_steps % self.config.trainer.save_freq == 0:

with _timer('save_checkpoint', timing_raw):

self._save_checkpoint()

# collect metrics

metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))

metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))

# TODO: make a canonical logger that supports various backend

logger.log(data=metrics, step=self.global_steps)

self.global_steps += 1

if self.global_steps >= self.total_training_steps:

# perform validation after training

if self.val_reward_fn is not None:

val_metrics = self._validate()

pprint(f'Final validation metrics: {val_metrics}')

logger.log(data=val_metrics, step=self.global_steps)

return

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/17 8:21:08

vscode python debug方式

找到vscode左侧的debug案件&#xff0c;新建点击创建launch.json文件然后可以看到在当前项目下创建一个了launch.json的文件现在需要根据要跑的代码修改aunch.json文件内容&#xff1a;使用下面代码查看python位置&#xff1a;which pythonlauch.json的模板文件&#xff1a;{&q…

作者头像 李华
网站建设 2026/4/17 12:21:28

pk3DS:颠覆传统体验的3DS宝可梦游戏全能编辑器

pk3DS&#xff1a;颠覆传统体验的3DS宝可梦游戏全能编辑器 【免费下载链接】pk3DS Pokmon (3DS) ROM Editor & Randomizer 项目地址: https://gitcode.com/gh_mirrors/pk/pk3DS 你是否厌倦了重复的宝可梦冒险&#xff1f;是否渴望打造属于自己的独特游戏世界&#x…

作者头像 李华
网站建设 2026/4/20 12:00:41

StatementHandler语句处理器

1. 学习目标确认1.0 第5篇思考题解答在深入学习StatementHandler语句处理器之前&#xff0c;让我们先回顾并解答第5篇中提出的思考题&#xff0c;这将帮助我们更好地理解StatementHandler在整个架构中的作用。思考题1&#xff1a;为什么MyBatis要设计多种Executor类型&#xff…

作者头像 李华
网站建设 2026/4/17 8:21:09

文章里放太多反向链接会分散权重吗?最多能放几个?

本文用大白话拆解外链使用的底层逻辑&#xff0c;你会发现&#xff1a;所谓“权重分散”的本质&#xff0c;其实是内容价值的集中度问题不同场景的外链数量标准差异巨大&#xff08;产品页和干货文能差4倍&#xff09;真正影响SEO的往往不是质量&#xff0c;而是数量什么是&quo…

作者头像 李华
网站建设 2026/4/18 8:08:54

Comate Spec模式实测:让AI编程更精准可靠

作为一名长期关注AI编程工具的开发者&#xff0c;最近深度体验了百度Comate的Spec模式&#xff0c;这种“先规划后执行”的新颖工作流让我眼前一亮。 与传统AI编程助手直接生成代码不同&#xff0c;Spec模式要求AI先输出实现文档和任务拆解&#xff0c;经过用户确认后才开始编…

作者头像 李华