当前位置: 首页 > news >正文

【第四章:大模型(LLM)】01.神经网络中的 NLP-(2)Seq2Seq 原理及代码解析

第四章:大模型(LLM)

第二部分:神经网络中的 NLP

第二节:Seq2Seq 原理及代码解析

1. Seq2Seq(Sequence-to-Sequence)模型原理

Seq2Seq 是一种处理序列到序列任务(如机器翻译、文本摘要、对话生成等)的深度学习架构,最早由 Google 在 2014 年提出。其核心思想是使用 编码器(Encoder) 将输入序列编码为上下文向量,再通过 解码器(Decoder) 逐步生成输出序列。

1.1 架构组成

  1. 编码器(Encoder)

    • 通常是 RNN、LSTM 或 GRU。

    • 输入:序列 x = (x_1, x_2, ..., x_T)

    • 输出:隐藏状态 h_T​,作为上下文向量。

  2. 解码器(Decoder)

    • 结构类似于编码器。

    • 输入:编码器输出的上下文向量 + 上一步预测的输出。

    • 输出:目标序列 y = (y_1, y_2, ..., y_T)

  3. 上下文向量(Context Vector)

    • 编码器最后一个隐藏状态 h_T​ 作为整个输入序列的信息摘要。


2. 数学公式

  • 编码器:

h_t = f(h_{t-1}, x_t)

  • 解码器:

s_t = f(s_{t-1}, y_{t-1}, c)
P(y_t|y_{<t}, x) = \text{softmax}(W s_t)

其中 c 是上下文向量。


3. 经典 Seq2Seq 训练流程

  1. 输入序列通过编码器,生成上下文向量。

  2. 解码器利用上下文向量和前一时刻的预测结果,逐步生成输出。

  3. 使用 教师强制(Teacher Forcing) 技术,训练时将真实标签输入解码器。


4. 改进:Attention 机制

Seq2Seq 传统模型存在 长序列信息丢失 问题。
Attention 通过在每一步解码时为输入序列不同部分分配权重,解决了这个问题。
公式:

c_t = \sum_{i=1}^{T_x} \alpha_{t,i} h_i

其中 \alpha_{t,i}​ 是注意力权重。


5. PyTorch 代码解析:Seq2Seq 示例

import torch
import torch.nn as nn
import torch.optim as optim# Encoder
class Encoder(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers=1):super(Encoder, self).__init__()self.rnn = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)def forward(self, x):outputs, hidden = self.rnn(x)return hidden# Decoder
class Decoder(nn.Module):def __init__(self, output_dim, hidden_dim, num_layers=1):super(Decoder, self).__init__()self.rnn = nn.GRU(output_dim, hidden_dim, num_layers, batch_first=True)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x, hidden):output, hidden = self.rnn(x, hidden)pred = self.fc(output)return pred, hidden# Seq2Seq
class Seq2Seq(nn.Module):def __init__(self, encoder, decoder):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderdef forward(self, src, trg):hidden = self.encoder(src)outputs, _ = self.decoder(trg, hidden)return outputs# Example usage
input_dim, output_dim, hidden_dim = 10, 10, 32
encoder = Encoder(input_dim, hidden_dim)
decoder = Decoder(output_dim, hidden_dim)
model = Seq2Seq(encoder, decoder)src = torch.randn(16, 20, input_dim)  # batch=16, seq_len=20
trg = torch.randn(16, 20, output_dim)
output = model(src, trg)
print(output.shape)  # [16, 20, 10]


6. 应用场景

  • 机器翻译(Google Translate)

  • 文本摘要(新闻摘要生成)

  • 对话系统(聊天机器人)

  • 语音识别(语音到文本)

http://www.dtcms.com/a/301903.html

相关文章:

  • 数据结构 | 队列:从概念到实战
  • nvim cspell
  • Nginx HTTP 反向代理负载均衡实验
  • NAT地址转换,静态NAT,高级NAT,NAPT,easy IP
  • 【Linux指南】Linux粘滞位详解:解决共享目录文件删除安全隐患
  • GaussDB 开发基本规范
  • XML Expat Parser:深入解析与高效应用
  • Python 列表内存存储本质:存储差异原因与优化建议
  • 第4章唯一ID生成器——4.2 单调递增的唯一ID
  • 【Android】卡片式布局 滚动容器ScrollView
  • Go语法入门:变量、函数与基础数据类型
  • 飞算科技重磅出品:飞算 JavaAI 重构 Java 开发效率新标杆
  • JAVA后端开发——用 Spring Boot 实现定时任务
  • 【Spring】Spring Boot启动过程源码解析
  • 鸿蒙打包签名
  • HarmonyOS 6 云开发-用户头像上传云存储
  • 前端工程化常见问题总结
  • Windows|CUDA和cuDNN下载和安装,默认安装在C盘和不安装在C盘的两种方法
  • AI技术革命:产业重塑与未来工作范式转型。
  • 深入解析MIPI C-PHY (四)C-PHY物理层对应的上层协议的深度解析
  • 齐护Ebook科技与艺术Steam教育套件 可图形化micropython Arduino编程ESP32纸电路手工
  • 湖南(源点咨询)市场调研 如何在行业研究中快速有效介入 起头篇
  • Triton编译
  • 【n8n教程笔记——工作流Workflow】文本课程(第一阶段)——5.5 计算预订订单数量和总金额 (Calculating booked orders)
  • Rouge:面向摘要自动评估的召回导向型指标——原理、演进与应用全景
  • 分表分库与分区表
  • Android启动时间优化大全
  • 蛋白质反向折叠模型-ProteinMPNN安装教程
  • 学习日志20 python
  • 【unitrix】 6.18 二进制小数特质(t_decimal.rs)