跳到主要内容

RNN & Transformer

处理序列数据的核心架构:从 RNN/LSTM 到彻底改变 NLP 的 Transformer。

为什么需要 RNN

全连接和 CNN 假设输入之间相互独立,但序列数据(文本、语音、时间序列)的输入之间存在先后依赖关系。RNN 通过隐藏状态在时间步之间传递信息。

RNN 基础

每个时刻,RNN 接收当前输入 xtx_t 和上一步的隐藏状态 ht1h_{t-1},输出新的隐藏状态 hth_t

ht=tanh(Whht1+Wxxt+b)h_t = \tanh(W_h h_{t-1} + W_x x_t + b)
import torch.nn as nn

rnn = nn.RNN(
input_size=128, # 每个时刻的输入维度(如词向量维度)
hidden_size=256, # 隐藏状态维度
num_layers=2, # 堆叠层数
batch_first=True, # 输入形状 (batch, seq_len, input_size)
)

# 输入:batch=4,序列长度=10,每个 token 128 维
x = torch.randn(4, 10, 128)
output, h_n = rnn(x)
# output: (4, 10, 256) — 每个时刻的输出
# h_n: (2, 4, 256) — 最后一层的隐藏状态

RNN 的问题

  • 梯度消失/爆炸:长序列反向传播时梯度呈指数衰减/增长
  • 长期记忆差:序列开头的输入对末尾几乎无影响

LSTM

长短期记忆网络通过门控机制解决 RNN 的长期依赖问题:

class SimpleLSTM(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.lstm = nn.LSTM(
input_size, hidden_size,
num_layers=2,
batch_first=True,
bidirectional=True, # 双向 LSTM
)

def forward(self, x):
output, (h_n, c_n) = self.lstm(x)
return output

LSTM 的三个门:

作用
遗忘门决定丢弃哪些旧信息
输入门决定保存哪些新信息
输出门决定输出哪些信息

GRU

LSTM 的简化版,合并遗忘门和输入门为「更新门」,参数更少训练更快,效果相当。

gru = nn.GRU(input_size=128, hidden_size=256, num_layers=2, batch_first=True)

Transformer

核心思想

Transformer 抛弃了 RNN 的循环结构,完全基于自注意力机制并行处理序列。

Attention is all you need. — Vaswani et al., 2017

自注意力机制

每个 token 的表示由序列中所有 token 的加权平均得到,权重反映了 token 之间的相关性:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
  • QQ(Query):我想要什么信息
  • KK(Key):我有什么信息
  • VV(Value):实际信息内容

手写注意力

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V shape: (batch, num_heads, seq_len, d_k)
"""
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5) # 缩放

if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

attn_weights = F.softmax(scores, dim=-1)
output = attn_weights @ V
return output, attn_weights

多头注意力

用多组不同的 Q,K,VQ,K,V 投影,让模型关注不同子空间的信息:

class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads

self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

def forward(self, x, mask=None):
B, T, D = x.shape
# 分头
Q = self.W_q(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)

attn_out, _ = scaled_dot_product_attention(Q, K, V, mask)
# 合并头
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, D)
return self.W_o(attn_out)

位置编码

Transformer 没有 RNN 的顺序处理,需要显式注入位置信息:

def sinusoidal_position_encoding(seq_len, d_model):
pe = torch.zeros(seq_len, d_model)
position = torch.arange(seq_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-torch.log(torch.tensor(10000.0)) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数位
pe[:, 1::2] = torch.cos(position * div_term) # 奇数位
return pe

Transformer 整体结构

输入 → Embedding + 位置编码
→ [多头自注意力 → 残差 + LN → FFN → 残差 + LN] × N
→ 输出投影 → Softmax
class TransformerBlock(nn.Module):
def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model),
)
self.dropout = nn.Dropout(dropout)

def forward(self, x, mask=None):
# 自注意力 + 残差
x = self.norm1(x + self.dropout(self.attention(x, mask)))
# FFN + 残差
x = self.norm2(x + self.dropout(self.ffn(x)))
return x

BERT 与 GPT

维度BERTGPT
架构Transformer 编码器Transformer 解码器
注意力双向(能看到上下文)单向(因果掩码)
预训练掩码语言模型(MLM)自回归(Next Token)
擅长理解任务(分类、NER、QA)生成任务(对话、写作)
微调需要无需(In-Context Learning)

大模型时代

Prompt Engineering

通过精心设计的提示词引导模型输出,不需要训练:

prompt = """将以下句子翻译成英文:
中文:今天天气真好。
英文:"""
# 模型自动补全

RAG(检索增强生成)

先检索相关文档,再加上提示词,解决 LLM 的知识盲区:

用户问题 → 检索相关文档 → 文档 + 问题 → LLM → 带引用的回答

Fine-tuning

在小数据集上微调预训练模型,适应特定任务:

from peft import LoraConfig, get_peft_model # LoRA 高效微调

lora_config = LoraConfig(
r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"],
)
model = get_peft_model(base_model, lora_config)
# 只训练 0.1% 的参数量就能达接近全量微调的效果

总结

特性说明
RNN/LSTM序列建模的经典方案,被 Transformer 逐渐替代
Transformer当前 NLP/CV 的统一架构,并行计算 + 长程建模
实践建议先学会用 Hugging Face 的预训练模型,再深入原理