当前位置: 首页 > news >正文

whisper 语种检测学习笔记

目录

transformers推理:

transformers 源代码

网上的语种检测调用例子:

语种检测 api


transformers推理:

https://github.com/openai/whisper/blob/c0d2f624c09dc18e709e37c2ad90c039a4eb72a2/whisper/decoding.py

   waveform, sample_rate = torchaudio.load(file_path)# Ensure the sample rate is 16000 Hz (Whisper's expected sample rate)if sample_rate != 16000:waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)inputs = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000)with torch.no_grad():generated_ids = model.generate(inputs["input_features"])# Extract language token from the model's outputlanguage_token = processor.tokenizer.decode(generated_ids[0][:2])  # First two tokensreturn processor.tokenizer.convert_tokens_to_string(language_token)

transformers 源代码

https://github.com/huggingface/transformers/blob/05000aefe173bf7a10fa1d90e4c528585b45d3c7/src/transformers/models/whisper/generation_whisper.py#L1622

def detect_language(self,input_features: Optional[torch.FloatTensor] = None,encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,generation_config: Optional[GenerationConfig] = None,num_segment_frames: int = 3000,) -> torch.Tensor:"""Detects language from log-mel input features or encoder_outputsParameters:input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained byloading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* viathe soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into atensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence ofhidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.generation_config (`~generation.GenerationConfig`, *optional*):The generation configuration to be used as base parametrization for the generation call. `**kwargs`passed to generate matching the attributes of `generation_config` will override them. If`generation_config` is not provided, the default will be used, which had the following loadingpriority: 1) from the `generation_config.json` model file, if it exists; 2) from the modelconfiguration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'sdefault values, whose documentation should be checked to parameterize generation.num_segment_frames (`int`, *optional*, defaults to 3000):The number of log-mel frames the model expectsReturn:A `torch.LongTensor` representing the detected language ids."""if input_features is None and encoder_outputs is None:raise ValueError("You have to specify either `input_features` or `encoder_outputs`")elif input_features is not None and encoder_outputs is not None:raise ValueError("Make sure to specify only one of `input_features` or `encoder_outputs` - not both!")elif input_features is not None:inputs = {"input_features": input_features[:, :, :num_segment_frames]}batch_size = input_features.shape[0]elif encoder_outputs is not None:inputs = {"encoder_outputs": encoder_outputs}batch_size = (encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0])generation_config = generation_config or self.generation_configdecoder_input_ids = (torch.ones((batch_size, 1), device=self.device, dtype=torch.long)* generation_config.decoder_start_token_id)with torch.no_grad():logits = self(**inputs, decoder_input_ids=decoder_input_ids, use_cache=False).logits[:, -1]non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)non_lang_mask[list(generation_config.lang_to_id.values())] = Falselogits[:, non_lang_mask] = -np.inflang_ids = logits.argmax(-1)return lang_ids

网上的语种检测调用例子:

import whispermodel = whisper.load_model("base") # 加载预训练的语音识别模型,这里使用了名为"base"的模型。# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio("audio.mp3")
audio = whisper.pad_or_trim(audio)  # 对加载的音频进行填充或裁剪,使其适合30秒的滑动窗口处理。# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device) 
# 将音频转换为对数梅尔频谱图,并将其移动到与模型相同的设备(如GPU)上进行处理。# detect the spoken language
_, probs = model.detect_language(mel) # 使用模型进行语言检测,返回检测到的语言和对应的概率。
# 打印检测到的语言,选取概率最高的语言作为结果。
print(f"Detected language: {max(probs, key=probs.get)}")# decode the audio
# 置解码的选项,如语言模型、解码器等。
options = whisper.DecodingOptions()
# 使用模型对音频进行解码,生成识别结果。
result = whisper.decode(model, mel, options)# print the recognized text
# 打印识别结果,即模型识别出的文本内容。
print(result.text)
————————————————
版权声明:本文为CSDN博主「陌上阳光」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_42831564/article/details/138667560

语种检测 api

语种检测源代码:

https://github.com/openai/whisper/blob/c0d2f624c09dc18e709e37c2ad90c039a4eb72a2/whisper/decoding.py

@torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]:"""Detect the spoken language in the audio, and return them as list of strings, along with the idsof the most probable language tokens and the probability distribution over all language tokens.This is performed outside the main decode loop in order to not interfere with kv-caching.Returns-------language_tokens : Tensor, shape = (n_audio,)ids of the most probable language tokens, which appears after the startoftranscript token.language_probs : List[Dict[str, float]], length = n_audiolist of dictionaries containing the probability distribution over all languages."""if tokenizer is None:tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)if (tokenizer.language is Noneor tokenizer.language_token not in tokenizer.sot_sequence):raise ValueError("This model doesn't have language tokens so it can't perform lang id")single = mel.ndim == 2if single:mel = mel.unsqueeze(0)# skip encoder forward pass if already-encoded audio features were givenif mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):mel = model.encoder(mel)# forward pass using a single token, startoftranscriptn_audio = mel.shape[0]x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]logits = model.logits(x, mel)[:, 0]# collect detected languages; suppress all non-language tokensmask = torch.ones(logits.shape[-1], dtype=torch.bool)mask[list(tokenizer.all_language_tokens)] = Falselogits[:, mask] = -np.inflanguage_tokens = logits.argmax(dim=-1)language_token_probs = logits.softmax(dim=-1).cpu()language_probs = [{c: language_token_probs[i, j].item()for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)}for i in range(n_audio)]if single:language_tokens = language_tokens[0]language_probs = language_probs[0]return language_tokens, language_probs

http://www.dtcms.com/a/332749.html

相关文章:

  • 迅为RK3588S开发板usb连接adb方式
  • USB ADB 简介
  • 机器学习 - Kaggle项目实践(4)Toxic Comment Classification Challenge 垃圾评论分类问题
  • python爬虫(四)----requests
  • 组合模式及优化
  • 新手向:Python循环结构(for/while)及控制语句(break/continue)
  • 中国象棋人机对战
  • 使用转换函数重载布尔值类
  • 申请第二个域名还要备案吗
  • 《软件工程导论》实验报告四 详细设计工具
  • 两幅美国国旗版权挂钩专利发起跨境诉讼
  • 云原生俱乐部-杂谈2
  • 机器学习之PCA降维
  • uniapp 开发微信小程序,获取经纬度并且转化详细地址(单独封装版本)
  • week1-[顺序结构]跑道
  • IStoreOS(OpenWrt)开启IPV6
  • 设备数据采集服务器软件TOP Server OPC Server详细介绍
  • wsl安装完美教程
  • Vulnhub Deathnote靶机复现攻略
  • 告别手动优化!React Compiler 自动记忆化技术深度解析
  • 16进制pcm数据转py波形脚本
  • Vim 常用快捷键及插件
  • 关于simplifyweibo_4_moods数据集的分类问题
  • 大白话解析“入口点合约”
  • Linux系统--库制作与原理
  • Java—注解
  • mysql-条件查询案例
  • zabbix部署问题后常见问题
  • Codeforces 无路可走
  • 分布式系统设计的容错机制