AI: Android 运行ONNX模型
概述
ONNX(Open Neural Network Exchange)模型, Android ONNX Runtime 是微软开源的跨平台推理引擎 ONNX Runtime 在 Android 平台上的应用版本,主要用于在 Android 设备上高效运行机器学习模型。
实现方法 (参考信息来源Grok)
1. 使用 ONNX Runtime
ONNX Runtime 是由微软开发的高性能推理引擎,支持在 Android 平台上运行 ONNX 模型。它提供了高效的优化和跨平台支持。
2. 使用 TensorFlow Lite(转换 ONNX 模型)
TensorFlow Lite 是 Android 上常用的轻量级深度学习框架。虽然它原生不支持 ONNX 模型,但可以通过转换工具将 ONNX 模型转换为 TFLite 格式。
3. 使用 PyTorch Mobile
如果 ONNX 模型是从 PyTorch 导出的,可以考虑直接使用 PyTorch Mobile 运行模型,绕过 ONNX 格式(或在必要时转换)。
4. 使用 MNN(Mobile Neural Network)
MNN 是阿里巴巴开发的轻量级推理框架,支持 ONNX 模型,适用于 Android 平台。
5. 使用 NCNN
NCNN 是腾讯优图开发的移动端推理框架,也支持 ONNX 模型。
比较与建议
方法 | 优点 | 缺点 | 适用场景 |
---|---|---|---|
ONNX Runtime | 高性能、硬件加速、跨平台 | 需要学习 API | 通用、高性能推理 |
TensorFlow Lite | 移动端优化、广泛支持 | 模型转换复杂 | 轻量级、资源受限设备 |
PyTorch Mobile | 适合 PyTorch 模型、优化良好 | 不直接支持 ONNX | PyTorch 模型直接部署 |
MNN | 轻量级、多格式支持 | 社区较小 | 资源受限设备、跨格式支持 |
NCNN | 高性能、低内存占用 | 转换复杂、C++ 接口 | 高性能、低资源需求场景 |
建议:
- 如果追求简单性和高性能,ONNX Runtime 是首选,适合大多数场景。
- 如果模型复杂且需要移动端优化,考虑将 ONNX 转换为 TFLite 或 MNN。
- 如果模型来自 PyTorch,PyTorch Mobile 是更直接的选择。
- 对于极致性能和低资源占用,NCNN 是不错的选择,但需要更多开发工作。
注意事项
- 模型优化:运行前可使用 ONNX 优化工具(如
onnx-simplifier
)简化模型,减少计算量。 - 硬件加速:根据设备支持,选择合适的硬件加速选项(如 NNAPI、GPU)。
- 兼容性测试:不同框架对 ONNX 算子的支持程度不同,需测试模型兼容性。
- 安全性:确保模型文件存储在安全位置,避免泄露。
尝试ONNX Runtime
实现步骤:
- 引入依赖:在 Android 项目的
build.gradle
文件中添加 ONNX Runtime 的依赖。例如:implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.18.0'
- 加载模型:将训练好的 ONNX 模型文件(例如
model.onnx
)放入 Android 项目的assets
目录或存储中,并通过 ONNX Runtime 加载:import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtSession;OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession session = env.createSession(modelPath, new OrtSession.SessionOptions());
- 预处理输入:根据模型输入要求,将数据(例如图像或张量)转换为
OnnxTensor
格式。 - 执行推理:使用
session.run()
方法运行模型,获取输出:OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData); Map<String, OnnxTensor> inputs = new HashMap<>(); inputs.put("input_name", inputTensor); OrtSession.Result outputs = session.run(inputs);
- 后处理输出:解析输出结果,转换为应用需要的格式。
- 优化:ONNX Runtime 支持硬件加速(如 NNAPI),可以在
SessionOptions
中启用:SessionOptions options = new SessionOptions(); options.addNnapi();
解析ONNX
在NETRON上打开ONNX文件, 可以看到如下信息:
对应在代码中获取的结果如下:
inputNames: x, h, c,
inputNode:x: [1, 512], FLOAT
inputNode:h: [2, 1, 64], FLOAT
inputNode:c: [2, 1, 64], FLOAT
outputNames: prob, new_h, new_c,
outputNode: prob: [1, 1], FLOAT
outputNode: new_h: [2, 1, 64], FLOAT
outputNode: new_c: [2, 1, 64], FLOAT
在 ONNX Runtime 中,通过 session.getInputNames() 和 session.getOutputNames() 获取的输入和输出名称是 Set 类型,表示模型可能具有多个输入节点和多个输出节点。
模型支持多个输入节点(Multiple Input Nodes)
- 含义:ONNX 模型可以定义多个输入节点,每个节点有唯一的名称、形状和数据类型。
session.getInputNames()
返回所有输入节点的名称集合。 - 应用场景:
- 多模态模型:例如,一个模型同时接受图像和文本作为输入。输入名称可能是
["image_input", "text_input"]
,分别对应图像张量(如[1, 3, 224, 224]
,FLOAT
)和文本张量(如[1, 128]
,INT32
)。 - 多分支网络:某些网络(如双塔模型)需要不同类型的输入数据(如用户特征和物品特征)。
- 控制输入:模型可能需要额外的输入(如超参数、权重调整张量)来控制推理行为。
- 多模态模型:例如,一个模型同时接受图像和文本作为输入。输入名称可能是
- 推理时:
- 每次推理需要为所有输入节点提供数据,存储在
Map<String, OnnxTensor>
中。例如:val inputs = mapOf("image_input" to imageTensor,"text_input" to textTensor ) val result = session.run(inputs)
- 因此,输入集合表示模型在一次推理中需要多种数据同时送入算法,而不是“支持多种输入类型”。
- 每次推理需要为所有输入节点提供数据,存储在
输出集合:
- 多输出节点:类似输入,输出集合 (
session.getOutputNames()
) 表示模型可能产生多个输出节点。例如:- 目标检测模型可能输出
["boxes", "scores", "labels"]
,分别表示边界框坐标、置信度分数和类别标签。 - 多任务学习模型可能输出分类和回归结果。
- 目标检测模型可能输出
- 推理时:
OrtSession.Result
包含所有输出节点的张量,键是输出名称,值是OnnxTensor
。开发者可以选择处理全部或部分输出:val outputNames = session.outputNames // 例如 ["boxes", "scores"] session.use {it.run(inputs).use { result ->val boxes = result.get("boxes") as OnnxTensorval scores = result.get("scores") as OnnxTensor// 处理 boxes 和 scores} }
示例解读
基于 Android 平台的 ONNX 运行时的基本对象检测示例应用程序,支持 Ort-Extensions 进行预处理/后处理。该演示应用程序完成了从给定图像中检测对象的任务。此处使用的模型来自 Yolov8 扩展版本,并支持预处理/后处理。
该模型 (Yolov8n) 可以直接输入图像字节,并输出带有边界框的检测到的对象。
完整的示例代码可以参考:Object Detection Android sample
关键文件目录:
│ ├── main
│ │ ├── AndroidManifest.xml
│ │ ├── assets //测试图片
│ │ │ ├── test_object_detection_0.jpg
│ │ │ └── test_object_detection_1.jpg
│ │ ├── java
│ │ │ └── ai
│ │ │ └── onnxruntime
│ │ │ └── example
│ │ │ └── objectdetection
│ │ │ ├── MainActivity.kt //主界面
│ │ │ └── ObjectDetector.kt //关键调用模型实现
│ │ └── res
│ │ ├── drawable
│ │ ├── raw
│ │ │ ├── classes.txt //分类标签
│ │ │ └── yolov8n_with_pre_post_processing.onnx //模型文件
│ │ ├── values
│ │ │ ├── colors.xml
│ │ │ ├── ids.xml
│ │ │ ├── strings.xml
│ │ │ └── themes.xml
│ │ └── xml
│ │ ├── backup_rules.xml
│ │ └── data_extraction_rules.xml
MainActivity.kt
package ai.onnxruntime.example.objectdetectionimport ai.onnxruntime.*
import ai.onnxruntime.extensions.OrtxPackage
import android.annotation.SuppressLint
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.Paint
import android.graphics.PorterDuff
import android.graphics.PorterDuffXfermode
import android.os.Bundle
import android.util.Log
import android.widget.Button
import android.widget.ImageView
import android.widget.Toast
import androidx.activity.*
import androidx.appcompat.app.AppCompatActivity
import kotlinx.android.synthetic.main.activity_main.*
import kotlinx.coroutines.*
import java.io.InputStream
import java.util.*class MainActivity : AppCompatActivity() {private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()private lateinit var ortSession: OrtSessionprivate lateinit var inputImage: ImageViewprivate lateinit var outputImage: ImageViewprivate lateinit var objectDetectionButton: Buttonprivate var imageid = 0;private lateinit var classes:List<String>@SuppressLint("UseCompatLoadingForDrawables")override fun onCreate(savedInstanceState: Bundle?) {super.onCreate(savedInstanceState)setContentView(R.layout.activity_main)inputImage = findViewById(R.id.imageView1)outputImage = findViewById(R.id.imageView2)objectDetectionButton = findViewById(R.id.object_detection_button)inputImage.setImageBitmap(BitmapFactory.decodeStream(readInputImage()));imageid = 0classes = readClasses();// Initialize Ort Session and register the onnxruntime extensions package that contains the custom operators.// Note: These are used to decode the input image into the format the original model requires,// and to encode the model output into png formatval sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions()sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())//从raw中读取模型文件进行初始化ortSession = ortEnv.createSession(readModel(), sessionOptions)objectDetectionButton.setOnClickListener {try {//启动算法检测performObjectDetection(ortSession)Toast.makeText(baseContext, "ObjectDetection performed!", Toast.LENGTH_SHORT).show()} catch (e: Exception) {Log.e(TAG, "Exception caught when perform ObjectDetection", e)Toast.makeText(baseContext, "Failed to perform ObjectDetection", Toast.LENGTH_SHORT).show()}}}override fun onDestroy() {super.onDestroy()ortEnv.close()ortSession.close()}private fun updateUI(result: Result) {val mutableBitmap: Bitmap = result.outputBitmap.copy(Bitmap.Config.ARGB_8888, true)val canvas = Canvas(mutableBitmap)val paint = Paint()paint.color = Color.WHITE // Text Colorpaint.textSize = 28f // Text Sizepaint.xfermode = PorterDuffXfermode(PorterDuff.Mode.SRC_OVER) // Text Overlapping Patterncanvas.drawBitmap(mutableBitmap, 0.0f, 0.0f, paint)var boxit = result.outputBox.iterator()while(boxit.hasNext()) {var box_info = boxit.next()canvas.drawText("%s:%.2f".format(classes[box_info[5].toInt()],box_info[4]),box_info[0]-box_info[2]/2, box_info[1]-box_info[3]/2, paint)}outputImage.setImageBitmap(mutableBitmap)}private fun readModel(): ByteArray {val modelID = R.raw.yolov8n_with_pre_post_processingreturn resources.openRawResource(modelID).readBytes()}private fun readClasses(): List<String> {return resources.openRawResource(R.raw.classes).bufferedReader().readLines()}private fun readInputImage(): InputStream {imageid = imageid.xor(1)return assets.open("test_object_detection_${imageid}.jpg")}//调用算法并读取解析结果, 最后更新UIprivate fun performObjectDetection(ortSession: OrtSession) {var objDetector = ObjectDetector()var imagestream = readInputImage()inputImage.setImageBitmap(BitmapFactory.decodeStream(imagestream));imagestream.reset()var result = objDetector.detect(imagestream, ortEnv, ortSession)updateUI(result);}companion object {const val TAG = "ORTObjectDetection"}
}
ObjectDetector.kt 调用ONXX模型.
package ai.onnxruntime.example.objectdetectionimport ai.onnxruntime.OnnxJavaType
import ai.onnxruntime.OrtSession
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import java.io.InputStream
import java.nio.ByteBuffer
import java.util.*internal data class Result(var outputBitmap: Bitmap,var outputBox: Array<FloatArray>
) {}internal class ObjectDetector(
) {fun detect(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): Result {// Step 1: convert image into byte array (raw image bytes)val rawImageBytes = inputStream.readBytes()// Step 2: get the shape of the byte array and make ort tensorval shape = longArrayOf(rawImageBytes.size.toLong())val inputTensor = OnnxTensor.createTensor(ortEnv,ByteBuffer.wrap(rawImageBytes),shape,OnnxJavaType.UINT8)inputTensor.use {// Step 3: call ort inferenceSession runval output = ortSession.run(Collections.singletonMap("image", inputTensor),setOf("image_out","scaled_box_out_next"))// Step 4: output analysisoutput.use {val rawOutput = (output?.get(0)?.value) as ByteArrayval boxOutput = (output?.get(1)?.value) as Array<FloatArray>val outputImageBitmap = byteArrayToBitmap(rawOutput)// Step 5: set output resultvar result = Result(outputImageBitmap,boxOutput)return result}}}private fun byteArrayToBitmap(data: ByteArray): Bitmap {return BitmapFactory.decodeByteArray(data, 0, data.size)}
}
执行效果:
扩展: ASR, 本地语音听写的实现(SherpaOnnxVadAsr)
sherpa-onnx
Android build
Android build 2
sherpa-onnx Android 源码和测试apk
asr-models
步骤:
- 下载源码
- 配置 SDK and NDK
- 调用 build-android-arm64-v8a.sh 进行编译, 然而编译失败 转而下载已发布的aardownload sherpa-onnx.aar
- 下载模型文件 modelsherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
- SherpaOnnxVadAsr最终源码结构如下:
├── main │ ├── AndroidManifest.xml │ ├── assets │ │ ├── sherpa-onnx-paraformer-zh-2023-09-14 │ │ │ ├── model.int8.onnx │ │ │ └── tokens.txt │ │ └── silero_vad.onnx │ ├── java │ │ └── com │ │ └── k2fsa │ │ └── sherpa │ │ └── onnx │ │ ├── FeatureConfig.kt -> ../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt │ │ ├── HomophoneReplacerConfig.kt -> ../../../../../../../../../../sherpa-onnx/kotlin-api/HomophoneReplacerConfig.kt │ │ ├── MainActivity.kt │ │ ├── OfflineRecognizer.kt -> ../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineRecognizer.kt │ │ ├── OfflineStream.kt -> ../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt │ │ └── Vad.kt -> ../../../../../../../../../../sherpa-onnx/kotlin-api/Vad.kt │ ├── jniLibs │ │ ├── arm64-v8a │ │ │ ├── libonnxruntime4j_jni.so │ │ │ ├── libonnxruntime.so │ │ │ ├── libsherpa-onnx-c-api.so │ │ │ ├── libsherpa-onnx-cxx-api.so │ │ │ └── libsherpa-onnx-jni.so │ │ ├── armeabi-v7a │ │ │ ├── libonnxruntime4j_jni.so │ │ │ ├── libonnxruntime.so │ │ │ ├── libsherpa-onnx-c-api.so │ │ │ ├── libsherpa-onnx-cxx-api.so │ │ │ └── libsherpa-onnx-jni.so │ │ ├── x86 │ │ │ ├── libonnxruntime4j_jni.so │ │ │ ├── libonnxruntime.so │ │ │ ├── libsherpa-onnx-c-api.so │ │ │ ├── libsherpa-onnx-cxx-api.so │ │ │ └── libsherpa-onnx-jni.so │ │ └── x86_64 │ │ ├── libonnxruntime4j_jni.so │ │ ├── libonnxruntime.so │ │ ├── libsherpa-onnx-c-api.so │ │ ├── libsherpa-onnx-cxx-api.so │ │ └── libsherpa-onnx-jni.so │ └── res │ ├── drawable │ │ └── ic_launcher_background.xml │ ├── drawable-v24 │ │ └── ic_launcher_foreground.xml │ ├── layout │ │ └── activity_main.xml │ ├── mipmap-anydpi-v26 │ │ ├── ic_launcher_round.xml │ │ └── ic_launcher.xml │ ├── mipmap-hdpi │ │ ├── ic_launcher_round.webp │ │ └── ic_launcher.webp │ ├── mipmap-mdpi │ │ ├── ic_launcher_round.webp │ │ └── ic_launcher.webp │ ├── mipmap-xhdpi │ │ ├── ic_launcher_round.webp │ │ └── ic_launcher.webp │ ├── mipmap-xxhdpi │ │ ├── ic_launcher_round.webp │ │ └── ic_launcher.webp │ ├── mipmap-xxxhdpi │ │ ├── ic_launcher_round.webp │ │ └── ic_launcher.webp │ ├── values │ │ ├── colors.xml │ │ ├── strings.xml │ │ └── themes.xml │ ├── values-night │ │ └── themes.xml │ └── xml │ ├── backup_rules.xml │ └── data_extraction_rules.xml
## 参考
1. [android onnx](https://blog.51cto.com/u_16213465/13067353)
2. [Get started with ONNX Runtime Mobile](https://onnxruntime.ai/docs/get-started/with-mobile.html)
3. [ONNX Runtime](https://github.com/microsoft/onnxruntime)
4. [Object Detection Android sample](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/mobile/examples/object_detection/android)