当前位置: 首页 > news >正文

雏光 网络推广 网站建设网站建设交流发言材料

雏光 网络推广 网站建设,网站建设交流发言材料,新河网站,建网站的步骤和方法生成式AI大模型输出长度预测器:S论文预测器复现手记 为什么需要输出长度预测? 在Transformer大模型推理过程中,**KV缓存(Key-Value Cache)**的内存占用会随着序列长度呈线性增长。传统系统(如HuggingFace…

在这里插入图片描述


生成式AI大模型输出长度预测器:S³论文预测器复现手记

为什么需要输出长度预测?

在Transformer大模型推理过程中,**KV缓存(Key-Value Cache)**的内存占用会随着序列长度呈线性增长。传统系统(如HuggingFace Transformers和FasterTransformer)要么频繁分配内存导致延迟,要么预分配最大长度造成资源浪费。S³论文的核心突破在于:通过预测输出序列长度实现精准显存分配,将吞吐量提升6.49倍。


预测器设计解析

1. 模型架构选择

论文采用DistilBERT-base(66M参数)作为基础模型,主要考量:

  • 轻量高效:单次预测仅需3.7ms(A100 GPU)
  • 兼容性强:模型体积小于大语言模型的单层参数(如GPT-J每层214M)
  • 微调潜力:在问答数据集上展现98.6%的分桶准确率
2. 数据准备策略
  • 数据集:Alpaca(指令微调数据集)、Google Natural Questions、The Pile
  • 标签构造:将输出序列长度划分为10个桶(Bucket)

数据集转化

import json
from transformers import DistilBertTokenizer# 定义最大输出长度和桶的数量
MAX_OUTPUT_LENGTH = 1024
NUM_BUCKETS = 8  # 0 - 7 个类别# 计算每个桶的大小
bucket_size = MAX_OUTPUT_LENGTH // NUM_BUCKETS# 加载 DistilBERT 的分词器
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-based-uncased')def get_bucket(token_length):"""根据 token 长度返回对应的桶编号"""if token_length >= MAX_OUTPUT_LENGTH:return NUM_BUCKETS - 1return token_length // bucket_sizedef transform_data(input_file, output_file):"""将原始数据集转换为包含 prompt 和 output token 长度桶的新数据集"""with open(input_file, 'r') as f:data = json.load(f)transformed_data = []for item in data:prompt = f"{item['instruction']} {item['input']}"# 使用分词器计算 output 的 token 长度output_ids = tokenizer.encode(item['output'])output_ids_length = len(output_ids)bucket = get_bucket(output_ids_length)transformed_data.append({"prompt": prompt,"output_length_bucket": bucket})with open(output_file, 'w') as f:json.dump(transformed_data, f, indent=4)# 输入文件和输出文件路径
input_file = 'alpaca_data.json'
output_file = 'alpaca_data_postprocess.json'# 执行转换
transform_data(input_file, output_file)print(f"数据集已成功转换并保存到 {output_file}")

模型训练(微调)

import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, AdamW
from sklearn.model_selection import train_test_split
from tqdm import tqdm# 超参数设置
MAX_LENGTH = 512  # 输入的最大长度
BATCH_SIZE = 16   # 批大小
EPOCHS = 10        # 训练轮数
LEARNING_RATE = 2e-5  # 学习率
NUM_BUCKETS = 8
# 加载本地分词器
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-based-uncased')# 自定义数据集类
class OutputLengthDataset(Dataset):def __init__(self, data, tokenizer, max_length):self.data = dataself.tokenizer = tokenizerself.max_length = max_lengthdef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]prompt = item['prompt']label = item['output_length_bucket']# 对 prompt 进行编码encoding = self.tokenizer(prompt,truncation=True,padding='max_length',max_length=self.max_length,return_tensors='pt')# 返回输入和标签return {'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'labels': torch.tensor(label, dtype=torch.long)}# 加载数据集
with open('alpaca_data_postprocess.json', 'r') as f:data = json.load(f)# 划分训练集和验证集
train_data, val_data = train_test_split(data, test_size=0.1, random_state=42)# 创建数据集和数据加载器
train_dataset = OutputLengthDataset(train_data, tokenizer, MAX_LENGTH)
val_dataset = OutputLengthDataset(val_data, tokenizer, MAX_LENGTH)train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)# 加载本地预训练模型
model = DistilBertForSequenceClassification.from_pretrained('distilbert-based-uncased',num_labels=NUM_BUCKETS  # 分类任务的类别数
)# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)# 优化器
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)# 训练函数
def train(model, dataloader, optimizer, device):model.train()total_loss = 0for batch in tqdm(dataloader, desc="Training"):input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)# 前向传播outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.loss# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()avg_loss = total_loss / len(dataloader)print(f"Training loss: {avg_loss}")# 验证函数
def evaluate(model, dataloader, device):model.eval()total_loss = 0correct = 0total = 0with torch.no_grad():for batch in tqdm(dataloader, desc="Evaluating"):input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)# 前向传播outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losslogits = outputs.logitstotal_loss += loss.item()# 计算准确率predictions = torch.argmax(logits, dim=-1)correct += (predictions == labels).sum().item()total += labels.size(0)avg_loss = total_loss / len(dataloader)accuracy = correct / totalprint(f"Validation loss: {avg_loss}, Accuracy: {accuracy}")# 训练和验证循环
for epoch in range(EPOCHS):print(f"Epoch {epoch + 1}/{EPOCHS}")train(model, train_loader, optimizer, device)evaluate(model, val_loader, device)# 保存微调后的模型
model.save_pretrained('distilbert-based-uncased-finetuned')
tokenizer.save_pretrained('distilbert-based-uncased-finetuned')

创建模型服务接口

import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerclass LengthPredictor:def __init__(self, model_path, max_length=512, bucket_size=128, num_buckets=8):"""初始化 LengthPredictor。:param model_path: 微调后的模型路径:param max_length: 输入的最大长度:param bucket_size: 每个桶的大小:param num_buckets: 桶的数量"""self.max_length = max_lengthself.bucket_size = bucket_sizeself.num_buckets = num_buckets# 加载分词器和模型self.tokenizer = DistilBertTokenizer.from_pretrained(model_path)self.model = DistilBertForSequenceClassification.from_pretrained(model_path)# 设置设备# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.device = torch.device('cpu')self.model.to(self.device)def predict_length(self, prompt):"""根据 prompt 预测输出长度。:param prompt: 输入的 prompt 文本:return: 预测的长度(向上取整)"""# 对 prompt 进行编码encoding = self.tokenizer(prompt,truncation=True,padding='max_length',max_length=self.max_length,return_tensors='pt')# 将输入数据移动到设备input_ids = encoding['input_ids'].to(self.device)attention_mask = encoding['attention_mask'].to(self.device)# 模型预测self.model.eval()with torch.no_grad():outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)logits = outputs.logitspredicted_class = torch.argmax(logits, dim=-1).item()# 将类别映射到长度范围if predicted_class == 0:return self.bucket_sizeelse:return (predicted_class + 1) * self.bucket_size# 示例用法
if __name__ == "__main__":# 初始化预测器model_path = 'distilbert-based-uncased-finetuned'predictor = LengthPredictor(model_path)# 测试预测prompt = "Identify the odd one out. Twitter, Instagram, Telegram"predicted_length = predictor.predict_length(prompt)print(f"Predicted length: {predicted_length}")
http://www.dtcms.com/a/480262.html

相关文章:

  • 网站内容如何优化揭阳网站制作费用
  • 秦皇岛公司做网站优创智汇高端网站建设电话怎么样
  • 电商网站建设与管理 教案灵宝市建设局网站
  • 寿光网站建设公司网站开发必须要用js
  • 响应式网站开发昌吉州住房和城乡建设局网站
  • 主流门户网站有哪些西部数码网站管理助手破解版
  • 南通网站建设项目中国最大的外贸平台
  • “中非咖桥 世界湘见”2025首届星沙-非洲咖啡嘉年华系列活动启动
  • 如何注册网站免费的吗婚礼效果图网站
  • Tesseract-OCR软件安装和语言包安装(Windows系统)
  • 订阅号如何做微网站能够做二维码网站
  • 扬州建网站室内设计师联盟官网入口
  • 建设网站需要购买虚拟主机吗wordpress视频缩略图不显示
  • 网页游戏怎么下载windows优化大师免费
  • t想学网站建设山西太原建站怎么做
  • 上海建设网站便宜的网站建设 费用高
  • 高端酒店网站模板免费下载佛山北京网站建设公司哪家好
  • 微信对接网站可以做301跳转吗连城县住房和城乡建设局 网站
  • 汉沽谁做网站找大学生做家教的网站
  • 便利的赣州网站建设凯里哪里有做网站的
  • 虚拟主机部署网站室内设计学校广州
  • wordpress建站需要学什么意思企业网站的推广建议
  • 诸城哪有做公司网站和的套用模板网站
  • 手机商城建站系统网站建设江干建设局网站
  • 网站建设公司 青岛医院 网站建设
  • 新网站建设服务公司给网站加织梦后台
  • 网页怎么做网站地图php 做资讯网站
  • 郑州微信公众号网站建设做笑话网站
  • 网站开发留言板代码公司宣传片ppt模板
  • 网站建设的相关论文设计软件排行榜