news 2026/6/9 17:42:30

DiffSynth-Studio训练踩坑记录

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
DiffSynth-Studio训练踩坑记录

DiffSynth-Studio训练踩坑记录:PyTorch 2.5.1 + Meta Tensor + 新增模块 + strict=True 导致的加载失败

环境:

  • PyTorch 2.5.1
  • DiffSynth-Studio / Wan2.1-T2V-1.3B
    任务:在官方 WanVideo 模型基础上增加模块,继续训练 LoRA

这篇文章记录一次在 WanVideo 训练过程中遇到的模型加载问题,涉及到:

  • PyTorch 2.5.1 的 meta tensor 机制;
  • 使用strict=True加载权重导致的结构不匹配;
  • 给模型增加新模块后如何正确加载旧 checkpoint。

最后训练已经成功跑起来,这里把整个排查和修复过程整理一下。


1. 场景简介:在 WanVideo 上加模块继续训练

训练命令大致如下:

nohupbashexamples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh\>wan2.1-1.3B.log2>&1&

项目中使用:

self.pipe=WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16,device=device,model_configs=model_configs,tokenizer_config=tokenizer_config,audio_processor_config=audio_processor_config,)

来下载并加载 Wan2.1-T2V-1.3B 模型。

我对模型做了一个改动:在原有模型的基础上增加了一些新的模块(新的层),希望用原模型的权重作为初始化,继续训练 LoRA。

自然地,旧 checkpoint 中没有新模块对应的权重,这为后面的报错埋下了伏笔。


2. 第一层坑:strict=True 导致结构不匹配直接报错

最开始,代码里加载权重用的是默认的strict=True,类似:

model.load_state_dict(state_dict)# 默认 strict=True

当我在模型结构上增加模块之后,这些新增层的参数在 checkpoint 中不存在。
strict=True的情况下,load_state_dict的行为是:

  • 模型里有,但state_dict里没有 → 归为missing_keys,直接报错;
  • state_dict里有,但模型里没有 → 归为unexpected_keys,也会报错。

也就是说,只要你对模型结构进行了增删改,strict=True 会让加载必然失败

正确做法应该是:

  • 改成strict=False
    load_info=model.load_state_dict(state_dict,strict=False)print("missing:",load_info.missing_keys)print("unexpected:",load_info.unexpected_keys)
  • 允许结构不完全一致;
  • missing_keys/unexpected_keys明确看到哪些参数没加载上、哪些是 checkpoint 里多余的。

后面在修复时,我把这一步和 meta 问题一起处理了,最终使用:

load_info=model.load_state_dict(state_dict,assign=True,strict=False)

下面先解释第二层坑:meta tensor。


3. 什么是 meta tensor?先理解再排错

在这次问题里,最关键的一条报错是:

NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

这里的 “meta tensor” 其实是 PyTorch 提供的一种特殊机制:meta device

3.1 meta 是什么?

可以简单理解为:

“只有形状和 dtype 信息、没有真实数据、不占显存的假张量,用来先搭模型结构、后装权重。”

比如:

withtorch.device("meta"):w=torch.empty(16,16)print(w.device)# device(type='meta')print(w.shape)# torch.Size([16, 16])print(w.is_meta)# True

特点:

  • 有 shape、有 dtype;
  • w.is_meta == True
  • 不占 GPU/CPU 实际内存;
  • 但不能做任何需要真实数据的操作,例如:
    • w.to("cuda")(要拷贝数据);
    • w + 1(要访问数据);
    • 卷积、matmul 等计算。

3.2 为什么要有 meta?

大模型构建时,如果立刻在 GPU 上分配全部参数,很容易 OOM。
因此很多框架采用“空权重 / meta device”技术:

withtorch.device("meta"):model=MyBigModel()# 所有参数都是 meta tensor

这一步只搭结构,不占真实显存。
后面再按策略,把参数真正“实例化”到 GPU / CPU 并加载权重。

3.3 和这次报错的关系

在 DiffSynth-Studio 里,构建 WanVideo 模型时就用了 meta 技术。
也就是说,模型刚被创建时,参数是is_meta=True,还没有真实数据。

而在diffsynth/core/loader/model.py中,当时的代码类似:

model=model.to(dtype=torch_dtype,device=device)

这里直接对 meta 模型调用.to(),PyTorch 会尝试:

从 meta 上“拷贝”数据到目标设备。

但 meta 上根本没有数据,所以就抛出了:

Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() ...

官方的意思是:

对 meta 模型,不要直接.to(),而要用to_empty(device=...)先在目标设备上创建“空参数容器”,再用load_state_dict填入真正的 checkpoint 数据。


4. 正确的迁移方式:to_empty+load_state_dict

PyTorch 2.5.1 下的推荐做法如下:

  1. 在 meta 设备上构建模型:

    withtorch.device("meta"):model=MyModel()
  2. 不要直接model.to("cuda"),而是:

    model=model.to_empty(device="cuda")# 只分配空参数容器
  3. 然后用 checkpoint 填充:

    state_dict=torch.load("xx.pth",map_location="cpu")model.load_state_dict(state_dict)

5. 把 meta + strict + 新增模块 这三件事串起来

结合前面的分析,这次问题本质上是三件事叠加

  1. 模型是通过 meta 方式构建的;
  2. 我对模型结构做了修改(增加了模块);
  3. 一开始用的是strict=True,随后在 meta 状态下直接调用了.to()

想要一劳永逸地解决,需要做到:

  • 判断是否有 meta 参数;
  • 对 meta 模型用to_empty(device=...)而不是.to()
  • load_state_dict(..., strict=False)
  • 打印missing_keys/unexpected_keys
  • 初始化新增模块。

最终,我在load_model中整理出的核心逻辑大概是这样:

# 1. 判断是否有 meta 参数has_meta_param=any(p.is_metaforpinmodel.parameters())ifhas_meta_param:# 2. meta 模型:使用 to_empty 迁移到目标设备# 注意:PyTorch 2.5.1 要用关键字参数 device=model=model.to_empty(device=device)# 3. 之后再统一 dtype(此时已经不是 meta 了,可以正常 .to)iftorch_dtypeisnotNone:model=model.to(dtype=torch_dtype)else:# 非 meta 模型:直接 .to 即可model=model.to(dtype=torch_dtype,device=device)# 4. 加载 checkpoint(非严格模式 + assign=True)load_info=model.load_state_dict(state_dict,assign=True,strict=False)missing=load_info.missing_keys unexpected=load_info.unexpected_keysifmissing:print("未加载到的参数 (missing_keys):")forkinmissing:print(" ",k)ifunexpected:print("多余的权重 (unexpected_keys)(state_dict 中有,但模型中没有):")forkinunexpected:print(" ",k)ifnotmissingandnotunexpected:print("所有参数均已加载")

几点注意:

  • to_empty在 2.5.1 中不支持dtype=参数,所以用:
    model=model.to_empty(device=device)model=model.to(dtype=torch_dtype)
  • strict=False是为了允许新增模块不在 checkpoint 中,避免硬报错。
  • assign=True用的是 PyTorch 2.x 的新行为:直接让参数引用指向state_dict中的 tensor,减少一次拷贝。
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/7 7:49:43

7、OpenWrt第三方固件使用指南

OpenWrt第三方固件使用指南 1. 配置和使用OpenWrt 在Linksys设备上安装OpenWrt后,可通过GUI或命令行进行配置。强烈建议使用命令行,它功能强大,便于实现高级配置,而GUI功能有限,仅能进行基本设置。OpenWrt命令行基于Linux/UNIX,由BusyBox处理,它是一个小而强大的可执行…

作者头像 李华
网站建设 2026/6/7 22:02:02

在调度的花园里面挖呀挖

上文使用koordinator演示gang-scheduling和binpack调度, 已经生效。4个2卡Pod龟缩在一个节点,另外一个2卡Pod被挤到另外一个节点(每节点上虚拟gpu:8卡)。此时我们再尝试申请8卡作业,pod会Pending状态。但一…

作者头像 李华
网站建设 2026/6/9 4:16:58

万亿参数Kimi K2大语言模型:如何3分钟完成快速部署的完整指南

在人工智能技术飞速发展的今天,开发者和研究者面临着一个共同的挑战:如何在有限的计算资源下部署和运行万亿参数级别的大语言模型?Moonshot AI最新开源的Kimi-K2-Base模型以其1万亿总参数和320亿激活参数的混合专家架构,为这一难题…

作者头像 李华
网站建设 2026/6/9 2:18:50

Boost 电路右半平面零点 (RHPZ) 的仿真与解析

. 右半平面零点 (RHPZ) 来源解析 Boost 电路的传递函数为: H ( s ) V g D ′ 2 D ′ 2 R − s L s 2 L C R s L R D ′ 2 该传递函数的零点位于 s D ′ 2 R L ,由于零点符号为正,因此属于右半平面零点。 为了…

作者头像 李华
网站建设 2026/6/8 17:13:37

C++内存管理相关面试题图解

用香蕉尝试制作了一些面试题图解,主要是跟C的内存管理有关,方便大家更好地理解这些概念和准备相关的面试。有些文字生成的不够准确,但是基本上还是能够认出来,见谅。

作者头像 李华