当前位置: 首页 > news >正文

AI: Android 运行ONNX模型

在这里插入图片描述

概述

      ONNX(Open Neural Network Exchange)模型, Android ONNX Runtime 是微软开源的跨平台推理引擎 ONNX Runtime 在 Android 平台上的应用版本,主要用于在 Android 设备上高效运行机器学习模型。

实现方法 (参考信息来源Grok)

1. 使用 ONNX Runtime

ONNX Runtime 是由微软开发的高性能推理引擎,支持在 Android 平台上运行 ONNX 模型。它提供了高效的优化和跨平台支持。

2. 使用 TensorFlow Lite(转换 ONNX 模型)

TensorFlow Lite 是 Android 上常用的轻量级深度学习框架。虽然它原生不支持 ONNX 模型,但可以通过转换工具将 ONNX 模型转换为 TFLite 格式。

3. 使用 PyTorch Mobile

如果 ONNX 模型是从 PyTorch 导出的,可以考虑直接使用 PyTorch Mobile 运行模型,绕过 ONNX 格式(或在必要时转换)。

4. 使用 MNN(Mobile Neural Network)

MNN 是阿里巴巴开发的轻量级推理框架,支持 ONNX 模型,适用于 Android 平台。

5. 使用 NCNN

NCNN 是腾讯优图开发的移动端推理框架,也支持 ONNX 模型。

比较与建议

方法优点缺点适用场景
ONNX Runtime高性能、硬件加速、跨平台需要学习 API通用、高性能推理
TensorFlow Lite移动端优化、广泛支持模型转换复杂轻量级、资源受限设备
PyTorch Mobile适合 PyTorch 模型、优化良好不直接支持 ONNXPyTorch 模型直接部署
MNN轻量级、多格式支持社区较小资源受限设备、跨格式支持
NCNN高性能、低内存占用转换复杂、C++ 接口高性能、低资源需求场景

建议

  • 如果追求简单性和高性能,ONNX Runtime 是首选,适合大多数场景。
  • 如果模型复杂且需要移动端优化,考虑将 ONNX 转换为 TFLiteMNN
  • 如果模型来自 PyTorch,PyTorch Mobile 是更直接的选择。
  • 对于极致性能和低资源占用,NCNN 是不错的选择,但需要更多开发工作。

注意事项

  • 模型优化:运行前可使用 ONNX 优化工具(如 onnx-simplifier)简化模型,减少计算量。
  • 硬件加速:根据设备支持,选择合适的硬件加速选项(如 NNAPI、GPU)。
  • 兼容性测试:不同框架对 ONNX 算子的支持程度不同,需测试模型兼容性。
  • 安全性:确保模型文件存储在安全位置,避免泄露。

尝试ONNX Runtime

实现步骤

  • 引入依赖:在 Android 项目的 build.gradle 文件中添加 ONNX Runtime 的依赖。例如:
    implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.18.0'
    
  • 加载模型:将训练好的 ONNX 模型文件(例如 model.onnx)放入 Android 项目的 assets 目录或存储中,并通过 ONNX Runtime 加载:
    import ai.onnxruntime.OnnxTensor;
    import ai.onnxruntime.OrtEnvironment;
    import ai.onnxruntime.OrtSession;OrtEnvironment env = OrtEnvironment.getEnvironment();
    OrtSession session = env.createSession(modelPath, new OrtSession.SessionOptions());
    
  • 预处理输入:根据模型输入要求,将数据(例如图像或张量)转换为 OnnxTensor 格式。
  • 执行推理:使用 session.run() 方法运行模型,获取输出:
    OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData);
    Map<String, OnnxTensor> inputs = new HashMap<>();
    inputs.put("input_name", inputTensor);
    OrtSession.Result outputs = session.run(inputs);
    
  • 后处理输出:解析输出结果,转换为应用需要的格式。
  • 优化:ONNX Runtime 支持硬件加速(如 NNAPI),可以在 SessionOptions 中启用:
    SessionOptions options = new SessionOptions();
    options.addNnapi();
    

解析ONNX

在NETRON上打开ONNX文件, 可以看到如下信息:
在这里插入图片描述
对应在代码中获取的结果如下:

inputNames: x, h, c, 
inputNode:x: [1, 512], FLOAT
inputNode:h: [2, 1, 64], FLOAT
inputNode:c: [2, 1, 64], FLOAT
outputNames: prob, new_h, new_c, 
outputNode: prob: [1, 1], FLOAT
outputNode: new_h: [2, 1, 64], FLOAT
outputNode: new_c: [2, 1, 64], FLOAT

在 ONNX Runtime 中,通过 session.getInputNames() 和 session.getOutputNames() 获取的输入和输出名称是 Set 类型,表示模型可能具有多个输入节点和多个输出节点。
模型支持多个输入节点(Multiple Input Nodes)

  • 含义:ONNX 模型可以定义多个输入节点,每个节点有唯一的名称、形状和数据类型。session.getInputNames() 返回所有输入节点的名称集合。
  • 应用场景
    • 多模态模型:例如,一个模型同时接受图像和文本作为输入。输入名称可能是 ["image_input", "text_input"],分别对应图像张量(如 [1, 3, 224, 224]FLOAT)和文本张量(如 [1, 128]INT32)。
    • 多分支网络:某些网络(如双塔模型)需要不同类型的输入数据(如用户特征和物品特征)。
    • 控制输入:模型可能需要额外的输入(如超参数、权重调整张量)来控制推理行为。
  • 推理时
    • 每次推理需要为所有输入节点提供数据,存储在 Map<String, OnnxTensor> 中。例如:
      val inputs = mapOf("image_input" to imageTensor,"text_input" to textTensor
      )
      val result = session.run(inputs)
      
    • 因此,输入集合表示模型在一次推理中需要多种数据同时送入算法,而不是“支持多种输入类型”。

输出集合:

  • 多输出节点:类似输入,输出集合 (session.getOutputNames()) 表示模型可能产生多个输出节点。例如:
    • 目标检测模型可能输出 ["boxes", "scores", "labels"],分别表示边界框坐标、置信度分数和类别标签。
    • 多任务学习模型可能输出分类和回归结果。
  • 推理时OrtSession.Result 包含所有输出节点的张量,键是输出名称,值是 OnnxTensor。开发者可以选择处理全部或部分输出:
    val outputNames = session.outputNames // 例如 ["boxes", "scores"]
    session.use {it.run(inputs).use { result ->val boxes = result.get("boxes") as OnnxTensorval scores = result.get("scores") as OnnxTensor// 处理 boxes 和 scores}
    }
    

示例解读

基于 Android 平台的 ONNX 运行时的基本对象检测示例应用程序,支持 Ort-Extensions 进行预处理/后处理。该演示应用程序完成了从给定图像中检测对象的任务。此处使用的模型来自 Yolov8 扩展版本,并支持预处理/后处理。
该模型 (Yolov8n) 可以直接输入图像字节,并输出带有边界框的检测到的对象。

完整的示例代码可以参考:Object Detection Android sample
关键文件目录:

│       ├── main
│       │   ├── AndroidManifest.xml
│       │   ├── assets //测试图片
│       │   │   ├── test_object_detection_0.jpg
│       │   │   └── test_object_detection_1.jpg
│       │   ├── java
│       │   │   └── ai
│       │   │       └── onnxruntime
│       │   │           └── example
│       │   │               └── objectdetection
│       │   │                   ├── MainActivity.kt //主界面
│       │   │                   └── ObjectDetector.kt //关键调用模型实现
│       │   └── res
│       │       ├── drawable
│       │       ├── raw
│       │       │   ├── classes.txt  //分类标签
│       │       │   └── yolov8n_with_pre_post_processing.onnx  //模型文件
│       │       ├── values
│       │       │   ├── colors.xml
│       │       │   ├── ids.xml
│       │       │   ├── strings.xml
│       │       │   └── themes.xml
│       │       └── xml
│       │           ├── backup_rules.xml
│       │           └── data_extraction_rules.xml

MainActivity.kt

package ai.onnxruntime.example.objectdetectionimport ai.onnxruntime.*
import ai.onnxruntime.extensions.OrtxPackage
import android.annotation.SuppressLint
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.Paint
import android.graphics.PorterDuff
import android.graphics.PorterDuffXfermode
import android.os.Bundle
import android.util.Log
import android.widget.Button
import android.widget.ImageView
import android.widget.Toast
import androidx.activity.*
import androidx.appcompat.app.AppCompatActivity
import kotlinx.android.synthetic.main.activity_main.*
import kotlinx.coroutines.*
import java.io.InputStream
import java.util.*class MainActivity : AppCompatActivity() {private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()private lateinit var ortSession: OrtSessionprivate lateinit var inputImage: ImageViewprivate lateinit var outputImage: ImageViewprivate lateinit var objectDetectionButton: Buttonprivate var imageid = 0;private lateinit var classes:List<String>@SuppressLint("UseCompatLoadingForDrawables")override fun onCreate(savedInstanceState: Bundle?) {super.onCreate(savedInstanceState)setContentView(R.layout.activity_main)inputImage = findViewById(R.id.imageView1)outputImage = findViewById(R.id.imageView2)objectDetectionButton = findViewById(R.id.object_detection_button)inputImage.setImageBitmap(BitmapFactory.decodeStream(readInputImage()));imageid = 0classes = readClasses();// Initialize Ort Session and register the onnxruntime extensions package that contains the custom operators.// Note: These are used to decode the input image into the format the original model requires,// and to encode the model output into png formatval sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions()sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())//从raw中读取模型文件进行初始化ortSession = ortEnv.createSession(readModel(), sessionOptions)objectDetectionButton.setOnClickListener {try {//启动算法检测performObjectDetection(ortSession)Toast.makeText(baseContext, "ObjectDetection performed!", Toast.LENGTH_SHORT).show()} catch (e: Exception) {Log.e(TAG, "Exception caught when perform ObjectDetection", e)Toast.makeText(baseContext, "Failed to perform ObjectDetection", Toast.LENGTH_SHORT).show()}}}override fun onDestroy() {super.onDestroy()ortEnv.close()ortSession.close()}private fun updateUI(result: Result) {val mutableBitmap: Bitmap = result.outputBitmap.copy(Bitmap.Config.ARGB_8888, true)val canvas = Canvas(mutableBitmap)val paint = Paint()paint.color = Color.WHITE // Text Colorpaint.textSize = 28f // Text Sizepaint.xfermode = PorterDuffXfermode(PorterDuff.Mode.SRC_OVER) // Text Overlapping Patterncanvas.drawBitmap(mutableBitmap, 0.0f, 0.0f, paint)var boxit = result.outputBox.iterator()while(boxit.hasNext()) {var box_info = boxit.next()canvas.drawText("%s:%.2f".format(classes[box_info[5].toInt()],box_info[4]),box_info[0]-box_info[2]/2, box_info[1]-box_info[3]/2, paint)}outputImage.setImageBitmap(mutableBitmap)}private fun readModel(): ByteArray {val modelID = R.raw.yolov8n_with_pre_post_processingreturn resources.openRawResource(modelID).readBytes()}private fun readClasses(): List<String> {return resources.openRawResource(R.raw.classes).bufferedReader().readLines()}private fun readInputImage(): InputStream {imageid = imageid.xor(1)return assets.open("test_object_detection_${imageid}.jpg")}//调用算法并读取解析结果, 最后更新UIprivate fun performObjectDetection(ortSession: OrtSession) {var objDetector = ObjectDetector()var imagestream = readInputImage()inputImage.setImageBitmap(BitmapFactory.decodeStream(imagestream));imagestream.reset()var result = objDetector.detect(imagestream, ortEnv, ortSession)updateUI(result);}companion object {const val TAG = "ORTObjectDetection"}
}

ObjectDetector.kt 调用ONXX模型.

package ai.onnxruntime.example.objectdetectionimport ai.onnxruntime.OnnxJavaType
import ai.onnxruntime.OrtSession
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import java.io.InputStream
import java.nio.ByteBuffer
import java.util.*internal data class Result(var outputBitmap: Bitmap,var outputBox: Array<FloatArray>
) {}internal class ObjectDetector(
) {fun detect(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): Result {// Step 1: convert image into byte array (raw image bytes)val rawImageBytes = inputStream.readBytes()// Step 2: get the shape of the byte array and make ort tensorval shape = longArrayOf(rawImageBytes.size.toLong())val inputTensor = OnnxTensor.createTensor(ortEnv,ByteBuffer.wrap(rawImageBytes),shape,OnnxJavaType.UINT8)inputTensor.use {// Step 3: call ort inferenceSession runval output = ortSession.run(Collections.singletonMap("image", inputTensor),setOf("image_out","scaled_box_out_next"))// Step 4: output analysisoutput.use {val rawOutput = (output?.get(0)?.value) as ByteArrayval boxOutput = (output?.get(1)?.value) as Array<FloatArray>val outputImageBitmap = byteArrayToBitmap(rawOutput)// Step 5: set output resultvar result = Result(outputImageBitmap,boxOutput)return result}}}private fun byteArrayToBitmap(data: ByteArray): Bitmap {return BitmapFactory.decodeByteArray(data, 0, data.size)}
}

执行效果:
请添加图片描述
请添加图片描述

扩展: ASR, 本地语音听写的实现(SherpaOnnxVadAsr)

sherpa-onnx
Android build
Android build 2
sherpa-onnx Android 源码和测试apk
asr-models

步骤:

  1. 下载源码
  2. 配置 SDK and NDK
  3. 调用 build-android-arm64-v8a.sh 进行编译, 然而编译失败 转而下载已发布的aardownload sherpa-onnx.aar
  4. 下载模型文件 modelsherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
  5. SherpaOnnxVadAsr最终源码结构如下:
    ├── main
    │   ├── AndroidManifest.xml
    │   ├── assets
    │   │   ├── sherpa-onnx-paraformer-zh-2023-09-14
    │   │   │   ├── model.int8.onnx
    │   │   │   └── tokens.txt
    │   │   └── silero_vad.onnx
    │   ├── java
    │   │   └── com
    │   │       └── k2fsa
    │   │           └── sherpa
    │   │               └── onnx
    │   │                   ├── FeatureConfig.kt -> ../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt
    │   │                   ├── HomophoneReplacerConfig.kt -> ../../../../../../../../../../sherpa-onnx/kotlin-api/HomophoneReplacerConfig.kt
    │   │                   ├── MainActivity.kt
    │   │                   ├── OfflineRecognizer.kt -> ../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineRecognizer.kt
    │   │                   ├── OfflineStream.kt -> ../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt
    │   │                   └── Vad.kt -> ../../../../../../../../../../sherpa-onnx/kotlin-api/Vad.kt
    │   ├── jniLibs
    │   │   ├── arm64-v8a
    │   │   │   ├── libonnxruntime4j_jni.so
    │   │   │   ├── libonnxruntime.so
    │   │   │   ├── libsherpa-onnx-c-api.so
    │   │   │   ├── libsherpa-onnx-cxx-api.so
    │   │   │   └── libsherpa-onnx-jni.so
    │   │   ├── armeabi-v7a
    │   │   │   ├── libonnxruntime4j_jni.so
    │   │   │   ├── libonnxruntime.so
    │   │   │   ├── libsherpa-onnx-c-api.so
    │   │   │   ├── libsherpa-onnx-cxx-api.so
    │   │   │   └── libsherpa-onnx-jni.so
    │   │   ├── x86
    │   │   │   ├── libonnxruntime4j_jni.so
    │   │   │   ├── libonnxruntime.so
    │   │   │   ├── libsherpa-onnx-c-api.so
    │   │   │   ├── libsherpa-onnx-cxx-api.so
    │   │   │   └── libsherpa-onnx-jni.so
    │   │   └── x86_64
    │   │       ├── libonnxruntime4j_jni.so
    │   │       ├── libonnxruntime.so
    │   │       ├── libsherpa-onnx-c-api.so
    │   │       ├── libsherpa-onnx-cxx-api.so
    │   │       └── libsherpa-onnx-jni.so
    │   └── res
    │       ├── drawable
    │       │   └── ic_launcher_background.xml
    │       ├── drawable-v24
    │       │   └── ic_launcher_foreground.xml
    │       ├── layout
    │       │   └── activity_main.xml
    │       ├── mipmap-anydpi-v26
    │       │   ├── ic_launcher_round.xml
    │       │   └── ic_launcher.xml
    │       ├── mipmap-hdpi
    │       │   ├── ic_launcher_round.webp
    │       │   └── ic_launcher.webp
    │       ├── mipmap-mdpi
    │       │   ├── ic_launcher_round.webp
    │       │   └── ic_launcher.webp
    │       ├── mipmap-xhdpi
    │       │   ├── ic_launcher_round.webp
    │       │   └── ic_launcher.webp
    │       ├── mipmap-xxhdpi
    │       │   ├── ic_launcher_round.webp
    │       │   └── ic_launcher.webp
    │       ├── mipmap-xxxhdpi
    │       │   ├── ic_launcher_round.webp
    │       │   └── ic_launcher.webp
    │       ├── values
    │       │   ├── colors.xml
    │       │   ├── strings.xml
    │       │   └── themes.xml
    │       ├── values-night
    │       │   └── themes.xml
    │       └── xml
    │           ├── backup_rules.xml
    │           └── data_extraction_rules.xml
    
## 参考
1. [android onnx](https://blog.51cto.com/u_16213465/13067353)
2. [Get started with ONNX Runtime Mobile](https://onnxruntime.ai/docs/get-started/with-mobile.html)
3. [ONNX Runtime](https://github.com/microsoft/onnxruntime)
4. [Object Detection Android sample](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/mobile/examples/object_detection/android)

文章转载自:

http://YTjMRX9l.wgcng.cn
http://Htr5W7X5.wgcng.cn
http://oBPiIB7j.wgcng.cn
http://lccqmQZb.wgcng.cn
http://92B7Zn0G.wgcng.cn
http://sHW3xj6X.wgcng.cn
http://8BILlfd7.wgcng.cn
http://wqWPpWGa.wgcng.cn
http://45JBGpzM.wgcng.cn
http://Z5rNg7yl.wgcng.cn
http://JBj5HXjH.wgcng.cn
http://I9LFo9X7.wgcng.cn
http://bBSn2kNB.wgcng.cn
http://6uzKq1lJ.wgcng.cn
http://Hz0DWImh.wgcng.cn
http://C0bpC1Gj.wgcng.cn
http://1WPfqc3y.wgcng.cn
http://9MhauC7P.wgcng.cn
http://ZjNQuG4j.wgcng.cn
http://lWQkGUZz.wgcng.cn
http://iQGjSP23.wgcng.cn
http://ytlwGJBF.wgcng.cn
http://7xLNd2lW.wgcng.cn
http://OmlFE2oY.wgcng.cn
http://MMJqCip6.wgcng.cn
http://Dqgr2pME.wgcng.cn
http://ye6nIoJp.wgcng.cn
http://wh0sl15g.wgcng.cn
http://R64V5ltY.wgcng.cn
http://Onq50Pu7.wgcng.cn
http://www.dtcms.com/a/387059.html

相关文章:

  • transformer各层的输入输出
  • lvgl图形库和qt图形库比较
  • 如何解决 pip install 安装报错 ModuleNotFoundError: No module named ‘PIL’ 问题
  • 搭建 PHP 网站
  • 流式分析:细胞分群方法
  • Redis 底层数据结构之 Dict(字典)
  • UE 最短上手路线
  • 动手学Agent:Agent设计模式——构建有效Agent的7种模型
  • 苍穹外卖day01
  • 《LINUX系统编程》笔记p14
  • 可直接落地的pytest+request+allure接口自动化框架
  • 【精品资料鉴赏】267页政务大数据资源平台建设方案
  • 面试前端遇到的问题
  • 【深度学习计算机视觉】05:多尺度目标检测——从理论到YOLOv5实践
  • STM32 通过USB的Mass Storage Class读写挂载的SD卡出现卡死问题
  • 【Nginx开荒攻略】Nginx基本服务配置:从启动到运维的完整指南
  • 《漫威争锋》公布开发者愿景视频:介绍1.5版本的内容
  • Isight许可管理与其他软件集成的方法
  • 论文提纲:学术写作的“蓝图”,如何用AI工具沁言学术高效构建?
  • 快速解决云服务器的数据库PhpMyAdmin登录问题
  • 知识更新缺乏责任人会带来哪些风险
  • 容器化部署番外篇之Nexus3搭建私有仓库09
  • 计算机视觉(opencv)实战二十四——扫描答题卡打分
  • 居住证申请:线上照片回执办理!
  • Roo Code 的差异_快速编辑功能
  • 【深度学习】基于深度学习算法的图像版权保护数字水印技术
  • mcp初探
  • 深入C++对象生命周期:从构造到析构的奥秘
  • 视频上传以及在线播放
  • Powershell and Python are very similar