当前位置: 首页 > news >正文

大模型生成长度预测器

在这里插入图片描述


生成式AI大模型输出长度预测器:S³论文预测器复现手记

为什么需要输出长度预测?

在Transformer大模型推理过程中,**KV缓存(Key-Value Cache)**的内存占用会随着序列长度呈线性增长。传统系统(如HuggingFace Transformers和FasterTransformer)要么频繁分配内存导致延迟,要么预分配最大长度造成资源浪费。S³论文的核心突破在于:通过预测输出序列长度实现精准显存分配,将吞吐量提升6.49倍。


预测器设计解析

1. 模型架构选择

论文采用DistilBERT-base(66M参数)作为基础模型,主要考量:

  • 轻量高效:单次预测仅需3.7ms(A100 GPU)
  • 兼容性强:模型体积小于大语言模型的单层参数(如GPT-J每层214M)
  • 微调潜力:在问答数据集上展现98.6%的分桶准确率
2. 数据准备策略
  • 数据集:Alpaca(指令微调数据集)、Google Natural Questions、The Pile
  • 标签构造:将输出序列长度划分为10个桶(Bucket)

数据集转化

import json
from transformers import DistilBertTokenizer

# 定义最大输出长度和桶的数量
MAX_OUTPUT_LENGTH = 1024
NUM_BUCKETS = 8  # 0 - 7 个类别

# 计算每个桶的大小
bucket_size = MAX_OUTPUT_LENGTH // NUM_BUCKETS

# 加载 DistilBERT 的分词器
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-based-uncased')

def get_bucket(token_length):
    """根据 token 长度返回对应的桶编号"""
    if token_length >= MAX_OUTPUT_LENGTH:
        return NUM_BUCKETS - 1
    return token_length // bucket_size

def transform_data(input_file, output_file):
    """将原始数据集转换为包含 prompt 和 output token 长度桶的新数据集"""
    with open(input_file, 'r') as f:
        data = json.load(f)
    
    transformed_data = []
    
    for item in data:
        prompt = f"{item['instruction']} {item['input']}"
        # 使用分词器计算 output 的 token 长度
        output_ids = tokenizer.encode(item['output'])
        output_ids_length = len(output_ids)
        bucket = get_bucket(output_ids_length)
        
        transformed_data.append({
            "prompt": prompt,
            "output_length_bucket": bucket
        })
    
    with open(output_file, 'w') as f:
        json.dump(transformed_data, f, indent=4)

# 输入文件和输出文件路径
input_file = 'alpaca_data.json'
output_file = 'alpaca_data_postprocess.json'

# 执行转换
transform_data(input_file, output_file)

print(f"数据集已成功转换并保存到 {output_file}")

模型训练(微调)

import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, AdamW
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# 超参数设置
MAX_LENGTH = 512  # 输入的最大长度
BATCH_SIZE = 16   # 批大小
EPOCHS = 10        # 训练轮数
LEARNING_RATE = 2e-5  # 学习率
NUM_BUCKETS = 8
# 加载本地分词器
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-based-uncased')

# 自定义数据集类
class OutputLengthDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        prompt = item['prompt']
        label = item['output_length_bucket']
        
        # 对 prompt 进行编码
        encoding = self.tokenizer(
            prompt,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # 返回输入和标签
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# 加载数据集
with open('alpaca_data_postprocess.json', 'r') as f:
    data = json.load(f)

# 划分训练集和验证集
train_data, val_data = train_test_split(data, test_size=0.1, random_state=42)

# 创建数据集和数据加载器
train_dataset = OutputLengthDataset(train_data, tokenizer, MAX_LENGTH)
val_dataset = OutputLengthDataset(val_data, tokenizer, MAX_LENGTH)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# 加载本地预训练模型
model = DistilBertForSequenceClassification.from_pretrained(
    'distilbert-based-uncased',
    num_labels=NUM_BUCKETS  # 分类任务的类别数
)

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 优化器
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

# 训练函数
def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # 前向传播
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Training loss: {avg_loss}")

# 验证函数
def evaluate(model, dataloader, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # 前向传播
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits
            
            total_loss += loss.item()
            
            # 计算准确率
            predictions = torch.argmax(logits, dim=-1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    
    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    print(f"Validation loss: {avg_loss}, Accuracy: {accuracy}")

# 训练和验证循环
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    train(model, train_loader, optimizer, device)
    evaluate(model, val_loader, device)

# 保存微调后的模型
model.save_pretrained('distilbert-based-uncased-finetuned')
tokenizer.save_pretrained('distilbert-based-uncased-finetuned')

创建模型服务接口

import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer

class LengthPredictor:
    def __init__(self, model_path, max_length=512, bucket_size=128, num_buckets=8):
        """
        初始化 LengthPredictor。
        
        :param model_path: 微调后的模型路径
        :param max_length: 输入的最大长度
        :param bucket_size: 每个桶的大小
        :param num_buckets: 桶的数量
        """
        self.max_length = max_length
        self.bucket_size = bucket_size
        self.num_buckets = num_buckets
        
        # 加载分词器和模型
        self.tokenizer = DistilBertTokenizer.from_pretrained(model_path)
        self.model = DistilBertForSequenceClassification.from_pretrained(model_path)
        
        # 设置设备
        # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = torch.device('cpu')
        self.model.to(self.device)
    
    def predict_length(self, prompt):
        """
        根据 prompt 预测输出长度。
        
        :param prompt: 输入的 prompt 文本
        :return: 预测的长度(向上取整)
        """
        # 对 prompt 进行编码
        encoding = self.tokenizer(
            prompt,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # 将输入数据移动到设备
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        # 模型预测
        self.model.eval()
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            predicted_class = torch.argmax(logits, dim=-1).item()
        
        # 将类别映射到长度范围
        if predicted_class == 0:
            return self.bucket_size
        else:
            return (predicted_class + 1) * self.bucket_size

# 示例用法
if __name__ == "__main__":
    # 初始化预测器
    model_path = 'distilbert-based-uncased-finetuned'
    predictor = LengthPredictor(model_path)
    
    # 测试预测
    prompt = "Identify the odd one out. Twitter, Instagram, Telegram"
    predicted_length = predictor.predict_length(prompt)
    print(f"Predicted length: {predicted_length}")

相关文章:

  • Solon AI —— RAG
  • 推流项目的ffmpeg配置和流程重点总结一下
  • 【Elasticsearch】Elasticsearch 中使用 HDFS 存储快照
  • 从vue源码解析Vue.set()和this.$set()
  • Checkpoint 模型与Stable Diffusion XL(SDXL)模型的区别
  • SpringBoot 异常处理
  • 【四.RAG技术与应用】【12.阿里云百炼应用(下):RAG的云端优化与扩展】
  • 靶场之路-VulnHub-DC-6 nmap提权、kali爆破、shell反连
  • 【MySQL】MySQL 复制
  • Git 批量合并 Commit 并且保留之前的 Commit 快速实现的思路
  • 【Jenkins】Pipeline流水线语法解析全集 -- 脚本式流水线、groovy语法
  • 数字后端培训实战项目六大典型后端实现案例
  • DeepSeek:构筑大数据平台底座的最优解
  • Unity3D 刚体动力学(Rigidbody Dynamics)详解
  • LIUNX学习-线程
  • 【3DMAX室内设计】2D转3D平面图插件2Dto3D使用方法
  • TomcatServlet
  • MyBatis-Plus 自定义 SQL 和复杂查询
  • 迭代器模式:遍历集合的艺术
  • flink集成tidb cdc
  • 怎么黑进网站后台/友情链接格式
  • 网站建设集团/正规手游代理平台有哪些
  • 下载的网站模板如何安装/网上营销的方式
  • 门店管理系统软件免费/爱站seo工具包官网
  • 网站主机ip查询/微信加人推码35一单
  • 公司网站建设哪家公司好/重庆关键词优化服务