【深度学习】输入长度大于训练时输入长度会发生什么?LSTM 和 Transformer对比。
问题背景:
当训练时输入长度为 2048,但在生成时输入一个长度为 4096 的文本时,LSTM 和 Transformer 内部会发生什么,以及它们是否能够记住最初的 2048 个 token。
1. LSTM 的情况
(1) LSTM 的工作机制
LSTM 在处理序列时,通过隐藏状态(Hidden State)逐步更新记忆。在训练时,模型通常以固定长度的上下文窗口(如 2048)进行截断反向传播(Truncated Backpropagation Through Time, TBPTT)。这意味着:
- 模型只会在每个窗口内更新参数。
- 隐藏状态可以在窗口之间传递,理论上允许 LSTM 记住比单个窗口更长的信息。
(2) 输入长度为 4096 时的情况
假设你将一个长度为 4096 的文本输入到 LSTM 中:
- 如果你没有手动重置隐藏状态,LSTM 的隐藏状态会随着序列逐步更新,并尝试记住整个 4096 个 token 的信息。
- 然而,由于以下原因,LSTM 很难有效记住最初的 2048 个 token:
- 梯度消失问题:即使有门控机制,长时间依赖仍然可能导致信息丢失。
- 隐藏状态容量有限:LSTM 的隐藏状态是一个固定大小的向量(如 512 维),当序列过长时,它可能无法容纳足够的历史信息。
- 训练限制:模型在训练时只见过长度为 2048 的上下文,因此对超过这个范围的依赖关系建模能力较弱。
(3) 总结
- LSTM 理论上可以记住比 2048 更长的序列,但由于梯度消失和隐藏状态容量的限制,实际上很难有效记住最初的 2048 个 token。
- 如果需要处理超长序列,可以通过分段输入或引入外部记忆模块(如 Memory Networks)来增强记忆能力。
2. Transformer 的情况
(1) Transformer 的工作机制
Transformer 的自注意力机制允许模型一次性关注整个上下文窗口内的所有 token。然而,Transformer 的上下文窗口是固定的(如 2048),这意味着:
- 在训练时,模型只能看到长度为 2048 的输入。
- 在推理时,模型也无法直接处理超过 2048 的序列。
(2) 输入长度为 4096 时的情况
假设你将一个长度为 4096 的文本输入到 Transformer 中:
- 如果你直接输入整个序列,模型会报错,因为它的上下文窗口大小为 2048,无法处理超出范围的部分。
- 为了处理长序列,通常采用以下方法:
- 滑动窗口(Sliding Window):
将序列分割成多个重叠的窗口(如每 2048 个 token 为一个窗口),分别处理后再合并结果。- 这种方法会导致模型无法直接访问超出当前窗口的历史信息。
- 缓存机制(Caching):
使用递归结构(如 Transformer-XL)或缓存过去的隐藏状态,使得模型能够在一定程度上利用历史信息。- Transformer-XL 通过缓存过去的隐藏状态,允许模型记住比单个窗口更长的信息。
- 稀疏注意力(Sparse Attention):
通过限制注意力范围,降低计算复杂度,同时支持更长的上下文。
- 滑动窗口(Sliding Window):
(3) 总结
- Transformer 默认情况下无法记住最初的 2048 个 token,因为它的上下文窗口大小限制了其视野范围。
- 如果使用滑动窗口,模型只能记住当前窗口内的信息。
- 如果使用缓存机制(如 Transformer-XL),模型可以在一定程度上记住更长的历史信息,但仍然受限于缓存的大小。
3. 对比与总结
特性 | LSTM | Transformer |
---|---|---|
训练时上下文窗口 | 2048 | 2048 |
输入长度为 4096 时 | 隐藏状态逐步更新,理论上可以记住整个序列 | 默认无法处理,需分段或使用缓存机制 |
能否记住最初 2048 个 token | 受限于梯度消失和隐藏状态容量,难以有效记住 | 默认无法记住,需使用缓存机制才能部分记住 |
长期记忆能力 | 较弱,适合短距离依赖 | 较强,但需扩展架构以支持长距离依赖 |
4. 实际建议
如果你需要处理长度为 4096 的文本,并希望模型能够记住最初的 2048 个 token,可以考虑以下方法:
(1) 对于 LSTM
- 分段输入:将序列分割成多个 2048 长度的片段,分别输入模型,并手动传递隐藏状态。
- 引入外部记忆模块:例如 Memory Networks 或 Neural Turing Machines,增强模型的长期记忆能力。
(2) 对于 Transformer
- 滑动窗口:将序列分割成多个重叠的窗口,分别处理后再合并结果。
- 使用 Transformer-XL:通过缓存过去的隐藏状态,允许模型记住比单个窗口更长的信息。
- 稀疏注意力:例如 Longformer 或 BigBird,通过稀疏化注意力机制支持更长的上下文。
5. 总结
- LSTM:理论上可以记住比 2048 更长的序列,但由于梯度消失和隐藏状态容量的限制,实际效果较差。
- Transformer:默认情况下无法记住最初的 2048 个 token,但通过缓存机制或稀疏注意力等改进方法,可以在一定程度上增强长期记忆能力。