零基础吃透:TensorFlow稀疏张量(SparseTensor)的核心操作
稀疏张量无法直接使用tf.math.add等密集张量的算术算子,必须通过tf.sparse包下的专用工具进行操作。本文拆解加法、矩阵乘法、拼接、切片、元素级运算五大核心操作,结合示例讲清原理、用法和版本兼容细节。
前置准备(必运行)
importtensorflowastf# 复用之前的美观打印函数(调试必备)defpprint_sparse_tensor(st):s="<SparseTensor shape=%s \n values={"%(st.dense_shape.numpy().tolist(),)for(index,value)inzip(st.indices,st.values):s+=f"\n %s: %s"%(index.numpy().tolist(),value.numpy().tolist())returns+"}>"# 示例稀疏张量(后续操作会复用)st2=tf.sparse.from_dense([[1,0,0,8],[0,0,0,0],[0,0,3,0]])一、稀疏张量加法(tf.sparse.add)
核心原理
仅对同形状稀疏张量的「相同坐标非零值」相加,不同坐标的非零值直接保留,最终输出仍为稀疏张量(仅存储非零结果)。
示例代码
# 构造两个同形状的稀疏张量st_a=tf.sparse.SparseTensor(indices=[[0,2],[3,4]],values=[31,2],dense_shape=[4,10]# 4行10列)st_b=tf.sparse.SparseTensor(indices=[[0,2],[3,0]],values=[56,38],dense_shape=[4,10]# 必须与st_a形状一致)# 稀疏张量加法st_sum=tf.sparse.add(st_a,st_b)print("稀疏张量相加结果:")print(pprint_sparse_tensor(st_sum))输出解读
<SparseTensor shape=[4, 10] values={ [0, 2]: 87 # st_a[0,2]=31 + st_b[0,2]=56 [3, 0]: 38 # 仅st_b有该坐标,直接保留 [3, 4]: 2 # 仅st_a有该坐标,直接保留 }>关键注意事项
- ❌ 形状不同会报错:必须保证
dense_shape完全一致; - ✅ 结果仅保留非零值:若相加后某坐标值为0(如
st_a[0,2]=-56 + st_b[0,2]=56),会被过滤出结果。
二、稀疏×密集矩阵乘法(tf.sparse.sparse_dense_matmul)
核心原理
稀疏张量作为矩阵(需满足矩阵乘法的形状规则),与密集矩阵相乘,无需转换为密集张量,大幅节省内存(超稀疏矩阵效率提升显著)。
示例代码
# 构造2×2的稀疏矩阵(非零值:[0,1]=13,[1,0]=15,[1,1]=17)st_c=tf.sparse.SparseTensor(indices=[[0,1],[1,0],[1,1]],# 注意:原代码的indices写法有误,修正为列表格式values=[13,15,17],dense_shape=(2,2))# 构造2×1的密集矩阵mb=tf.constant([[4],[6]])# 稀疏×密集矩阵乘法product=tf.sparse.sparse_dense_matmul(st_c,mb)print("\n稀疏×密集矩阵乘法结果:")print(product)计算逻辑(验证结果)
矩阵乘法规则:C × B = [ (0×4+13×6), (15×4+17×6) ]^T
- 第一行:
0×4 + 13×6 = 78 - 第二行:
15×4 + 17×6 = 60 + 102 = 162
输出解读
tf.Tensor( [[ 78] [162]], shape=(2, 1), dtype=int32)关键注意事项
- 形状规则:稀疏张量的列数 = 密集矩阵的行数(如2×2 × 2×1 合法);
- 索引顺序:建议先通过
tf.sparse.reorder排序稀疏张量索引,避免运算异常。
三、稀疏张量拼接(tf.sparse.concat)
核心原理
沿指定轴(如列轴axis=1)拼接多个稀疏张量,要求除拼接轴外的其他轴形状一致,最终输出合并后的稀疏张量。
示例代码
# 构造3个待拼接的稀疏张量(行维度均为8,列维度不同)sparse_pattern_A=tf.sparse.SparseTensor(indices=[[2,4],[3,3],[3,4],[4,3],[4,4],[5,4]],values=[1]*6,dense_shape=[8,5]# 8行5列)sparse_pattern_B=tf.sparse.SparseTensor(indices=[[0,2],[1,1],[1,3],[2,0],[2,4],[2,5],[3,5],[4,5],[5,0],[5,4],[5,5],[6,1],[6,3],[7,2]],values=[1]*14,dense_shape=[8,6]# 8行6列)sparse_pattern_C=tf.sparse.SparseTensor(indices=[[3,0],[4,0]],values=[1]*2,dense_shape=[8,6]# 8行6列)# 沿列轴(axis=1)拼接sparse_pattern=tf.sparse.concat(axis=1,# 列轴拼接(行轴保持8不变)sp_inputs=[sparse_pattern_A,sparse_pattern_B,sparse_pattern_C])# 转换为密集张量查看拼接结果print("\n拼接后的密集张量:")print(tf.sparse.to_dense(sparse_pattern))输出解读
拼接后形状为8×(5+6+6)=8×17,非零值按原位置分布在对应列区间:
- A的非零值在列0~4;
- B的非零值在列5~10;
- C的非零值在列11~16。
关键注意事项
- 拼接轴外的维度必须一致:如示例中所有张量的行维度均为8,仅列维度不同;
- 拼接后非零值位置:原张量的列索引自动偏移(如B的列0→拼接后的列5)。
四、稀疏张量切片(tf.sparse.slice)
核心原理
沿指定轴截取稀疏张量的子区域,仅保留「切片范围内的非零值」,输出新的稀疏张量(形状为指定的size)。
函数参数
| 参数 | 作用 |
|---|---|
start | 切片起始坐标(列表/张量),长度=张量秩(如[0,0]表示行0列0开始) |
size | 切片大小(列表/张量),长度=张量秩(如[8,5]表示截取8行5列) |
示例代码
# 对拼接后的张量切片(还原原张量)sparse_slice_A=tf.sparse.slice(sparse_pattern_A,start=[0,0],size=[8,5])sparse_slice_B=tf.sparse.slice(sparse_pattern_B,start=[0,5],size=[8,6])sparse_slice_C=tf.sparse.slice(sparse_pattern_C,start=[0,10],size=[8,6])# 打印切片结果(转密集张量)print("\n切片A(8×5):")print(tf.sparse.to_dense(sparse_slice_A))print("\n切片B(8×1):")# 原B的start=[0,5],size=[8,6]但仅列5有值,故输出8×1print(tf.sparse.to_dense(sparse_slice_B))print("\n切片C(8×0):")# 无符合条件的非零值,输出空print(tf.sparse.to_dense(sparse_slice_C))输出解读
切片A(8×5): [[0 0 0 0 0] [0 0 0 0 0] [0 0 0 0 1] [0 0 0 1 1] [0 0 0 1 1] [0 0 0 0 1] [0 0 0 0 0] [0 0 0 0 0]] 切片B(8×1): [[0] [0] [1] [1] [1] [1] [0] [0]] 切片C(8×0): []关键注意事项
- 切片范围外的非零值会被过滤:如切片B仅截取列5,原B的其他列非零值被丢弃;
- 空切片:无符合条件的非零值时,输出
shape=(8,0)的空稀疏张量。
五、元素级运算(仅修改非零值)
场景:对稀疏张量的所有非零值做统一运算(如+5)
方式1:TF2.4+ 专用(tf.sparse.map_values)
tf.sparse.map_values专门对稀疏张量的values(非零值)做元素级运算,零值保持不变。
# 对st2的非零值+5st2_plus_5=tf.sparse.map_values(tf.add,st2,5)print("\nTF2.4+ 非零值+5(密集张量):")print(tf.sparse.to_dense(st2_plus_5))方式2:TF2.4前 兼容方案
手动构造新的SparseTensor,仅修改values,保留indices和dense_shape。
# 老版本兼容写法:直接修改valuesst2_plus_5_compat=tf.sparse.SparseTensor(st2.indices,# 保留原坐标st2.values+5,# 非零值+5st2.dense_shape# 保留原形状)print("\n老版本兼容 非零值+5(密集张量):")print(tf.sparse.to_dense(st2_plus_5_compat))输出解读(两种方式结果一致)
[[ 6 0 0 13] [ 0 0 0 0] [ 0 0 8 0]]- 仅非零值被修改:原
1→6、8→13、3→8; - 零值保持不变:符合稀疏张量的设计初衷(仅操作有效数据)。
核心操作总结表
| 操作 | 函数 | 核心要求 | 适用场景 |
|---|---|---|---|
| 稀疏加法 | tf.sparse.add | 张量形状完全一致 | 同形状稀疏张量逐坐标相加 |
| 稀疏-密集矩阵乘法 | tf.sparse.sparse_dense_matmul | 稀疏列数=密集行数 | 超稀疏矩阵与密集矩阵相乘 |
| 稀疏拼接 | tf.sparse.concat | 非拼接轴形状一致 | 合并多个稀疏张量的列/行 |
| 稀疏切片 | tf.sparse.slice | start/size长度=张量秩 | 截取稀疏张量的子区域 |
| 元素级运算 | tf.sparse.map_values | TF2.4+,仅修改非零值 | 对非零值做统一算术运算(+/-/*//) |
避坑关键
- 形状匹配:所有稀疏张量操作的核心是「形状兼容」,形状不匹配会直接报错;
- 索引顺序:运算前建议用
tf.sparse.reorder排序索引,避免算子异常; - 版本兼容:
tf.sparse.map_values仅TF2.4+支持,老版本需手动修改values; - 零值处理:所有操作均仅处理非零值,零值始终保持隐式存储(不占用内存)。
掌握这些操作,就能高效处理NLP(TF-IDF)、计算机视觉(稀疏像素)等场景下的超稀疏数据,大幅降低内存占用和计算开销。