在Unity中实现DTLN-AEC处理音频文件的功能
1、首先把tflite模型转onnx模型
https://github.com/breizhn/DTLN-aec
import tf2onnx
import os# --- Configuration ---
model_dir = "./pretrained_models"
model_name = "dtln_aec_128"
opset_version = 13 # A common and stable opset# --- Define model paths ---
model_1_tflite_path = os.path.join(model_dir, f"{model_name}_1.tflite")
model_1_onnx_path = os.path.join(model_dir, f"{model_name}_1.onnx")model_2_tflite_path = os.path.join(model_dir, f"{model_name}_2.tflite")
model_2_onnx_path = os.path.join(model_dir, f"{model_name}_2.onnx")# --- Perform Conversions ---
print(f"--- Converting {os.path.basename(model_1_tflite_path)} ---")
# Use the built-in tflite converter from tf2onnx
try:os.system(f"python -m tf2onnx.convert --tflite \"{model_1_tflite_path}\" --output \"{model_1_onnx_path}\" --opset {opset_version}")print(f"Successfully converted and saved to {model_1_onnx_path}\n")
except Exception as e:print(f"An error occurred: {e}")print(f"--- Converting {os.path.basename(model_2_tflite_path)} ---")
try:os.system(f"python -m tf2onnx.convert --tflite \"{model_2_tflite_path}\" --output \"{model_2_onnx_path}\" --opset {opset_version}")print(f"Successfully converted and saved to {model_2_onnx_path}\n")
except Exception as e:print(f"An error occurred: {e}")
2、用C#代码实现功能
using System;
using System.Linq;
using System.Numerics;
using System.Collections.Generic;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using MathNet.Numerics.IntegralTransforms;
using UnityEngine;public class DtlnaecProcessor
{// Constants from the Python scriptprivate const int BlockLen = 512;private const int BlockShift = 128;private const int FftSize = BlockLen;private const int RequiredSampleRate = 16000;// RFFT returns (N/2)+1 complex numbersprivate const int FftHalfSize = (FftSize / 2) + 1;// ONNX session instancesprivate InferenceSession _session1;private InferenceSession _session2;// State tensorsprivate DenseTensor<float> _states1;private DenseTensor<float> _states2;// Input/output namesprivate List<string> _inputNames1;private List<string> _outputNames1;private List<string> _inputNames2;private List<string> _outputNames2;// Buffersprivate float[] _inBuffer = new float[BlockLen];private float[] _inBufferLpb = new float[BlockLen];private float[] _outBuffer = new float[BlockLen];public bool Initialize(string model1Path, string model2Path){try{// Use recommended session options for performancevar sessionOptions = new SessionOptions();sessionOptions.ExecutionMode = ExecutionMode.ORT_SEQUENTIAL;sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;sessionOptions.InterOpNumThreads = 1;sessionOptions.IntraOpNumThreads = 1;// Load ONNX models_session1 = new InferenceSession(model1Path, sessionOptions);_session2 = new InferenceSession(model2Path, sessionOptions);// Get input/output names_inputNames1 = _session1.InputMetadata.Keys.ToList();_outputNames1 = _session1.OutputMetadata.Keys.ToList();_inputNames2 = _session2.InputMetadata.Keys.ToList();_outputNames2 = _session2.OutputMetadata.Keys.ToList();// Initialize state tensorsvar stateShape1 = _session1.InputMetadata[_inputNames1[1]].Dimensions;var stateShape2 = _session2.InputMetadata[_inputNames2[1]].Dimensions;_states1 = new DenseTensor<float>(new ReadOnlySpan<int>(stateShape1.ToArray()), false);_states2 = new DenseTensor<float>(new ReadOnlySpan<int>(stateShape2.ToArray()), false);// Reset statesResetStates();Debug.Log("DTLN-AEC processor initialized successfully");return true;}catch (Exception ex){Debug.LogError($"Failed to initialize DTLN-AEC processor: {ex.Message}");return false;}}public void ResetStates(){// Reset state tensors to zerosif (_states1 != null){_states1.Buffer.Span.Clear();}if (_states2 != null){_states2.Buffer.Span.Clear();}// Reset buffersArray.Clear(_inBuffer, 0, _inBuffer.Length);Array.Clear(_inBufferLpb, 0, _inBufferLpb.Length);Array.Clear(_outBuffer, 0, _outBuffer.Length);}public float[] ProcessAudio(float[] micAudio, float[] lpbAudio){if (_session1 == null || _session2 == null){Debug.LogError("DTLN-AEC processor not initialized");return null;}// Ensure audio lengths are the sameint minLen = Math.Min(micAudio.Length, lpbAudio.Length);var micAudioTrimmed = new float[minLen];var lpbAudioTrimmed = new float[minLen];Array.Copy(micAudio, micAudioTrimmed, minLen);Array.Copy(lpbAudio, lpbAudioTrimmed, minLen);int originalLen = minLen;// Pad audiovar padding = new float[BlockLen - BlockShift];var micPadded = new float[padding.Length * 2 + micAudioTrimmed.Length];var lpbPadded = new float[padding.Length * 2 + lpbAudioTrimmed.Length];Array.Copy(padding, 0, micPadded, 0, padding.Length);Array.Copy(micAudioTrimmed, 0, micPadded, padding.Length, micAudioTrimmed.Length);Array.Copy(padding, 0, micPadded, padding.Length + micAudioTrimmed.Length, padding.Length);Array.Copy(padding, 0, lpbPadded, 0, padding.Length);Array.Copy(lpbAudioTrimmed, 0, lpbPadded, padding.Length, lpbAudioTrimmed.Length);Array.Copy(padding, 0, lpbPadded, padding.Length + lpbAudioTrimmed.Length, padding.Length);// Preallocate output filevar outFile = new float[micPadded.Length];// Calculate number of blocksint numBlocks = (micPadded.Length - (BlockLen - BlockShift)) / BlockShift;// Process each blockfor (int idx = 0; idx < numBlocks; idx++){int start = idx * BlockShift;// Shift and update buffersArray.Copy(_inBuffer, BlockShift, _inBuffer, 0, BlockLen - BlockShift);Array.Copy(micPadded, start, _inBuffer, BlockLen - BlockShift, BlockShift);Array.Copy(_inBufferLpb, BlockShift, _inBufferLpb, 0, BlockLen - BlockShift);Array.Copy(lpbPadded, start, _inBufferLpb, BlockLen - BlockShift, BlockShift);// Process the current blockProcessBlock(outFile, start);}// Trim to original lengthvar predictedSpeech = new float[originalLen];Array.Copy(outFile, BlockLen - BlockShift, predictedSpeech, 0, originalLen);// Normalize if clipping occursfloat maxVal = predictedSpeech.Max(x => Math.Abs(x));if (maxVal > 1.0f){for (int i = 0; i < predictedSpeech.Length; i++){predictedSpeech[i] = (predictedSpeech[i] / maxVal) * 0.99f;}}return predictedSpeech;}private void ProcessBlock(float[] outFile, int startIndex){// --- FFT ---var inBlockFft = PerformRfft(_inBuffer);var lpbBlockFft = PerformRfft(_inBufferLpb);// Calculate magnitude for model 1 inputvar inMag = new DenseTensor<float>(dimensions: new[] { 1, 1, FftHalfSize });var lpbMag = new DenseTensor<float>(dimensions: new[] { 1, 1, FftHalfSize });for (int i = 0; i < FftHalfSize; i++){inMag[0, 0, i] = (float)inBlockFft[i].Magnitude;lpbMag[0, 0, i] = (float)lpbBlockFft[i].Magnitude;}// --- Run Model 1 ---var inputs1 = new List<NamedOnnxValue>{NamedOnnxValue.CreateFromTensor(_inputNames1[0], inMag),NamedOnnxValue.CreateFromTensor(_inputNames1[2], lpbMag),NamedOnnxValue.CreateFromTensor(_inputNames1[1], _states1)};using var outputs1 = _session1.Run(inputs1);var outMask = outputs1.First(v => v.Name == _outputNames1[0]).AsTensor<float>();_states1 = outputs1.First(v => v.Name == _outputNames1[1]).AsTensor<float>().ToDenseTensor();// --- Apply Mask and IFFT ---for (int i = 0; i < FftHalfSize; i++){inBlockFft[i] = new Complex(inBlockFft[i].Real * outMask[0, 0, i],inBlockFft[i].Imaginary * outMask[0, 0, i]);}var estimatedBlockTime = PerformIrfft(inBlockFft);// --- Run Model 2 ---var estimatedBlockTensor = new DenseTensor<float>(dimensions: new[] { 1, 1, BlockLen });var inLpbTensor = new DenseTensor<float>(dimensions: new[] { 1, 1, BlockLen });for (int i = 0; i < BlockLen; i++){estimatedBlockTensor[0, 0, i] = estimatedBlockTime[i];inLpbTensor[0, 0, i] = _inBufferLpb[i];}var inputs2 = new List<NamedOnnxValue>{NamedOnnxValue.CreateFromTensor(_inputNames2[0], estimatedBlockTensor),NamedOnnxValue.CreateFromTensor(_inputNames2[2], inLpbTensor),NamedOnnxValue.CreateFromTensor(_inputNames2[1], _states2)};using var outputs2 = _session2.Run(inputs2);var outBlock = outputs2.First(v => v.Name == _outputNames2[0]).AsTensor<float>() as DenseTensor<float>;_states2 = outputs2.First(v => v.Name == _outputNames2[1]).AsTensor<float>().ToDenseTensor();// --- Overlap-Add ---Array.Copy(_outBuffer, BlockShift, _outBuffer, 0, BlockLen - BlockShift);Array.Clear(_outBuffer, BlockLen - BlockShift, BlockShift);var outBlockSpan = outBlock.Buffer.Span;for (int i = 0; i < BlockLen; i++){_outBuffer[i] += outBlockSpan[i];}// Write to final output arrayArray.Copy(_outBuffer, 0, outFile, startIndex, BlockShift);}private Complex[] PerformRfft(float[] input){var complexInput = new Complex[FftSize];for (int i = 0; i < FftSize; i++){complexInput[i] = new Complex(input[i], 0);}Fourier.Forward(complexInput, FourierOptions.Matlab);// Return only the first half (N/2 + 1)var result = new Complex[FftHalfSize];Array.Copy(complexInput, result, FftHalfSize);return result;}private float[] PerformIrfft(Complex[] input){// Reconstruct the full spectrum for IFFTvar fullSpectrum = new Complex[FftSize];Array.Copy(input, fullSpectrum, FftHalfSize);// Fill the second half with complex conjugatesfor (int i = 1; i < FftHalfSize - 1; i++){fullSpectrum[FftSize - i] = Complex.Conjugate(input[i]);}Fourier.Inverse(fullSpectrum, FourierOptions.Matlab);// Return the real part of the resultvar result = new float[FftSize];for (int i = 0; i < FftSize; i++){result[i] = (float)fullSpectrum[i].Real;}return result;}public void Dispose(){_session1?.Dispose();_session2?.Dispose();_session1 = null;_session2 = null;}
}
3、效果图
mic是近端 lpb远端
4、工程地址
https://github.com/xue-fei/dtlnaec-unity