使用TensorFlow Lite Mirco 跑mirco_speech语音识别yes/no
基于 TensorFlow Lite Micro(TFLM)的关键词识别(KWS)示例,运行在 STM32F407VGT6 上,实现 yes/no 识别。
工程基于官方 micro_speech 示例思路改造,采用官方“两模型管线”:独立 Audio Preprocessor 模型 + Micro Speech 分类模型,并结合嵌入式实际对内存与接口做了适配和优化。
1、下载官方源码
github仓库:https://github.com/tensorflow/tflite-micro/
2、下载flatbuffers放到tflite-micro\third_party目录下面
github仓库:https://github.com/google/flatbuffers/archive/v23.5.26.zip
3、下载pigweed放到tflite-micro\third_party目录下面
github仓库:git clone https://pigweed.googlesource.com/pigweed/pigweed
git checkout 47268dff45019863e20438ca3746c6c62df6ef09
4、下载kissfft放到tflite-micro\third_party目录下面
github仓库:https://github.com/mborgerding/kissfft/archive/refs/tags/v130.zip
5、下载gemmlowp放到tflite-micro\third_party目录下面
github仓库:https://github.com/google/gemmlowp/archive/719139ce755a0f31cbf1c37f7f98adcc7fc9f425.zip
6、下载ruy放到tflite-micro\third_party目录下面
github仓库:https://github.com/google/ruy/archive/d37128311b445e758136b8602d1bbd2a755e115d.zip
7、编译以下路径的文件
C_SOURCES := \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/window.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/window_util.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/filterbank.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/filterbank_util.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/noise_reduction.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/log_scale.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/log_scale_util.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/log_lut.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/frontend.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/frontend_util.c \CXX_SOURCES_CC := \
tflite-micro/tensorflow/lite/micro/micro_interpreter.cc \
tflite-micro/tensorflow/lite/micro/micro_utils.cc \
tflite-micro/tensorflow/lite/micro/cortex_m_generic/debug_log.cc \
tflite-micro/tensorflow/lite/micro/micro_allocator.cc \
tflite-micro/tensorflow/lite/micro/micro_allocation_info.cc \
tflite-micro/tensorflow/lite/micro/micro_interpreter_context.cc \
tflite-micro/tensorflow/lite/micro/micro_interpreter_graph.cc \
tflite-micro/tensorflow/lite/micro/memory_helpers.cc \
tflite-micro/tensorflow/lite/micro/flatbuffer_utils.cc \
tflite-micro/tensorflow/lite/micro/micro_log.cc \
tflite-micro/tensorflow/lite/micro/micro_time.cc \
tflite-micro/tensorflow/lite/micro/micro_profiler.cc \
tflite-micro/tensorflow/lite/micro/system_setup.cc \
tflite-micro/tensorflow/lite/micro/micro_resource_variable.cc \
tflite-micro/tensorflow/lite/micro/micro_op_resolver.cc \
tflite-micro/tensorflow/lite/micro/tflite_bridge/flatbuffer_conversions_bridge.cc \
tflite-micro/tensorflow/lite/micro/tflite_bridge/micro_error_reporter.cc \
tflite-micro/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc \
tflite-micro/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.cc \
tflite-micro/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc \
tflite-micro/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc \
tflite-micro/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc \
tflite-micro/tensorflow/lite/micro/memory_planner/linear_memory_planner.cc \
tflite-micro/tensorflow/lite/micro/kernels/kernel_util.cc \
tflite-micro/tensorflow/lite/micro/kernel_util_compat.cc \
tflite-micro/tensorflow/lite/kernels/internal/common.cc \
tflite-micro/tensorflow/lite/kernels/internal/quantization_util.cc \
tflite-micro/tensorflow/lite/micro/kernels/fully_connected_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/softmax_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/elementwise.cc \
tflite-micro/tensorflow/lite/micro/kernels/micro_tensor_utils.cc \
tflite-micro/tensorflow/lite/micro/kernels/activations_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/conv_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/depthwise_conv_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/reshape.cc \
tflite-micro/tensorflow/lite/micro/kernels/reshape_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/fully_connected.cc \
tflite-micro/tensorflow/lite/micro/kernels/depthwise_conv.cc \
tflite-micro/tensorflow/lite/micro/kernels/softmax.cc \
tflite-micro/tensorflow/lite/micro/kernels/cast.cc \
tflite-micro/tensorflow/lite/micro/kernels/add.cc \
tflite-micro/tensorflow/lite/micro/kernels/add_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/div.cc \
tflite-micro/tensorflow/lite/micro/kernels/strided_slice.cc \
tflite-micro/tensorflow/lite/micro/kernels/strided_slice_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/concatenation.cc \
tflite-micro/tensorflow/lite/micro/kernels/mul.cc \
tflite-micro/tensorflow/lite/micro/kernels/mul_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/pad.cc \
tflite-micro/tensorflow/lite/micro/kernels/pad_common.cc \
tflite-micro/tensorflow/lite/core/c/common.cc \
tflite-micro/tensorflow/lite/core/api/flatbuffer_conversions.cc \
tflite-micro/tensorflow/lite/kernels/internal/tensor_ctypes.cc \
tflite-micro/tensorflow/lite/kernels/internal/portable_tensor_utils.cc \
tflite-micro/tensorflow/lite/micro/micro_context.cc \
tflite-micro/tensorflow/compiler/mlir/lite/schema/schema_utils.cc \
tflite-micro/tensorflow/lite/micro/kernels/activations.cc \
tflite-micro/tensorflow/lite/micro/kernels/conv.cc \
tflite-micro/tensorflow/lite/micro/kernels/pooling.cc \
tflite-micro/tensorflow/lite/micro/kernels/pooling_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/dequantize.cc \
tflite-micro/tensorflow/lite/micro/kernels/dequantize_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/quantize.cc \
tflite-micro/tensorflow/lite/micro/kernels/quantize_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/maximum_minimum.cc \
tflite-micro/tensorflow/lite/micro/kernels/logistic.cc \
tflite-micro/tensorflow/lite/micro/kernels/logistic_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/reduce.cc \
tflite-micro/tensorflow/lite/micro/kernels/reduce_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/sub.cc \
tflite-micro/tensorflow/lite/micro/kernels/sub_common.cc \
tflite-micro/signal/micro/kernels/window.cc \
tflite-micro/signal/micro/kernels/fft_auto_scale_kernel.cc \
tflite-micro/signal/micro/kernels/rfft.cc \
tflite-micro/signal/micro/kernels/energy.cc \
tflite-micro/signal/micro/kernels/filter_bank.cc \
tflite-micro/signal/micro/kernels/filter_bank_square_root.cc \
tflite-micro/signal/micro/kernels/filter_bank_square_root_common.cc \
tflite-micro/signal/micro/kernels/filter_bank_spectral_subtraction.cc \
tflite-micro/signal/micro/kernels/filter_bank_log.cc \
tflite-micro/signal/micro/kernels/pcan.cc \
tflite-micro/signal/src/window.cc \
tflite-micro/signal/src/fft_auto_scale.cc \
tflite-micro/signal/src/irfft_int16.cc \
tflite-micro/signal/src/irfft_int32.cc \
tflite-micro/signal/src/irfft_float.cc \
tflite-micro/signal/src/rfft_int16.cc \
tflite-micro/signal/src/rfft_int32.cc \
tflite-micro/signal/src/rfft_float.cc \
tflite-micro/signal/src/energy.cc \
tflite-micro/signal/src/filter_bank.cc \
tflite-micro/signal/src/filter_bank_square_root.cc \
tflite-micro/signal/src/filter_bank_spectral_subtraction.cc \
tflite-micro/signal/src/filter_bank_log.cc \
tflite-micro/signal/src/log.cc \
tflite-micro/signal/src/kiss_fft_wrappers/kiss_fft_int16.cc \
tflite-micro/signal/src/msb_32.cc \
tflite-micro/signal/src/max_abs.cc \
tflite-micro/signal/src/square_root_32.cc \
tflite-micro/signal/src/square_root_64.cc \
tflite-micro/signal/src/pcan_argc_fixed.cc \
tflite-micro/signal/micro/kernels/fft_auto_scale_common.cc \
tflite-micro/signal/src/msb_64.cc \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/fft.cc \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc \
8、添加添加以下头文件路径
INCLUDE_DIRS := \
. \
tflite-micro/third_party/kissfft \
tflite-micro/third_party/gemmlowp \
tflite-micro/third_party/flatbuffers/include \
tflite-micro/third_party/ruy \
tflite-micro/third_party \
tflite-micro \
tflite-micro/signal \
9、转换预处理模型tflite为C语言数组
在tflite-micro\tensorflow\lite\micro\examples\micro_speech\models\路径下面有一个audio_preprocessor_int8.tflite的预处理模型数据,
运行: xxd -i audio_preprocessor_int8.tflite > audio_preprocessor_int8_model_data.c
把audio_preprocessor_int8_model_data.c添加编译
10、转换识别分类模型tflite为C语言数组
在tflite-micro\tensorflow\lite\micro\examples\micro_speech\models\路径下面有一个micro_speech_quantized.tflite的识别模型数据,
运行: xxd -i micro_speech_quantized.tflite > micro_speech_quantized_model_data.c
把micro_speech_quantized_model_data.c添加编译
11、编写模型识别代码
// Public KWS interfaces and minimal pipeline implementation#include "micro_speech_quantized_model_data.h"
// #include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
// #include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
// (C frontend removed; using Audio Preprocessor model)
// Error reporter
#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"
// Schema version macro
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "signal/micro/kernels/rfft.h"
// TFLM DebugLog callback registration
#include "tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h"
// #include "audio_frontend.h" // TFLM 音频前端#include "log_debug.h"#include <stdint.h>
#include <string.h>
#include <new>#define KWS_LOG_INFO log_debug// (uvprojx has included required kernel source files)// ======== TensorFlow Lite Micro persistent objects ========
// STM32F407 has 192KB RAM, but Keil config only uses 128KB
// Reduce tensor arena to 64KB, leaving space for other components
constexpr int kTensorArenaSize = 10 * 1024;alignas(16) static uint8_t tensor_arena[kTensorArenaSize];static tflite::ErrorReporter* g_error_reporter = nullptr;
static const tflite::Model* g_model = nullptr;
// Increase op resolver size to accommodate all required operators
static tflite::MicroMutableOpResolver<48> g_resolver;
static tflite::MicroInterpreter* g_interpreter = nullptr;
static TfLiteTensor* g_input = nullptr;
static TfLiteTensor* g_output = nullptr;
static bool g_inited = false;
// 为主模型解释器提供全局存储,初始化时使用 placement new 构造
alignas(alignof(tflite::MicroInterpreter)) static uint8_t g_interpreter_buffer[sizeof(tflite::MicroInterpreter)];// Model / feature parameters (micro_speech-like)
static constexpr int kSampleRate = 16000; // 16 kHz
static constexpr int kFrameLenMs = 30; // 30 ms window
static constexpr int kFrameStrideMs = 20; // 20 ms hop
static constexpr int kNumMelBins = 40; // 40 mel bands
// Frames per inference is determined from model input shape at runtime
static int g_frames_per_inference = 49; // fallback typical valuestatic inline int feature_size_bytes() { return kNumMelBins * g_frames_per_inference; }// ======== Two-model pipeline (Audio Preprocessor + MicroSpeech) ========
// 始终使用 Audio Preprocessor int8 模型生成 49x40 特征
#include "audio_preprocessor_int8_model_data.h"// 预处理模型的解释器与算子解析器(参照 micro_speech_test.cc)
static constexpr size_t kPreprocArenaSize = 12 * 1024;
alignas(16) static uint8_t s_preproc_arena[kPreprocArenaSize];
using PreprocOpResolver = tflite::MicroMutableOpResolver<18>;
static tflite::MicroInterpreter* g_preproc_interpreter = nullptr;
static PreprocOpResolver* g_preproc_resolver = nullptr;
static const tflite::Model* g_preproc_model = nullptr;
// 预处理解析器对象与解释器全局存储
static PreprocOpResolver g_preproc_resolver_inst;
alignas(alignof(tflite::MicroInterpreter)) static uint8_t g_preproc_interpreter_buffer[sizeof(tflite::MicroInterpreter)];static bool init_preprocessor()
{if (g_preproc_interpreter) {KWS_LOG_INFO("preproc interpreter already inited\n");return true;}g_preproc_model = tflite::GetModel(__audio_preprocessor_int8_tflite);if (!g_preproc_model) { KWS_LOG_INFO("error: preproc model null\n"); return false; }g_preproc_resolver = &g_preproc_resolver_inst;LINE_INFO// 注册与 README/test 一致的信号算子g_preproc_resolver->AddReshape();g_preproc_resolver->AddCast();g_preproc_resolver->AddStridedSlice();g_preproc_resolver->AddConcatenation();g_preproc_resolver->AddMul();g_preproc_resolver->AddAdd();g_preproc_resolver->AddDiv();g_preproc_resolver->AddMinimum();g_preproc_resolver->AddMaximum();g_preproc_resolver->AddCustom("SignalWindow", tflite::tflm_signal::Register_WINDOW());g_preproc_resolver->AddCustom("SignalFftAutoScale", tflite::tflm_signal::Register_FFT_AUTO_SCALE());g_preproc_resolver->AddCustom("SignalRfft", tflite::tflm_signal::Register_RFFT());g_preproc_resolver->AddCustom("SignalEnergy", tflite::tflm_signal::Register_ENERGY());g_preproc_resolver->AddCustom("SignalFilterBank", tflite::tflm_signal::Register_FILTER_BANK());g_preproc_resolver->AddCustom("SignalFilterBankSquareRoot", tflite::tflm_signal::Register_FILTER_BANK_SQUARE_ROOT());g_preproc_resolver->AddCustom("SignalFilterBankSpectralSubtraction", tflite::tflm_signal::Register_FILTER_BANK_SPECTRAL_SUBTRACTION());g_preproc_resolver->AddCustom("SignalPCAN", tflite::tflm_signal::Register_PCAN());g_preproc_resolver->AddCustom("SignalFilterBankLog", tflite::tflm_signal::Register_FILTER_BANK_LOG());g_preproc_interpreter = new (g_preproc_interpreter_buffer)tflite::MicroInterpreter(g_preproc_model, *g_preproc_resolver,s_preproc_arena, kPreprocArenaSize);LINE_INFOTfLiteStatus status;KWS_LOG_INFO("preproc AllocateTensors start\n");status = g_preproc_interpreter->AllocateTensors();if (status != kTfLiteOk) {KWS_LOG_INFO("preproc AllocateTensors failed, status %d\n", status);size_t used_bytes = g_preproc_interpreter->arena_used_bytes();KWS_LOG_INFO("preproc Arena used bytes: %d\n", used_bytes);g_preproc_interpreter = nullptr;return false;}size_t used_bytes = g_preproc_interpreter->arena_used_bytes();KWS_LOG_INFO("preproc Arena used bytes: %d / %d (%.1f%%)\n", used_bytes, kPreprocArenaSize, (100.0f * used_bytes) / kPreprocArenaSize);KWS_LOG_INFO("preproc AllocateTensors succeeded!\n");return true;
}// 以 30ms(480) 窗/20ms(320) 步生成 49 帧,每帧 40 维,输出按行优先写入 out_features(长度需≥49*40)
static bool generate_features_with_preproc(const int16_t* pcm, int num_samples, int8_t* out_features)
{// 预处理必须已在 init_tflm() 中完成初始化,运行期不再懒初始化if (!g_preproc_interpreter) {KWS_LOG_INFO("error: preproc interpreter null\n");return false;}TfLiteTensor* pin = g_preproc_interpreter->input(0);TfLiteTensor* pout = g_preproc_interpreter->output(0);if (!pin || !pout) {KWS_LOG_INFO("error: preproc input or output null\n");return false;}const int frame_samples = (kSampleRate * kFrameLenMs) / 1000; // 480const int stride_samples = (kSampleRate * kFrameStrideMs) / 1000; // 320const int frames = g_frames_per_inference; // 49int produced = 0;const int16_t* cursor = pcm;int remaining = num_samples;while (remaining >= frame_samples && produced < frames) {// 拷贝单帧 PCMmemcpy(tflite::GetTensorData<int16_t>(pin), cursor, frame_samples * sizeof(int16_t));TfLiteStatus status = g_preproc_interpreter->Invoke();if (status != kTfLiteOk) {KWS_LOG_INFO("preproc Invoke failed, status %d\n", status);return false;}// 拷贝 40 维 int8 特征memcpy(out_features + produced * kNumMelBins,tflite::GetTensorData<int8_t>(pout),kNumMelBins * sizeof(int8_t));produced++;cursor += stride_samples;remaining -= stride_samples;}// 若不足 49 帧,后续填 0for (int f = produced; f < frames; ++f) {memset(out_features + f * kNumMelBins, 0, kNumMelBins);}return true;
}// ======== Feature computation via Audio Preprocessor model ========
static void compute_mfcc(const int16_t* audio, int8_t* mfcc_output, int length)
{// 用预处理模型直接产出 int8(40) 特征,累计 g_frames_per_inference 帧if (!g_preproc_interpreter || !generate_features_with_preproc(audio, length, mfcc_output)) {KWS_LOG_INFO("preproc features failed, fill zeros\n");memset(mfcc_output, 0, kNumMelBins * g_frames_per_inference);}
}// ======== Public interface: one-time init ========
extern "C" void init_tflm()
{KWS_LOG_INFO("init_tflm\n");if (g_inited) {KWS_LOG_INFO("error: init_tflm already inited\n");return;}KWS_LOG_INFO("init_tflm2\n");LINE_INFO// Redirect TFLM DebugLog()/MicroPrintf() to RTTRegisterDebugLogCallback([](const char* s){ printf("%s", s); });
LINE_INFO// 提前初始化音频预处理模型(不在跑数据时初始化)if (!init_preprocessor()) {KWS_LOG_INFO("init preprocessor failed\n");return;}LINE_INFOTfLiteStatus status;// Model and interpreterg_model = tflite::GetModel(micro_speech_quantized_tflite);if (g_model == nullptr) {KWS_LOG_INFO("Failed to load model\n");return;}KWS_LOG_INFO("Model loaded, size: %d bytes\n", micro_speech_quantized_tflite_len);// Check schema versionconst int model_version = g_model->version();if (model_version != TFLITE_SCHEMA_VERSION) {KWS_LOG_INFO("Model schema version %d != supported %d\n", model_version, TFLITE_SCHEMA_VERSION);return;}// Dump operators required by model (builtin codes)if (g_model->operator_codes()) {int num_ops = g_model->operator_codes()->size();KWS_LOG_INFO("Model operator_codes: %d\n", num_ops);for (int i = 0; i < num_ops; ++i) {const auto* oc = g_model->operator_codes()->Get(i);int bcode = static_cast<int>(oc->builtin_code());int8_t ver = oc->version();const auto* cname = oc->custom_code();if (cname) {KWS_LOG_INFO(" opcode[%d]: builtin=%d ver=%d custom=%s\n", i, bcode, ver, cname->c_str());} else {KWS_LOG_INFO(" opcode[%d]: builtin=%d ver=%d\n", i, bcode, ver);}}}LINE_INFOauto add_ok = [&](TfLiteStatus s, const char* name){if (s != kTfLiteOk) { KWS_LOG_INFO("Add %s failed\n", name); return false; }log_info("%s added\n", name); return true;};if (!add_ok(g_resolver.AddConv2D(), "op Conv2D")) return;if (!add_ok(g_resolver.AddDepthwiseConv2D(), "op DepthwiseConv2D")) return;if (!add_ok(g_resolver.AddMaxPool2D(), "op MaxPool2D")) return;if (!add_ok(g_resolver.AddAveragePool2D(), "op AvgPool2D")) return;if (!add_ok(g_resolver.AddFullyConnected(), "op FullyConnected")) return;if (!add_ok(g_resolver.AddReshape(), "op Reshape")) return;if (!add_ok(g_resolver.AddSoftmax(), "op Softmax")) return;if (!add_ok(g_resolver.AddPad(), "op Pad")) return;if (!add_ok(g_resolver.AddMean(), "op Mean")) return;if (!add_ok(g_resolver.AddLogistic(), "op Logistic")) return;if (!add_ok(g_resolver.AddCast(), "op Cast")) return;if (!add_ok(g_resolver.AddStridedSlice(), "op StridedSlice")) return;if (!add_ok(g_resolver.AddConcatenation(), "op Concat")) return;if (!add_ok(g_resolver.AddAdd(), "op Add")) return;if (!add_ok(g_resolver.AddDiv(), "op Div")) return;if (!add_ok(g_resolver.AddMul(), "op Mul")) return;if (!add_ok(g_resolver.AddSub(), "op Sub")) return;if (!add_ok(g_resolver.AddMinimum(), "op Minimum")) return;if (!add_ok(g_resolver.AddMaximum(), "op Maximum")) return;if (!add_ok(g_resolver.AddQuantize(), "op Quantize")) return;if (!add_ok(g_resolver.AddDequantize(), "op Dequantize")) return;// Register custom audio frontend operators (consistent with custom names in model)if (!add_ok(g_resolver.AddCustom("SignalWindow", tflite::tflm_signal::Register_WINDOW()), "custom SignalWindow")) return;if (!add_ok(g_resolver.AddCustom("SignalFftAutoScale", tflite::tflm_signal::Register_FFT_AUTO_SCALE()), "custom SignalFftAutoScale")) return;if (!add_ok(g_resolver.AddCustom("SignalRfft", tflite::tflm_signal::Register_RFFT()), "custom SignalRfft")) return;if (!add_ok(g_resolver.AddCustom("SignalEnergy", tflite::tflm_signal::Register_ENERGY()), "custom SignalEnergy")) return;if (!add_ok(g_resolver.AddCustom("SignalFilterBank", tflite::tflm_signal::Register_FILTER_BANK()), "custom SignalFilterBank")) return;if (!add_ok(g_resolver.AddCustom("SignalFilterBankSquareRoot", tflite::tflm_signal::Register_FILTER_BANK_SQUARE_ROOT()), "custom SignalFilterBankSquareRoot")) return;if (!add_ok(g_resolver.AddCustom("SignalFilterBankSpectralSubtraction", tflite::tflm_signal::Register_FILTER_BANK_SPECTRAL_SUBTRACTION()), "custom SignalFilterBankSpectralSubtraction")) return;if (!add_ok(g_resolver.AddCustom("SignalPCAN", tflite::tflm_signal::Register_PCAN()), "custom SignalPCAN")) return;if (!add_ok(g_resolver.AddCustom("SignalFilterBankLog", tflite::tflm_signal::Register_FILTER_BANK_LOG()), "custom SignalFilterBankLog")) return;LINE_INFO// Error reporter must be obtained before constructing interpreterg_error_reporter = tflite::GetMicroErrorReporter();LINE_INFO// Print memory info before creating interpreterKWS_LOG_INFO("Creating interpreter with tensor arena size: %d KB\n", kTensorArenaSize / 1024);g_interpreter = new (g_interpreter_buffer)tflite::MicroInterpreter(g_model, g_resolver, tensor_arena, kTensorArenaSize);LINE_INFO// Print more debug infoKWS_LOG_INFO("Allocating tensors...\n");status = g_interpreter->AllocateTensors();if (status != kTfLiteOk) {KWS_LOG_INFO("AllocateTensors failed, status %d (kTfLiteError)\n", status);KWS_LOG_INFO("This usually means insufficient memory or unsupported operations\n");// Try to get more informationsize_t used_bytes = g_interpreter->arena_used_bytes();KWS_LOG_INFO("Arena used bytes: %d\n", used_bytes);return;}// After successful tensor allocation, print memory usageKWS_LOG_INFO("AllocateTensors succeeded!\n");size_t used_bytes = g_interpreter->arena_used_bytes();KWS_LOG_INFO("Arena used bytes: %d / %d (%.1f%%)\n", used_bytes, kTensorArenaSize, (100.0f * used_bytes) / kTensorArenaSize);LINE_INFOg_input = g_interpreter->input(0);LINE_INFOg_output = g_interpreter->output(0);// Print input/output tensor infoif (g_input) {KWS_LOG_INFO("Input tensor info:\n");KWS_LOG_INFO(" Type: %d\n", g_input->type);KWS_LOG_INFO(" Bytes: %d\n", g_input->bytes);if (g_input->dims) {KWS_LOG_INFO(" Dims: [");for (int i = 0; i < g_input->dims->size; ++i) {KWS_LOG_INFO("%d", g_input->dims->data[i]);if (i < g_input->dims->size - 1) r_printf(", ");}KWS_LOG_INFO("]\n");}}if (g_output) {KWS_LOG_INFO("Output tensor info:\n");KWS_LOG_INFO(" Type: %d\n", g_output->type);KWS_LOG_INFO(" Bytes: %d\n", g_output->bytes);if (g_output->dims) {KWS_LOG_INFO(" Dims: [");for (int i = 0; i < g_output->dims->size; ++i) {KWS_LOG_INFO("%d", g_output->dims->data[i]);if (i < g_output->dims->size - 1) KWS_LOG_INFO(", ");}KWS_LOG_INFO("]\n");}}// Derive frames per inference from input tensor shapeif (g_input && g_input->dims) {int elems = 1;for (int i = 0; i < g_input->dims->size; ++i) {elems *= g_input->dims->data[i];}if (elems > 0 && (elems % kNumMelBins) == 0) {g_frames_per_inference = elems / kNumMelBins;}KWS_LOG_INFO("input elems=%d, mel=%d, frames=%d\n", elems, kNumMelBins, g_frames_per_inference);}LINE_INFOg_inited = true;}// ======== Extended C-callable KWS interfaces ========
// Return required audio front-end parameters
extern "C" int kws_get_required_sample_rate()
{return kSampleRate;
}extern "C" int kws_get_frame_length_ms()
{return kFrameLenMs;
}extern "C" int kws_get_frame_stride_ms()
{return kFrameStrideMs;
}extern "C" int kws_get_required_samples_per_inference()
{const int frame_samples = (kSampleRate * kFrameLenMs) / 1000;const int stride_samples = (kSampleRate * kFrameStrideMs) / 1000;return frame_samples + (g_frames_per_inference - 1) * stride_samples;
}extern "C" int kws_get_num_classes()
{if (!g_inited) {init_tflm();}if (!g_output || !g_output->dims) {return 0;}int elems = 1;for (int i = 0; i < g_output->dims->size; ++i) {elems *= g_output->dims->data[i];}return elems;
}extern "C" int kws_get_output_shape(int* out_rank, int* out_dims, int max_dims)
{if (!g_inited) {init_tflm();}if (out_rank) {*out_rank = 0;}if (!g_output || !g_output->dims) {return 0;}const int rank = g_output->dims->size;if (out_rank) {*out_rank = rank;}if (out_dims && max_dims > 0) {const int n = (rank < max_dims) ? rank : max_dims;for (int i = 0; i < n; ++i) {out_dims[i] = g_output->dims->data[i];}}return rank;
}// One-shot inference interface: feed a block of PCM and get top-1 result
// Returns 0 on failure, 1 on success
extern "C" int kws_run_inference(const int16_t* pcm, int num_samples, int* out_index, int8_t* out_score)
{if (!g_inited) {KWS_LOG_INFO("error: not inited\n");return 0;}if (!pcm || num_samples <= 0 || !out_index || !out_score) {return 0;}// Compute features directly into input tensorcompute_mfcc(pcm, g_input->data.int8, num_samples);TfLiteStatus status = g_interpreter->Invoke();if (status != kTfLiteOk) {KWS_LOG_INFO("Invoke failed, status %d \n", status);return 0;}// Argmax over output tensor (int8 logits or probabilities)const int8_t* out_data = g_output->data.int8;int out_len = g_output->bytes;int best_i = 0;int8_t best_v = out_data[0];for (int i = 1; i < out_len; ++i) {if (out_data[i] > best_v) {best_v = out_data[i];best_i = i;}}*out_index = best_i;*out_score = best_v;if (out_len >= 4) {KWS_LOG_INFO("scores: s=%d u=%d y=%d n=%d | best=%d(%d)\n", out_data[0], out_data[1], out_data[2], out_data[3], best_i, best_v);} else {KWS_LOG_INFO("best_i %d, best_v %d\n", best_i, best_v);}return 1;
}static int16_t window_buf[16000];
static int filled = 0;extern "C" void on_audio_chunk_samples(const int16_t* audio_data, int num_samples);
extern "C" void on_audio_chunk_10ms(const int16_t* in160)
{// 10ms@16kHz = 160 samples, use common entry to avoid inconsistency with fixed strideon_audio_chunk_samples(in160, 160);
}extern "C" void on_audio_chunk_samples(const int16_t* audio_data, int num_samples)
{if (!g_inited) {KWS_LOG_INFO("error: not inited\n");return;}const int required = kws_get_required_samples_per_inference();const int stride = (kws_get_required_sample_rate() * kws_get_frame_stride_ms()) / 1000;int capacity = 0;// 1) Append new data to FIFOif (num_samples > 0) {// Simple protection: if append would exceed buffer, truncate to window tail available capacitycapacity = (int)(sizeof(window_buf) / sizeof(window_buf[0]));if (filled + num_samples > capacity) {// First try to consume existing data by stride to free spacewhile (filled >= required && (filled + num_samples > capacity)) {int index; int8_t score;(void)kws_run_inference(window_buf, required, &index, &score);memmove(window_buf, window_buf + stride, (filled - stride) * sizeof(int16_t));filled -= stride;}// If still exceeding limit, only keep required-1 data from window tailif (filled + num_samples > capacity) {if (required < capacity) {// Move tail required-1 data to buffer startif (filled > required) {memmove(window_buf, window_buf + (filled - required), required * sizeof(int16_t));filled = required;}} else {// Extreme case: capacity insufficient for one window, truncate to capacityif (filled > capacity) {memmove(window_buf, window_buf + (filled - capacity), capacity * sizeof(int16_t));filled = capacity;}}}}int copy = num_samples;memcpy(&window_buf[filled], audio_data, copy * sizeof(int16_t));filled += copy;}// 2) Fixed stride advance with simple smoothing votestatic const int kVoteWindow = 8;static int s_pred_buf[kVoteWindow];static int s_vote_counts[4] = {0,0,0,0};static int s_write = 0; static int s_filled_win = 0;while (filled >= required) {int index; int8_t score;if (kws_run_inference(window_buf, required, &index, &score)) {if (s_filled_win < kVoteWindow) {s_pred_buf[s_write] = index;if (index >= 0 && index < 4) s_vote_counts[index]++;s_write = (s_write + 1) % kVoteWindow;s_filled_win++;} else {int old = s_pred_buf[s_write];if (old >= 0 && old < 4) s_vote_counts[old]--;s_pred_buf[s_write] = index;if (index >= 0 && index < 4) s_vote_counts[index]++;s_write = (s_write + 1) % kVoteWindow;}int smooth_i = 0; int smooth_v = s_vote_counts[0];for (int c = 1; c < 4; ++c) {if (s_vote_counts[c] > smooth_v) { smooth_v = s_vote_counts[c]; smooth_i = c; }}r_printf("smooth_best=%d votes=[%d,%d,%d,%d]\n", smooth_i, s_vote_counts[0], s_vote_counts[1], s_vote_counts[2], s_vote_counts[3]);}memmove(window_buf, window_buf + stride, (filled - stride) * sizeof(int16_t));filled -= stride;}
}// ======== Public interface: release resources ========
extern "C" void kws_close()
{// 在 STM32 上:对象与张量内存均为静态存储,不进行释放/析构// 仅复位运行状态并停止流式处理g_inited = false;filled = 0;g_preproc_interpreter = nullptr;
}
12、传入yes/no的数据进行测试
init_tflm();int frame_length_ms = kws_get_frame_length_ms();int frame_stride_ms = kws_get_frame_stride_ms();int required_samples_per_inference = kws_get_required_samples_per_inference();log_info("SR=%d, frame_len=%dms, frame_stride=%dms, required_samples=%d, wav_samples=%d\n",required_sr, frame_length_ms, frame_stride_ms, required_samples_per_inference, num_samples);int16_t in160pf[320];//uint16_t window_buf[1840];int index = 0;uint8_t score = 0;int get_no_int16_data(uint16_t *s_cnt, int16_t *data, uint16_t points, uint8_t ch);int get_yes_int16_data(uint16_t *s_cnt, int16_t *data, uint16_t points, uint8_t ch);uint16_t yes_s_cnt = 0;for (int i = 0; i < 300; i++) {// get_yes_int16_data(&yes_s_cnt, in160pf, 160, 1);get_no_int16_data(&yes_s_cnt, in160pf, 160, 1);//on_audio_chunk_samples(in160pf, 160);on_audio_chunk_10ms(in160pf);}