当前位置: 首页 > news >正文

transformer demo

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import pytestclass PositionalEncoding(nn.Module):def __init__(self, d_model, max_seq_length=5000):super(PositionalEncoding, self).__init__()# 创建位置编码矩阵pe = torch.zeros(max_seq_length, d_model)position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))# 计算正弦和余弦位置编码pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)# 注册为非训练参数self.register_buffer('pe', pe)def forward(self, x):# 添加位置编码到输入张量return x + self.pe[:, :x.size(1)]class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_heads# 定义线性变换层self.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.out_linear = nn.Linear(d_model, d_model)def forward(self, q, k, v, mask=None):batch_size = q.size(0)# 线性变换和重塑q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)# 计算注意力分数scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)# 应用掩码(如果提供)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)# 应用softmax获取注意力权重attn_weights = F.softmax(scores, dim=-1)# 应用注意力权重到值向量attn_output = torch.matmul(attn_weights, v)# 重塑并应用最终线性变换attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)output = self.out_linear(attn_output)return outputclass FeedForward(nn.Module):def __init__(self, d_model, d_ff):super(FeedForward, self).__init__()self.linear1 = nn.Linear(d_model, d_ff)self.linear2 = nn.Linear(d_ff, d_model)def forward(self, x):return self.linear2(F.relu(self.linear1(x)))class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout=0.1):super(EncoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = FeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):# 自注意力层和残差连接attn_output = self.self_attn(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))# 前馈网络和残差连接ff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return xclass DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout=0.1):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.cross_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = FeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, src_mask=None, tgt_mask=None):# 自注意力层和残差连接attn_output = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(attn_output))# 编码器-解码器注意力层和残差连接cross_attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)x = self.norm2(x + self.dropout(cross_attn_output))# 前馈网络和残差连接ff_output = self.feed_forward(x)x = self.norm3(x + self.dropout(ff_output))return xclass Transformer(nn.Module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers,num_decoder_layers, d_ff, max_seq_length, dropout=0.1):super(Transformer, self).__init__()# 词嵌入层self.src_embedding = nn.Embedding(src_vocab_size, d_model)self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)# 位置编码self.positional_encoding = PositionalEncoding(d_model, max_seq_length)# 编码器和解码器层self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout)for _ in range(num_encoder_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout)for _ in range(num_decoder_layers)])# 输出层self.output_layer = nn.Linear(d_model, tgt_vocab_size)self.dropout = nn.Dropout(dropout)self.d_model = d_model# 初始化参数self._init_parameters()def _init_parameters(self):for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def forward(self, src, tgt, src_mask=None, tgt_mask=None):# 源序列和目标序列的嵌入和位置编码src = self.src_embedding(src) * math.sqrt(self.d_model)src = self.positional_encoding(src)src = self.dropout(src)tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)tgt = self.positional_encoding(tgt)tgt = self.dropout(tgt)# 编码器前向传播enc_output = srcfor enc_layer in self.encoder_layers:enc_output = enc_layer(enc_output, src_mask)# 解码器前向传播dec_output = tgtfor dec_layer in self.decoder_layers:dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)# 输出层output = self.output_layer(dec_output)return output# 创建掩码函数
def create_masks(src, tgt):# 源序列掩码(用于屏蔽填充标记)src_mask = (src != 0).unsqueeze(1).unsqueeze(2)# 目标序列掩码(用于屏蔽填充标记和未来标记)tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)# 创建后续标记掩码(用于自回归解码)seq_length = tgt.size(1)nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()# 合并掩码tgt_mask = tgt_mask & nopeak_maskreturn src_mask, tgt_mask# 简单的训练函数
def train_transformer(model, optimizer, criterion, train_loader, epochs):model.train()for epoch in range(epochs):total_loss = 0for src, tgt in train_loader:# 创建掩码src_mask, tgt_mask = create_masks(src, tgt[:, :-1])# 前向传播output = model(src, tgt[:, :-1], src_mask, tgt_mask)# 计算损失loss = criterion(output.contiguous().view(-1, output.size(-1)),tgt[:, 1:].contiguous().view(-1))# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}')# 添加model fixture
@pytest.fixture
def model():# 定义超参数d_model = 512num_heads = 8num_encoder_layers = 6num_decoder_layers = 6d_ff = 2048max_seq_length = 100dropout = 0.1# 假设的词汇表大小src_vocab_size = 10000tgt_vocab_size = 10000# 创建模型model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads,num_encoder_layers, num_decoder_layers, d_ff, max_seq_length, dropout)return model# 添加test_loader fixture
@pytest.fixture
def test_loader():# 创建一个简单的测试数据集batch_size = 2seq_length = 10# 随机生成一些测试数据src_data = torch.randint(1, 10000, (batch_size, seq_length))tgt_data = torch.randint(1, 10000, (batch_size, seq_length))# 创建DataLoaderfrom torch.utils.data import TensorDataset, DataLoaderdataset = TensorDataset(src_data, tgt_data)test_loader = DataLoader(dataset, batch_size=batch_size)return test_loader# 简单的测试函数
def test_transformer(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for src, tgt in test_loader:# 创建掩码src_mask, _ = create_masks(src, tgt)# 预测output = model(src, tgt, src_mask, None)pred = output.argmax(dim=-1)# 计算准确率total += tgt.size(0) * tgt.size(1)correct += (pred == tgt).sum().item()accuracy = correct / totalprint(f'Test Accuracy: {accuracy:.4f}')# 简单的序列到序列翻译示例
def translate(model, src_sequence, src_vocab, tgt_vocab, max_length=50):model.eval()# 将源序列转换为索引src_indices = [src_vocab.get(token, src_vocab['<unk>']) for token in src_sequence]src_tensor = torch.LongTensor(src_indices).unsqueeze(0)# 创建源序列掩码src_mask = (src_tensor != 0).unsqueeze(1).unsqueeze(2)# 初始目标序列为开始标记tgt_indices = [tgt_vocab['<sos>']]with torch.no_grad():for i in range(max_length):tgt_tensor = torch.LongTensor(tgt_indices).unsqueeze(0)# 创建目标序列掩码_, tgt_mask = create_masks(src_tensor, tgt_tensor)# 预测下一个标记output = model(src_tensor, tgt_tensor, src_mask, tgt_mask)next_token_logits = output[:, -1, :]next_token = next_token_logits.argmax(dim=-1).item()# 添加预测的标记到目标序列tgt_indices.append(next_token)# 如果预测到结束标记,则停止if next_token == tgt_vocab['<eos>']:break# 将目标序列索引转换回标记tgt_sequence = [tgt_vocab.get(index, '<unk>') for index in tgt_indices]return tgt_sequence# 示例使用
if __name__ == "__main__":# 定义超参数d_model = 512num_heads = 8num_encoder_layers = 6num_decoder_layers = 6d_ff = 2048max_seq_length = 100dropout = 0.1# 假设的词汇表大小src_vocab_size = 10000tgt_vocab_size = 10000# 创建模型model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads,num_encoder_layers, num_decoder_layers, d_ff, max_seq_length, dropout)# 定义优化器和损失函数optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略填充标记# 这里应该有实际的数据加载代码# train_loader = ...# test_loader = ...# 训练模型# train_transformer(model, optimizer, criterion, train_loader, epochs=10)# 测试模型# test_transformer(model, test_loader)# 翻译示例# src_vocab = ...# tgt_vocab = ...# src_sequence = ["hello", "world", "!"]# translation = translate(model, src_sequence, src_vocab, tgt_vocab)# print(f"Source: {' '.join(src_sequence)}")# print(f"Translation: {' '.join(translation)}")

相关文章:

  • 东土科技参与国家重点研发计划 ,共同研发工业智控创新技术
  • 【Linux】进程创建、终止、等待、替换
  • 《单光子成像》第四章 预习2025.6.13
  • Vue里面的映射方法
  • python+django/flask厨房达人美食分享系统
  • 英语—四级CET4考试—规律篇—从历年真题中找规律—仔细阅读题—汇总
  • 秘籍分享:如何让ZIP下载的源码拥有Git“身份证”
  • Kubernetes安全机制深度解析(三):准入控制器
  • Cilium动手实验室: 精通之旅---26.Cilium Host Firewall
  • ffmpeg覆盖区域
  • 准确--使用 ThinBackup 插件执行备份和恢复
  • 泰国草药保健电商平台开发|泰式草药知识科普 + 跨境直邮,聚焦健康养生
  • codeforces 274D. Lovely Matrix
  • 【RAG+读代码】学术文档解析工具Nougat
  • ReentrantLock和RLock
  • 图数据库如何构筑 Web3 风控防线 | 聚焦批量注册与链上盗转
  • PRUD币将于6月16日正式上线欧易Web3交易所,市场热度持续飙升
  • python精讲之迭代器和生成器
  • 宝塔解决同源策略阻挡
  • Google Chrome 书签导出
  • wordpress发邮件收到不到邮件/郑州seo技术
  • 嘉兴专业定制网站制作企业/贴吧高级搜索
  • 网站建设与维护作业/如何做好网上销售
  • 做网站模板在哪儿找/小程序推广方案
  • 成功网站运营案例/湖南网站建站系统哪家好
  • 单一产品企业或多元化产品企业的网站建设与策划有什么不同?/企业seo顾问服务阿亮