当前位置: 首页 > news >正文

Stanford CS336 | Assignment 1 - Transformer Language Model Architecture

所有关于 assignment1 的代码已开源在:
https://github.com/ACEEE-1222/Standford-CS336-Assignment-1
如果对你有帮助的话,记得顺手点个star喔!

作业要求见 https://github.com/stanford-cs336/assignment1-basics

作业1后半段要求从零实现一个基于Transformer的语言模型(LM)——这是理解现代大语言模型(LLM)内部机制的关键实践。

本文将详细拆解Transformer语言模型的完整实现过程,涵盖多头注意力、旋转位置编码(RoPE)、RMS归一化、SwiGLU前馈网络等核心组件,同时讲解自定义AdamW优化器、学习率调度、批量数据处理等训练辅助工具,帮助读者掌握从模型构建到训练落地的全流程。

一、项目概述

本次作业的核心目标是搭建一个仅含解码器的Transformer语言模型(类似GPT结构),使其具备预测序列中下一个token的能力。模型采用了当前主流的设计方案(如预归一化、RoPE、RMSNorm)以兼顾效率与性能,同时配套实现了完整的训练流水线,可直接在文本数据上进行优化。

模型与训练框架的核心特点:

  • 基于解码器的Transformer结构,包含多头自注意力机制
  • 采用旋转位置编码(RoPE),增强模型对序列位置信息的捕捉能力
  • 使用RMSNorm替代传统LayerNorm,提升训练稳定性
  • 前馈网络中引入SwiGLU激活函数,性能优于ReLU等传统激活
  • 自定义AdamW优化器、余弦学习率调度器与梯度裁剪模块
  • 支持训练 checkpoint 保存与加载,便于中断后续训

二、核心组件实现:transformer.py

transformer.py 文件包含了Transformer语言模型的所有核心模块,各组件设计遵循模块化原则,既便于调试,也为后续扩展预留了空间。

2.1 基础通用层

这类层是模型的"基础工具",在多个模块中被复用,负责实现最基本的张量变换操作。

2.1.1 无偏置线性层(Linear)

简化版的线性变换层,移除了偏置项(bias),并采用类Xavier初始化(截断正态分布)保证训练稳定性。前向传播通过einops库实现清晰的张量维度映射,避免手动reshape导致的维度混乱。

class Linear(nn.Module):def __init__(self, in_features: int, out_features: int, device=None, dtype=None):super().__init__()# 定义权重参数:形状为 (输出维度, 输入维度)self.weight = nn.Parameter(torch.empty((out_features, in_features), device=device, dtype=dtype))# 类Xavier初始化:标准差 = sqrt(2/(输入维度 + 输出维度)),避免梯度消失/爆炸std = (2 / (in_features + out_features)) ** 0.5nn.init.trunc_normal_(self.weight, std=std, a=-3*std, b=3*std)def forward(self, x: torch.Tensor) -> torch.Tensor:# 前向计算:y = x @ W^T(输入形状 ..., in_features → 输出形状 ..., out_features)return einsum(x, self.weight, "... in_features, out_features in_features -> ... out_features")
2.1.2 词嵌入层(Embedding)

将离散的token ID映射为连续的稠密向量,嵌入维度(embedding_dim)与模型隐藏层维度(d_model)保持一致。权重同样采用截断正态分布初始化,确保初始嵌入向量的分布合理性。

class Embedding(nn.Module):def __init__(self,num_embeddings: int,  # 词汇表大小(即总token数)embedding_dim: int,   # 嵌入向量维度(需等于d_model)device: torch.device | None = None,dtype: torch.dtype | None = None,):super().__init__()self.vocab_size = num_embeddingsself.d_model = embedding_dim# 嵌入权重矩阵:形状为 (词汇表大小, 嵌入维度)self.weight = nn.Parameter(torch.empty((self.vocab_size, self.d_model), device=device, dtype=dtype))nn.init.trunc_normal_(self.weight, std=1, a=-3, b=3)def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:# 输入:(batch_size, seq_len) → 输出:(batch_size, seq_len, embedding_dim)return self.weight[token_ids]  # 通过索引直接获取对应token的嵌入向量
2.1.3 RMS归一化层(RMSNorm)

相比传统LayerNorm,RMSNorm移除了均值中心化步骤,仅对输入的均方根(RMS)进行归一化,在减少计算量的同时提升训练稳定性,是LLaMA、GPT-4等模型的默认归一化方案。

class RMSNorm(nn.Module):def __init__(self,d_model: int,          # 输入维度(需等于模型隐藏层维度)eps: float = 1e-5,     # 防止分母为0的微小值device: torch.device | None = None,dtype: torch.dtype | None = None,):super().__init__()self.d_model = d_modelself.eps = eps# 可学习的缩放参数:形状为 (d_model,),初始化为1(不改变归一化结果)self.weight = nn.Parameter(torch.ones(self.d_model, device=device, dtype=dtype))def forward(self, x: torch.Tensor) -> torch.Tensor:# 输入:(batch_size, seq_len, d_model) → 输出:同输入形状in_dtype = x.dtype  # 保存输入数据类型,避免精度损失x = x.to(dtype=torch.float32)  # 转为float32计算,提升数值稳定性# 1. 计算最后一维的均方根(RMS)rms = (x.pow(2).mean(-1, keepdim=True) + self.eps) ** 0.5# 2. 归一化 + 应用缩放参数out = x / rms * self.weightreturn out.to(dtype=in_dtype)  # 恢复原数据类型

2.2 激活函数与前馈网络

前馈网络(FFN)是Transformer中负责"特征转换"的核心模块,而SwiGLU则是当前性能最优的激活函数之一,两者结合可显著提升模型的表达能力。

2.2.1 SwiGLU激活函数

SwiGLU是GLU(Gated Linear Unit)的变体,通过Sigmoid门控对线性变换结果进行筛选,相比ReLU能更好地捕捉特征间的非线性关系,同时避免梯度消失问题。

class SwiGLU(nn.Module):def __init__(self,d_model: int,    # 输入维度(模型隐藏层维度)d_ff: int,       # 前馈网络中间层维度(通常为d_model的4倍)device: torch.device | None = None,dtype: torch.dtype | None = None,):super().__init__()self.d_model = d_modelself.d_ff = d_ff# 定义三个线性层:W1/W3用于生成门控与候选特征,W2用于输出投影self.w1 = Linear(d_model, d_ff, device=device, dtype=dtype)self.w2 = Linear(d_ff, d_model, device=device, dtype=dtype)self.w3 = Linear(d_model, d_ff, device=device, dtype=dtype)# 辅助函数:Sigmoid线性单元(SiLU)def _silu(self, x: torch.Tensor) -> torch.Tensor:return x * torch.sigmoid(x)# 辅助函数:门控线性单元(GLU)def _glu(self, x: torch.Tensor) -> torch.Tensor:return self._silu(self.w1(x)) * self.w3(x)  # SiLU门控 × W3线性变换结果def forward(self, x: torch.Tensor) -> torch.Tensor:# 输入:(batch_size, seq_len, d_model) → 输出:同输入形状return self.w2(self._glu(x))  # 门控结果通过W2投影回d_model维度

2.3 位置编码:旋转位置编码(RoPE)

Transformer本身不具备位置感知能力,需通过位置编码注入序列顺序信息。RoPE通过旋转矩阵将位置信息编码到token的嵌入向量中,且支持长度外推(对长于训练序列的文本仍有效),是当前主流的位置编码方案。

class ROPE(nn.Module):def __init__(self,theta: float,       # RoPE基础频率(通常设为10000)d_k: int,           # 注意力头维度(需为偶数,因按奇偶维度分组旋转)max_seq_len: int,   # 支持的最大序列长度device: torch.device | None = None):super().__init__()self.theta = thetaself.d_k = d_kself.max_seq_len = max_seq_lenself.device = device# 预计算cos/sin缓存(仅在首次初始化时计算,避免重复计算)if not hasattr(self, "cos_cached") or not hasattr(self, "sin_cached"):# 1. 计算频率矩阵:shape (d_k//2,)freqs_d = 1 / (theta ** (torch.arange(0, d_k, 2, device=device).float() / d_k))# 2. 计算位置矩阵:shape (max_seq_len,)pos_i = torch.arange(max_seq_len, device=device).float()# 3. 频率-位置外积:shape (max_seq_len, d_k//2)freqs = einsum(freqs_d, pos_i, "d_half, max_seq_len -> max_seq_len d_half")# 预计算cos和sin值(后续直接索引使用)self.register_buffer("cos_cached", torch.cos(freqs), persistent=False)self.register_buffer("sin_cached", torch.sin(freqs), persistent=False)def forward(self,x: torch.Tensor,                # 输入:(..., seq_len, d_k)token_positions: torch.Tensor   # 位置索引:(..., seq_len)) -> torch.Tensor:# 1. 按最后一维的奇偶索引分组(d_k需为偶数)x_odd = x[..., 1::2]  # 奇数维度:索引1,3,5...x_even = x[..., ::2]  # 偶数维度:索引0,2,4...# 2. 获取当前序列长度对应的cos/sin值cos = self.cos_cached[token_positions]  # (..., seq_len, d_k//2)sin = self.sin_cached[token_positions]  # (..., seq_len, d_k//2)# 3. 应用旋转公式:将位置信息融入向量out1 = cos * x_even - sin * x_odd  # 偶数维度旋转结果out2 = sin * x_even + cos * x_odd  # 奇数维度旋转结果# 4. 重组维度:将奇偶分组合并回原d_k维度out = torch.stack([out1, out2], dim=-1).flatten(-2)  # (..., seq_len, d_k)return out

2.4 注意力机制:多头自注意力(Multi-Head Self-Attention)

注意力机制是Transformer的核心,负责捕捉序列内token间的依赖关系。多头注意力通过将隐藏层向量拆分到多个"头"中,并行计算注意力,可捕捉不同维度的依赖信息。

2.4.1 缩放点积注意力(辅助函数)

基础注意力计算模块,通过"查询(Q)-键(K)-值(V)"机制计算注意力权重,并引入缩放因子(√d_k)避免注意力分数过大导致的Softmax饱和问题。

def scaled_dot_product_attention(query: torch.Tensor,  # Q:(batch_size, ..., seq_len_q, d_k)key: torch.Tensor,    # K:(batch_size, ..., seq_len_k, d_k)value: torch.Tensor,  # V:(batch_size, ..., seq_len_k, d_v)mask: torch.Tensor = None  # 掩码:(seq_len_q, seq_len_k),True表示可关注
) -> torch.Tensor:d_k = query.shape[-1]# 1. 计算Q与K的点积(注意力分数),并除以√d_k缩放attention_scores = einsum(query, key, "... seq_len_q d_k, ... seq_len_k d_k -> ... seq_len_q seq_len_k") / (d_k ** 0.5)# 2. 应用掩码(如因果掩码,避免关注未来token)if mask is not None:attention_scores = attention_scores.masked_fill(~mask, float('-inf'))  # 掩码位置设为-∞,Softmax后权重为0# 3. Softmax归一化得到注意力权重,再与V加权求和attention_weights = softmax(attention_scores, dim=-1)output = einsum(attention_weights, value, "... seq_len_q seq_len_k, ... seq_len_k d_v -> ... seq_len_q d_v")return output# 辅助函数:自定义Softmax(避免数值溢出)
def softmax(x: torch.tensor, dim: int):x_max = torch.max(x, dim=dim, keepdim=True).values  # 减去最大值,防止指数爆炸x_exp = torch.exp(x - x_max)sum_exp = torch.sum(x_exp, dim=dim, keepdim=True)return x_exp / sum_exp
2.4.2 多头自注意力模块

将输入向量通过线性层投影为Q、K、V,拆分到多个头中并行计算注意力,最后将所有头的结果拼接并投影回模型隐藏层维度(d_model)。

class MultiHeadSelfAttention(nn.Module):def __init__(self, d_model: int, num_heads: int, theta: float | None = None,max_seq_len: int | None = None,):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_heads  # 每个头的维度(d_model需能被num_heads整除)self.d_v = d_model // num_heads  # 简化设计:V的维度与Q/K一致# 1. Q/K/V投影层:将d_model映射为 num_heads × d_k(或d_v)self.q_proj = Linear(d_model, num_heads * self.d_k)self.k_proj = Linear(d_model, num_heads * self.d_k)self.v_proj = Linear(d_model, num_heads * self.d_v)# 2. 输出投影层:将多个头的结果拼接后映射回d_modelself.output_proj = Linear(num_heads * self.d_v, d_model)# 3. 若传入theta和max_seq_len,初始化RoPE模块if theta is not None and max_seq_len is not None:self.rope = ROPE(theta, self.d_k, max_seq_len)def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None,token_positions: torch.Tensor | None = None) -> torch.Tensor:# 输入:(batch_size, seq_len, d_model)*batch_dims, seq_len, _ = x.shape  # 提取批次维度(如batch_size)# 1. Q/K/V投影与多头拆分x_q = self.q_proj(x)  # (batch_size, seq_len, num_heads×d_k)x_k = self.k_proj(x)  # (batch_size, seq_len, num_heads×d_k)x_v = self.v_proj(x)  # (batch_size, seq_len, num_heads×d_v)# 拆分多头:(batch_size, num_heads, seq_len, d_k)x_q = rearrange(x_q, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)x_k = rearrange(x_k, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)x_v = rearrange(x_v, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads=self.num_heads, d_v=self.d_v)# 2. 应用RoPE(若已初始化)if hasattr(self, "rope"):# 若未指定token_positions,默认按0~seq_len-1顺序编码if token_positions is None:token_positions = torch.arange(seq_len, device=x.device)#                # 扩展token_positions维度以匹配输入批次维度for _ in range(len(batch_dims)):token_positions = token_positions.unsqueeze(0)# 对Q和K应用RoPE(V无需旋转)x_q = self.rope(x_q, token_positions)x_k = self.rope(x_k, token_positions)# 3. 生成掩码(默认使用因果掩码,防止关注未来token)if mask is None:# 因果掩码:上三角为False(不可关注),下三角及对角线为True(可关注)mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device))# 扩展掩码维度以匹配批次和头数维度for _ in range(len(batch_dims) + 1):  # +1 是为了适配num_heads维度mask = mask.unsqueeze(0)else:# 扩展用户提供的掩码维度for _ in range(len(batch_dims) + 1):mask = mask.unsqueeze(0)# 4. 计算缩放点积注意力attn_output = scaled_dot_product_attention(x_q, x_k, x_v, mask)# 5. 拼接多头结果并投影attn_output = rearrange(attn_output, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)",num_heads=self.num_heads, d_v=self.d_v)output = self.output_proj(attn_output)  # 投影回d_model维度return output### 2.5 Transformer块(TransformerBlock)
单个Transformer块是模型的基本重复单元,由"多头自注意力+前馈网络"组成,并采用**预归一化**(Pre-normalization)设计——在注意力和前馈网络前应用归一化,而非之后,这已被证明能显著提升训练稳定性。```python
class TransformerBlock(nn.Module):def __init__(self, d_model: int, num_heads: int, d_ff: int,theta: float | None = None,max_seq_len: int | None = None,):super().__init__()# 1. 归一化层(预归一化设计)self.ln1 = RMSNorm(d_model)  # 注意力层前的归一化self.ln2 = RMSNorm(d_model)  # 前馈网络前的归一化# 2. 前馈网络self.ffn = SwiGLU(d_model, d_ff)# 3. 多头自注意力(若指定theta和max_seq_len,则启用RoPE)if theta is not None and max_seq_len is not None:self.attn = MultiHeadSelfAttention(d_model, num_heads, theta, max_seq_len)else:self.attn = MultiHeadSelfAttention(d_model, num_heads)def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None, token_positions: torch.Tensor | None = None) -> torch.Tensor:# 残差连接 + 注意力:x = x + Attention(LN(x))x = x + self.attn(self.ln1(x), mask=mask, token_positions=token_positions)# 残差连接 + 前馈网络:x = x + FFN(LN(x))x = x + self.ffn(self.ln2(x))return x

2.6 完整语言模型(TransformerLM)

将所有组件组合成最终的Transformer语言模型,包含词嵌入层、多个Transformer块、最终归一化层和语言模型头(LM Head)。

class TransformerLM(nn.Module):def __init__(self, vocab_size: int,context_length: int,  # 上下文窗口长度(最大序列长度)num_layers: int,       # Transformer块数量d_model: int,          # 模型隐藏层维度num_heads: int,        # 注意力头数d_ff: int,             # 前馈网络中间层维度theta: float | None = None,):  # RoPE的基础频率super().__init__()self.vocab_size = vocab_sizeself.context_length = context_length# 1. 词嵌入层:将token ID映射为d_model维度向量self.token_embeddings = Embedding(vocab_size, d_model)# 2. 堆叠多个Transformer块self.layers = nn.ModuleList([TransformerBlock(d_model, num_heads, d_ff, theta, context_length)for _ in range(num_layers)])# 3. 最终归一化层self.ln_final = RMSNorm(d_model)# 4. 语言模型头:将d_model维度映射到词汇表大小(输出logits)self.lm_head = Linear(d_model, vocab_size)def forward(self, inputs: torch.Tensor) -> torch.Tensor:# 输入:(batch_size, seq_len) → 输出:(batch_size, seq_len, vocab_size)# 1. 词嵌入x = self.token_embeddings(inputs)  # (batch_size, seq_len, d_model)# 2. 经过所有Transformer块for layer in self.layers:x = layer(x)  # 每层输出仍为 (batch_size, seq_len, d_model)# 3. 最终归一化 + 映射到词汇表x = self.ln_final(x)logits = self.lm_head(x)  # (batch_size, seq_len, vocab_size)return logits

三、训练工具实现:train.py

train.py 包含了模型训练所需的核心工具,包括损失函数、优化器、学习率调度、数据批处理、梯度裁剪和模型 checkpoint 管理等。

3.1 损失函数:交叉熵(Cross-Entropy)

语言模型的核心任务是预测下一个token,因此采用交叉熵损失函数,衡量预测分布与真实token的差距。实现中加入了数值稳定技巧(减去最大值),避免指数运算溢出。

def cross_entropy(inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:"""计算批次上的平均交叉熵损失输入:inputs: (batch_size, ..., vocab_size) → 未归一化的logitstargets: (batch_size, ...) → 真实token的索引输出:标量张量 → 批次平均损失"""batch_size = inputs.shape[0]# 数值稳定:减去最大值,避免exp(x)溢出o_max = torch.max(inputs, dim=-1, keepdim=True).valueso = inputs - o_max# 获取目标token对应的logittarget_logits = o[torch.arange(batch_size), targets]# 计算log(sum(exp(o)))logsumexp = torch.log(torch.sum(torch.exp(o), dim=-1))# 单个样本损失:-target_logit + logsumexploss = -target_logits + logsumexp# 返回批次平均值return loss.mean(dim=0)

3.2 优化器:AdamW

AdamW是Adam优化器的改进版,在Adam的基础上分离了权重衰减(Weight Decay)与梯度更新,有效提升模型泛化能力,是当前训练Transformer的主流优化器。

class AdamW(torch.optim.Optimizer):def __init__(self,params,lr=1e-3,          # 初始学习率betas=(0.9, 0.999),  # 动量参数(一阶矩和二阶矩的指数衰减率)eps=1e-8,         # 防止分母为0的微小值weight_decay=0.01  # 权重衰减系数):# 参数合法性检查if lr < 0.0:raise ValueError(f"无效学习率: {lr}")if eps < 0.0:raise ValueError(f"无效epsilon值: {eps}")if weight_decay < 0.0:raise ValueError(f"无效weight_decay值: {weight_decay}")if not 0.0 <= betas[0] < 1.0 or not 0.0 <= betas[1] < 1.0:raise ValueError(f"无效betas参数: {betas}")defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)super().__init__(params, defaults)def step(self, closure: Callable | None = None):# 可选:计算闭包(用于某些特殊场景)loss = None if closure is None else closure()# 遍历参数组(支持不同参数使用不同学习率)for group in self.param_groups:lr = group["lr"]beta1, beta2 = group["betas"]eps = group["eps"]weight_decay = group["weight_decay"]# 遍历参数for p in group["params"]:if p.grad is None:continue  # 无梯度的参数跳过grad = p.grad.data  # 梯度state = self.state[p]  # 获取参数的状态字典# 初始化状态(首次更新时)t = state.get("t", 1)  # 迭代次数(初始为1)m = state.get("m", torch.zeros_like(grad))  # 一阶矩估计(动量)v = state.get("v", torch.zeros_like(grad))  # 二阶矩估计# 更新一阶矩和二阶矩(带偏差校正)m = beta1 * m + (1 - beta1) * gradv = beta2 * v + (1 - beta2) * grad ** 2# 学习率校正(偏差校正)lr_t = lr * (1 - beta2 ** t) ** 0.5 / (1 - beta1 ** t)# 参数更新:先减去梯度项,再应用权重衰减p.data = p.data - lr_t * m / (v ** 0.5 + eps)p.data = p.data - lr * weight_decay * p.data  # 权重衰减独立于梯度# 更新状态state["t"] = t + 1state["m"] = mstate["v"] = vreturn loss

3.3 学习率调度:余弦退火(Cosine Schedule)

学习率调度对Transformer训练至关重要。我们实现了带预热的余弦退火调度:初始阶段线性提升学习率(预热),避免训练初期大学习率导致的不稳定性;随后按余弦曲线衰减至最小值,帮助模型收敛到更优解。

def lr_cosine_schedule(t: int, lr_max: float, lr_min: float, warmup_iters: int, cosine_cycle_iters: int):"""带预热的余弦退火学习率调度参数:t: 当前迭代次数lr_max: 最大学习率(预热结束时达到)lr_min: 最小学习率(余弦衰减的下限)warmup_iters: 预热迭代次数cosine_cycle_iters: 余弦衰减的总迭代次数(含预热)返回:当前迭代的学习率"""if t < warmup_iters:# 预热阶段:线性增长lr = t / warmup_iters * lr_maxelif t < cosine_cycle_iters:# 余弦衰减阶段:从lr_max平滑衰减到lr_min# 计算当前相位(0到π之间)phase = (t - warmup_iters) / (cosine_cycle_iters - warmup_iters) * math.pilr = lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(phase))else:# 衰减结束:固定为lr_minlr = lr_minreturn lr

3.4 梯度裁剪(Gradient Clipping)

训练过程中,梯度可能因某些异常样本或大学习率而剧烈波动(梯度爆炸),导致模型不稳定。梯度裁剪通过限制梯度的L2范数,将梯度控制在合理范围内。

def gradient_clipping(parameters: Iterable[torch.nn.Parameter], max_l2_norm: float, eps: float = 1e-6):"""裁剪梯度的L2范数以防止梯度爆炸参数:parameters: 需裁剪梯度的参数列表max_l2_norm: 梯度L2范数的上限eps: 防止分母为0的微小值"""# 收集所有有梯度的参数grads = [p.grad for p in parameters if p.grad is not None]if not grads:return  # 无梯度可裁剪# 计算所有梯度的总L2范数l2_norm = 0.0for g in grads:l2_norm += torch.sum(g ** 2)  # 累加平方和l2_norm = torch.sqrt(l2_norm)  # 开平方得L2范数# 计算裁剪系数(若范数超过上限,则按比例缩小)clip_coef = min(1.0, max_l2_norm / (l2_norm + eps))# 应用裁剪for g in grads:g *= clip_coef

3.5 数据批处理:get_batch

从文本数据集中采样训练批次,生成输入序列(x)和对应的标签序列(y)——对于语言模型,yx 向右偏移一位的序列(即预测下一个token)。

def get_batch(dataset: npt.NDArray, batch_size: int, context_length: int, device: str
) -> tuple[torch.Tensor, torch.Tensor]:"""从数据集中采样批次数据参数:dataset: 1D numpy数组,存储文本的token ID序列batch_size: 批次大小context_length: 每个样本的序列长度device: 数据存放的设备(如'cpu'或'cuda')返回:x: (batch_size, context_length) → 输入序列y: (batch_size, context_length) → 标签序列(x向右偏移一位)"""# 计算最大起始索引(确保序列不越界)max_start = len(dataset) - context_length - 1if max_start <= 0:raise ValueError("数据集长度小于指定的context_length")# 随机采样batch_size个起始索引starts = np.random.randint(0, max_start + 1, size=batch_size)x_batch = []y_batch = []for s in starts:# 截取序列:[s, s+context_length+1)seq = dataset[s : s + context_length + 1]x_batch.append(seq[:-1])  # 输入:前context_length个tokeny_batch.append(seq[1:])   # 标签:后context_length个token(偏移一位)# 转换为PyTorch张量并移动到指定设备x = torch.tensor(x_batch, dtype=torch.long, device=device)y = torch.tensor(y_batch, dtype=torch.long, device=device)return x, y

3.6 模型Checkpoint管理

训练大型模型时,需要定期保存训练状态(模型参数、优化器状态、当前迭代次数),以便中断后恢复训练或后续评估。

def save_checkpoint(model: torch.nn.Module,optimizer: torch.optim.Optimizer,iteration: int,out: str | os.PathLike | BinaryIO | IO[bytes],
):"""保存训练状态到文件"""checkpoint = {'model_state': model.state_dict(),      # 模型参数'optimizer_state': optimizer.state_dict(),  # 优化器状态'iteration': iteration,                # 当前迭代次数}torch.save(checkpoint, out)def load_checkpoint(src: str | os.PathLike | BinaryIO | IO[bytes],model: torch.nn.Module,optimizer: torch.optim.Optimizer
) -> int:"""从文件加载训练状态并恢复模型和优化器"""checkpoint = torch.load(src)model.load_state_dict(checkpoint['model_state'])optimizer.load_state_dict(checkpoint['optimizer_state'])return checkpoint['iteration']  # 返回保存时的迭代次数

四、总结与扩展

通过本次作业,我们从零实现了一个完整的Transformer语言模型,涵盖了现代LLM的核心组件:

  • 采用预归一化设计的Transformer块,提升训练稳定性
  • 旋转位置编码(RoPE)解决位置信息编码问题,支持长度外推
  • 多头自注意力机制捕捉token间的依赖关系
  • SwiGLU激活函数增强前馈网络的表达能力
  • 配套实现了AdamW优化器、余弦学习率调度等训练工具

另外,adapter.py里的调用就不再一个个放了,考虑到上次有人在评论区里问,我放个例子吧:

def run_transformer_lm(vocab_size: int,context_length: int,d_model: int,num_layers: int,num_heads: int,d_ff: int,rope_theta: float,weights: dict[str, Tensor],in_indices: Int[Tensor, " batch_size sequence_length"],
) -> Float[Tensor, " batch_size sequence_length vocab_size"]:"""Given the weights of a Transformer language model and input indices,return the output of running a forward pass on the input indices.This function should use RoPE.Args:vocab_size (int): The number of unique items in the output vocabulary to be predicted.context_length (int): The maximum number of tokens to process at once.d_model (int): The dimensionality of the model embeddings and sublayer outputs.num_layers (int): The number of Transformer layers to use.num_heads (int): Number of heads to use in multi-headed attention. `d_model` must beevenly divisible by `num_heads`.d_ff (int): Dimensionality of the feed-forward inner layer (section 3.3).rope_theta (float): The RoPE $\Theta$ parameter.weights (dict[str, Tensor]): State dict of our reference implementation. {num_layers} refers to aninteger between `0` and `num_layers - 1` (the layer index).The keys of this dictionary are:- `token_embeddings.weight`Token embedding matrix. Shape is (vocab_size, d_model).- `layers.{num_layers}.attn.q_proj.weight`The query projections for all `num_heads` attention heads.Shape is (num_heads * (d_model / num_heads), d_model).The rows are ordered by matrices of shape (num_heads, d_k),so `attn.q_proj.weight == torch.cat([q_heads.0.weight, ..., q_heads.N.weight], dim=0)`.- `layers.{num_layers}.attn.k_proj.weight`The key projections for all `num_heads` attention heads.Shape is (num_heads * (d_model / num_heads), d_model).The rows are ordered by matrices of shape (num_heads, d_k),so `attn.k_proj.weight == torch.cat([k_heads.0.weight, ..., k_heads.N.weight], dim=0)`.- `layers.{num_layers}.attn.v_proj.weight`The value projections for all `num_heads` attention heads.Shape is (num_heads * (d_model / num_heads), d_model).The rows are ordered by matrices of shape (num_heads, d_v),so `attn.v_proj.weight == torch.cat([v_heads.0.weight, ..., v_heads.N.weight], dim=0)`.- `layers.{num_layers}.attn.output_proj.weight`Weight of the multi-head self-attention output projectionShape is ((d_model / num_heads) * num_heads, d_model).- `layers.{num_layers}.ln1.weight`Weights of affine transform for the first RMSNormapplied in the transformer block.Shape is (d_model,).- `layers.{num_layers}.ffn.w1.weight`Weight of the first linear transformation in the FFN.Shape is (d_model, d_ff).- `layers.{num_layers}.ffn.w2.weight`Weight of the second linear transformation in the FFN.Shape is (d_ff, d_model).- `layers.{num_layers}.ffn.w3.weight`Weight of the third linear transformation in the FFN.Shape is (d_model, d_ff).- `layers.{num_layers}.ln2.weight`Weights of affine transform for the second RMSNormapplied in the transformer block.Shape is (d_model,).- `ln_final.weight`Weights of affine transform for RMSNorm applied to the output of the final transformer block.Shape is (d_model, ).- `lm_head.weight`Weights of the language model output embedding.Shape is (vocab_size, d_model).in_indices (Int[Tensor, "batch_size sequence_length"]) Tensor with input indices to run the language model on. Shape is (batch_size, sequence_length), where`sequence_length` is at most `context_length`.Returns:Float[Tensor, "batch_size sequence_length vocab_size"]: Tensor with the predicted unnormalizednext-word distribution for each token."""model = TransformerLM(vocab_size,context_length, num_layers,d_model, num_heads, d_ff, rope_theta).to(device)model.load_state_dict(weights)return model(in_indices.to(device=device))

文章转载自:

http://aRvuvWVF.rjypL.cn
http://zk3BG2mf.rjypL.cn
http://1XouWdSh.rjypL.cn
http://eU4BiQIh.rjypL.cn
http://k3Kk0TOV.rjypL.cn
http://AfxYLp1k.rjypL.cn
http://cF37AREL.rjypL.cn
http://nu1ZO9y4.rjypL.cn
http://vTxrOq3s.rjypL.cn
http://VhcDiLld.rjypL.cn
http://A0ayEDT6.rjypL.cn
http://H8YrpeUj.rjypL.cn
http://qWnFQoTS.rjypL.cn
http://BjT75OQg.rjypL.cn
http://V15Uh2Ed.rjypL.cn
http://hUOLKzKb.rjypL.cn
http://UzDVEciI.rjypL.cn
http://fYMSzANg.rjypL.cn
http://ZSzb1AoE.rjypL.cn
http://184IWXVT.rjypL.cn
http://4MadRyte.rjypL.cn
http://xQkUuxtC.rjypL.cn
http://P0ZU4LrC.rjypL.cn
http://gSw8gfFH.rjypL.cn
http://R2eedqcl.rjypL.cn
http://hEyCQ2Ht.rjypL.cn
http://Ps67HeP0.rjypL.cn
http://LLO4jHZZ.rjypL.cn
http://eOFqjPOC.rjypL.cn
http://j4UdrIQI.rjypL.cn
http://www.dtcms.com/a/381676.html

相关文章:

  • 计算机视觉(opencv)实战十八——图像透视转换
  • 【二开】CRMEB开源版按钮权限控制
  • 联邦学习过程中,了解清楚影响准确率的因素有哪些也很重要
  • Ubuntu 文件复制大师:精通cp命令完整指南
  • 给定单词倒排
  • Golang | http/server Gin框架简述
  • Android-EDLA XTS常用网站总结
  • Android webview更新记录-aosp
  • 大数据电商流量分析项目实战:Flume 数据采集及ETL入仓(五)
  • 用 PyTorch 打造 AIOps 小体系:日志异常、指标预测与训练失败根因分析
  • 涂鸦智能携手亚马逊云科技,以全球基础设施与生成式AI加速万物智联时代到来
  • 【完整源码+数据集+部署教程】交通工具图像分割系统: yolov8-seg-C2f-RFCAConv
  • uniapp 混合mixins和继承extends详解
  • 【Lua】Windows 下编写 C 扩展模块:VS 编译与 Lua 调用全流程
  • 004 Rust控制台打印输出
  • idea自动编译,idea不重启项目,加载修改的内容
  • 阻塞 IO为什么叫BIO,非阻塞IO为什么叫NIO,异步IO为什么叫AIO
  • 少即是多:从 MPTCP 看优化干预的边界
  • 2025服贸会“海淀之夜”,点亮“科技”与“服务”底色
  • String留言板
  • js生成excel表格进阶版
  • Win 11 ARM 版搭建ESP-IDF环境问题记录
  • MyBatis主键返回:必须显式配置
  • MySQL——MVCC实现原理流程分析
  • Linux -- 基于TCP服务器实现一个简单的电商网站
  • 佳维视工业一体机 vs 普通电脑:工业场景选哪个?
  • 小迪安全v2023学习笔记(八十二讲)—— Java组件安全SolrShiroLog4jCVE复现
  • Spring AI(四)机构化输出
  • 单体到微服务拆分方案
  • 云端服务器使用指南:如何跨机传输较大文件(通过windows自带工具远程桌面连接 非常方便)