当前位置: 首页 > news >正文

Unity使用GTCRN实现流式语音增强

流式语音增强

sherpa-onnx 已经添加了 GTCRN 实现了离线语音增强,但流式语音增强一直没有添加,用官方onnxruntime在Unity中实现了一下,有点问题,对比官方效果有差距,但增强效果还是可以的。
在这里插入图片描述
原始音频
在这里插入图片描述
官方效果
在这里插入图片描述
Unity里效果
在这里插入图片描述
Unity里最新效果

在这里插入图片描述

主要代码如下

using MathNet.Numerics;
using MathNet.Numerics.IntegralTransforms;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using UnityEngine;
using Debug = UnityEngine.Debug;public class GtcrnTest2 : MonoBehaviour
{string simpleModelPath = "your_model_name.onnx";string outputPath = "test_wavs/enh_onnx.wav"; // Path to save the enhanced audio.int sampleRate = 16000; // Audio sample rate.int n_fft = 512; // FFT window size.int hop_length = 256; // Hop length between consecutive frames.int win_length = 512; // Window length (usually equals n_fft for STFT).// Start is called before the first frame updatevoid Start(){simpleModelPath = Application.streamingAssetsPath + "/gtcrn_simple.onnx";outputPath = Application.streamingAssetsPath + "/result.wav";float[] rawAudio = ReadWav(Application.streamingAssetsPath + "/mix.wav");var window = Window.Hann(win_length).Select(x => (float)Math.Sqrt(x)).ToArray();(float[,,] stftResult, int frames) = ComputeSTFT(rawAudio, n_fft, hop_length, win_length, window);// 3. 初始化ONNX推理会话var session = new InferenceSession(simpleModelPath);var convCache = new DenseTensor<float>(new[] { 2, 1, 16, 16, 33 });var traCache = new DenseTensor<float>(new[] { 2, 3, 1, 1, 16 });var interCache = new DenseTensor<float>(new[] { 2, 1, 33, 16 });var outputs = new List<float[,,]>();for (int i = 0; i < frames; i++){// 准备当前帧var input = new DenseTensor<float>(new[] { 1, 257, 1, 2 });for (int j = 0; j < 257; j++){input[0, j, 0, 0] = stftResult[i, j, 0]; // 实部input[0, j, 0, 1] = stftResult[i, j, 1]; // 虚部}// 准备输入var inputs = new List<NamedOnnxValue>{NamedOnnxValue.CreateFromTensor("mix", input),NamedOnnxValue.CreateFromTensor("conv_cache", convCache),NamedOnnxValue.CreateFromTensor("tra_cache", traCache),NamedOnnxValue.CreateFromTensor("inter_cache", interCache)};// 推理using (var results = session.Run(inputs)){var enh = results.First(t => t.Name == "enh").AsTensor<float>();convCache = (DenseTensor<float>)results.First(t => t.Name == "conv_cache_out").AsTensor<float>();traCache = (DenseTensor<float>)results.First(t => t.Name == "tra_cache_out").AsTensor<float>();interCache = (DenseTensor<float>)results.First(t => t.Name == "inter_cache_out").AsTensor<float>();// 存储输出var frameOutput = new float[1, 257, 2];for (int j = 0; j < 257; j++){frameOutput[0, j, 0] = enh[0, j, 0, 0];frameOutput[0, j, 1] = enh[0, j, 0, 1];}outputs.Add(frameOutput);}}// 6. 合并所有帧var allFrames = new float[frames, 257, 2];for (int i = 0; i < frames; i++){for (int j = 0; j < 257; j++){allFrames[i, j, 0] = outputs[i][0, j, 0];allFrames[i, j, 1] = outputs[i][0, j, 1];}}// 7. 计算ISTFTfloat[] enhancedAudio = ComputeISTFT(allFrames, n_fft, hop_length, win_length, window);SaveClip(1, 16000, enhancedAudio, outputPath);}// Update is called once per framevoid Update(){}static (float[,,] result, int frames) ComputeSTFT(float[] audio, int n_fft, int hop, int win, float[] window){int frames = (audio.Length - n_fft) / hop + 1;var stft = new float[frames, n_fft / 2 + 1, 2]; // [frame, freq, real/imag]for (int i = 0; i < frames; i++){// 提取帧并加窗var frame = new float[n_fft];Array.Copy(audio, i * hop, frame, 0, Math.Min(n_fft, audio.Length - i * hop));for (int j = 0; j < n_fft; j++) frame[j] *= window[j];// 计算FFT (使用MathNet.Numerics)var complexFrame = new Complex32[n_fft];for (int j = 0; j < n_fft; j++){complexFrame[j] = new Complex32(frame[j], 0);}Fourier.Forward(complexFrame, FourierOptions.Default);// 存储结果(仅保留一半)for (int j = 0; j <= n_fft / 2; j++){stft[i, j, 0] = complexFrame[j].Real;stft[i, j, 1] = complexFrame[j].Imaginary;}}return (stft, frames);}static float[] ComputeISTFT(float[,,] stft, int n_fft, int hop, int win, float[] window){int frames = stft.GetLength(0);int outputLength = (frames - 1) * hop + n_fft;var output = new float[outputLength];var scale = window.Select(w => w * w).Sum(); // 用于归一化for (int i = 0; i < frames; i++){// 重建完整频谱var fullSpectrum = new Complex32[n_fft];for (int j = 0; j <= n_fft / 2; j++){fullSpectrum[j] = new Complex32(stft[i, j, 0], stft[i, j, 1]);if (j > 0 && j < n_fft / 2){fullSpectrum[n_fft - j] = fullSpectrum[j].Conjugate();}}// 逆FFTFourier.Inverse(fullSpectrum, FourierOptions.Default);// 加窗并重叠相加int pos = i * hop;for (int j = 0; j < n_fft; j++){if (pos + j < output.Length){output[pos + j] += fullSpectrum[j].Real * window[j] / scale;}}}return output;}float[] ReadWav(string filePath){using (FileStream fs = new FileStream(filePath, FileMode.Open, FileAccess.Read))using (BinaryReader reader = new BinaryReader(fs)){// 读取WAV文件头string riff = new string(reader.ReadChars(4));    // "RIFF"int fileSize = reader.ReadInt32();                // 文件总大小-8string wave = new string(reader.ReadChars(4));    // "WAVE"string fmt = new string(reader.ReadChars(4));     // "fmt "int fmtSize = reader.ReadInt32();                 // fmt块大小(至少16)// 读取音频格式信息short audioFormat = reader.ReadInt16();           // 1=PCMshort numChannels = reader.ReadInt16();           // 通道数int sampleRate = reader.ReadInt32();              // 采样率int byteRate = reader.ReadInt32();                // 字节率short blockAlign = reader.ReadInt16();            // 块对齐short bitsPerSample = reader.ReadInt16();         // 采样深度// 验证文件格式if (riff != "RIFF" || wave != "WAVE" || fmt != "fmt ")throw new Exception("无效的WAV文件头");// 跳过fmt块的额外信息(如果有)if (fmtSize > 16)reader.ReadBytes(fmtSize - 16);// 查找数据块string dataChunkId;do{dataChunkId = new string(reader.ReadChars(4));if (dataChunkId != "data")reader.ReadBytes(reader.ReadInt32()); // 跳过非数据块} while (dataChunkId != "data");int dataSize = reader.ReadInt32(); // 数据块大小(字节)// 验证音频参数if (audioFormat != 1)throw new Exception("仅支持PCM格式");if (numChannels != 1)throw new Exception("仅支持单声道音频");if (sampleRate != 16000)throw new Exception("仅支持16kHz采样率");if (bitsPerSample != 16)throw new Exception("仅支持16位采样深度");// 读取PCM数据并转换为floatint sampleCount = dataSize / 2; // 16位 = 2字节/样本float[] floatData = new float[sampleCount];for (int i = 0; i < sampleCount; i++){// 小端序读取16位样本byte lowByte = reader.ReadByte();byte highByte = reader.ReadByte();short pcmValue = (short)((highByte << 8) | lowByte);// 将16位PCM值转换为[-1.0, 1.0]范围的floatfloatData[i] = pcmValue / 32768.0f;}return floatData;}}void SaveClip(int channels, int frequency, float[] data, string filePath){using (FileStream fileStream = new FileStream(filePath, FileMode.Create)){using (BinaryWriter writer = new BinaryWriter(fileStream)){// 写入RIFF头部标识writer.Write("RIFF".ToCharArray());// 写入文件总长度(后续填充)writer.Write(0);writer.Write("WAVE".ToCharArray());// 写入fmt子块writer.Write("fmt ".ToCharArray());writer.Write(16); // PCM格式块长度writer.Write((short)1); // PCM编码类型writer.Write((short)channels);writer.Write(frequency);writer.Write(frequency * channels * 2); // 字节率writer.Write((short)(channels * 2)); // 块对齐writer.Write((short)16); // 位深度// 写入data子块writer.Write("data".ToCharArray());writer.Write(data.Length * 2); // 音频数据字节数// 写入PCM数据(float转为short)foreach (float sample in data){// 转换过程代码可能有误,此处把音量放大100倍writer.Write((short)(sample * 32767 * 100));}// 返回填充文件总长度fileStream.Position = 4;writer.Write((int)(fileStream.Length - 8));}}}
}

最后是工程地址
gtcrn-unity

http://www.dtcms.com/a/284409.html

相关文章:

  • SpringBoot一Web Flux、函数式Web请求的使用、和传统注解@Controller + @RequestMapping的区别
  • 探微“元宇宙”:概念内涵、形态发展与演变机理
  • CSS面试题及详细答案140道之(41-60)
  • Kiro AI IDE上手初体验!亚马逊出品,能否撼动Cursor的王座?
  • Amazon S3成本优化完全指南:从入门到精通
  • 8 几何叠加分析
  • 系统设计时平衡超时时间与多因素认证(MFA)带来的用户体验下降
  • 量子计算的安全与伦理:当技术革命叩击数字时代的潘多拉魔盒
  • sqli-labs靶场通关笔记:第25-26a关 and、or、空格和注释符多重过滤
  • 4G模块 A7680通过MQTT协议连接到腾讯云
  • AI赋能Baklib,重塑企业知识管理与客户支持方式
  • Curr. Res. Food Sci.|福州大学吕旭聪团队:富硒鼠李糖乳杆菌GG重塑肠-肝轴,显著缓解酒精性肝损伤
  • 网络通信之基础知识
  • deep learning(李宏毅)--(六)--loss
  • day19-四剑客与正则-特殊符号正则-awk
  • [yotroy.cool] 记一次 Git 移除某个不该提交的文件
  • iOS WebView 调试与性能优化 跨平台团队高效协作方法解析
  • PyTorch生成式人工智能(18)——循环神经网络详解与实现
  • 可视化图解算法56:岛屿数量
  • Word 中为什么我的图片一拖就乱跑,怎么精确定位?
  • python使用pymysql库
  • modbus 校验
  • 泛型与类型安全深度解析及响应式API实战
  • Java 集合框架详解:Collection 接口全解析,从基础到实战
  • 7月17日日记
  • 【机器学习】向量数据库选型指南:企业内网部署场景
  • 从零开始:C++ UDP通信实战教程
  • 河南萌新联赛2025第(一)场:河南工业大学(补题)
  • SQLite的可视化界面软件的安装
  • YOLO11 vs LMWP-YOLO:参数量-52.5%,mAP+22.07%,小型无人机的远距离检测