大模型生成长度预测器
生成式AI大模型输出长度预测器:S³论文预测器复现手记
为什么需要输出长度预测?
在Transformer大模型推理过程中,**KV缓存(Key-Value Cache)**的内存占用会随着序列长度呈线性增长。传统系统(如HuggingFace Transformers和FasterTransformer)要么频繁分配内存导致延迟,要么预分配最大长度造成资源浪费。S³论文的核心突破在于:通过预测输出序列长度实现精准显存分配,将吞吐量提升6.49倍。
预测器设计解析
1. 模型架构选择
论文采用DistilBERT-base(66M参数)作为基础模型,主要考量:
- 轻量高效:单次预测仅需3.7ms(A100 GPU)
- 兼容性强:模型体积小于大语言模型的单层参数(如GPT-J每层214M)
- 微调潜力:在问答数据集上展现98.6%的分桶准确率
2. 数据准备策略
- 数据集:Alpaca(指令微调数据集)、Google Natural Questions、The Pile
- 标签构造:将输出序列长度划分为10个桶(Bucket)
数据集转化
import json
from transformers import DistilBertTokenizer
# 定义最大输出长度和桶的数量
MAX_OUTPUT_LENGTH = 1024
NUM_BUCKETS = 8 # 0 - 7 个类别
# 计算每个桶的大小
bucket_size = MAX_OUTPUT_LENGTH // NUM_BUCKETS
# 加载 DistilBERT 的分词器
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-based-uncased')
def get_bucket(token_length):
"""根据 token 长度返回对应的桶编号"""
if token_length >= MAX_OUTPUT_LENGTH:
return NUM_BUCKETS - 1
return token_length // bucket_size
def transform_data(input_file, output_file):
"""将原始数据集转换为包含 prompt 和 output token 长度桶的新数据集"""
with open(input_file, 'r') as f:
data = json.load(f)
transformed_data = []
for item in data:
prompt = f"{item['instruction']} {item['input']}"
# 使用分词器计算 output 的 token 长度
output_ids = tokenizer.encode(item['output'])
output_ids_length = len(output_ids)
bucket = get_bucket(output_ids_length)
transformed_data.append({
"prompt": prompt,
"output_length_bucket": bucket
})
with open(output_file, 'w') as f:
json.dump(transformed_data, f, indent=4)
# 输入文件和输出文件路径
input_file = 'alpaca_data.json'
output_file = 'alpaca_data_postprocess.json'
# 执行转换
transform_data(input_file, output_file)
print(f"数据集已成功转换并保存到 {output_file}")
模型训练(微调)
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, AdamW
from sklearn.model_selection import train_test_split
from tqdm import tqdm
# 超参数设置
MAX_LENGTH = 512 # 输入的最大长度
BATCH_SIZE = 16 # 批大小
EPOCHS = 10 # 训练轮数
LEARNING_RATE = 2e-5 # 学习率
NUM_BUCKETS = 8
# 加载本地分词器
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-based-uncased')
# 自定义数据集类
class OutputLengthDataset(Dataset):
def __init__(self, data, tokenizer, max_length):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
prompt = item['prompt']
label = item['output_length_bucket']
# 对 prompt 进行编码
encoding = self.tokenizer(
prompt,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
# 返回输入和标签
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
# 加载数据集
with open('alpaca_data_postprocess.json', 'r') as f:
data = json.load(f)
# 划分训练集和验证集
train_data, val_data = train_test_split(data, test_size=0.1, random_state=42)
# 创建数据集和数据加载器
train_dataset = OutputLengthDataset(train_data, tokenizer, MAX_LENGTH)
val_dataset = OutputLengthDataset(val_data, tokenizer, MAX_LENGTH)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
# 加载本地预训练模型
model = DistilBertForSequenceClassification.from_pretrained(
'distilbert-based-uncased',
num_labels=NUM_BUCKETS # 分类任务的类别数
)
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# 优化器
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
# 训练函数
def train(model, dataloader, optimizer, device):
model.train()
total_loss = 0
for batch in tqdm(dataloader, desc="Training"):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# 前向传播
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Training loss: {avg_loss}")
# 验证函数
def evaluate(model, dataloader, device):
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# 前向传播
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
logits = outputs.logits
total_loss += loss.item()
# 计算准确率
predictions = torch.argmax(logits, dim=-1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct / total
print(f"Validation loss: {avg_loss}, Accuracy: {accuracy}")
# 训练和验证循环
for epoch in range(EPOCHS):
print(f"Epoch {epoch + 1}/{EPOCHS}")
train(model, train_loader, optimizer, device)
evaluate(model, val_loader, device)
# 保存微调后的模型
model.save_pretrained('distilbert-based-uncased-finetuned')
tokenizer.save_pretrained('distilbert-based-uncased-finetuned')
创建模型服务接口
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
class LengthPredictor:
def __init__(self, model_path, max_length=512, bucket_size=128, num_buckets=8):
"""
初始化 LengthPredictor。
:param model_path: 微调后的模型路径
:param max_length: 输入的最大长度
:param bucket_size: 每个桶的大小
:param num_buckets: 桶的数量
"""
self.max_length = max_length
self.bucket_size = bucket_size
self.num_buckets = num_buckets
# 加载分词器和模型
self.tokenizer = DistilBertTokenizer.from_pretrained(model_path)
self.model = DistilBertForSequenceClassification.from_pretrained(model_path)
# 设置设备
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = torch.device('cpu')
self.model.to(self.device)
def predict_length(self, prompt):
"""
根据 prompt 预测输出长度。
:param prompt: 输入的 prompt 文本
:return: 预测的长度(向上取整)
"""
# 对 prompt 进行编码
encoding = self.tokenizer(
prompt,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
# 将输入数据移动到设备
input_ids = encoding['input_ids'].to(self.device)
attention_mask = encoding['attention_mask'].to(self.device)
# 模型预测
self.model.eval()
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
# 将类别映射到长度范围
if predicted_class == 0:
return self.bucket_size
else:
return (predicted_class + 1) * self.bucket_size
# 示例用法
if __name__ == "__main__":
# 初始化预测器
model_path = 'distilbert-based-uncased-finetuned'
predictor = LengthPredictor(model_path)
# 测试预测
prompt = "Identify the odd one out. Twitter, Instagram, Telegram"
predicted_length = predictor.predict_length(prompt)
print(f"Predicted length: {predicted_length}")