当前位置: 首页 > news >正文

transformer架构解析{模型基本测试}(含代码)-9

目录

前言

模型基本测试

学习目标

cpoy任务介绍

实现模型copy任务的四步曲

 实现模型训练和测试

构建数据集

 实例化模型,优化器和损失函数


前言

        经过前面的学习,我们已经学完了transformer模型的各个组成部分以及实现代码,最后也实现了模型的创建,接下来我们用一个任务来测试一下模型,看它是否能将规律学到。

模型基本测试

学习目标

        了解transformer模型基本测试的copy任务

        掌握实现模型copy任务的四步曲

cpoy任务介绍

        任务描述:针对数字序列进行学习,学习的最终目标是使输出与输入的序列相同,如输入[1,5,8,5,6]输出也是[1,5,8,5,6]

        任务意义:copy任务在模型基础测试中具有重要意义,因为copy操作对于模型来讲是一条明显的规律,因此模型能否在短时间内,小数据集中学会它,可以帮助我们断定模型的所有过程是否正常,是否已具备基本的学习能力。

实现模型copy任务的四步曲

 实现模型训练和测试

def train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, epochs, device):
    #模型的训练
    ''' 输入参数:
    model : 模型
    train_loader:训练数据
    vaL_loader: 测试数据
    criterion: 计算损失
    optimizer: 优化器
    epochs: 训练轮数
    device: 加载设备
    '''
    T_Loss = []  #训练的损失
    V_Loss = []  #测试的损失
    model.train() #模型训练
    for epoch in range(epochs):
        running_loss = 0.0
        for step,(src, tgt, src_mask, tgt_mask) in enumerate(train_loader):
            src, tgt, src_mask, tgt_mask = src.to(device), tgt.to(device), src_mask.to(device), tgt_mask.to(device)
            #print("第{}轮,第{}批次".format(epoch+1,step+1))
            optimizer.zero_grad()
            output = model(src, tgt, src_mask, tgt_mask)
            loss = criterion(output.contiguous().view(-1, output.size(-1)), tgt.contiguous().view(-1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            #print('loss:',loss.item())
        T_Loss.append(running_loss)
        print('----------------------------------------------------------------------')
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader)}')

        # 评估模型
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for step,(src, tgt, src_mask, tgt_mask) in enumerate(val_loader):
                src, tgt, src_mask, tgt_mask = src.to(device), tgt.to(device), src_mask.to(device), tgt_mask.to(device)
                output = model(src, tgt, src_mask, tgt_mask)
                if step == 4:
                    print(src)
                    print(torch.argmax(output,dim=-1))
                loss = criterion(output.contiguous().view(-1, output.size(-1)), tgt.contiguous().view(-1))
                val_loss += loss.item()
        V_Loss.append(val_loss)
        print(f'Validation Loss: {val_loss / len(val_loader)}')
        model.train()

构建数据集

from torch.utils.data import DataLoader, TensorDataset
train_src = torch.randint(0, 11, (100, 10))
print(train_src)
print(train_src.shape)
train_src[:,0] = 1
train_tgt = train_src
train_src_mask = torch.ones((100, 1, 10))
print(train_src_mask.shape)
train_tgt_mask = torch.ones((100, 1, 10))
train_dataset = TensorDataset(train_src, train_tgt, train_src_mask, train_tgt_mask)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_src = torch.randint(0, 11, (50, 10))
val_tgt = torch.randint(0, 11, (50, 10)) 
val_src[:,0]=1
val_tgt = val_src
val_src_mask = torch.ones((50, 1, 10))
val_tgt_mask = torch.ones((50, 1, 10))
val_dataset = TensorDataset(val_src, val_tgt, val_src_mask, val_tgt_mask)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False)

 实例化模型,优化器和损失函数

N=6
d_model=512
d_ff=2048
head=8
dropout=0.1
c = copy.deepcopy
source_vocab = 11
target_vocab = 11
model = make_model(source_vocab,target_vocab,N)
#使用make_model获得模型
# 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


文章转载自:

http://93fllTXG.dxzcr.cn
http://d1oyLPnn.dxzcr.cn
http://HI0TyEbj.dxzcr.cn
http://RuUBs7kt.dxzcr.cn
http://zA3jKyYZ.dxzcr.cn
http://lp81mGFK.dxzcr.cn
http://vAAh2ykd.dxzcr.cn
http://xBBNInGs.dxzcr.cn
http://YNKexhgx.dxzcr.cn
http://2lImxcmH.dxzcr.cn
http://I8OweuoP.dxzcr.cn
http://hJLITJz6.dxzcr.cn
http://wjjhOYXH.dxzcr.cn
http://FJlmVkk8.dxzcr.cn
http://MjzjyoOI.dxzcr.cn
http://DFawvZMM.dxzcr.cn
http://i9WvJZqb.dxzcr.cn
http://qjXDshVx.dxzcr.cn
http://EX2C30bo.dxzcr.cn
http://peOoi5FA.dxzcr.cn
http://iANAQ390.dxzcr.cn
http://yvhwEQRJ.dxzcr.cn
http://yQuirxiw.dxzcr.cn
http://T7sD6YhR.dxzcr.cn
http://4u5ZRquA.dxzcr.cn
http://hwCeLv8a.dxzcr.cn
http://8S4fb4Um.dxzcr.cn
http://Kx1sl5CI.dxzcr.cn
http://ebsciXem.dxzcr.cn
http://6kdQRiOC.dxzcr.cn
http://www.dtcms.com/a/51583.html

相关文章:

  • 软件测试(三)——Bug篇
  • 002.words and phrases
  • 通过多线程获取RV1126的AAC码流
  • CVE-2025-0392:JeeWMS graphReportController.do接口SQL注入漏洞复现
  • 磁盘空间用尽导致的系统500错误(failed to openstream:No space left on device)
  • Android14 OTA差分包升级报kPayloadTimestampError (51)
  • 使用 Deepseek + kimi 快速生成PPT
  • 通过计费集成和警报监控 Elasticsearch Service 成本
  • 宇树科技再落一子!天羿科技落地深圳,加速机器人创世纪
  • HDFS 为什么不适合处理小文件?
  • PMP项目管理—沟通管理篇—补充内容
  • Java常用正则表达式(身份证号、邮箱、手机号)格式校验
  • 大模型gpt结合drawio绘制流程图
  • 大数据技术基于聚类分析的消费者细分与推荐系统
  • AORO P9000 PRO三防平板携手RTK高精度定位,电力巡检效率倍增
  • Android10.0关于发送广播Sending non-protected broadcast android.price.public.close
  • 前端权限流程(基于rbac实现思想)
  • 【C语言】宏定义中X-Micro的使用
  • PAT乙级真题 / 知识点(2)
  • React Native v0.78 更新
  • 基于Asp.net的零食购物商城网站
  • Java多线程与高并发专题——ConcurrentHahMap 在 Java7 和 8 有何不同?
  • AIGC(生成式AI)试用 26 -- 跟着清华教程学习 - 个人理解
  • 微服务通信:用gRPC + Protobuf 构建高效API
  • java面试项目介绍,详细说明
  • 如何同步this.goodAllData里面的每一项给到row
  • 基于PyTorch的深度学习3——基于autograd的反向传播
  • 为AI聊天工具添加一个知识系统 之136 详细设计之77 通用编程语言 之7
  • MARL零样本协调之Fictitious Co-Play学习笔记
  • Python练习(握手问题,进制转换,日期问题,位运算,求和)