news 2026/4/14 18:13:13

动手学深度学习——RNN从零开始实现

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
动手学深度学习——RNN从零开始实现

1. 前言

上一篇我们已经从概念上认识了RNN(循环神经网络)

  • 它适合处理序列数据

  • 它会把过去的信息压缩进隐藏状态

  • 当前时刻的计算依赖“当前输入 + 上一时刻隐藏状态”

  • 它非常适合做语言模型

但如果只停留在公式层面,理解还不够扎实。
这一节就正式进入:

RNN 从零开始实现

这里的“从零开始”,不是说手搓底层矩阵乘法库,
而是指:

不用现成的nn.RNN封装层,而是自己把 RNN 的参数、前向传播、状态更新、预测过程一步一步写出来。

这样做的意义非常大,因为你会真正看清楚:

  • 输入是怎么进来的

  • 隐藏状态是怎么更新的

  • 输出是怎么得到的

  • 语言模型为什么能逐步生成字符


2. 从零实现要解决哪些问题

如果把这一节拆开看,其实主要要解决 5 件事:

2.1 如何表示输入

文本已经被预处理成 token 索引序列,但 RNN 不能直接把“整数编号”当成线性特征用,所以要先做表示转换。

2.2 如何初始化参数

RNN 的输入到隐藏层、隐藏层到隐藏层、隐藏层到输出层,都需要权重矩阵和偏置。

2.3 如何初始化隐藏状态

在序列开始时,模型还没有任何历史信息,所以需要一个初始隐藏状态。

2.4 如何写前向传播

也就是按时间步循环,逐步更新隐藏状态并输出预测结果。

2.5 如何根据模型生成新文本

语言模型训练好后,要能根据前缀一步一步往后预测字符。


3. 为什么输入不能直接用整数索引

假设词表里有:

  • a -> 0

  • b -> 1

  • c -> 2

如果直接把这些编号当成输入数值,那么模型可能会误以为:

  • 21“更大”

  • 10距离更近

但实际上,字符编号本身只是一个离散标签,没有这种数值意义。

所以在 RNN 从零实现里,最常见的做法是:

把索引转成 one-hot 向量

这样每个 token 都被表示成一个稀疏的离散向量。


4. 什么是 one-hot 表示

假设词表大小是 4,字符集合为:

[a, b, c, d]

那么:

  • a可以表示成[1, 0, 0, 0]

  • b表示成[0, 1, 0, 0]

  • c表示成[0, 0, 1, 0]

  • d表示成[0, 0, 0, 1]

这就叫 one-hot 编码。

它的特点是:

  • 只有一个位置是 1

  • 其余位置全是 0

  • 不会人为引入“编号大小规律”

所以它非常适合表示离散 token。


5. 代码里怎么做 one-hot

李沐这里常见代码大致是:

import torch from torch import nn from d2l import torch as d2l F = ერთ?

不过真正常见写法是直接用 PyTorch 的 one_hot:

X = torch.arange(10).reshape((2, 5)) F.one_hot(X.T, 28).shape

它的关键点在于:

  • X是索引张量

  • 28是词表大小

  • 输出会把每个索引变成一个 one-hot 向量

如果输入形状是:

(batch_size, num_steps)

那么转成 one-hot 后通常会变成:

(num_steps, batch_size, vocab_size)

这是后面手写 RNN 前向传播的重要输入格式。


6. 为什么常把时间步维度放前面

在手写 RNN 时,经常会把输入整理成:

(num_steps, batch_size, vocab_size)

而不是:

(batch_size, num_steps, vocab_size)

这样做是因为前向传播时,我们通常会:

沿着时间维一列一列地循环

也就是说,先处理第 1 个时间步,再处理第 2 个时间步,再处理第 3 个时间步。

如果时间维在最前面,那么写循环会更自然。


7. 先初始化模型参数

RNN 从零实现最核心的一步,就是自己初始化参数。

基础 RNN 的公式是:

H_t = tanh(X_t W_xh + H_{t-1} W_hh + b_h) Y_t = H_t W_hq + b_q

所以参数至少要包括:

输入到隐藏层

W_xh

隐藏层到隐藏层

W_hh

隐藏层偏置

b_h

隐藏层到输出层

W_hq

输出层偏置

b_q

8. 参数初始化代码怎么写

常见写法类似这样:

def get_params(vocab_size, num_hiddens, device): num_inputs = num_outputs = vocab_size def normal(shape): return torch.randn(size=shape, device=device) * 0.01 W_xh = normal((num_inputs, num_hiddens)) W_hh = normal((num_hiddens, num_hiddens)) b_h = torch.zeros(num_hiddens, device=device) W_hq = normal((num_hiddens, num_outputs)) b_q = torch.zeros(num_outputs, device=device) params = [W_xh, W_hh, b_h, W_hq, b_q] for param in params: param.requires_grad_(True) return params

这段代码是整个“从零实现”最关键的基础之一。


9. 这段参数代码怎么理解

9.1num_inputs = vocab_size

因为我们输入的是 one-hot 向量,它的长度就是词表大小。

9.2num_outputs = vocab_size

因为这是字符级语言模型,输出通常也是对整个词表做分类,所以输出维度也等于词表大小。

9.3num_hiddens

表示隐藏状态维度,也就是模型内部记忆空间的大小。

9.4 小随机数初始化

权重通常初始化成较小随机值,防止一开始数值太大。

9.5requires_grad_(True)

表示这些参数后续需要参与梯度更新。


10. 如何初始化隐藏状态

RNN 在处理第一个时间步之前,没有过去信息。
所以通常会把隐藏状态初始化成全零。

常见写法如下:

def init_rnn_state(batch_size, num_hiddens, device): return (torch.zeros((batch_size, num_hiddens), device=device), )

这里返回的是一个元组,里面放一个张量。

为什么写成元组?

因为后面更复杂的循环模型,比如 LSTM,会有多个状态量。
所以这里提前统一接口形式。


11. 为什么隐藏状态形状是(batch_size, num_hiddens)

因为:

  • 每个 batch 中有batch_size条序列

  • 每条序列在当前时刻都要有自己的隐藏状态

  • 每个隐藏状态长度是num_hiddens

所以形状就是:

(batch_size, num_hiddens)

这表示:

每个样本各自维护一份当前时刻的记忆表示。


12. 手写 RNN 前向传播

这是这一节最核心的代码。

常见写法如下:

def rnn(inputs, state, params): W_xh, W_hh, b_h, W_hq, b_q = params H, = state outputs = [] for X in inputs: H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h) Y = torch.mm(H, W_hq) + b_q outputs.append(Y) return torch.cat(outputs, dim=0), (H,)

这段代码基本就是“RNN 从零实现”的灵魂。


13. 这段前向传播逐行理解

13.1 拆参数

W_xh, W_hh, b_h, W_hq, b_q = params

把前面初始化好的参数取出来。

13.2 取出当前隐藏状态

H, = state

这里state是一个元组,所以要这样解包。

13.3 准备输出列表

outputs = []

因为每个时间步都会产生一个输出,所以先放到列表里。

13.4 按时间步循环

for X in inputs:

这里每次拿到的是某一个时间步的输入,形状通常是:

(batch_size, vocab_size)

14. 隐藏状态更新公式对应哪一行

这一句:

H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)

就是基础 RNN 的状态更新公式。

它表示:

  • 当前输入X

  • 乘上输入到隐藏层权重W_xh

  • 再加上上一隐藏状态H

  • 乘上隐藏到隐藏权重W_hh

  • 再加偏置b_h

  • 最后经过tanh

得到新的隐藏状态。

这一步就是“记忆更新”。


15. 输出怎么得到

这一句:

Y = torch.mm(H, W_hq) + b_q

表示把当前隐藏状态映射到输出空间。

因为这是语言模型,所以这里的输出维度通常就是:

vocab_size

也就是说,Y可以理解为:

当前时刻对词表中每个字符/词元的打分

之后通过 softmax,就能转成概率分布。


16. 为什么最后要torch.cat(outputs, dim=0)

因为outputs列表里装的是每个时间步的输出:

  • 第 1 步一个Y

  • 第 2 步一个Y

  • 第 3 步一个Y

而训练时通常希望把它们拼成一个二维张量,方便统一计算损失。

所以:

torch.cat(outputs, dim=0)

会把所有时间步输出拼到一起。

如果每个Y形状是:

(batch_size, vocab_size)

那最后拼起来大致就是:

(num_steps * batch_size, vocab_size)

17. 为什么返回(H,)

因为前向传播结束后,我们还需要保留最后一个时间步的隐藏状态,
方便下一段序列接着用。

所以返回值一般是:

  • 所有时间步输出

  • 最终隐藏状态

写成:

return outputs, (H,)

是为了和后续 GRU、LSTM 保持接口统一。


18. 封装一个从零实现的 RNN 模型类

为了方便调用,通常会把参数初始化、前向传播、状态初始化封装到类里。

常见形式大概是:

class RNNModelScratch: def __init__(self, vocab_size, num_hiddens, device, get_params, init_state, forward_fn): self.vocab_size, self.num_hiddens = vocab_size, num_hiddens self.params = get_params(vocab_size, num_hiddens, device) self.init_state, self.forward_fn = init_state, forward_fn def __call__(self, X, state): X = F.one_hot(X.T, self.vocab_size).type(torch.float32) return self.forward_fn(X, state, self.params) def begin_state(self, batch_size, device): return self.init_state(batch_size, self.num_hiddens, device)

它本质上就是一个“手写版 RNN 容器”。


19.__call__里做了什么

这一句很关键:

X = F.one_hot(X.T, self.vocab_size).type(torch.float32)

它完成了三步:

第一步:转置

把输入从:

(batch_size, num_steps)

变成:

(num_steps, batch_size)

第二步:one-hot 编码

把每个索引变成 one-hot 向量。

第三步:转成浮点数

因为后面矩阵乘法需要浮点类型。

这样输入就符合手写前向传播函数的要求了。


20. 如何根据前缀生成新字符

语言模型最有趣的地方之一,就是可以根据前缀生成后续文本。

常见预测函数思路如下:

def predict_ch8(prefix, num_preds, net, vocab, device): state = net.begin_state(batch_size=1, device=device) outputs = [vocab[prefix[0]]] def get_input(): return torch.tensor([outputs[-1]], device=device).reshape((1, 1)) for y in prefix[1:]: _, state = net(get_input(), state) outputs.append(vocab[y]) for _ in range(num_preds): y, state = net(get_input(), state) outputs.append(int(y.argmax(dim=1).reshape(1))) return ''.join([vocab.idx_to_token[i] for i in outputs])

这段代码是“字符级生成”的核心。


21. 这段预测代码怎么理解

21.1 先初始化状态

state = net.begin_state(batch_size=1, device=device)

因为我们只生成一条序列,所以 batch_size=1。

21.2 用前缀预热模型

例如前缀是"time traveller"中的前几个字符,
模型要先把这些字符读进去,更新到合适的隐藏状态。

21.3 再一步步自回归生成

每次把上一步输出的字符作为下一步输入,
再预测新的字符。

这就是最典型的自回归文本生成过程。


22. 为什么前缀阶段不直接预测新字符

因为前缀的作用是:

给模型提供上下文

例如你让模型从"time "开始生成,
那就要先让模型把"t" "i" "m" "e" " "这些字符都读进去。

这样它的隐藏状态里才会包含这个前缀信息。
后面生成时,才能更自然地沿着这个前缀继续写。


23. 训练时损失函数怎么接

虽然这一节重点是“从零实现模型结构”,
但训练时通常还是会配合交叉熵损失。

因为每个时间步的输出本质上是在做:

对词表的多分类

所以目标标签就是“当前正确下一个字符的索引”,
输出则是对整个词表的打分。

这和普通分类任务在损失形式上是一样的,
只是多了时间维展开。


24. 为什么从零实现很重要

你可能会想:

既然 PyTorch 有现成的nn.RNN,为什么还要自己写?

原因很简单:

第一,能真正看懂 RNN 的本质

现成 API 很方便,但容易把人“封装麻木”。

第二,便于理解后续 GRU、LSTM

后面所有变种,其实都是在这个基础上改状态更新公式。

第三,能理解输入形状、状态传递、输出组织方式

这些都是后面训练和调试时特别重要的东西。

所以“从零实现”不是为了炫技,而是为了建立底层理解。


25. 这一节最该掌握什么

如果从学习重点看,这一节最重要的是下面几件事。

25.1 one-hot 输入

知道为什么离散 token 要转成 one-hot。

25.2 参数初始化

明白 RNN 至少需要哪些权重和偏置。

25.3 隐藏状态初始化

知道初始状态为什么通常设成 0。

25.4 前向传播循环

真正理解时间步递推是怎么写出来的。

25.5 预测函数

理解语言模型是如何根据前缀一步步生成字符的。


26. 本节总结

这一节我们从零开始实现了 RNN,核心内容可以总结为以下几点。

26.1 输入 token 需要先转成 one-hot 向量

这是手写 RNN 最常见的输入表示方式。

26.2 RNN 的参数包括输入到隐藏、隐藏到隐藏、隐藏到输出三部分

这些参数共同决定状态更新和输出预测。

26.3 隐藏状态会在时间步之间递推传递

这是 RNN 记忆历史信息的关键。

26.4 前向传播本质上是一个时间循环

每一步根据当前输入和上一状态更新当前状态。

26.5 语言模型可以利用 RNN 逐步生成文本

根据前缀预热,再不断自回归预测后续字符。


27. 学习感悟

这一节特别有价值,因为它让你第一次真正“摸到”RNN 的内部结构。

以前我们说:

  • RNN 有记忆

  • RNN 能建模序列

  • RNN 能预测下一个字符

这些话都比较抽象。

但一旦你自己把:

  • one-hot

  • 参数矩阵

  • 隐藏状态

  • 时间循环

  • 输出拼接

这些东西串起来,RNN 就不再神秘了。

你会发现它其实非常朴素:

就是把“当前输入”和“过去记忆”一起拿来算下一个状态。

而这份朴素,也正是它经典的原因。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/14 18:13:06

动手学深度学习——长短期记忆网络(LSTM)

1. 前言前面我们已经学了:RNNRNN 从零开始实现RNN 简洁实现GRUGRU 代码到这里,你应该已经很清楚一个主线:基础 RNN 能处理序列,但长期依赖能力弱;门控机制是改进方向。GRU 已经通过:更新门重置门让模型具备…

作者头像 李华
网站建设 2026/4/14 18:12:12

如何做好家电数码产品的AI生成式引擎优化(GEO)?

做好家电数码产品的AI生成式引擎优化(GEO),关键在于重构品牌与AI对话的方式。这是一套系统化的策略,目的是让你的品牌和产品信息,成为AI助手(如DeepSeek、豆包等)在回答用户问题时,优…

作者头像 李华
网站建设 2026/4/14 18:09:41

iOS 15 电池优化全攻略:告别电量焦虑

1. iOS 15电池耗电的真相:为什么你的iPhone掉电这么快? 每次看到手机右上角的电量图标变红,心里是不是都会咯噔一下?特别是升级到iOS 15后,很多用户都反映电池续航明显变差。作为一个从iPhone 4用到iPhone 13的老用户…

作者头像 李华