当前位置: 首页 > news >正文

PEFT实战LoRA微调OpenAI Whisper 中文语音识别

OpenAI Whisper模型介绍

OpenAI 的 Whisper 模型是自动语音识别系统。拥有以下特点

  • 大规模训练数据: 使用了从互联网收集的 680,000 小时 多语言、多任务的带标签数据进行训练。

  • 强大鲁棒性: 庞大的数据量使得模型对口音、背景噪音和专业术语具有更好的识别能力。

  • 多功能: 不仅能够将语音转录成文本,还能进行多种语言到英语的翻译。

  • 开源: OpenAI 开源了模型和推理代码,以促进相关应用和进一步的研究。

Whisper自动语音识别模型使用典型的 编码器-解码器 Transformer结构,输入音频被分割为 30秒的块,转换为 对数梅尔频谱图(Log-Mel Spectrogram),然后输入编码器,使用一个解码器来预测文本标题,并夹杂特殊标记(如语言标识、时间戳等),使单一模型能完成多种语音识别、翻译、语言检测等任务。

鲁棒性实验如下
在这里插入图片描述
实验数据解读:Wav2vec 2.0:一个重要的自监督学习语音识别模型。Whisper:本次介绍的主角。表格中的数值代表 词错误率相对降低的百分比。这个百分比是相对于 Wav2vec 2.0 模型的表现来计算的。数值越大越好,正数表示错误率降低(性能提升),负数表示错误率增加(性能下降)举例:在 AMI SDM1(会议录音,单麦克风)数据集上,Whisper 的词错误率比 Wav2vec 2.0 降低了 46.2%,这是一个巨大的提升。

结论:Whisper 在在未经特定数据集训练的情况下(零样本)场景下表现出色:这意味着它没有使用表格中这些特定数据集进行过训练,直接拿来测试,但效果非常好,其性能更接近人类水平:特别是在处理多样化、有挑战性的真实世界语音数据时。

数据集Common Voice介绍

Common Voice 11.0 数据集包含许多不同语言的录音,总时长达数小时。它在语音技术领域扮演着“基础设施”的角色。
在这里插入图片描述

实战

0、介绍

使用 LoRA 在 OpenAI Whisper-large-v2 模型上实现语音识别 (ASR) 任务的微调训练,还结合了 int8 量化进一步降低训练过程资源开销,同时保证了精度几乎不受影响。

1、全参数设置

# 原始模型保存路径
model_name_or_path = "openai/whisper-large-v2"
# 微调后模型保存路径
model_dir = "models/whisper-large-v2-asr-int8"language = "Chinese (China)"
language_abbr = "zh-CN"
# 指定任务为转录
task = "transcribe"
# 指定微调的数据集
dataset_name = "mozilla-foundation/common_voice_11_0"
# 指定批次
batch_size=64

2、下载数据集 Common Voice

Common Voice 11.0 数据集包含许多不同语言的录音,总时长达数小时。当前以中文数据为例,展示如何使用 LoRA 在 Whisper-large-v2 上进行微调训练。

from datasets import load_dataset
from datasets import load_dataset, DatasetDict# 初始化一个 DatasetDict 结构
common_voice = DatasetDict()
# 将训练集(将训练+验证拆分为训练集)和测试集拆分好,按照中文数据集构建配置加载到内存中
common_voice["train"] = load_dataset(dataset_name, language_abbr, split="train+validation")
common_voice["test"] = load_dataset(dataset_name, language_abbr, split="test")

3、预处理训练数据集

from transformers import AutoFeatureExtractor, AutoTokenizer, AutoProcessor# 从预训练模型加载特征提取器
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)# 从预训练模型加载分词器,可以指定语言和任务以获得最适合特定需求的分词器配置
tokenizer = AutoTokenizer.from_pretrained(  model_name_or_path, language=language, task=task)# 从预训练模型加载处理器,处理器通常结合了特征提取器和分词器,为特定任务提供一站式的数据预处理
processor = AutoProcessor.from_pretrained(  model_name_or_path, language=language, task=task)

针对不想要的数据集标题,可以移除

common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"]
)

降采样音频数据

查看 common_voice 数据集介绍,你会发现其音频是以48kHz的采样率进行采样的。而 Whisper 模型是在16kHz的音频输入上预训练的,因此我们需要将音频输入降采样以匹配模型预训练时使用的采样率。通过在音频列上使用 cast_column 方法,并将 sampling_rate 设置为16kHz来对音频进行降采样。

from datasets import Audio
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

定义数据预处理函数:用于将音频和文本数据转换为模型训练所需的格式,

  • 通过加载音频列将音频输入重新采样为16kHZ。
  • 使用特征提取器从音频数组计算输入特征。
  • 将句子列标记化为输入标签。
def prepare_dataset(batch):audio = batch["audio"]batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]batch["labels"] = tokenizer(batch["sentence"]).input_idsreturn batch

给到dataset.map()方法

tokenized_common_voice = common_voice.map(prepare_dataset, num_proc=8)

定义一个针对语音到文本(Seq2Seq) 模型的自定义数据整理器类,特别适用于输入为语音特征、输出为文本序列的数据集。

这个整理器 (DataCollatorSpeechSeq2SeqWithPadding) 旨在将数据点批量打包,将每个批次中的 attention_mask 填充到最大长度,以保持批处理中张量形状的一致性,并用 -100 替换填充值,以便在损失函数中被忽略。这对于神经网络的高效训练至关重要。

import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union# 定义一个针对语音到文本任务的数据整理器类
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:processor: Any  # 处理器结合了特征提取器和分词器# 整理器函数,将特征列表处理成一个批次def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:# 从特征列表中提取输入特征,并填充以使它们具有相同的形状input_features = [{"input_features": feature["input_features"]} for feature in features]batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")# 从特征列表中提取标签特征(文本令牌),并进行填充label_features = [{"input_ids": feature["labels"]} for feature in features]labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")# 使用-100替换标签中的填充区域,-100通常用于在损失计算中忽略填充令牌labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)# 如果批次中的所有序列都以句子开始令牌开头,则移除它if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():labels = labels[:, 1:]# 将处理过的标签添加到批次中batch["labels"] = labelsreturn batch  # 返回最终的批次,准备好进行训练或评估
# 用指定的处理器实例化数据整理器
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

4、模型准备

1、加载预训练模型(int8 精度)

使用 int8 精度加载预训练模型,进一步降低显存需求。

from transformers import AutoModelForSpeechSeq2Seq
model = AutoModeLForSpeechSeq2Seq.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")# 设置模型配置中的forced_decoder_ids属性为None,这通常用于指定在解码(生成文本)过程中必须使用的特定token的ID,设置为None表示没有这样的强制要求
model.config.forced_decoder_ids = None# 设置模型配置中的suppress_tokens列表为空,这用于指定在生成过程中应被抑制(不生成)的token的列表,设置为空列表表示没有要抑制的token
model.config.suppress_tokens = []
2、PEFT 微调前的模型处理

在使用 peft 训练 int8 模型之前,需要进行一些预处理:

  • 将所有非 int8 精度模块转换为全精度(fp32)以保证稳定性
  • 为输入嵌入层添加一个 forward_book,以启用输入隐藏状态的梯度计算
  • 启用梯度检查点以实现更高效的内存训练

使用 peft 库预定义的工具函数 prepare_model_for_int8_training,便可自动完成以上模型处理工作。

from peft import prepare_model_for_int8_training  
model = prepare_model_for_int8_training(model)
3、LoRA Adapter 配置

peft 中使用 LoRA 非常简捷,借助 PeftModel 抽象,我们可以快速使用低秩适配器(LoRA)到任意模型。

通过使用 peft 中的 get_peft_model 工具函数来实现。

关于 LoRA 超参数的说明:

MatWu1(B,A) * Scaling

​ Scaling = LoRA_Alpha / Rank

创建一个LoraConfig对象,用于设置LoRA(Low-Rank Adaptation)的配置参数

from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model# 创建一个LoraConfig对象,用于设置LoRA(Low-Rank Adaptation)的配置参数
config = LoraConfig(r=8, # LoRA的秩,影响LoRA矩阵的大小lora_alpha=64, # LoRA适应的比例因子# 指定将LoRA应用到的模型模块,通常是attention和全连接层的投影。target_modules=["q_proj", "v_proj"],lora_dropout=0.05, # 在LoRA模块中使用的dropout率bias="none", # 设置bias的使用方式,这里没有使用bias
)
4、使用get_peft_model函数和给定的配置来获取一个PEFT模型
# 使用get_peft_model函数和给定的配置来获取一个PEFT模型
peft_model = get_peft_model(model, config)# 打印 LoRA 微调训练的模型参数,可以看到到底要用多少的参数进行训练
# peft_model.print_trainable_parameters()

5、模型训练

1、训练参数

关于设置训练步数和评估步数

# 基于 epochs 设置:
num_train_epochs=3, # 训练的总轮数  
evaluation_strategy="epoch", # 设置评估策略,这里是在每个epoch结束时进行评估  
warmup_steps=50, # 在训练初期增加学习率的步数,有助于稳定训练# 基于 steps 设置:
max_steps=100, # 训练总步数
evaluation_strategy="steps",  # 评估策略
eval_steps=25, # 评估步数

设置序列到序列模型训练的参数

from transforms import Seq2SeqTrainingArguments# 设置序列到序列模型训练的参数
training_args = Seq2SeqTrainingArguments(output_dir=model_dir, # 指定模型输出和保存的目录per_device_train_batch_size=batch_size, # 每个设备上的训练批量大小learning_rate=1e-3, # 学习率num_train_epochs=1, # 训练的总轮数,实际可用3轮evaluation_strategy="epoch", # 设置评估策略,这里是在每个epoch结束时进行评估# warmup_steps=50, # 在训练初期增加学习率的步数,有助于稳定训练# fp16=True, # 启用混合精度训练,可以提高训练速度,同时减少内存使用per_device_eval_batch_size=batch_size, # 每个设备上的评估批量大小generation_max_length=128, # 生成任务的最大长度logging_steps=10, # 指定日志记录的步骤,用于跟踪训练进度remove_unused_columns=False, # 是否删除不使用的列,以减少数据处理开销label_names="labels", # 指定标签列的名称,用于训练过程中# evaluation_strategy="steps",# eval_steps=25,
)
2、实例化 Seq2SeqTrainer 训练器开始训练
from transforms import Seq2SeqTrainertrainer = Seq2SeqTrainer(args=training_args,model=peft_model,train_dataset=tokenized_common_voice["train"],eval_dataset=tokenized_common_voice["validation"],data_collator=data_collator,tokenizer=processor.feature_extractor,
)
peft_model.config.use_cachef= Falsetrainer.train()
3、保存训练的模型
trainer.save_model(model_dir)

6、使用微调好的模型

1、加载模型
  • 使用 PeftConfig 加载 LoRA Adapter 配置参数,使用 PeftModel 加载微调后 Whisper 模型
model_dir = "models/whisper-large-v2-asr-int8"language = "Chinese (China)"
language_abbr = "zh-CN"
language_decode = "chinese"
task = "transcribe"from transformers import AutoMode[ForSpeechSeq2Seq, AutoTokenizer, AutoProcessor
from peft import PeftConfig, PeftModelpeft_config = PeftConfig.from_pretrained(model_dir)
# base_model_name_or_path这是模型自带的常量
base_model = AutoMode[ForSpeechSeq2Seq.from_pretrained(peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto"
)peft_model = PeftModel.from_pretrained(base_model, model_dir)
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
processor = AutoProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
feature_extractor = processor.feature_extractor
2、使用 Pipeline API 部署微调后 Whisper 实现中文语音识别任务
test_audio = "data/audio/test_zh.flac"from transformers import AutomaticSpeechRecognitionPipelinepipeline = AutomaticSpeechRecognitionPipeline(model=peft_model, tokenizer=tokenizer, feature_extractor=feature_extractor)forced_decoder_ids = processor.get_decoder_prompt_ids(language=language_decode, task=task)import torchwith torch.cuda.amp.autocast():text = pipeline(test_audio, max_new_tokens=255)["text"]

7、评估微调好的模型

mode_name_or_path = "openai/whisper-large-v2"
model_dir = "models/whisper-large-v2-asr-int8"language = "Chinese (China)"
language_abbr = "zh-CN"
task = "transcribe"
dataset_name = "mozilla-foundation/common_voice_11_0"batch_size=16
from transformers import AutoModeLForSpeechSeq2Seq, AutoTokenizer, AutoProcessor
from peft import PeftConfig, PeftModelpeft_config = PeftConfig.from_pretrained(model_dir)base_model = AutoModeLForSpeechSeq2Seq.from_pretrained(peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto"
)
# 权重不用做反向传播,很多步骤可以省略
base_model.requires_grad(false)
trainer.save_model(model_dir)
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
processor = AutoProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
feature_extractor = processor.feature_extractor

评估数据集处理

from datasets import load_dataset, DatasetDict, Audiocommon_voice = DatasetDict()
common_voice["test"] = load_dataset(dataset_name, language_abbr, split="test", trust_remote_code=True)
common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"]
)
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
def prepare_dataset(batch):audio = batch["audio"]batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"].input_features[0]batch["labels"] = tokenizer(batch["sentence"].input_idsreturn batchsmall_common_voice = DatasetDict()small_common_voice["test"] = common_voice["test"].shuffle(seed=16).select(range(328))tokenized_common_voice = small_common_voice.map(prepare_dataset)

评估模型

import evaluate# 词错误率 (NER) 是评估ASR模型常用的指标。从 Evaluate 加载 MER 指标
metric = evaluate.load("wer")from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import gceval_dataloader = DataLoader(tokenized_common_voice["test"], batch_size=batch_size, collate_fm=data_collator)
# 遍历评估数据加载器中的所有批次
for step, batch in enumerate(tqdm(eval_dataloader)):# 使用自动混合精度来加速计算,并减少显存使用with torch.cuda.amp.autocast():# 不计算梯度,以节省计算资源,仅用于推理和评估with torch.no_grad():# 生成预测的标记(tokens),这里使用模型的generate函数进行文本生成generated_tokens = (peft_model.generate(input_features=batch["input_features"].to("cuda"),    # 将输入特征移动到GPU上decoder_input_ids=batch["labels"][:,:4].to("cuda"),    # 提供解码器的初始输入max_new_tokens=255,    # 设置生成的最大新标记数量).cpu()    # 将生成的标记移回CPU.numpy()    # 转换为Numpy数组以便进一步处理)# 获取批次中的标签,并将其移回CPUlabels = batch["labels"].cpu().numpy()# 将标签中的-100替换为填充标记的ID,-100通常用于忽略计算损失的标记labels = np.where(labels != -100, labels, tokenizer.pad_token_id)# 使用分词器解码生成的标记和标签,以获得可读的文本decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)# 将预测和参考添加到评估指标中,用于后续的性能评估metric.add_batch(predictions=decoded_preds,references=decoded_labels,)# 删除不同需要的变量以释放内存del generated_tokens, labels, batch# 手动触发垃圾收集,进一步清理内存gc.collect()

使用全量数据微调后,对比 WER 指标降低了多少

# 计算词错误率 (WER) 指标,并将结果转换为百分比形式
wer = 100 * metric.compute()# 打印词错误率,f"{wer=}"是一种格式化字符串的简洁写法,它会展示变量名和值
print(f"{wer=}%")
http://www.dtcms.com/a/434352.html

相关文章:

  • Django第三方扩展详解:提升开发效率的利器
  • 正能量不良网站直接进入自助建站系统模板
  • 考研复习-线性代数强化-向量组和方程组特征值
  • Chromium 138 编译指南 - Android 篇:环境搭建与准备(一)
  • 2023 年真题配套词汇单词笔记(考研真相)
  • Android 窗口结构(三) Home Task 添加Home ActivityRecord
  • 峨边网站建设网站iis安全配置
  • CMU与谷歌提出FM-SIREN:受奈奎斯特定理启发,让神经元“各司其职”,特征冗余降低50%
  • 【软件安全】fgets / strncpy / gets(不安全) / snprintf的对比
  • 济南免费做网站四平网站建设联系方式
  • 向量数据库前沿:Faiss 向量数据库的配置与使用
  • 机床铸铁底座在高端机床行业中的核心作用
  • 我为您整理出了 Coolify 可以添加的所有服务类型,并附上其用途说明。
  • 《回溯 C++98:string 核心机制拆解 —— 从拷贝策略到高效 swap》
  • JAVA过时了吗?
  • fnos安装并更新最新版sunpanel(显示为套件)
  • 资阳网站建设资阳河南郑州地图
  • 【Java初学基础11】Annotation-注解
  • OSPF LSA Type 2(Network LSA)概念及题目
  • OSPF 伪节点(Transit Node) 概念及题目
  • Android 开发 | 提取已安装应用的安装包 apk 的方式
  • 商务网站建设步骤犀牛云网站怎么建设
  • Vite 前端构建工具入门教程
  • 一种利用 qBittorrent 的 WebUI API 实现的检查BT种子的磁力链接是否可用的程序
  • nodejs换源管理工具nrm
  • async/await的基本使用以及fetchAPI的部分细节
  • MySQL新学知识(一)
  • 小迪web自用笔记47
  • 前端如何优雅地生成唯一标识?——一份跨环境 UUID 工具函数的封装与实战
  • iBizModel 应用程序(PSSYSAPP)模型体系详解