YOLOv11.pt 模型转换为 TFLite 和 NCNN 模型
鉴于 Windows 的兼容性问题,强烈建议使用 Google Colab,因为它提供 Linux 环境,预装 CUDA,兼容 ai_edge_litert。只需上传模型和脚本,安装依赖即可完成转换。
模型转换
打开 Google Colab(https://colab.research.google.com)
创建一个新笔记本并上传训练好的 yolo11n.pt 模型。
安装依赖:
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --extra-index-url https://download.pytorch.org/whl/cu118
pip install ultralytics>=8.3.15 opencv-python>=4.6.0 tensorflow==2.18.0 tf_keras==2.18.0 onnx==1.17.0 onnx2tf>=1.26.3 sng4onnx>=1.0.1 onnx_graphsurgeon>=0.3.26 sympy>=1.13.3 protobuf>=5.26.1 onnxslim>=0.1.59
运行转换脚本:
from ultralytics import YOLO# 加载模型
model = YOLO("best1.pt")# 导出为 TFLite
model.export(format="tflite", imgsz=320) # 创建 'yolo11n_float32.tflite'# 导出为 NCNN
model.export(format="ncnn", imgsz=320) # 创建 './yolo11n_ncnn_model'
下载生成的 .tflite 和 NCNN 模型文件。
TFLite 模型推理
使用 tflite_runtime.interpreter 加载导出的 TFLite 模型。
预处理输入图像(调整大小、归一化、格式转换)。
执行推理并获取输出张量。
输出需要进一步后处理以解析检测结果(如边界框、类别、置信度)。
python端推理代码:
# TFLite 模型推理
def tflite_infer(model_path, img_path):# 加载导出的 TFLite 模型进行推理tflite_model = YOLO(model_path)results_tflite = tflite_model(img_path)# 提取结果output_result(results_tflite)# 检测结果可视化并保存save_path = Path(img_path).with_name('detected.jpg')# 使用 ultralytics 自带的画框(已带标签)annotated = results_tflite[0].plot(labels=True) # labels=True 会显示类别名cv2.imwrite(str(save_path), annotated)print(f'结果图已保存: {save_path}')# 提取结果
def output_result(results):boxes = results[0].boxes# 构造 DataFramedf = pd.DataFrame({"similar": boxes.conf.cpu().numpy(), # 置信度"rect": [b.tolist() for b in boxes.xyxy.cpu().numpy()], # [x1,y1,x2,y2]"class_id": boxes.cls.cpu().numpy().astype(int)})# 按相似度降序排序df_sorted = df.sort_values("similar", ascending=False)print(df_sorted)
android端推理代码:
package com.magicianguo.mediaprojectiondemo.service;import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.util.Log;
import androidx.annotation.NonNull;import org.tensorflow.lite.Interpreter;import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;public class YoloV11TFLiteDetector {private Interpreter tflite;private static final String[] names = {"Battle_Identification", "Can_Be_Purchased", "Chess_Pieces", "Current_Player","Fetters", "Player", "Player_Character", "Prepare_For_War","Preparing_Chess_Piece_Area", "Preparing_For_Chess_Pieces", "Rune", "Store"};private final float threshold = 0.7f;private final Context context;private final int inputSize = 320;// Detection result classpublic static class Detection {public float x1, y1, x2, y2; // Bounding box coordinatespublic float conf; // Confidence scorepublic int classId; // Class IDpublic Detection(float x1, float y1, float x2, float y2, float conf, int classId) {this.x1 = x1;this.y1 = y1;this.x2 = x2;this.y2 = y2;this.conf = conf;this.classId = classId;}@NonNull@Overridepublic String toString() {return String.format(Locale.US, "Detection: {classId=%d, className=%s conf=%.2f, rect=(%.2f, %.2f, %.2f, %.2f)}",classId, YoloV11TFLiteDetector.names[classId], conf, x1, y1, x2, y2);}}public YoloV11TFLiteDetector(Context context, String modelPath) {this.context = context;try {tflite = new Interpreter(loadModelFile(modelPath), new Interpreter.Options().setNumThreads(1));Log.i("YoloV11TFLiteDetector", "Model loaded successfully");} catch (IOException e) {Log.e("YoloV11TFLiteDetector", "Failed to load model", e);}}// Load TFLite model from assetsprivate MappedByteBuffer loadModelFile(String modelPath) throws IOException {AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelPath);FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());FileChannel fileChannel = inputStream.getChannel();long startOffset = fileDescriptor.getStartOffset();long declaredLength = fileDescriptor.getDeclaredLength();return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);}// Detect objects in the imagepublic List<Detection> detect(Bitmap bitmap) {if (bitmap == null) {Log.e("YoloV11TFLiteDetector", "Provided Bitmap is null");return new ArrayList<>(); // Return empty list if bitmap is null}int width = bitmap.getWidth(); // 宽度(像素)int height = bitmap.getHeight(); // 高度(像素)Log.i("YoloV11TFLiteDetector", "width:"+width+",height:"+height);// Load and preprocess imageBitmap resizedBitmap = Bitmap.createScaledBitmap(bitmap, inputSize, inputSize, true);float[][][][] inputImage = bitmapToFloatArray(resizedBitmap);// Run inferencefloat[][][] output = new float[1][16][2100];tflite.run(inputImage, output);// Parse outputList<Detection> detections = parseOutput(output[0]);for (Detection detection : detections) {detection.x1 *= width;detection.y1 *= height;detection.x2 *= width;detection.y2 *= height;}return detections;}// Convert Bitmap to float array for model inputprivate float[][][][] bitmapToFloatArray(Bitmap bitmap) {float[][][][] inputImage = new float[1][inputSize][inputSize][3];for (int y = 0; y < inputSize; y++) {for (int x = 0; x < inputSize; x++) {int pixel = bitmap.getPixel(x, y);// 若模型需要BGR,交换R和B通道(根据训练数据格式调整)inputImage[0][y][x][0] = ((pixel & 0xFF)) / 255.0f; // BinputImage[0][y][x][1] = ((pixel >> 8) & 0xFF) / 255.0f; // GinputImage[0][y][x][2] = ((pixel >> 16) & 0xFF) / 255.0f; // R}}return inputImage;}// Parse model outputprivate List<Detection> parseOutput(float[][] output) {List<Detection> detections = new ArrayList<>();int numDetections = output[0].length; // 2100个检测框int attributesPerDetection = output.length; // 16个属性(4坐标+12类别)for (int i = 0; i < numDetections; i++) {// 1. 解析归一化坐标(x1, y1, x2, y2)float x = output[0][i];float y = output[1][i];float w = output[2][i];float h = output[3][i];float x1 = x-w/2;float x2 = x+w/2;float y1 = y-h/2;float y2 = y+h/2;// 2. 解析12个类别的置信度(索引4-15)float[] classProbs = new float[12]; // 匹配元数据的12个类别for (int j = 0; j < 12; j++) {int index = 4 + j;if (index < attributesPerDetection) {classProbs[j] = output[index][i];}}// 3. 提取最大置信度和对应类别int classId = argmax(classProbs);float conf = classProbs[classId];// 4. 过滤低置信度结果if (conf > threshold) {Log.i("YoloV11TFLiteDetector", String.format("%.2f %.2f %.2f %.2f %.2f %d", x1, y1, x2, y2, conf, classId));detections.add(new Detection(x1, y1, x2, y2, conf, classId));}}return detections;}// 确保argmax方法正确处理12个类别private int argmax(float[] array) {int maxIdx = 0;for (int i = 1; i < array.length; i++) { // array长度为12if (array[i] > array[maxIdx]) {maxIdx = i;}}return maxIdx;}
}
修改 build.gradle 添加依赖
dependencies {implementation 'org.tensorflow:tensorflow-lite:2.9.0'implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0'implementation 'org.tensorflow:tensorflow-lite-support:0.3.1'implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.9.0'
}