1. 前言
上一篇我们已经从概念上认识了RNN(循环神经网络):
它适合处理序列数据
它会把过去的信息压缩进隐藏状态
当前时刻的计算依赖“当前输入 + 上一时刻隐藏状态”
它非常适合做语言模型
但如果只停留在公式层面,理解还不够扎实。
这一节就正式进入:
RNN 从零开始实现
这里的“从零开始”,不是说手搓底层矩阵乘法库,
而是指:
不用现成的
nn.RNN封装层,而是自己把 RNN 的参数、前向传播、状态更新、预测过程一步一步写出来。
这样做的意义非常大,因为你会真正看清楚:
输入是怎么进来的
隐藏状态是怎么更新的
输出是怎么得到的
语言模型为什么能逐步生成字符
2. 从零实现要解决哪些问题
如果把这一节拆开看,其实主要要解决 5 件事:
2.1 如何表示输入
文本已经被预处理成 token 索引序列,但 RNN 不能直接把“整数编号”当成线性特征用,所以要先做表示转换。
2.2 如何初始化参数
RNN 的输入到隐藏层、隐藏层到隐藏层、隐藏层到输出层,都需要权重矩阵和偏置。
2.3 如何初始化隐藏状态
在序列开始时,模型还没有任何历史信息,所以需要一个初始隐藏状态。
2.4 如何写前向传播
也就是按时间步循环,逐步更新隐藏状态并输出预测结果。
2.5 如何根据模型生成新文本
语言模型训练好后,要能根据前缀一步一步往后预测字符。
3. 为什么输入不能直接用整数索引
假设词表里有:
a -> 0b -> 1c -> 2
如果直接把这些编号当成输入数值,那么模型可能会误以为:
2比1“更大”1和0距离更近
但实际上,字符编号本身只是一个离散标签,没有这种数值意义。
所以在 RNN 从零实现里,最常见的做法是:
把索引转成 one-hot 向量
这样每个 token 都被表示成一个稀疏的离散向量。
4. 什么是 one-hot 表示
假设词表大小是 4,字符集合为:
[a, b, c, d]那么:
a可以表示成[1, 0, 0, 0]b表示成[0, 1, 0, 0]c表示成[0, 0, 1, 0]d表示成[0, 0, 0, 1]
这就叫 one-hot 编码。
它的特点是:
只有一个位置是 1
其余位置全是 0
不会人为引入“编号大小规律”
所以它非常适合表示离散 token。
5. 代码里怎么做 one-hot
李沐这里常见代码大致是:
import torch from torch import nn from d2l import torch as d2l F = ერთ?不过真正常见写法是直接用 PyTorch 的 one_hot:
X = torch.arange(10).reshape((2, 5)) F.one_hot(X.T, 28).shape它的关键点在于:
X是索引张量28是词表大小输出会把每个索引变成一个 one-hot 向量
如果输入形状是:
(batch_size, num_steps)那么转成 one-hot 后通常会变成:
(num_steps, batch_size, vocab_size)这是后面手写 RNN 前向传播的重要输入格式。
6. 为什么常把时间步维度放前面
在手写 RNN 时,经常会把输入整理成:
(num_steps, batch_size, vocab_size)而不是:
(batch_size, num_steps, vocab_size)这样做是因为前向传播时,我们通常会:
沿着时间维一列一列地循环
也就是说,先处理第 1 个时间步,再处理第 2 个时间步,再处理第 3 个时间步。
如果时间维在最前面,那么写循环会更自然。
7. 先初始化模型参数
RNN 从零实现最核心的一步,就是自己初始化参数。
基础 RNN 的公式是:
H_t = tanh(X_t W_xh + H_{t-1} W_hh + b_h) Y_t = H_t W_hq + b_q所以参数至少要包括:
输入到隐藏层
W_xh隐藏层到隐藏层
W_hh隐藏层偏置
b_h隐藏层到输出层
W_hq输出层偏置
b_q8. 参数初始化代码怎么写
常见写法类似这样:
def get_params(vocab_size, num_hiddens, device): num_inputs = num_outputs = vocab_size def normal(shape): return torch.randn(size=shape, device=device) * 0.01 W_xh = normal((num_inputs, num_hiddens)) W_hh = normal((num_hiddens, num_hiddens)) b_h = torch.zeros(num_hiddens, device=device) W_hq = normal((num_hiddens, num_outputs)) b_q = torch.zeros(num_outputs, device=device) params = [W_xh, W_hh, b_h, W_hq, b_q] for param in params: param.requires_grad_(True) return params这段代码是整个“从零实现”最关键的基础之一。
9. 这段参数代码怎么理解
9.1num_inputs = vocab_size
因为我们输入的是 one-hot 向量,它的长度就是词表大小。
9.2num_outputs = vocab_size
因为这是字符级语言模型,输出通常也是对整个词表做分类,所以输出维度也等于词表大小。
9.3num_hiddens
表示隐藏状态维度,也就是模型内部记忆空间的大小。
9.4 小随机数初始化
权重通常初始化成较小随机值,防止一开始数值太大。
9.5requires_grad_(True)
表示这些参数后续需要参与梯度更新。
10. 如何初始化隐藏状态
RNN 在处理第一个时间步之前,没有过去信息。
所以通常会把隐藏状态初始化成全零。
常见写法如下:
def init_rnn_state(batch_size, num_hiddens, device): return (torch.zeros((batch_size, num_hiddens), device=device), )这里返回的是一个元组,里面放一个张量。
为什么写成元组?
因为后面更复杂的循环模型,比如 LSTM,会有多个状态量。
所以这里提前统一接口形式。
11. 为什么隐藏状态形状是(batch_size, num_hiddens)
因为:
每个 batch 中有
batch_size条序列每条序列在当前时刻都要有自己的隐藏状态
每个隐藏状态长度是
num_hiddens
所以形状就是:
(batch_size, num_hiddens)这表示:
每个样本各自维护一份当前时刻的记忆表示。
12. 手写 RNN 前向传播
这是这一节最核心的代码。
常见写法如下:
def rnn(inputs, state, params): W_xh, W_hh, b_h, W_hq, b_q = params H, = state outputs = [] for X in inputs: H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h) Y = torch.mm(H, W_hq) + b_q outputs.append(Y) return torch.cat(outputs, dim=0), (H,)这段代码基本就是“RNN 从零实现”的灵魂。
13. 这段前向传播逐行理解
13.1 拆参数
W_xh, W_hh, b_h, W_hq, b_q = params把前面初始化好的参数取出来。
13.2 取出当前隐藏状态
H, = state这里state是一个元组,所以要这样解包。
13.3 准备输出列表
outputs = []因为每个时间步都会产生一个输出,所以先放到列表里。
13.4 按时间步循环
for X in inputs:这里每次拿到的是某一个时间步的输入,形状通常是:
(batch_size, vocab_size)14. 隐藏状态更新公式对应哪一行
这一句:
H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)就是基础 RNN 的状态更新公式。
它表示:
当前输入
X乘上输入到隐藏层权重
W_xh再加上上一隐藏状态
H乘上隐藏到隐藏权重
W_hh再加偏置
b_h最后经过
tanh
得到新的隐藏状态。
这一步就是“记忆更新”。
15. 输出怎么得到
这一句:
Y = torch.mm(H, W_hq) + b_q表示把当前隐藏状态映射到输出空间。
因为这是语言模型,所以这里的输出维度通常就是:
vocab_size也就是说,Y可以理解为:
当前时刻对词表中每个字符/词元的打分
之后通过 softmax,就能转成概率分布。
16. 为什么最后要torch.cat(outputs, dim=0)
因为outputs列表里装的是每个时间步的输出:
第 1 步一个
Y第 2 步一个
Y第 3 步一个
Y…
而训练时通常希望把它们拼成一个二维张量,方便统一计算损失。
所以:
torch.cat(outputs, dim=0)会把所有时间步输出拼到一起。
如果每个Y形状是:
(batch_size, vocab_size)那最后拼起来大致就是:
(num_steps * batch_size, vocab_size)17. 为什么返回(H,)
因为前向传播结束后,我们还需要保留最后一个时间步的隐藏状态,
方便下一段序列接着用。
所以返回值一般是:
所有时间步输出
最终隐藏状态
写成:
return outputs, (H,)是为了和后续 GRU、LSTM 保持接口统一。
18. 封装一个从零实现的 RNN 模型类
为了方便调用,通常会把参数初始化、前向传播、状态初始化封装到类里。
常见形式大概是:
class RNNModelScratch: def __init__(self, vocab_size, num_hiddens, device, get_params, init_state, forward_fn): self.vocab_size, self.num_hiddens = vocab_size, num_hiddens self.params = get_params(vocab_size, num_hiddens, device) self.init_state, self.forward_fn = init_state, forward_fn def __call__(self, X, state): X = F.one_hot(X.T, self.vocab_size).type(torch.float32) return self.forward_fn(X, state, self.params) def begin_state(self, batch_size, device): return self.init_state(batch_size, self.num_hiddens, device)它本质上就是一个“手写版 RNN 容器”。
19.__call__里做了什么
这一句很关键:
X = F.one_hot(X.T, self.vocab_size).type(torch.float32)它完成了三步:
第一步:转置
把输入从:
(batch_size, num_steps)变成:
(num_steps, batch_size)第二步:one-hot 编码
把每个索引变成 one-hot 向量。
第三步:转成浮点数
因为后面矩阵乘法需要浮点类型。
这样输入就符合手写前向传播函数的要求了。
20. 如何根据前缀生成新字符
语言模型最有趣的地方之一,就是可以根据前缀生成后续文本。
常见预测函数思路如下:
def predict_ch8(prefix, num_preds, net, vocab, device): state = net.begin_state(batch_size=1, device=device) outputs = [vocab[prefix[0]]] def get_input(): return torch.tensor([outputs[-1]], device=device).reshape((1, 1)) for y in prefix[1:]: _, state = net(get_input(), state) outputs.append(vocab[y]) for _ in range(num_preds): y, state = net(get_input(), state) outputs.append(int(y.argmax(dim=1).reshape(1))) return ''.join([vocab.idx_to_token[i] for i in outputs])这段代码是“字符级生成”的核心。
21. 这段预测代码怎么理解
21.1 先初始化状态
state = net.begin_state(batch_size=1, device=device)因为我们只生成一条序列,所以 batch_size=1。
21.2 用前缀预热模型
例如前缀是"time traveller"中的前几个字符,
模型要先把这些字符读进去,更新到合适的隐藏状态。
21.3 再一步步自回归生成
每次把上一步输出的字符作为下一步输入,
再预测新的字符。
这就是最典型的自回归文本生成过程。
22. 为什么前缀阶段不直接预测新字符
因为前缀的作用是:
给模型提供上下文
例如你让模型从"time "开始生成,
那就要先让模型把"t" "i" "m" "e" " "这些字符都读进去。
这样它的隐藏状态里才会包含这个前缀信息。
后面生成时,才能更自然地沿着这个前缀继续写。
23. 训练时损失函数怎么接
虽然这一节重点是“从零实现模型结构”,
但训练时通常还是会配合交叉熵损失。
因为每个时间步的输出本质上是在做:
对词表的多分类
所以目标标签就是“当前正确下一个字符的索引”,
输出则是对整个词表的打分。
这和普通分类任务在损失形式上是一样的,
只是多了时间维展开。
24. 为什么从零实现很重要
你可能会想:
既然 PyTorch 有现成的
nn.RNN,为什么还要自己写?
原因很简单:
第一,能真正看懂 RNN 的本质
现成 API 很方便,但容易把人“封装麻木”。
第二,便于理解后续 GRU、LSTM
后面所有变种,其实都是在这个基础上改状态更新公式。
第三,能理解输入形状、状态传递、输出组织方式
这些都是后面训练和调试时特别重要的东西。
所以“从零实现”不是为了炫技,而是为了建立底层理解。
25. 这一节最该掌握什么
如果从学习重点看,这一节最重要的是下面几件事。
25.1 one-hot 输入
知道为什么离散 token 要转成 one-hot。
25.2 参数初始化
明白 RNN 至少需要哪些权重和偏置。
25.3 隐藏状态初始化
知道初始状态为什么通常设成 0。
25.4 前向传播循环
真正理解时间步递推是怎么写出来的。
25.5 预测函数
理解语言模型是如何根据前缀一步步生成字符的。
26. 本节总结
这一节我们从零开始实现了 RNN,核心内容可以总结为以下几点。
26.1 输入 token 需要先转成 one-hot 向量
这是手写 RNN 最常见的输入表示方式。
26.2 RNN 的参数包括输入到隐藏、隐藏到隐藏、隐藏到输出三部分
这些参数共同决定状态更新和输出预测。
26.3 隐藏状态会在时间步之间递推传递
这是 RNN 记忆历史信息的关键。
26.4 前向传播本质上是一个时间循环
每一步根据当前输入和上一状态更新当前状态。
26.5 语言模型可以利用 RNN 逐步生成文本
根据前缀预热,再不断自回归预测后续字符。
27. 学习感悟
这一节特别有价值,因为它让你第一次真正“摸到”RNN 的内部结构。
以前我们说:
RNN 有记忆
RNN 能建模序列
RNN 能预测下一个字符
这些话都比较抽象。
但一旦你自己把:
one-hot
参数矩阵
隐藏状态
时间循环
输出拼接
这些东西串起来,RNN 就不再神秘了。
你会发现它其实非常朴素:
就是把“当前输入”和“过去记忆”一起拿来算下一个状态。
而这份朴素,也正是它经典的原因。