whisper 语种检测学习笔记
目录
transformers推理:
transformers 源代码
网上的语种检测调用例子:
语种检测 api
transformers推理:
https://github.com/openai/whisper/blob/c0d2f624c09dc18e709e37c2ad90c039a4eb72a2/whisper/decoding.py
waveform, sample_rate = torchaudio.load(file_path)# Ensure the sample rate is 16000 Hz (Whisper's expected sample rate)if sample_rate != 16000:waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)inputs = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000)with torch.no_grad():generated_ids = model.generate(inputs["input_features"])# Extract language token from the model's outputlanguage_token = processor.tokenizer.decode(generated_ids[0][:2]) # First two tokensreturn processor.tokenizer.convert_tokens_to_string(language_token)
transformers 源代码
https://github.com/huggingface/transformers/blob/05000aefe173bf7a10fa1d90e4c528585b45d3c7/src/transformers/models/whisper/generation_whisper.py#L1622
def detect_language(self,input_features: Optional[torch.FloatTensor] = None,encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,generation_config: Optional[GenerationConfig] = None,num_segment_frames: int = 3000,) -> torch.Tensor:"""Detects language from log-mel input features or encoder_outputsParameters:input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained byloading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* viathe soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into atensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence ofhidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.generation_config (`~generation.GenerationConfig`, *optional*):The generation configuration to be used as base parametrization for the generation call. `**kwargs`passed to generate matching the attributes of `generation_config` will override them. If`generation_config` is not provided, the default will be used, which had the following loadingpriority: 1) from the `generation_config.json` model file, if it exists; 2) from the modelconfiguration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'sdefault values, whose documentation should be checked to parameterize generation.num_segment_frames (`int`, *optional*, defaults to 3000):The number of log-mel frames the model expectsReturn:A `torch.LongTensor` representing the detected language ids."""if input_features is None and encoder_outputs is None:raise ValueError("You have to specify either `input_features` or `encoder_outputs`")elif input_features is not None and encoder_outputs is not None:raise ValueError("Make sure to specify only one of `input_features` or `encoder_outputs` - not both!")elif input_features is not None:inputs = {"input_features": input_features[:, :, :num_segment_frames]}batch_size = input_features.shape[0]elif encoder_outputs is not None:inputs = {"encoder_outputs": encoder_outputs}batch_size = (encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0])generation_config = generation_config or self.generation_configdecoder_input_ids = (torch.ones((batch_size, 1), device=self.device, dtype=torch.long)* generation_config.decoder_start_token_id)with torch.no_grad():logits = self(**inputs, decoder_input_ids=decoder_input_ids, use_cache=False).logits[:, -1]non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)non_lang_mask[list(generation_config.lang_to_id.values())] = Falselogits[:, non_lang_mask] = -np.inflang_ids = logits.argmax(-1)return lang_ids
网上的语种检测调用例子:
import whispermodel = whisper.load_model("base") # 加载预训练的语音识别模型,这里使用了名为"base"的模型。# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio("audio.mp3")
audio = whisper.pad_or_trim(audio) # 对加载的音频进行填充或裁剪,使其适合30秒的滑动窗口处理。# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device)
# 将音频转换为对数梅尔频谱图,并将其移动到与模型相同的设备(如GPU)上进行处理。# detect the spoken language
_, probs = model.detect_language(mel) # 使用模型进行语言检测,返回检测到的语言和对应的概率。
# 打印检测到的语言,选取概率最高的语言作为结果。
print(f"Detected language: {max(probs, key=probs.get)}")# decode the audio
# 置解码的选项,如语言模型、解码器等。
options = whisper.DecodingOptions()
# 使用模型对音频进行解码,生成识别结果。
result = whisper.decode(model, mel, options)# print the recognized text
# 打印识别结果,即模型识别出的文本内容。
print(result.text)
————————————————
版权声明:本文为CSDN博主「陌上阳光」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_42831564/article/details/138667560
语种检测 api
语种检测源代码:
https://github.com/openai/whisper/blob/c0d2f624c09dc18e709e37c2ad90c039a4eb72a2/whisper/decoding.py
@torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]:"""Detect the spoken language in the audio, and return them as list of strings, along with the idsof the most probable language tokens and the probability distribution over all language tokens.This is performed outside the main decode loop in order to not interfere with kv-caching.Returns-------language_tokens : Tensor, shape = (n_audio,)ids of the most probable language tokens, which appears after the startoftranscript token.language_probs : List[Dict[str, float]], length = n_audiolist of dictionaries containing the probability distribution over all languages."""if tokenizer is None:tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)if (tokenizer.language is Noneor tokenizer.language_token not in tokenizer.sot_sequence):raise ValueError("This model doesn't have language tokens so it can't perform lang id")single = mel.ndim == 2if single:mel = mel.unsqueeze(0)# skip encoder forward pass if already-encoded audio features were givenif mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):mel = model.encoder(mel)# forward pass using a single token, startoftranscriptn_audio = mel.shape[0]x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]logits = model.logits(x, mel)[:, 0]# collect detected languages; suppress all non-language tokensmask = torch.ones(logits.shape[-1], dtype=torch.bool)mask[list(tokenizer.all_language_tokens)] = Falselogits[:, mask] = -np.inflanguage_tokens = logits.argmax(dim=-1)language_token_probs = logits.softmax(dim=-1).cpu()language_probs = [{c: language_token_probs[i, j].item()for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)}for i in range(n_audio)]if single:language_tokens = language_tokens[0]language_probs = language_probs[0]return language_tokens, language_probs