int8_to_float(output_tensor->data.int8, output_float, load_class_num);
#include “tensorflow/lite/core/c/common.h”
#include “tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h”
#include “tensorflow/lite/micro/micro_interpreter.h”
#include “tensorflow/lite/micro/micro_log.h”
#include “tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h”
#include “tensorflow/lite/micro/micro_mutable_op_resolver.h”
#include “tensorflow/lite/micro/system_setup.h”
#include “tensorflow/lite/schema/schema_generated.h”
// 外部模型数据(需在其他文件中定义)
extern const char g_model_data[];
extern const int g_model_data_len;
// 内存分配区,32字节对齐
attribute((section(“.tensor_arena”), aligned(32)))
uint8_t g_tensor_arena[32 * 1024] attribute((aligned(32))) = {0};
extern const int g_tensor_arena_len;
// 量化参数(必须与 Python 训练时保持一致)
const float input_scale = 0.05299549177289009f;
const int32_t input_zero_point = 4;
const float output_scale = 0.00390625f;
const int32_t output_zero_point = -128;
// 其他声明
const int kMaxOps = 13; // 支持的操作数
const int load_class_num = 5; // 分类数量
namespace {
using OpResolver = tflite::MicroMutableOpResolver;
}
// 算法上下文结构体
struct ModelContext {
OpResolver resolver;
tflite::MicroInterpreter* interpreter;
};
// 内存池管理
class MemoryPool {
public:
void* Allocate(size_t size) {
if (current_position_ + size <= pool_size_) {
void* ptr = &pool_[current_position_];
current_position_ += size;
return ptr;
} else {
return nullptr;
}
}
void Reset() { current_position_ = 0; }
private:
static constexpr size_t pool_size_ = 64 * 1024; // 内存池大小
uint8_t pool_[pool_size_]; // 内存池
size_t current_position_ = 0;
};
// 内部自定义函数
int n1m_argmax(const float* data, int size);
void n1m_input_float_to_int8(const float* input_f32, int8_t* output_i8, int size);
void n1m_output_int8_to_float(const int8_t* input_i8, float* output_f32, int size);
// 注册支持的算子
TFliteStatus RegisterOps(OpResolver& resolver) {
TF_LITE_ENSURE_STATUS(resolver.AddFullyConnected());
TF_LITE_ENSURE_STATUS(resolver.AddConv2D());
TF_LITE_ENSURE_STATUS(resolver.AddDepthwiseConv2D());
TF_LITE_ENSURE_STATUS(resolver.AddMaxPool2D());
TF_LITE_ENSURE_STATUS(resolver.AddAveragePool2D());
TF_LITE_ENSURE_STATUS(resolver.AddSoftmax());
TF_LITE_ENSURE_STATUS(resolver.AddReshape());
TF_LITE_ENSURE_STATUS(resolver.AddExpandDims());
TF_LITE_ENSURE_STATUS(resolver.AddMean());
TF_LITE_ENSURE_STATUS(resolver.AddShape());
TF_LITE_ENSURE_STATUS(resolver.AddStridedSlice());
TF_LITE_ENSURE_STATUS(resolver.AddPack());
TF_LITE_ENSURE_STATUS(resolver.AddDequantize());
return kTfLiteOk;
}
// 输入浮点转 int8(量化)
void n1m_input_float_to_int8(const float* input_f32, int8_t* output_i8, int size) {
for (int i = 0; i < size; ++i) {
float scaled = roundf(input_f32[i] / input_scale + input_zero_point);
int32_t q;
if (scaled > 127.0f) {
q = 127;
} else if (scaled < -128.0f) {
q = -128;
} else {
q = static_cast<int32_t>(scaled);
}
output_i8[i] = static_cast<int8_t>(q);
}
}
// 输出 int8 转浮点(反量化)
void n1m_output_int8_to_float(const int8_t* input_i8, float* output_f32, int size) {
for (int i = 0; i < size; ++i) {
output_f32[i] = (input_i8[i] - output_zero_point) * output_scale;
}
}
// 找最大值索引(argmax)
int n1m_argmax(const float* data, int num_classes) {
int max_index = 0;
float max_value = data[0];
for (int i = 1; i < num_classes; ++i) {
if (data[i] > max_value) {
max_value = data[i];
max_index = i;
}
}
return max_index;
}
// 主要推理函数
extern “C” int CPGC_ALG_NILM_RUN_T1(const float* VI_input, const int data_len) {
MemoryPool memory_pool; // 使用内存池
// 初始化上下文
ModelContext* ctx = static_cast<ModelContext*>(memory_pool.Allocate(sizeof(ModelContext)));
if (!ctx) {MicroPrintf("Failed to allocate ModelContext");return CPGC_ALG_NILM_ERR_OUT_OF_MEMORY;
}memset(g_tensor_arena, 0, sizeof(g_tensor_arena));// 加载模型
const tflite::Model* model = tflite::GetModel(g_model_data);
if (!model) {MicroPrintf("Failed to load model - GetModel returned null");return CPGC_ALG_NILM_ERR_MODEL_INCOMPLETE;
}if (model->version() != TFLITE_SCHEMA_VERSION) {MicroPrintf("Model version mismatch: got %d, expected %d", model->version(), TFLITE_SCHEMA_VERSION);return CPGC_ALG_NILM_ERR_MODEL_INCOMPLETE;
}// 创建解析器并注册操作
OpResolver resolver;
TFliteStatus resolver_status = RegisterOps(resolver);
if (resolver_status != kTfLiteOk) {MicroPrintf("RegisterOps failed: %d", resolver_status);return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}// 创建解释器
ctx->interpreter = new (memory_pool.Allocate(sizeof(tflite::MicroInterpreter))) tflite::MicroInterpreter(model,resolver,g_tensor_arena,sizeof(g_tensor_arena)
);if (ctx->interpreter == nullptr) {MicroPrintf("Failed to create interpreter - out of memory");return CPGC_ALG_NILM_ERR_OUT_OF_MEMORY;
}// 分配张量
TFliteStatus status = ctx->interpreter->AllocateTensors();
if (status != kTfLiteOk) {MicroPrintf("AllocateTensors failed: %d", status);MicroPrintf("Arena used: %u bytes", ctx->interpreter->arena_used_bytes());return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}MicroPrintf("Model loaded successfully");
MicroPrintf("Arena used: %u bytes", ctx->interpreter->arena_used_bytes());// 获取输入 tensor
TfLiteTensor* input_tensor = ctx->interpreter->input(0);
if (!input_tensor || !input_tensor->data.data) {MicroPrintf("ERROR: Failed to get input tensor");return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}// 量化输入
int8_t quantized_input[256] = {0};
n1m_input_float_to_int8(VI_input, quantized_input, data_len);
memcpy(input_tensor->data.int8, quantized_input, data_len);// 推理
TFliteStatus invoke_status = ctx->interpreter->Invoke();
if (invoke_status != kTfLiteOk) {MicroPrintf("Invoke failed: %d", invoke_status);return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}// 获取输出并反量化
TfLiteTensor* output_tensor = ctx->interpreter->output(0);
if (!output_tensor || !output_tensor->data.data) {MicroPrintf("ERROR: Failed to get output tensor");return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}// 打印原始 int8 输出
MicroPrintf("Raw int8 output (first 5):");
for (int i = 0; i < load_class_num; ++i) {MicroPrintf(" [%d]", output_tensor->data.int8[i]);
}// 反量化为 float
float output_float[10]; // 假设输出大小 <=10
n1m_output_int8_to_float(output_tensor->data.int8, output_float, load_class_num);MicroPrintf("Float output after dequantize (first 5):");
for (int i = 0; i < load_class_num; ++i) {MicroPrintf(" [%d]: %.6f", i, output_float[i]);
}// 获取预测类别
int pred_label = n1m_argmax(output_float, load_class_num);
MicroPrintf("Prediction: Class %d", pred_label);return pred_label;
}
