一 函数列表
def__init__(self, args:Args)
def_init_distributed(self)
def_init_logging(self)
def_init_directories(self)
defcheck_setting(self)
defprepare_models(self)
defprepare_dataset(self)
defprepare_trainable_parameters(self)
defprepareoptimizer(self)
defprepare_for_training(self)
defprepare_for_validation(self)
defprepare_trackers(self)
deftrain(self)
defvalidate(self, step: int)
deffit(self)
defcollate_fn(self, examples: List[Dict[str, Any]])
defload_components(self)
definitialize_pipeline(self)
defencode_video(self, video: torch.Tensor)
defencode_text(self, text: str)
defcompute_loss(self, batch)
defvalidation_step(self)
def__get_training_dtype(self)
def__move_components_to_device(self, dtype, ignore_list: List[str] = []);
def__move_components_to_cpu(self, upload_list: List[str] = [])
def__prepare_saving_loading_hooks(self, transformer_lora_config);
def__maybe_save_checkpoint(self, global_step: int, must_save: bool=False):
defsave_args_and_state(self)
defsave_trainable_param_status(self):
二 函数说明
主要入口函数
1fit主训练入口
deffit(self):self.check_setting()self.prepare_models()self.prepare_dataset()self.prepare_trainable_parameters()self.prepare_optimizer()self.prepare_for_training()ifself.args.do_validation:self.prepare_for_validation()self.prepare_trackers()self.train()#最终进入训练循环作用:完整的训练流程入口,按顺序调用所有准备工作
顺序:检查设置->准备模型->准备数据->准备参数->准备优化器->准备训练->准备验证->准备跟踪器-〉开始训练
2train()核心训练循环
deftrain(self):forepochinrange(first_epoch, self.args.train_epochs):forstep,batchinenumerate(self.data_loader):#前向传播loss=self.compute_loss(batch)#反向传播accelerator.backward(loss)#参数更新self.optimizer.step()作用:执行实际的训练循环,包含前向传播,反向传播,参数更新
包含:epoch循环,batch循环,梯度累积,检查点保存,验证触发。
三 初始化函数(按调用顺序)
第1阶段:构造函数初始化
__init__(self, args: Args)构造函数
def__init__(self, args):self.args=argsself.state=State(...)#创建训练状态self.components=self.load_components()#加载模块组件self.accelerator=None#分布式加速器self._init_distributed()#初始化分布式self._init_logging()#初始化日志self._init_directories()#初始化目录作用:训练器的初始化入口
关键:创建State对象,加载组件,初始化分布时环境
2load_components()抽象方法,加载模型组件
defload_components(self)->ComponentsraiseNotImplementedError#子类必须实现作用:加载VAE,Transformer文本编码器等核心组件
实现:子类具体实现如何从HuggingFace或本地加载模型
第2阶段:分布式和系统初始化
_init_distributed()初始化分布式训练
def_init_distributed(self):project_config=ProjectConfigureation(...)ddp_kwargs=DistributedDataParallelKwargs(...)self.accelerator=Accelerator(...)#创建加速器作用:设置DDP/DeepSpeed分布式训练环境
功能:配置进程组,混合精度,梯度累积
4_init_logging()初始化日志系统
def_init_logging(self):logging.basicConfig(...)#Python标准日志transformers.utils.logging.set_verbosity_warning()#HuggingFace日志作用:配置Python日志系统和HuggingFace库的日志级别
5_init_directories()初始化输出目录
def_init_directories(self):self.args.output_dir.mkdir(parents=True, exist_ok=True)作用:创建输出目录
四 准备阶段函数(fit()中按顺序调用)
第3阶段,设置检查和模型准备
check_setting()检查训练设置
defcheck_setting(self):ifself.UPLOAD_LISTisNone:logger.warning("No unload_list specified...")作用:验证训练配置的合理性
检查:UPLOAD_LIST是否有效,防止配置错误
7prepare_models()准备模型组件
defprepare_models(self):ifself.components.vaeisnotNone:ifself.args.enable_slicing:#启用VAE切片self.components.vae.enable_slicing()作用:启用模型和内存优化功能
功能:VAE切片/分块,保存Transformer配置
第4阶段:数据和参数准备
prepare_dataset()准备数据集
defprepare_dataset(self):if self.args.model_type == "i2v":self.dataset = I2VDatasetWithResize(...)elif self.args.model_type == "t2v":self.dataset = T2VDatasetWithResize(...)self.data_loader = DataLoader(...)作用:根据模型类型创建对应的数据集和数据加载器
包含:预计算潜在表示
9prepare_trainable_parameters()准备可训练参数
defprepare_trainable_parameters(self):ifself.args.training_type=="lora":transformer_lora_config = LoraConfig(...)self.components.transformer.add_adapter(transformer_lora_config)作用:根据训练类型LoRA/SFT设置参数梯度
功能:配置LoRA适配器,冻结/解冻参数
10prepare_optimizer()准备优化器
defprepare_optimizer(self):trainable_parameters=filter(lambda p: p.requires_grad, ...)self.optimizer=get_optimizer(...)self.lr_scheduler=get_scheduler(...)作用:创建优化器和学习率调度器
功能:配置AdamW等优化器,计算训练步数
第5阶段:训练前最后准备
prepare_for_training()为训练做最后准备
defprepare_for_training(self):ifself.args.use_perceptual_loss:self.lpips_loss=pyiqa.create_metric('lpips', ...)self.components.transformer=self.accelerator.prepare(...)作用:初始化感知损失模型,使用accelerator准备组件
功能:准备光流模型,计算每个epoch的更新步数
12prepare_for_validation准备验证
defprepare_for_validation(self)validation_videos = load_videos(...)validation_prompts = load_prompts(...)self.state.validation_prompts = validation_prompts作用:加载严重数据,初始化评估指标
仅当:self.args.do_validation为True时调用
13prepare_trackers()准备实验跟踪器
defprepare_trackers(self):
self.accelerator.init_trackers(...)
作用:初始化Wandn/TensorBoard等实验跟踪工具
五 核心训练和验证函数
第6阶段:训练循环相关
compute_loss(batch)抽象方法:计算损失
defcompute_loss(self, batch)
raiseNotImplementedError# 子类必须实现
作用:相当于模型的forward函数,计算训练损失
实现,子类中会调用Transformer的forward方法
15collate_fn(examples)抽象方法,批处理函数
defcollate_fn(self, examples: List[Dict[str, Any]]):
raiseNotImplementedError#子类必须实现
作用:将多个样本组合成一个batch
功能:处理不同长度的序列,添加padding等。
第7阶段:验证和推理
validate(step)执行验证
defvalidate(self, step: int):pipe=self.initialize_pipeline()#创建推理管道foriinrange(num_validation_samples):result=self.validation_step(...)#生成样本evaluate_video_metrics(...)#评估指标作用:在训练过程中定期运行验证
频率:每validation_steps步运行一次
17validation_step()抽象方法,单步验证
defvalidation_step(self):raiseNotImplementedError#子类必须实现作用:生成单个验证样本
输出:图像或视频列表
18 initialize_pipeline抽象方法,创建推理管道
definitialize_pipeline(self)->DiffusionPipeline:raiseNotImplementedError#子类必须实现作用:创建用于推理的DiffusionPipeline
用于:验证阶段生成样本
六 编码相关函数,子类实现
19encode_video(video)抽象方法:视频编码
def encode_video(self, video: torch.Tensor)->torch.Tensor:raiseNotImplementedError作用:将视频编码到潜在空间
输入:[B,C,F,H,W]->输出[B,C',F',H',W']
20encode_text(text)抽象方法:文本编码
defencode_text(self, text: str)->torch.Tensor:raise作用:将文本编码作为嵌入向量
输出:[batch_size, sequence_length, embedding_dim]
七 辅助和工具函数
私有辅助函数
21__get_training_dtype()获取训练数据类型
def__get_training_dtype(self):if self.args.mixed_precision == "no":return torch.float32elif self.args.mixed_precision == "fp16":return torch.float16作用:根据混合精度设置返回对应的torch.dtype
22__move_components_to_devie()移动组件到设备
def__move_components_to_device(self, dtype, ignore_list=[]):forname,componentincomponents.items():component.to(self.accelerator.device, dtype=dtype)作用:将模型组件移动到指定设备和数据类型
23__move_components_to_cpu()移动组件到CPU
def__move_components_to_cpu(self, upload_list=[]):forname,componentincomponents.items():component.to("cpu")作用:卸载组件到CPU以节省GPU内存
24__prepare_saving_loading_hooks()准备保存/加载钩子
def__prepare_saving_loading_hooks(self, transformer_lora_config):defsave_model_hook(models, weights, output_dir):#保存LoRA权重defload_model_hook(models, input_dir)#加载LoRA权重作用:为LoRA训练注册自定义的保存/加载钩子
25__maybe_save_checkpoint()可能保存检查点
def__maybe_save_checkpoint(self, global_step, must_save=False):if must_saveorglobal_step%self.args.checkpointing_steps==0:self.accelerator.save_state(save_path)作用:定期或在必须时保存训练检查点
状态保存函数
26save_args_and_state保存参数和状态
defsave_args_and_state(self):save_dict={"args": {...}, "state": {...}}withopen(output_path / "args.yaml", "w")asf:yaml.safe_dump(save_dict, f)作用:将训练参数和状态保存到YAML文件
27save_trainable_param_status()保存参数训练状态
defsave_trainable_param_status(self):param_status={}forname,paraminself.components.transformer.named_parameters():param_status[name] = "trainable" if param.requires_grad else "frozen"作用:保存哪些参数是可训练的,哪些是冻结的。
八 训练流程图
初始化流程:__init__()_init_distributed()#分布式环境_init_logging()#日志系统_init_directories()#输出目录训练流程fit()fit()check_setting()#检查设置prepare_models()#准备模型prepare_dataset()#准备数据prepare_trainable_parameters()#准备可训练参数prepare_optimizer()#准备优化器prepare_for_training()#训练前准备prepare_for_validation()验证准备prepare_trackers()#准备跟踪器train()#开始训练循环compute_loss()#相当于forward每个batch反向传播和优化__maybe_save_checkpoint()#定期保存validate()#定期验证九 关键总结
训练入口函数
主入口:fit()完整训练流程的单一入口点
训练循环:train()实际执行训练的循环
前向传播:compute_loss()相当于模型的forward()函数
为什么用fit()而不是直接train()
1层次清晰fit()负责准备,train()负责执行
2用于友好:用于只需调用fit(),不用关心复杂的准备过程
模型训练初始化入口
真正的训练初始化时从fit()方法开始的,按正确的顺序调用所有准备函数,确保训练环境完全就绪后才进入train()循环。
与普通PyTorch模型的区别
#普通PyTorch模型model=MyModel()optimizer=Adam(model.parameters())loss_fn=nn.MSELoss()#需要用户自己写训练循环forepochinrange(epochs):forbatchindataloader:output=model(batch):#调用forwardloss=loss_fn(output, target)loss.backward()optimizer.step()#Dove训练器,封装了完整流程trainer=Trainer(args)trainer.fit()#一键训练,内部自动调用forward