当前位置: 首页 > news >正文

技巧|SwanLab记录混淆矩阵攻略

绘制混淆矩阵(Confusion Matrix),用于评估分类模型的性能。混淆矩阵展示了模型预测结果与真实标签之间的对应关系,能够直观地显示各类别的预测准确性和错误类型。

混淆矩阵是评估分类模型性能的基础工具,特别适用于多分类问题。

你可以使用swanlab.confusion_matrix来记录混淆矩阵。

Demo链接:ComputeMetrics - SwanLab

在这里插入图片描述

基本用法

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import xgboost as xgb
import swanlab# 加载鸢尾花数据集
iris_data = load_iris()
X = iris_data.data
y = iris_data.target
class_names = iris_data.target_names.tolist()# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 训练模型
model = xgb.XGBClassifier(objective='multi:softmax', num_class=len(class_names))
model.fit(X_train, y_train)# 获取预测结果
y_pred = model.predict(X_test)# 初始化SwanLab
swanlab.init(project="Confusion-Matrix-Demo", experiment_name="Confusion-Matrix-Example")# 记录混淆矩阵
swanlab.log({"confusion_matrix": swanlab.confusion_matrix(y_test, y_pred, class_names)
})swanlab.finish()

使用自定义类别名称

# 定义自定义类别名称
custom_class_names = ["类别A", "类别B", "类别C"]# 记录混淆矩阵
confusion_matrix = swanlab.confusion_matrix(y_test, y_pred, custom_class_names)
swanlab.log({"confusion_matrix_custom": confusion_matrix})

不使用类别名称

# 不指定类别名称,将使用数字索引
confusion_matrix = swanlab.confusion_matrix(y_test, y_pred)
swanlab.log({"confusion_matrix_default": confusion_matrix})

二分类示例

import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import xgboost as xgb
import swanlab# 生成二分类数据
X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 训练模型
model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')
model.fit(X_train, y_train)# 获取预测结果
y_pred = model.predict(X_test)# 记录混淆矩阵
swanlab.log({"confusion_matrix": swanlab.confusion_matrix(y_test, y_pred, ["负类", "正类"])
})

注意事项

  1. 数据格式: y_truey_pred可以是列表或numpy数组
  2. 多分类支持: 此函数支持二分类和多分类问题
  3. 类别名称: class_names的长度应该与类别数量一致
  4. 依赖包: 需要安装scikit-learnpyecharts
  5. 坐标轴: sklearn的confusion_matrix左上角为(0,0),在pyecharts的heatmap中是左下角,函数会自动处理坐标转换
  6. 矩阵解读: 混淆矩阵中,行表示真实标签,列表示预测标签
http://www.dtcms.com/a/312277.html

相关文章:

  • express-jwt报错:Error: algorithms should be set
  • 【智能体cooragent】不同的单智能体调用的大模型的推理的输入与输出
  • 笔试——Day26
  • 【LLM】如何在Cursor中调用Dify工作流
  • Makefile 从入门到精通:自动化构建的艺术
  • 【Java基础知识 16】 数组详解
  • 微积分思想的严密性转变 | 极限、逼近与程序化
  • 计算机技术与软件专业技术资格(水平)考试简介
  • 【Pytorch✨】LSTM01 入门
  • 集成电路学习:什么是HAL硬件抽象层
  • 【设计模式】 3.设计模式基本原则
  • 对于考研数学的理解
  • 【攻防实战】记一次DOUBLETROUBLE攻防实战
  • build文件夹下面的主要配置文件
  • win10任务栏出问题了,原来是wincompressbar导致的
  • 扫雷游戏完整代码
  • RK3399 启动流程 --从复位到系统加载
  • Munge 安全认证和授权服务的工作原理,以及与 Slurm 的配合
  • 【python】转移本地安装的python包
  • vue3 新手学习入门
  • 【LeetCode 热题 100】(三)滑动窗口
  • 在线任意长度大整数计算器
  • 轻量级鼠标右键增强工具 MousePlus
  • 数据链路层、NAT、代理服务、内网穿透
  • 变频器实习DAY20 测试经验总结
  • WinForm之NumericUpDown控件
  • Noob靶机攻略
  • 力扣刷题日常(11-12)
  • linux编译基础知识-头文件标准路径
  • NX947NX955美光固态闪存NX962NX966