使用BART模型和T5模型实现文本改写
BART模型
BART(Bidirectional and Auto-Regressive Transformers)是由 Facebook AI Research(FAIR)在 2019 年提出的序列到序列(seq2seq)预训练模型,论文发表于《BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension》。
它结合了 BERT 的双向编码器 和 GPT 的自回归解码器,专为文本生成任务(如摘要、翻译、对话)设计,同时在理解任务(如分类、问答)上也表现优异。
BART 通过灵活的预训练任务和统一的编解码架构,成为生成与理解任务的通用基础模型,尤其适合需要同时处理输入理解和输出生成的场景。
核心特点
架构:标准 Transformer 编解码器
编码器:双向 Transformer(类似 BERT),理解上下文。
解码器:自回归 Transformer(类似 GPT),从左到右生成文本。
参数规模:从 BART-Base(140M)到 BART-Large(400M)。
预训练任务:文本破坏与还原(Denoising Autoencoder) 通过多种噪声破坏输入文本,再让模型还原原始文本,提升生成与理解能力:
Token Masking(类似 BERT):随机遮盖词(如
[MASK]
)。Token Deletion:随机删除词,需还原位置和内容。
Text Infilling:用单个
[MASK]
替换连续片段(如 SpanBERT),需生成缺失片段。Sentence Permutation:打乱句子顺序,需重排。
Document Rotation:随机选择词作为开头,需还原原文起始点。
微调灵活性:可直接用于下游任务:
生成任务:摘要(CNN/DailyMail)、对话、翻译(需多语言预训练)。
理解任务:文本分类、问答(将输入编码,解码为答案)。
推理示例代码:
from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipelineclass ChineseBart:def __init__(self):model_path = "/path/to/bart-base-chinese"self.load_model(model_path)def load_model(self, model_path):# 加载一个中文BART模型(假设已经有微调好的改写模型权重)self.tokenizer = BertTokenizer.from_pretrained(model_path)self.model = BartForConditionalGeneration.from_pretrained(model_path)self.text2text_generator = Text2TextGenerationPipeline(self.model, self.tokenizer, device=0) def rewrite_text(self, text):# text = "机器学习模型在图像识别领域取得了突破性的进展。"# 构造输入(BART可以直接输入文本)ret = self.text2text_generator(text, max_length=512, do_sample=False)if len(ret) > 0:rewritten_texts = []for obj in ret:ret_text = obj.get('generated_text').replace(" ", "")rewritten_texts.append(ret_text)rewritten_text = "\n\n".join(rewritten_texts)print("改写结果:", rewritten_text)return rewritten_textreturn text
T5模型
T5(Text-to-Text Transfer Transformer)是 Google Research 在 2019 年提出的统一文本到文本框架,论文发表于《Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer》。它将所有 NLP 任务(翻译、摘要、问答、分类等)统一为“文本输入 → 文本输出”的范式,通过大规模预训练 + 微调实现通用能力。
核心特点
统一框架:所有任务都是 Text-to-Text
- 输入和输出均为纯文本,无需任务特定架构。
- 任务前缀:通过在输入前加提示词区分任务,例如:
translate English to German: ...
summarize: ...
cola sentence: ...
(分类任务输出acceptable
或unacceptable
)。
架构:标准 Encoder-Decoder Transformer
- 完全基于原始 Transformer(Vaswani et al., 2017),未做架构创新。
- 规模:从 T5-Small(60M)到 T5-11B(110亿参数,最大版本)。
预训练任务:Span Corruption(改进的 MLM)
- 类似 BERT 的掩码语言模型(MLM),但连续片段(span)被掩码(平均长度3),需解码器还原。
- 预训练数据:C4(Colossal Clean Crawled Corpus),750GB cleaned English text。
微调灵活性
- 单任务微调:针对特定任务(如翻译)微调。
- 多任务微调:混合多个任务前缀联合训练(如翻译+摘要+QA)。
- 零样本/少样本:通过任务前缀泛化到新任务(如未微调的数学题)。
推理示例代码:
from transformers import T5Tokenizer, T5ForConditionalGenerationclass ChineseT5:def __init__(self):print("ChineseT5")model_path = "/path/to/flan-t5-base"self.load_model(model_path)def load_model(self, model_name):# 加载一个中文T5模型(假设已经有微调好的改写模型权重)self.tokenizer = T5Tokenizer.from_pretrained(model_name, legacy=False)self.model = T5ForConditionalGeneration.from_pretrained(model_name)def rewrite_text(self, input_text):# 构造输入(添加适当的前缀)input_text = "rewrite: " + input_textinput_ids = self.tokenizer(input_text, return_tensors="pt").input_idsoutputs = self.model.generate(input_ids)if len(outputs) > 0:rewritten_text = self.tokenizer.decode(outputs[0])print("改写结果:", rewritten_text)return rewritten_textreturn input_text