当前位置: 首页 > news >正文

gpt2架构学习(1)

多头注意力机制的代码逻辑

主要参考:<从零开始大模型>

整体结构如图:
https://www.processon.com/view/link/689d8bc94f0dbc7fcf49883f?cid=689d873239b7227d867825c8

在这里插入图片描述

本节主要是写token, embedding和MultiHeadAttention模块

名词说明

token: 词汇, 不一定是一个词,是tokenizer.encode()返回

context_length, 输入长度(token数)

emb_dim = 嵌入向量的维度[…]

qkv 权重, 被训练的权重, 形状上一般是 (emb_dim, emb_dim)的方阵,这里用到区分: (emb_dim, qkv_out_dim)

文本输入相关

1, 输入文本
2, 转成token,再转成向量矩阵(n, emb_dim),

(1), token_embedding (X*emb_dim) 训练生成的

(2), pos_embedding (n*emb_dim) 训练生成的

→ out_embedding

转成token需要词汇表映射(tokenize), 可以找对应的,
在这里插入图片描述

token相关代码大概是如下

tokenid:

    tokenizer = tiktoken.get_encoding("gpt2")encoded = tokenizer.encode(start_context)encoded_tensor = torch.tensor(encoded).unsqueeze(0)

embedding相关

class GPTModel(nn.Module):def __init__(self, cfg):super().__init__()self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])self.drop_emb = nn.Dropout(cfg["drop_rate"])# todo 注意力层# todo LayerNorm# 输出self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)def forward(self, in_idx):batch_size, seq_len = in_idx.shapetok_embeds = self.tok_emb(in_idx)  # 即token_embedding pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]# todo 扔给注意力层&LayerNormlogits = self.out_head(x) # 输出return logits

LayerNorm 是归一化技术,主要用于稳定训练过程、加速收敛并提升模型性,

注意力模块相关

3, 初始化qkv向量矩阵(n*qkv_out_dim), , 输入矩阵(n,emb_dim) @ qkv 权重(emb_dim, qkv_out_dim)
4, 得到注意力分数(attn_weight)

注意力分数是每个token的q各token的k, (1qkv_out_dim) @ 转置(nqkv_out_dim) 得到 (1n),

则注意力矩阵 是(n*n) 的方阵, 注意力分数再 softmax 得到 注意力权重矩阵,

在这里插入图片描述

5, 得到上下文向量(context_vec)

attn_weight @ v向量矩阵, 即 (n, n) @ (n, qkv_out_dim) = (n, qkv_out_dim)

6, 训练时,要mask 和 drop_out, 针对 attn_weight, 另外上面矩阵计算时要留意 batch的处理.

上面(3-6)逻辑的参考代码:

class CausalAttention(nn.Module):def __init__(self, d_in, d_out, context_length,dropout, qkv_bias=False):super().__init__()self.d_out = d_outself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.dropout = nn.Dropout(dropout) # Newself.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # Newdef forward(self, x):b, num_tokens, d_in = x.shape # New batch dimension b# For inputs where `num_tokens` exceeds `context_length`, this will result in errors# in the mask creation further below.# In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  # do not exceed `context_length` before reaching this forward method. keys = self.W_key(x)queries = self.W_query(x)values = self.W_value(x)attn_scores = queries @ keys.transpose(1, 2) # Changed transposeattn_scores.masked_fill_(  # New, _ ops are in-placeself.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_sizeattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights) # Newcontext_vec = attn_weights @ valuesreturn context_vectorch.manual_seed(123)context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)context_vecs = ca(batch)

多头注意力

7, 多头注意力的逻辑

在这里插入图片描述

上面(3-7)的代码如下:

class MultiHeadAttention(nn.Module):def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):super().__init__()assert (d_out % num_heads == 0), \"d_out must be divisible by num_heads"  # 整除self.d_out = d_outself.num_heads = num_heads  # 多少个头self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dimself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputsself.dropout = nn.Dropout(dropout)self.register_buffer("mask",torch.triu(torch.ones(context_length, context_length),diagonal=1))def forward(self, x):b, num_tokens, d_in = x.shape# As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, # this will result in errors in the mask creation further below. # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  # do not exceed `context_length` before reaching this forwarkeys = self.W_key(x) # Shape: (b, num_tokens, d_out)queries = self.W_query(x)values = self.W_value(x)# 要把 head_num 加到矩阵中: # We implicitly split the matrix by adding a `num_heads` dimension# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) values = values.view(b, num_tokens, self.num_heads, self.head_dim)queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim), 即转置一下, 让num_tokens, head_dim 放最后用于后台的@操作keys = keys.transpose(1, 2)queries = queries.transpose(1, 2)values = values.transpose(1, 2)# Compute scaled dot-product attention (aka self-attention) with a causal maskattn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head# Original mask truncated to the number of tokens and converted to booleanmask_bool = self.mask.bool()[:num_tokens, :num_tokens]# Use the mask to fill attention scoresattn_scores.masked_fill_(mask_bool, -torch.inf)attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights)# Shape: (b, num_tokens, num_heads, head_dim)context_vec = (attn_weights @ values).transpose(1, 2) # Combine heads, where self.d_out = self.num_heads * self.head_dimcontext_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)context_vec = self.out_proj(context_vec) # optional projectionreturn context_vectorch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)context_vecs = mha(batch)print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
http://www.dtcms.com/a/331156.html

相关文章:

  • PDM 如何通过 ERP/PLM 释放数据价值?
  • 力扣面试150(56/150)
  • CodeTop 复习
  • [免费]基于Python的影视数据可视化分析系统(Flask+echarts)【论文+源码+SQL脚本】
  • 实战指南|消防管理系统搭建全流程解析
  • Android 常用框架汇总
  • AI需要提供情绪价值吗?GPT-4o风波背后的安全与孤独之战
  • 云原生俱乐部-杂谈1
  • python爬虫学习(2)
  • vite.config.js详解;本地配置获取真实请求地址
  • mysql——count(*)、count(1)和count(字段)谁更快?有什么区别?
  • 《软件工程导论》实验报告三 需求分析建模(二)
  • SQL LEFT JOIN 与 WHERE 条件的隐藏坑
  • anaconda创建pytorch1.10.0和pytorch2.0.0的GPU环境
  • iOS 26 一键登录失效:三大运营商 SDK 无法正常获取手机号
  • 装个 Oracle 23ai 本地版玩玩~
  • 短剧小程序系统开发:赋能创作者,推动短剧艺术创新发展
  • SpringBoot+Vue线上部署MySQL问题解决
  • CPP模板编程
  • AI驱动的智能爬虫架构与应用
  • openvsx搭建私有插件仓库
  • 设计模式有哪些
  • 2022_ISG_CTF-rechall详解(含思考过程)
  • MixOne在macOS上安装碰到的问题
  • 上网行为安全概述
  • 【跨越 6G 安全、防御与智能协作:从APT检测到多模态通信再到AI代理语言革命】
  • 数据结构:用链表实现队列(Implementing Queue Using List)
  • python批量爬虫实战之windows主题爬取
  • 移位操作符技巧
  • 8. 函数简介