news 2026/4/29 18:05:06

TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(一)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(一)

tf.keras.losses.SparseCategoricalCrossentropy核心原理

SparseCategoricalCrossentropy(稀疏类别交叉熵)是 TensorFlow/Keras 中针对多分类任务的损失函数,专为稀疏标签(整数型标签,如0,1,2)设计,核心作用是衡量模型输出的类别概率分布与真实稀疏标签的「差异」,本质是交叉熵(Cross-Entropy)在稀疏标签场景下的优化实现。

一、先理解核心背景:交叉熵的本质

交叉熵源于信息论,用于衡量两个概率分布的「距离」(差异程度)。对于多分类任务:

  • 真实标签的分布是「one-hot 分布」(比如 3 分类中标签为 1,对应分布是[0,1,0]);
  • 模型输出是类别概率分布(经 Softmax 归一化后,和为 1,如[0.1,0.8,0.1])。

交叉熵的公式为:
H(p,q)=−∑i=1Cp(i)log⁡(q(i)) H(p,q) = -\sum_{i=1}^C p(i) \log(q(i))H(p,q)=i=1Cp(i)log(q(i))
其中:

  • ppp:真实标签的概率分布(one-hot 形式,仅目标类别为 1,其余为 0);
  • qqq:模型预测的类别概率分布;
  • CCC:类别总数。

由于ppp是 one-hot 分布,交叉熵可简化为:仅取目标类别对应的预测概率的负对数(因为其他项都是0×log⁡(q(i))=00 \times \log(q(i))=00×log(q(i))=0)。

二、SparseCategoricalCrossentropy 的核心适配:稀疏标签

普通的CategoricalCrossentropy要求标签是one-hot 编码(如 3 分类标签 1 对应[0,1,0]),而SparseCategoricalCrossentropy直接支持整数型稀疏标签(如 1),无需手动 one-hot 编码,核心优势是节省内存(尤其是类别数多的场景,比如 1000 类时,稀疏标签仅存 1 个整数,one-hot 需存 1000 维向量)。

三、完整计算逻辑(分两种场景)

SparseCategoricalCrossentropy的关键参数是from_logits(默认False),决定模型输出是否为「原始 logits(未归一化的得分)」或「Softmax 归一化后的概率」,两种场景的计算逻辑不同(TensorFlow 内部做了优化,避免数值不稳定)。

场景 1:from_logits=False(默认,模型输出是 Softmax 概率)

假设:

  • 类别数C=3C=3C=3
  • 真实稀疏标签y=1y=1y=1(对应目标类别是第 2 类,索引从 0 开始);
  • 模型输出 Softmax 概率q=[0.1,0.8,0.1]q=[0.1, 0.8, 0.1]q=[0.1,0.8,0.1]

计算步骤:

  1. 取真实标签对应的概率:q(y)=q(1)=0.8q(y)=q(1)=0.8q(y)=q(1)=0.8
  2. 计算负对数:−log⁡(q(y))=−log⁡(0.8)≈0.223-\log(q(y)) = -\log(0.8) ≈ 0.223log(q(y))=log(0.8)0.223
  3. 最终损失值即为该结果(批量数据会取均值/求和,由reduction参数控制)。

公式简化为:
loss=−log⁡(q(y)) \text{loss} = -\log(q(y))loss=log(q(y))

场景 2:from_logits=True(模型输出是原始 logits,推荐!)

模型输出的是未经过 Softmax 归一化的原始得分(logits,如z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]),此时 TensorFlow 不会先单独计算 Softmax(避免数值下溢/上溢),而是直接用log_softmax优化计算:

  1. 对 logits 计算log_softmaxlog⁡(Softmax(z))=z−log⁡(∑i=1Cezi)\log(\text{Softmax}(z)) = z - \log(\sum_{i=1}^C e^{z_i})log(Softmax(z))=zlog(i=1Cezi)
  2. 取真实标签对应的项,取负数即为损失:
    loss=−(zy−log⁡(∑i=1Cezi)) \text{loss} = - \left( z_y - \log(\sum_{i=1}^C e^{z_i}) \right)loss=(zylog(i=1Cezi))

示例计算(z=[1.0,3.0,0.5],y=1z=[1.0, 3.0, 0.5], y=1z=[1.0,3.0,0.5],y=1):

  • 先算∑ezi=e1.0+e3.0+e0.5≈2.718+20.085+1.648≈24.451\sum e^{z_i} = e^{1.0} + e^{3.0} + e^{0.5} ≈ 2.718 + 20.085 + 1.648 ≈ 24.451ezi=e1.0+e3.0+e0.52.718+20.085+1.64824.451
  • log⁡(24.451)≈3.200\log(24.451) ≈ 3.200log(24.451)3.200
  • log⁡(Softmax(z))1=3.0−3.200=−0.200\log(\text{Softmax}(z))_1 = 3.0 - 3.200 = -0.200log(Softmax(z))1=3.03.200=0.200
  • 损失值:−(−0.200)=0.200-(-0.200) = 0.200(0.200)=0.200

为什么推荐from_logits=True
Softmax 对大 logits 会产生e大值e^{大值}e大值(如e100e^{100}e100溢出),而log_softmax直接通过代数变换避免了单独计算 Softmax,提升数值稳定性。

四、批量数据的损失归约

实际训练中输入是批量数据(batch),损失会通过reduction参数归约(默认AUTO,等价于SUM_OVER_BATCH_SIZE):

  • 对每个样本计算损失值;
  • 求批量内所有样本损失的均值(或求和,取决于reduction)。

示例(batch_size=2):

样本稀疏标签模型概率单样本损失
11[0.1,0.8,0.1]0.223
20[0.9,0.05,0.05]0.105
批量损失 = (0.223 + 0.105) / 2 ≈ 0.164。

五、关键参数解析

参数作用示例
from_logits是否输入为原始 logits(非 Softmax 概率)from_logits=True(推荐)
reduction损失归约方式:
-NONE:返回每个样本的损失
-SUM:批量损失求和
-SUM_OVER_BATCH_SIZE:批量损失求均值
reduction="sum_over_batch_size"
ignore_index忽略指定标签(计算损失时跳过),适用于样本标注缺失场景ignore_index=-1
axis类别维度(默认 -1,即最后一维是类别)模型输出形状(batch, 3)时,axis=-1 对应 3 个类别

六、与CategoricalCrossentropy的对比

特性SparseCategoricalCrossentropyCategoricalCrossentropy
标签格式整数型稀疏标签(如 1,2,3)one-hot 编码标签(如 [0,1,0])
内存占用低(仅存整数)高(类别数维向量)
适用场景类别数多、标签天然为整数(如图像分类的类别索引)标签已做 one-hot 编码
核心公式同交叉熵,但直接取整数标签对应项交叉熵原始公式(遍历所有类别)

七、注意事项

  1. 标签范围:稀疏标签必须是[0,C−1][0, C-1][0,C1]范围内的整数(C 是类别数),否则会报错;
  2. 数值稳定性:优先设置from_logits=True,避免 Softmax 导致的数值溢出;
  3. 多标签任务:该损失适用于「单标签多分类」(每个样本仅属于一个类别),多标签任务需用BinaryCrossentropy

示例代码验证

importtensorflowastf# 1. 定义损失函数(from_logits=True,模型输出logits)loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)# 2. 模拟批量数据(batch_size=2,类别数=3)y_true=tf.constant([1,0])# 稀疏标签y_pred_logits=tf.constant([[1.0,3.0,0.5],[5.0,1.0,0.1]])# 模型输出logits# 3. 计算损失loss=loss_fn(y_true,y_pred_logits)print("批量损失值:",loss.numpy())# 输出约 0.15(手动计算验证)

综上,SparseCategoricalCrossentropy本质是「多分类交叉熵」在稀疏标签下的高效实现,核心是通过直接索引整数标签避免 one-hot 编码,同时优化数值计算保证稳定性,是单标签多分类任务的首选损失函数之一。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/18 21:56:26

ViGEmBus虚拟手柄驱动:从零开始的完整配置终极指南

想要在Windows系统中获得专业级的游戏控制体验吗?ViGEmBus虚拟手柄驱动为你打开了全新的技术大门!这款强大的内核级驱动程序能够完美重现Xbox 360和DualShock 4游戏控制器,让你无需任何硬件改造就能享受真正的游戏控制自由。 【免费下载链接】…

作者头像 李华
网站建设 2026/4/29 15:47:19

猫抓浏览器扩展终极指南:一键搞定网页视频下载与M3U8解析

猫抓浏览器扩展终极指南:一键搞定网页视频下载与M3U8解析 【免费下载链接】cat-catch 猫抓 chrome资源嗅探扩展 项目地址: https://gitcode.com/GitHub_Trending/ca/cat-catch 还在为无法保存喜欢的在线视频而烦恼吗?猫抓浏览器扩展正是你需要的完…

作者头像 李华
网站建设 2026/4/27 14:50:58

3分钟解决:网易云NCM加密格式的转换方案

3分钟解决:网易云NCM加密格式的转换方案 【免费下载链接】ncmdump 项目地址: https://gitcode.com/gh_mirrors/ncmd/ncmdump 你是否曾经遇到过这样的困扰:从网易云音乐下载的歌曲只能在特定客户端播放,换个设备就变成了"无法播放…

作者头像 李华
网站建设 2026/4/26 3:31:57

【稀缺资源】Open-AutoGLM开源首发:掌握下一代AutoGLM引擎的3个关键步骤

第一章:Open-AutoGLM开源首发背景与意义随着大语言模型在自动化任务中的广泛应用,构建高效、可扩展的智能代理系统成为前沿研究重点。Open-AutoGLM作为首个开源的AutoGLM实现框架,旨在复现并拓展GLM系列模型在自主决策、多步推理与工具调用方…

作者头像 李华
网站建设 2026/4/18 5:18:08

Open-AutoGLM核心技术全景图(仅限资深AI架构师阅读的深度解析)

第一章:Open-AutoGLM技术原理图Open-AutoGLM 是一种面向自动化自然语言任务的开源大语言模型框架,其核心在于融合生成式语言建模与任务自适应机制。该架构通过动态路由策略,在多专家模块(MoE)之间分配计算资源&#xf…

作者头像 李华
网站建设 2026/4/28 5:23:01

Open-AutoGLM上手机到底难不难?3个关键技术突破让你立刻上手

第一章:Open-AutoGLM上手机的现状与挑战随着大模型技术在移动端的加速落地,Open-AutoGLM作为一款面向轻量化推理与自动化任务处理的开源语言模型,正逐步进入智能手机的应用生态。然而,其在移动设备上的部署仍面临多重挑战&#xf…

作者头像 李华