news 2026/4/20 4:15:36

Dove模型函数分析

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Dove模型函数分析

一 函数列表

def__init__(self, args:Args)

def_init_distributed(self)

def_init_logging(self)

def_init_directories(self)

defcheck_setting(self)

defprepare_models(self)

defprepare_dataset(self)

defprepare_trainable_parameters(self)

defprepareoptimizer(self)

defprepare_for_training(self)

defprepare_for_validation(self)

defprepare_trackers(self)

deftrain(self)

defvalidate(self, step: int)

deffit(self)

defcollate_fn(self, examples: List[Dict[str, Any]])

defload_components(self)

definitialize_pipeline(self)

defencode_video(self, video: torch.Tensor)

defencode_text(self, text: str)

defcompute_loss(self, batch)

defvalidation_step(self)

def__get_training_dtype(self)

def__move_components_to_device(self, dtype, ignore_list: List[str] = []);

def__move_components_to_cpu(self, upload_list: List[str] = [])

def__prepare_saving_loading_hooks(self, transformer_lora_config);

def__maybe_save_checkpoint(self, global_step: int, must_save: bool=False):

defsave_args_and_state(self)

defsave_trainable_param_status(self):

二 函数说明

主要入口函数

1fit训练入口

deffit(self):self.check_setting()self.prepare_models()self.prepare_dataset()self.prepare_trainable_parameters()self.prepare_optimizer()self.prepare_for_training()ifself.args.do_validation:self.prepare_for_validation()self.prepare_trackers()self.train()#最终进入训练循环

作用完整训练流程入口按顺序调用所有准备工作

顺序检查设置->准备模型->准备数据->准备参数->准备优化器->准备训练->准备验证->准备跟踪器-开始训练

2train()核心训练循环

deftrain(self):forepochinrange(first_epoch, self.args.train_epochs):forstep,batchinenumerate(self.data_loader):#前向传播loss=self.compute_loss(batch)#反向传播accelerator.backward(loss)#参数更新self.optimizer.step()

作用执行实际训练循环包含前向传播反向传播参数更新

包含:epoch循环batch循环梯度累积检查点保存验证触发

三 初始化函数(按调用顺序)

1阶段构造函数初始化

__init__(self, args: Args)构造函数

def__init__(self, args):self.args=argsself.state=State(...)#创建训练状态self.components=self.load_components()#加载模块组件self.accelerator=None#分布式加速器self._init_distributed()#初始化分布self._init_logging()#初始化日志self._init_directories()#初始化目录

作用训练器初始化入口

关键创建State对象加载组件初始化分布时环境

2load_components()抽象方法加载模型组件

defload_components(self)->ComponentsraiseNotImplementedError#子类必须实现

作用加载VAETransformer文本编码核心组件

实现子类具体实现如何HuggingFace本地加载模型

2阶段分布式系统初始化

_init_distributed()初始化分布训练

def_init_distributed(self):project_config=ProjectConfigureation(...)ddp_kwargs=DistributedDataParallelKwargs(...)self.accelerator=Accelerator(...)#创建加速器

作用设置DDP/DeepSpeed分布式训练环境

功能配置进程混合精度梯度累积

4_init_logging()初始化日志系统

def_init_logging(self):logging.basicConfig(...)#Python标准日志transformers.utils.logging.set_verbosity_warning()#HuggingFace日志

作用配置Python日志系统HuggingFace日志级别

5_init_directories()初始化输出目录

def_init_directories(self):self.args.output_dir.mkdir(parents=True, exist_ok=True)

作用创建输出目录

四 准备阶段函数(fit()中按顺序调用)

3阶段设置检查模型准备

check_setting()检查训练设置

defcheck_setting(self):ifself.UPLOAD_LISTisNone:logger.warning("No unload_list specified...")

作用验证训练配置的合理性

检查UPLOAD_LIST是否有效防止配置错误

7prepare_models()准备模型组件

defprepare_models(self):ifself.components.vaeisnotNone:ifself.args.enable_slicing:#启用VAE切片self.components.vae.enable_slicing()

作用启用模型内存优化功能

功能VAE切片/分块保存Transformer配置

4阶段数据参数准备

prepare_dataset()准备数据集

defprepare_dataset(self):if self.args.model_type == "i2v":self.dataset = I2VDatasetWithResize(...)elif self.args.model_type == "t2v":self.dataset = T2VDatasetWithResize(...)self.data_loader = DataLoader(...)

作用根据模型类型创建对应数据集数据加载器

包含计算潜在表示

9prepare_trainable_parameters()准备训练参数

defprepare_trainable_parameters(self):ifself.args.training_type=="lora":transformer_lora_config = LoraConfig(...)self.components.transformer.add_adapter(transformer_lora_config)

作用根据训练类型LoRA/SFT设置参数梯度

功能配置LoRA适配器冻结/解冻参数

10prepare_optimizer()准备优化器

defprepare_optimizer(self):trainable_parameters=filter(lambda p: p.requires_grad, ...)self.optimizer=get_optimizer(...)self.lr_scheduler=get_scheduler(...)

作用创建优化器学习率调度器

功能配置AdamW优化器计算训练步数

5阶段训练最后准备

prepare_for_training()为训练最后准备

defprepare_for_training(self):ifself.args.use_perceptual_loss:self.lpips_loss=pyiqa.create_metric('lpips', ...)self.components.transformer=self.accelerator.prepare(...)

作用初始化感知损失模型使用accelerator准备组件

功能准备光流模型计算每个epoch更新步数

12prepare_for_validation准备验证

defprepare_for_validation(self)validation_videos = load_videos(...)validation_prompts = load_prompts(...)self.state.validation_prompts = validation_prompts

作用加载严重数据,初始化评估指标

仅当self.args.do_validationTrue调用

13prepare_trackers()准备实验跟踪

defprepare_trackers(self):

self.accelerator.init_trackers(...)

作用初始化Wandn/TensorBoard实验跟踪工具

五 核心训练和验证函数

6阶段训练循环相关

compute_loss(batch)抽象方法计算损失

defcompute_loss(self, batch)

raiseNotImplementedError# 子类必须实现

作用:相当于模型forward函数计算训练损失

实现子类调用Transformerforward方法

15collate_fn(examples)抽象方法批处理函数

defcollate_fn(self, examples: List[Dict[str, Any]]):

raiseNotImplementedError#子类必须实现

作用将多个样本组合成一个batch

功能处理不同长度序列添加padding

7阶段验证推理

validate(step)执行验证

defvalidate(self, step: int):pipe=self.initialize_pipeline()#创建推理管道foriinrange(num_validation_samples):result=self.validation_step(...)#生成样本evaluate_video_metrics(...)#评估指标

作用在训练过程定期运行验证

频率validation_steps运行一次

17validation_step()抽象方法单步验证

defvalidation_step(self):raiseNotImplementedError#子类必须实现

作用生成单个验证样本

输出图像视频列表

18 initialize_pipeline抽象方法创建推理管道

definitialize_pipeline(self)->DiffusionPipeline:raiseNotImplementedError#子类必须实现

作用创建用于推理DiffusionPipeline

用于验证阶段生成样本

六 编码相关函数,子类实现

19encode_video(video)抽象方法视频编码

def encode_video(self, video: torch.Tensor)->torch.Tensor:raiseNotImplementedError

作用视频编码潜在空间

输入[B,C,F,H,W]->输出[B,C',F',H',W']

20encode_text(text)抽象方法文本编码

defencode_text(self, text: str)->torch.Tensor:raise

作用将文本编码作为嵌入向量

输出:[batch_size, sequence_length, embedding_dim]

七 辅助和工具函数

私有辅助函数

21__get_training_dtype()获取训练数据类型

def__get_training_dtype(self):if self.args.mixed_precision == "no":return torch.float32elif self.args.mixed_precision == "fp16":return torch.float16

作用根据混合精度设置返回对应torch.dtype

22__move_components_to_devie()移动组件设备

def__move_components_to_device(self, dtype, ignore_list=[]):forname,componentincomponents.items():component.to(self.accelerator.device, dtype=dtype)

作用将模型组件移动到指定设备数据类型

23__move_components_to_cpu()移动组件CPU

def__move_components_to_cpu(self, upload_list=[]):forname,componentincomponents.items():component.to("cpu")

作用卸载组件CPU节省GPU内存

24__prepare_saving_loading_hooks()准备保存/加载钩子

def__prepare_saving_loading_hooks(self, transformer_lora_config):defsave_model_hook(models, weights, output_dir):#保存LoRA权重defload_model_hook(models, input_dir)#加载LoRA权重

作用LoRA训练注册自定义保存/加载钩子

25__maybe_save_checkpoint()可能保存检查点

def__maybe_save_checkpoint(self, global_step, must_save=False):if must_saveorglobal_step%self.args.checkpointing_steps==0:self.accelerator.save_state(save_path)

作用定期必须保存训练检查点

状态保存函数

26save_args_and_state保存参数状态

defsave_args_and_state(self):save_dict={"args": {...}, "state": {...}}withopen(output_path / "args.yaml", "w")asf:yaml.safe_dump(save_dict, f)

作用将训练参数状态保存YAML文件

27save_trainable_param_status()保存参数训练状态

defsave_trainable_param_status(self):param_status={}forname,paraminself.components.transformer.named_parameters():param_status[name] = "trainable" if param.requires_grad else "frozen"

作用保存哪些参数是可训练哪些冻结

八 训练流程图

初始化流程__init__()_init_distributed()#分布式环境_init_logging()#日志系统_init_directories()#输出目录训练流程fit()fit()check_setting()#检查设置prepare_models()#准备模型prepare_dataset()#准备数据prepare_trainable_parameters()#准备训练参数prepare_optimizer()#准备优化器prepare_for_training()#训练准备prepare_for_validation()验证准备prepare_trackers()#准备跟踪器train()#开始训练循环compute_loss()#相当于forward每个batch反向传播优化__maybe_save_checkpoint()#定期保存validate()#定期验证

九 关键总结

训练入口函数

主入口:fit()完整训练流程单一入口点

训练循环train()实际执行训练循环

前向传播:compute_loss()相当于模型forward()函数

为什么fit()而不是直接train()

1层次清晰fit()负责准备train()负责执行

2用于友好用于只需调用fit(),不用关心复杂准备过程

模型训练初始化入口

真正的训练初始化fit()方法开始正确顺序调用所有准备函数确保训练环境完全就绪才进入train()循环

普通PyTorch模型区别

#普通PyTorch模型model=MyModel()optimizer=Adam(model.parameters())loss_fn=nn.MSELoss()#需要用户自己训练循环forepochinrange(epochs):forbatchindataloader:output=model(batch):#调用forwardloss=loss_fn(output, target)loss.backward()optimizer.step()#Dove训练器封装完整流程trainer=Trainer(args)trainer.fit()#一键训练内部自动调用forward
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/19 3:05:46

AI 3D生成技术如何重塑数字创作新范式?

AI 3D生成技术如何重塑数字创作新范式? 【免费下载链接】Hunyuan3D-1 项目地址: https://ai.gitcode.com/hf_mirrors/tencent/Hunyuan3D-1 从手工雕琢到智能生成:3D建模的世纪难题 在数字内容爆炸式增长的时代,3D建模却始终保持着&q…

作者头像 李华
网站建设 2026/4/19 3:04:01

K8s HPA:自动扩缩容的终极指南

一、 HPA解决的问题HPA全称是 Horizontal Pod Autoscaler,也就是对k8s的workload的副本数进行自动水平扩缩容(scale)机制,也是k8s里使用需求最广泛的一种Autoscaler机制,在开始详细介绍HPA之前,先简单梳理下k8s autoscale的整个大…

作者头像 李华
网站建设 2026/4/19 3:04:04

慧荣SM32系列U盘量产工具全面解析:从入门到精通

还在为U盘批量生产而烦恼吗?🤔 慧荣SM32系列量产工具v20.02.04.21就是你的最佳选择!这款专业级工具专门针对SM3265AB、SM3271AB、SM3281AB、SM3281BB等主流芯片组设计,帮你轻松实现固件升级、坏块修复和格式化等核心功能。 【免费…

作者头像 李华
网站建设 2026/4/19 3:09:48

90亿参数打破720亿性能壁垒:GLM-4.1V-Base开启多模态推理新纪元

90亿参数打破720亿性能壁垒:GLM-4.1V-Base开启多模态推理新纪元 【免费下载链接】GLM-4.1V-9B-Base 项目地址: https://ai.gitcode.com/zai-org/GLM-4.1V-9B-Base 导语 智谱AI最新开源的GLM-4.1V-9B-Base多模态模型,以90亿参数规模在18项基准测…

作者头像 李华
网站建设 2026/4/19 3:13:53

MoveCertificate:Android设备证书管理终极指南

你是否遇到过在Android设备上安装抓包工具证书后,某些应用仍然无法正常识别的问题?这正是MoveCertificate项目要解决的核心痛点。作为一款支持Android 7到15系统的Magisk/KernelSU/APatch模块,它能将用户证书轻松移动到系统证书目录&#xff…

作者头像 李华
网站建设 2026/4/18 9:09:51

番鸽号快速原型:1小时验证产品创意

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容: 输入:创建一个电商产品展示页面的原型,包含产品图片轮播、价格展示、加入购物车按钮和用户评价区域。只需前端界面,不需要后端功能。要求设计简洁…

作者头像 李华