Kotlin与机器学习实战:Android端集成TensorFlow Lite全指南
本文将手把手教你如何在Android应用中集成TensorFlow Lite模型,实现端侧机器学习推理能力。我们以图像分类场景为例,提供可直接运行的完整代码示例。
环境准备
1. 开发环境要求
- Android Studio Arctic Fox以上版本
- AGP 7.0+
- Kotlin 1.6+
- Minimum SDK 21
2. 添加Gradle依赖
// build.gradle.kts
android {aaptOptions {noCompress "tflite" // 防止模型文件被压缩}
}dependencies {// TFLite核心库implementation("org.tensorflow:tensorflow-lite:2.12.0")implementation("org.tensorflow:tensorflow-lite-gpu:2.12.0") // GPU支持implementation("org.tensorflow:tensorflow-lite-support:0.4.4")// 相机扩展库(可选)implementation("androidx.camera:camera-core:1.3.0")implementation("androidx.camera:camera-lifecycle:1.3.0")implementation("androidx.camera:camera-view:1.3.0")// 协程支持implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
}
完整实现流程
步骤1:模型文件处理
将训练好的.tflite
模型文件放入app/src/main/assets
目录,建议同时包含labels.txt
标签文件
app/src/main/assets/
├── mobilenet_v1_1.0_224_quant.tflite
└── labels.txt
步骤2:核心分类器实现
import android.content.Context
import android.graphics.Bitmap
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.task.vision.classifier.ImageClassifierclass TFLiteImageClassifier(context: Context,modelPath: String = "mobilenet_v1_1.0_224_quant.tflite",labelPath: String = "labels.txt",private val threadNum: Int = 4
) {private var classifier: ImageClassifier? = nullprivate val labels: List<String>init {// 加载标签文件labels = context.assets.open(labelPath).bufferedReader().useLines { it.toList() }// 配置分类器选项val options = ImageClassifier.ImageClassifierOptions.builder().setMaxResults(3).setNumThreads(threadNum).setDelegate(Delegate.GPU) // 优先尝试GPU加速.build()try {classifier = ImageClassifier.createFromFileAndOptions(context, modelPath,options)} catch (e: IllegalStateException) {// GPU失败时回退CPUoptions.setDelegate(Delegate.CPU)classifier = ImageClassifier.createFromFileAndOptions(context,modelPath,options)}}fun classify(bitmap: Bitmap): List<Pair<String, Float>> {val imageProcessor = ImageProcessor.Builder().add(ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR)).add(NormalizeOp(127.5f, 127.5f)) // 根据模型类型调整.build()val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap))val results = classifier?.classify(tensorImage) ?: return emptyList()return results[0].categories.map {val label = labels.getOrNull(it.index) ?: "Unknown"label to it.score}}fun close() {classifier?.close()}
}
步骤3:UI界面实现
activity_main.xml
<androidx.constraintlayout.widget.ConstraintLayoutxmlns:android="http://schemas.android.com/apk/res/android"xmlns:app="http://schemas.android.com/apk/res-auto"android:layout_width="match_parent"android:layout_height="match_parent"><androidx.camera.view.PreviewViewandroid:id="@+id/cameraPreview"android:layout_width="300dp"android:layout_height="300dp"app:layout_constraintTop_toTopOf="parent"app:layout_constraintStart_toStartOf="parent"/><ImageViewandroid:id="@+id/ivPreview"android:layout_width="300dp"android:layout_height="300dp"app:layout_constraintTop_toTopOf="parent"app:layout_constraintEnd_toEndOf="parent"/><Buttonandroid:id="@+id/btnCapture"android:layout_width="wrap_content"android:layout_height="wrap_content"android:text="拍照识别"app:layout_constraintBottom_toBottomOf="parent"app:layout_constraintStart_toStartOf="parent"/><Buttonandroid:id="@+id/btnSelect"android:layout_width="wrap_content"android:layout_height="wrap_content"android:text="图库选择"app:layout_constraintBottom_toBottomOf="parent"app:layout_constraintEnd_toEndOf="parent"/><TextViewandroid:id="@+id/tvResult"android:layout_width="0dp"android:layout_height="wrap_content"android:padding="16dp"android:textSize="18sp"app:layout_constraintTop_toBottomOf="@id/cameraPreview"app:layout_constraintStart_toStartOf="parent"app:layout_constraintEnd_toEndOf="parent"/></androidx.constraintlayout.widget.ConstraintLayout>
步骤4:主Activity实现(CameraX集成版)
@RequiresApi(Build.VERSION_CODES.M)
class MainActivity : AppCompatActivity() {private lateinit var classifier: TFLiteImageClassifierprivate lateinit var cameraExecutor: ExecutorServiceprivate var imageCapture: ImageCapture? = nulloverride fun onCreate(savedInstanceState: Bundle?) {super.onCreate(savedInstanceState)setContentView(R.layout.activity_main)cameraExecutor = Executors.newSingleThreadExecutor()classifier = TFLiteImageClassifier(this)// 请求相机权限if (allPermissionsGranted()) {startCamera()} else {ActivityCompat.requestPermissions(this, REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS)}// 拍照按钮点击btnCapture.setOnClickListener {takePhoto()}// 图库选择btnSelect.setOnClickListener {val intent = Intent(Intent.ACTION_GET_CONTENT).apply {type = "image/*"}startActivityForResult(intent, REQUEST_IMAGE_PICK)}}private fun takePhoto() {val imageCapture = imageCapture ?: returnval outputOptions = ImageCapture.OutputFileOptions.Builder(File.createTempFile("ML_TEMP", ".jpg", cacheDir)).build()imageCapture.takePicture(outputOptions,ContextCompat.getMainExecutor(this),object : ImageCapture.OnImageSavedCallback {override fun onImageSaved(output: ImageCapture.OutputFileResults) {val uri = output.savedUri ?: returnprocessImage(uri)}override fun onError(exc: ImageCaptureException) {Log.e(TAG, "拍照失败: ${exc.message}", exc)}})}private fun processImage(uri: Uri) {lifecycleScope.launch(Dispatchers.IO) {try {val bitmap = contentResolver.loadThumbnail(uri, Size(224, 224), null)val results = classifier.classify(bitmap)withContext(Dispatchers.Main) {ivPreview.setImageBitmap(bitmap)showResults(results)}} catch (e: Exception) {Log.e(TAG, "图片处理失败", e)}}}private fun showResults(results: List<Pair<String, Float>>) {val output = buildString {append("识别结果:\n")results.forEach { (label, confidence) ->append("${label}: ${"%.2f".format(confidence * 100)}%\n")}}tvResult.text = output}// CameraX初始化private fun startCamera() {val cameraProviderFuture = ProcessCameraProvider.getInstance(this)cameraProviderFuture.addListener({val cameraProvider = cameraProviderFuture.get()val preview = Preview.Builder().build().also { it.setSurfaceProvider(cameraPreview.surfaceProvider) }imageCapture = ImageCapture.Builder().setCaptureMode(ImageCapture.CAPTURE_MODE_MINIMIZE_LATENCY).build()try {cameraProvider.unbindAll()cameraProvider.bindToLifecycle(this, CameraSelector.DEFAULT_BACK_CAMERA, preview, imageCapture)} catch (exc: Exception) {Log.e(TAG, "相机初始化失败", exc)}}, ContextCompat.getMainExecutor(this))}// 权限处理override fun onRequestPermissionsResult(requestCode: Int, permissions: Array<String>, grantResults: IntArray) {super.onRequestPermissionsResult(requestCode, permissions, grantResults)if (requestCode == REQUEST_CODE_PERMISSIONS) {if (allPermissionsGranted()) {startCamera()} else {Toast.makeText(this, "需要相机权限", Toast.LENGTH_SHORT).show()finish()}}}companion object {private const val TAG = "MLDemo"private const val REQUEST_CODE_PERMISSIONS = 10private const val REQUEST_IMAGE_PICK = 101private val REQUIRED_PERMISSIONS = arrayOf(Manifest.permission.CAMERA)private fun allPermissionsGranted() = REQUIRED_PERMISSIONS.all {ContextCompat.checkSelfPermission(context, it) == PackageManager.PERMISSION_GRANTED}}
}
高级优化技巧
1. 性能监控
class BenchmarkHelper {fun measureInference(bitmap: Bitmap) {val warmupRuns = 10val benchmarkRuns = 100// 预热repeat(warmupRuns) {classifier.classify(bitmap)}// 正式测试val start = SystemClock.elapsedRealtime()repeat(benchmarkRuns) {classifier.classify(bitmap)}val avgTime = (SystemClock.elapsedRealtime() - start) / benchmarkRuns.toFloat()Log.d("Benchmark", "平均推理时间: ${avgTime}ms")}
}
2. 模型动态更新
private fun downloadAndUpdateModel(modelUrl: String) {lifecycleScope.launch(Dispatchers.IO) {try {val tempFile = File.createTempFile("model", ".tflite")Retrofit.Builder().baseUrl("https://your-model-server/").build().create(ModelService::class.java).downloadModel(modelUrl).enqueue(object : Callback<ResponseBody> {override fun onResponse(call: Call<ResponseBody>, response: Response<ResponseBody>) {response.body()?.byteStream()?.use { input ->tempFile.outputStream().use { output ->input.copyTo(output)}}classifier.updateModel(tempFile)}override fun onFailure(call: Call<ResponseBody>, t: Throwable) {Log.e("ModelUpdate", "下载失败", t)}})} catch (e: Exception) {Log.e("ModelUpdate", "更新失败", e)}}
}
常见问题解决方案
问题1:输入尺寸不匹配
解决方案:
val inputTensor = classifier.getInputTensor(0)
val inputShape = inputTensor.shape() // 获取实际输入尺寸
val dataType = inputTensor.dataType()// 动态调整预处理
val resizeOp = when (dataType) {DataType.UINT8 -> ResizeWithCropOrPadOp(inputShape[1], inputShape[2])DataType.FLOAT32 -> ResizeOp(inputShape[1], inputShape[2], ResizeMethod.BILINEAR)else -> throw IllegalArgumentException("不支持的输入类型")
}
问题2:内存泄漏
预防措施:
override fun onDestroy() {super.onDestroy()classifier.close()cameraExecutor.shutdown()
}
扩展应用方向
实时视频流处理
class VideoAnalyzer(private val classifier: TFLiteImageClassifier) : ImageAnalysis.Analyzer {private val frameCounter = AtomicInteger(0)private val skipFrame = 3 // 控制处理频率override fun analyze(imageProxy: ImageProxy) {if (frameCounter.getAndIncrement() % skipFrame != 0) {imageProxy.close()return}val bitmap = imageProxy.toBitmap() // 实现ImageProxy转BitmaplifecycleScope.launch(Dispatchers.Default) {val results = classifier.classify(bitmap)updateUI(results)imageProxy.close()}}
}
最佳实践建议
-
模型优化:
- 使用TFLite Model Optimization Toolkit进行量化
- 使用ML Metadata添加模型描述
-
性能平衡:
- 根据设备性能动态选择推理后端(CPU/GPU/NNAPI)
- 针对低端设备启用XNNPACK优化:
ImageClassifierOptions.builder().setComputeSettings(ComputeSettings.builder().setDelegate(Delegate.XNNPACK).build()
-
安全防护:
// 模型完整性校验 fun verifyModel(file: File): Boolean {val expectedHash = "a1b2c3d4..." // 预计算SHA256return FileUtils.calculateSHA256(file) == expectedHash }
建议结合具体业务需求选择合适的模型,并通过性能分析工具持续优化推理流程。