kotlin kmp 跨平台环境使用sqldelight
欢迎访问我的主页: https://heeheeaii.github.io/
1. 项目结构
SQLDelightKMPDemo/
├── shared/
│ ├── src/
│ │ ├── commonMain/kotlin/
│ │ ├── androidMain/kotlin/
│ │ ├── desktopMain/kotlin/
│ │ └── commonMain/sqldelight/
│ └── build.gradle.kts
├── androidApp/
│ └── build.gradle.kts
├── desktopApp/
│ └── build.gradle.kts
└── build.gradle.kts
2. 根目录 build.gradle.kts
plugins {id("com.android.application") version "8.1.4" apply falseid("com.android.library") version "8.1.4" apply falseid("org.jetbrains.kotlin.multiplatform") version "1.9.20" apply falseid("org.jetbrains.kotlin.android") version "1.9.20" apply falseid("org.jetbrains.compose") version "1.5.4" apply falseid("app.cash.sqldelight") version "2.0.2" apply false
}
3. shared/build.gradle.kts
plugins {id("org.jetbrains.kotlin.multiplatform")id("com.android.library")id("app.cash.sqldelight")
}kotlin {androidTarget {compilations.all {kotlinOptions {jvmTarget = "1.8"}}}jvm("desktop")sourceSets {commonMain.dependencies {implementation("app.cash.sqldelight:runtime:2.0.2")implementation("app.cash.sqldelight:coroutines-extensions:2.0.2")implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.3")implementation("org.jetbrains.kotlinx:kotlinx-datetime:0.5.0")}androidMain.dependencies {implementation("app.cash.sqldelight:android-driver:2.0.2")implementation("androidx.lifecycle:lifecycle-viewmodel-ktx:2.7.0")}val desktopMain by getting {dependencies {implementation("app.cash.sqldelight:sqlite-driver:2.0.2")}}commonTest.dependencies {implementation("org.jetbrains.kotlin:kotlin-test:1.9.20")implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.7.3")}}
}android {namespace = "com.example.sqldelightkmp.shared"compileSdk = 34defaultConfig {minSdk = 24}compileOptions {sourceCompatibility = JavaVersion.VERSION_1_8targetCompatibility = JavaVersion.VERSION_1_8}
}sqldelight {databases {create("BeselfDatabase") {packageName.set("com.treevalue.beself.io")}}
}
SQL Schema定义
shared/src/commonMain/sqldelight/database/BeselfDatabase.sq
CREATE TABLE IF NOT EXISTS Task (id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,title TEXT NOT NULL,description TEXT,completed INTEGER NOT NULL DEFAULT 0,priority INTEGER NOT NULL DEFAULT 0,created_at INTEGER NOT NULL,updated_at INTEGER NOT NULL,due_date INTEGER
);-- 插入任务
insertTask:
INSERT INTO Task(title, description, completed, priority, created_at, updated_at, due_date)
VALUES(?, ?, ?, ?, ?, ?, ?);-- 获取所有任务
selectAllTasks:
SELECT * FROM Task
ORDER BY priority DESC, created_at DESC;-- 根据ID获取任务
selectTaskById:
SELECT * FROM Task WHERE id = ?;-- 根据完成状态获取任务
selectTasksByCompleted:
SELECT * FROM Task WHERE completed = ?
ORDER BY priority DESC, created_at DESC;-- 搜索任务
searchTasks:
SELECT * FROM Task
WHERE title LIKE '%' || ? || '%' OR description LIKE '%' || ? || '%'
ORDER BY priority DESC, created_at DESC;-- 更新任务
updateTask:
UPDATE Task
SET title = ?, description = ?, completed = ?, priority = ?, updated_at = ?, due_date = ?
WHERE id = ?;-- 标记任务完成
markTaskCompleted:
UPDATE Task SET completed = 1, updated_at = ? WHERE id = ?;-- 删除任务
deleteTask:
DELETE FROM Task WHERE id = ?;-- 删除所有已完成任务
deleteCompletedTasks:
DELETE FROM Task WHERE completed = 1;-- 获取任务统计
getTaskStats:
SELECTCOUNT(*) AS total,SUM(CASE WHEN completed = 1 THEN 1 ELSE 0 END) AS completed,SUM(CASE WHEN completed = 0 THEN 1 ELSE 0 END) AS pending
FROM Task;
通用代码实现
1. 数据模型
package com.treevalue.beself.ioimport kotlinx.datetime.Instantdata class Task(val id: Long = 0,val title: String,val description: String? = null,val completed: Boolean = false,val priority: Priority = Priority.MEDIUM,val createdAt: Instant,val updatedAt: Instant,val dueDate: Instant? = null
)enum class Priority(val value: Int, val displayName: String) {LOW(0, "Low"),MEDIUM(1, "Medium"),HIGH(2, "High"),URGENT(3, "Urgent");companion object {fun fromValue(value: Int): Priority = values().find { it.value == value } ?: MEDIUM}
}data class TaskStats(val total: Long,val completed: Long,val pending: Long
) {val completionRate: Double = if (total > 0) completed.toDouble() / total else 0.0
}
2. 数据库驱动工厂
package com.treevalue.beself.ioimport app.cash.sqldelight.db.SqlDriverexpect class DatabaseDriverFactory(context: Any? = null) {fun createDriver(): SqlDriver
}
BeselfDatabase
是由 SQLDelight 自动生成的类。要让它正常工作,需要确保以下几个步骤:
1. 确保 SQLDelight 配置正确
sqldelight {databases {create("BeselfDatabase") {packageName.set("com.treevalue.beself.io")}}
}
2. 确保 SQL 文件位置正确
SQL 文件应该位于:
shared/src/commonMain/sqldelight/database/BeselfDatabase.sq
注意:文件名 BeselfDatabase.sq
必须与 create("BeselfDatabase")
中的名称一致。
3. 构建项目生成代码
执行以下命令来生成 SQLDelight 代码:
./gradlew :shared:build
或者在 Android Studio/IntelliJ 中:
- 点击 “Build” → “Rebuild Project”
- 或者运行 “Sync Project with Gradle Files”
4. 验证生成的代码
构建成功后,SQLDelight 会在以下位置生成代码:
shared/build/generated/sqldelight/code/BeselfDatabase/commonMain/com/treevalue/beself/io
3. 数据库包装类
package com.treevalue.beself.ioimport app.cash.sqldelight.coroutines.asFlow
import app.cash.sqldelight.coroutines.mapToList
import app.cash.sqldelight.coroutines.mapToOneOrNull
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map
import kotlinx.datetime.Clock
import kotlinx.datetime.Instantclass Database(databaseDriverFactory: DatabaseDriverFactory) {private val database = BeselfDatabase(databaseDriverFactory.createDriver())private val dbQuery = database.beselfDatabaseQueries// 扩展函数将数据库行映射为模型private fun database.Task.toModel(): Task = Task(id = id,title = title,description = description,completed = completed != 0L,priority = Priority.fromValue(priority.toInt()),createdAt = Instant.fromEpochMilliseconds(created_at),updatedAt = Instant.fromEpochMilliseconds(updated_at),dueDate = due_date?.let { Instant.fromEpochMilliseconds(it) })suspend fun insertTask(title: String,description: String? = null,priority: Priority = Priority.MEDIUM,dueDate: Instant? = null,): Long {val now = Clock.System.now()return dbQuery.transactionWithResult {dbQuery.insertTask(title = title,description = description,completed = 0L,priority = priority.value.toLong(),created_at = now.toEpochMilliseconds(),updated_at = now.toEpochMilliseconds(),due_date = dueDate?.toEpochMilliseconds())// 返回最后插入的IDdbQuery.selectAllTasks().executeAsList().lastOrNull()?.id ?: 0L}}fun getAllTasksFlow(): Flow<List<Task>> {return dbQuery.selectAllTasks().asFlow().mapToList(Dispatchers.IO).map { tasks -> tasks.map { it.toModel() } }}suspend fun getAllTasks(): List<Task> {return dbQuery.selectAllTasks().executeAsList().map { it.toModel() }}suspend fun getTaskById(id: Long): Task? {return dbQuery.selectTaskById(id).executeAsOneOrNull()?.toModel()}fun getTaskByIdFlow(id: Long): Flow<Task?> {return dbQuery.selectTaskById(id).asFlow().mapToOneOrNull(Dispatchers.IO).map { it?.toModel() }}fun getTasksByCompletedFlow(completed: Boolean): Flow<List<Task>> {return dbQuery.selectTasksByCompleted(if (completed) 1L else 0L).asFlow().mapToList(Dispatchers.IO).map { tasks -> tasks.map { it.toModel() } }}suspend fun searchTasks(query: String): List<Task> {return dbQuery.searchTasks(query, query).executeAsList().map { it.toModel() }}suspend fun updateTask(task: Task) {dbQuery.updateTask(title = task.title,description = task.description,completed = if (task.completed) 1L else 0L,priority = task.priority.value.toLong(),updated_at = Clock.System.now().toEpochMilliseconds(),due_date = task.dueDate?.toEpochMilliseconds(),id = task.id)}suspend fun markTaskCompleted(id: Long) {dbQuery.markTaskCompleted(updated_at = Clock.System.now().toEpochMilliseconds(),id = id)}suspend fun deleteTask(id: Long) {dbQuery.deleteTask(id)}suspend fun deleteCompletedTasks() {dbQuery.deleteCompletedTasks()}fun getTaskStatsFlow(): Flow<TaskStats> {return dbQuery.getTaskStats().asFlow().mapToOneOrNull(Dispatchers.IO).map { stats ->stats?.let {TaskStats(total = it.total,completed = it.completed ?: 0,pending = it.pending ?: 0)} ?: TaskStats(0, 0, 0)}}
}
4. Repository层
package com.treevalue.beself.ioimport kotlinx.coroutines.flow.Flow
import kotlinx.datetime.Instantclass TaskRepository(private val database: Database) {fun getAllTasks(): Flow<List<Task>> = database.getAllTasksFlow()fun getCompletedTasks(): Flow<List<Task>> = database.getTasksByCompletedFlow(true)fun getPendingTasks(): Flow<List<Task>> = database.getTasksByCompletedFlow(false)fun getTaskById(id: Long): Flow<Task?> = database.getTaskByIdFlow(id)fun getTaskStats(): Flow<TaskStats> = database.getTaskStatsFlow()suspend fun createTask(title: String,description: String? = null,priority: Priority = Priority.MEDIUM,dueDate: Instant? = null): Long {return database.insertTask(title, description, priority, dueDate)}suspend fun updateTask(task: Task) {database.updateTask(task)}suspend fun toggleTaskCompleted(task: Task) {val updatedTask = task.copy(completed = !task.completed)database.updateTask(updatedTask)}suspend fun deleteTask(id: Long) {database.deleteTask(id)}suspend fun deleteAllCompletedTasks() {database.deleteCompletedTasks()}suspend fun searchTasks(query: String): List<Task> {return database.searchTasks(query)}
}
平台特定实现
1. Android驱动实现
package com.treevalue.beself.ioimport android.content.Context
import app.cash.sqldelight.db.SqlDriver
import app.cash.sqldelight.driver.android.AndroidSqliteDriveractual class DatabaseDriverFactory actual constructor(context: Any?) {private val androidContext = context as Contextactual fun createDriver(): SqlDriver {return AndroidSqliteDriver(schema = BeselfDatabase.Schema,context = androidContext,name = "task_database.db")}
}
2. 桌面驱动实现
package com.treevalue.beself.ioimport app.cash.sqldelight.db.SqlDriver
import app.cash.sqldelight.driver.jdbc.sqlite.JdbcSqliteDriver
import java.io.Fileactual class DatabaseDriverFactory actual constructor(context: Any?) {actual fun createDriver(): SqlDriver {val databasePath = File(System.getProperty("user.home"), ".taskapp/task_database.db")databasePath.parentFile?.mkdirs()val driver = JdbcSqliteDriver("jdbc:sqlite:${databasePath.absolutePath}")BeselfDatabase.Schema.create(driver)return driver}
}
桌面测试
1 依赖
val desktopTest by getting {dependencies {implementation(libs.testng)implementation(libs.kotlinx.coroutines.test)}
}
package com.beself.ioimport app.cash.sqldelight.db.SqlDriver
import app.cash.sqldelight.driver.jdbc.sqlite.JdbcSqliteDriver
import com.treevalue.beself.io.*
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.test.runTest
import org.junit.Assert.*
import org.junit.Before
import org.junit.Testclass TestDatabaseDriverFactory : DatabaseDriverFactory() {companion object {private val instanceCounter = atomic(0L)}override fun createDriver(): SqlDriver {// 为每个实例创建唯一的内存数据库val instanceId = instanceCounter.incrementAndGet()val url = "jdbc:sqlite:file:test_db_$instanceId?mode=memory&cache=shared"val driver = JdbcSqliteDriver(url)BeselfDatabase.Schema.create(driver)return driver}
}class DatabaseTest {private lateinit var database: Databaseprivate lateinit var repository: TaskRepository@Beforefun setUp() {// 为每个测试创建新的内存数据库实例,确保完全的测试隔离database = Database(TestDatabaseDriverFactory())repository = TaskRepository(database)}@Testfun testInsertAndRetrieveTask() = runTest {// 插入测试任务val taskId = database.insertTask(title = "Test Task",description = "This is a test task",priority = Priority.HIGH)assertTrue("Task ID should be greater than 0", taskId > 0)// 检索任务val retrievedTask = database.getTaskById(taskId)assertNotNull("Retrieved task should not be null", retrievedTask)assertEquals("Test Task", retrievedTask?.title ?: "")assertEquals("This is a test task", retrievedTask?.description ?: "")assertEquals(Priority.HIGH, retrievedTask?.priority ?: Priority.URGENT)retrievedTask?.let { assertFalse(it.completed) }}@Testfun testGetAllTasks() = runTest {// 插入多个任务val taskIds = mutableListOf<Long>()repeat(3) { index ->val id = database.insertTask(title = "Task $index",description = "Description $index",priority = Priority.entries[index % Priority.entries.size])taskIds.add(id)}// 获取所有任务val allTasks = database.getAllTasks()assertEquals("Should have exactly 3 tasks", 3, allTasks.size)// 验证排序(按优先级降序,创建时间降序)val sortedTasks = allTasks.sortedWith(compareByDescending<Task> { it.priority.value }.thenByDescending { it.createdAt })assertEquals("Tasks should be sorted correctly", sortedTasks.map { it.id }, allTasks.map { it.id })}@Testfun testUpdateTask() = runTest {// 插入任务val taskId = database.insertTask(title = "Original Title",description = "Original Description",priority = Priority.LOW)val originalTask = database.getTaskById(taskId)assertNotNull("Original task should exist", originalTask)// 更新任务val updatedTask = originalTask!!.copy(title = "Updated Title",description = "Updated Description",priority = Priority.URGENT,completed = true)database.updateTask(updatedTask)// 验证更新val retrievedTask = database.getTaskById(taskId)assertNotNull("Updated task should exist", retrievedTask)assertEquals("Updated Title", retrievedTask!!.title)assertEquals("Updated Description", retrievedTask.description)assertEquals(Priority.URGENT, retrievedTask.priority)assertTrue(retrievedTask.completed)}@Testfun testMarkTaskCompleted() = runTest {// 插入未完成任务val taskId = database.insertTask(title = "Incomplete Task",priority = Priority.MEDIUM)val originalTask = database.getTaskById(taskId)assertNotNull("Original task should exist", originalTask)assertFalse("Task should be incomplete initially", originalTask!!.completed)// 标记为完成database.markTaskCompleted(taskId)// 验证已完成val completedTask = database.getTaskById(taskId)assertNotNull("Completed task should exist", completedTask)assertTrue("Task should be marked as completed", completedTask!!.completed)}@Testfun testDeleteTask() = runTest {// 插入任务val taskId = database.insertTask(title = "Task to Delete",priority = Priority.LOW)// 确认任务存在assertNotNull("Task should exist before deletion", database.getTaskById(taskId))// 删除任务database.deleteTask(taskId)// 确认任务已删除assertNull("Task should be deleted", database.getTaskById(taskId))}@Testfun testGetTasksByCompleted() = runTest {// 插入已完成和未完成的任务val completedId = database.insertTask("Completed Task", priority = Priority.LOW)val pendingId = database.insertTask("Pending Task", priority = Priority.HIGH)// 标记一个为完成database.markTaskCompleted(completedId)// 测试Flowval completedTasks = repository.getCompletedTasks().first()val pendingTasks = repository.getPendingTasks().first()assertEquals("Should have 1 completed task", 1, completedTasks.size)assertEquals("Should have 1 pending task", 1, pendingTasks.size)assertEquals("Completed Task", completedTasks.first().title)assertEquals("Pending Task", pendingTasks.first().title)}@Testfun testSearchTasks() = runTest {// 插入搜索测试任务database.insertTask("Learn Kotlin", "Study Kotlin multiplatform")database.insertTask("Learn Swift", "Study iOS development")database.insertTask("Build App", "Create amazing mobile app")// 搜索包含"Learn"的任务val learnTasks = database.searchTasks("Learn")assertEquals("Should find 2 tasks with 'Learn'", 2, learnTasks.size)assertTrue("All found tasks should contain 'Learn'",learnTasks.all { it.title.contains("Learn") })// 搜索描述中包含"mobile"的任务val mobileTasks = database.searchTasks("mobile")assertEquals("Should find 1 task with 'mobile'", 1, mobileTasks.size)assertEquals("Build App", mobileTasks.first().title)// 搜索不存在的内容val noResults = database.searchTasks("NonExistent")assertTrue("Should find no results for non-existent term", noResults.isEmpty())}@Testfun testDeleteCompletedTasks() = runTest {// 插入混合状态的任务val taskIds = (1..5).map { index ->database.insertTask("Task $index")}// 标记前3个为完成taskIds.take(3).forEach { id ->database.markTaskCompleted(id)}// 验证初始状态assertEquals("Should have 5 tasks initially", 5, database.getAllTasks().size)assertEquals("Should have 3 completed tasks", 3, repository.getCompletedTasks().first().size)// 删除已完成的任务database.deleteCompletedTasks()// 验证删除结果val remainingTasks = database.getAllTasks()assertEquals("Should have 2 tasks remaining", 2, remainingTasks.size)assertTrue("All remaining tasks should be incomplete",remainingTasks.none { it.completed })}@Testfun testTaskStats() = runTest {// 插入不同状态的任务repeat(5) { index ->val taskId = database.insertTask("Task $index")if (index < 2) {database.markTaskCompleted(taskId)}}// 获取统计信息val stats = repository.getTaskStats().first()assertEquals("Should have 5 total tasks", 5, stats.total)assertEquals("Should have 2 completed tasks", 2, stats.completed)assertEquals("Should have 3 pending tasks", 3, stats.pending)assertEquals("Completion rate should be 0.4", 0.4, stats.completionRate, 0.01)}@Testfun testRepositoryOperations() = runTest {// 测试Repository层的操作val taskId = repository.createTask(title = "Repository Test",description = "Testing repository functionality",priority = Priority.HIGH)assertTrue("Task ID should be positive", taskId > 0)// 测试获取任务val task = repository.getTaskById(taskId).first()assertNotNull("Task should exist", task)assertEquals("Repository Test", task?.title)// 测试切换完成状态repository.toggleTaskCompleted(task!!)val updatedTask = repository.getTaskById(taskId).first()assertTrue("Task should be completed after toggle", updatedTask?.completed == true)// 再次切换repository.toggleTaskCompleted(updatedTask!!)val toggledTask = repository.getTaskById(taskId).first()assertFalse("Task should be incomplete after second toggle", toggledTask?.completed == true)}@Testfun testTaskPriorityMapping() = runTest {// 测试所有优先级Priority.values().forEach { priority ->val taskId = database.insertTask(title = "Priority ${priority.displayName}",priority = priority)val task = database.getTaskById(taskId)assertNotNull("Task with priority ${priority.displayName} should exist", task)assertEquals("Priority should match", priority, task!!.priority)assertEquals("Priority display name should match",priority.displayName, task.priority.displayName)}// 测试fromValue方法assertEquals("Priority.LOW should map from value 0", Priority.LOW, Priority.fromValue(0))assertEquals("Priority.MEDIUM should map from value 1", Priority.MEDIUM, Priority.fromValue(1))assertEquals("Priority.HIGH should map from value 2", Priority.HIGH, Priority.fromValue(2))assertEquals("Priority.URGENT should map from value 3", Priority.URGENT, Priority.fromValue(3))assertEquals("Invalid value should default to MEDIUM", Priority.MEDIUM, Priority.fromValue(999))}@Testfun testEmptyDatabaseStats() = runTest {// 确保数据库为空(已在setUp中清理)val allTasks = database.getAllTasks()assertEquals("Database should be empty", 0, allTasks.size)// 测试空数据库的统计val stats = repository.getTaskStats().first()assertEquals("Total should be 0 for empty database", 0, stats.total)assertEquals("Completed should be 0 for empty database", 0, stats.completed)assertEquals("Pending should be 0 for empty database", 0, stats.pending)assertEquals("Completion rate should be 0.0 for empty database", 0.0, stats.completionRate, 0.01)}@Testfun testFlowUpdates() = runTest {val allTasksFlow = repository.getAllTasks()// 初始状态应该为空val initialTasks = allTasksFlow.first()assertEquals("Initial tasks should be empty", 0, initialTasks.size)// 添加任务后应该能在Flow中看到val taskId = repository.createTask("Flow Test Task")val tasksAfterInsert = allTasksFlow.first()assertEquals("Should have 1 task after insert", 1, tasksAfterInsert.size)assertEquals("Flow Test Task", tasksAfterInsert.first().title)// 删除任务后Flow应该更新repository.deleteTask(taskId)val tasksAfterDelete = allTasksFlow.first()assertEquals("Should be empty after delete", 0, tasksAfterDelete.size)}
}