当前位置: 首页 > news >正文

机器学习 (1) 监督学习

目录

  • 简介
  • 核心概念
    • 训练数据结构
    • 两大类型
      • 1. 分类问题(Classification)
      • 2. 回归问题(Regression)

简介

  • 监督学习的核心特点是使用带标签的训练数据来训练模型,简单来说,就是通过输入xxx(特征)推导出输出yyy(标签),算法的目标是学习从输入到输出的映射关系,以便于对新的、未见过的数据进行预测。

核心概念

训练数据结构

  1. 特征(features):输入变量,用xxx表示
  2. 标签(labels):目标变量,用yyy表示
  3. 训练集:包含特征和标签的数据集合

两大类型

1. 分类问题(Classification)

  • 分类问题是要找到一条拟合的边界线 ,把两种类别的数据划分开来
  • 下面使用一个测试用的乳腺癌数据集来演示分类问题,想通过平均半径和平均纹理来寻找判断良性和恶性肿瘤的划分方式
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score# 加载乳腺癌数据集
data = load_breast_cancer()
X = data.data[:, :2]  # 只使用前两个特征:平均半径和平均纹理
y = data.target       # 0: 恶性,1: 良性# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 训练逻辑回归模型
model = LogisticRegression()
model.fit(X_train, y_train)# 进行预测
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)# 可视化数据
plt.figure(figsize=(10, 8))# 绘制恶性肿瘤(红色 x)
malignant = X[y == 0]
plt.scatter(malignant[:, 0], malignant[:, 1],c='red', marker='x', s=100, linewidth=2, label='Malignant')# 绘制良性肿瘤(蓝色圆圈)
benign = X[y == 1]
plt.scatter(benign[:, 0], benign[:, 1],c='blue', marker='o', s=80, alpha=0.7,facecolors='none', edgecolors='blue', linewidth=2, label='Benign')# 创建决策边界
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),np.arange(y_min, y_max, 0.1))# 预测网格点的类别
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)# 绘制决策边界
plt.contour(xx, yy, Z, levels=[0.5], colors='purple', linewidths=2)# 设置图表标签(英文避免乱码)
plt.xlabel('Tumor Size (mean radius)', fontsize=12)
plt.ylabel('Age (mean texture)', fontsize=12)
plt.title('Breast Cancer Classification\nTwo or more inputs', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)# 添加准确率文本
plt.text(0.02, 0.98, f'Accuracy: {accuracy:.3f}',transform=plt.gca().transAxes, fontsize=12,verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))plt.tight_layout()
plt.show()# 打印结果信息
print(f"数据集形状: {X.shape}")
print(f"恶性样本数量: {np.sum(y == 0)}")
print(f"良性样本数量: {np.sum(y == 1)}")
print(f"模型准确率: {accuracy:.3f}")

在这里插入图片描述

2. 回归问题(Regression)

  • 下面是一个回归的例子,通过一些数据,得到xxxyyy之间的映射关系
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression# 生成回归数据
x_reg, y_reg = make_regression(n_samples=1000, n_features=1, noise=10, random_state=42)
# 创建并训练线性回归模型
model = LinearRegression()
model.fit(x_reg, y_reg)
x_line = np.linspace(x_reg.min(), x_reg.max(), 100).reshape(-1, 1)
y_line = model.predict(x_line)# 可视化回归数据
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 2)
plt.scatter(x_reg, y_reg, alpha=0.6, color='blue', label='data point')
plt.plot(x_line, y_line, color='red', linewidth=2, label=f'line of regression (y = {model.coef_[0]:.2f}x + {model.intercept_:.2f})')
plt.title('data distribution')
plt.xlabel('feature value')
plt.ylabel('target value')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print('数据形状:', x_reg.shape, y_reg.shape)
print(f'回归线方程: y = {model.coef_[0]:.2f}x + {model.intercept_:.2f}')
print(f'模型R²分数: {model.score(x_reg, y_reg):.4f}')

在这里插入图片描述

http://www.dtcms.com/a/532498.html

相关文章:

  • 从哪里找网络推广公司网站优化 毕业设计
  • Java如何将数据写入到PDF文件
  • 开发板网络配置
  • 14天备考软考-day1: 计组、操作系统(仅自用)
  • 企业网站模板包含什么有什么软件可以做网站
  • .gitignore 不生效问题——删除错误追踪的文件
  • 深度学习优化器详解
  • 做企业公示的数字证书网站wordpress有识图接口吗
  • 中国商标注册申请官网百度蜘蛛池自动收录seo
  • GitHub 热榜项目 - 日榜(2025-10-26)
  • 数据分析:指标拆解、异动归因类题目
  • 做网站需要那些软件设计建网站
  • Gorm(十二)乐观锁和悲观锁
  • neo4j图数据库笔记
  • 网页网站设计公司有哪些网站排名有什么用
  • 泉州做网站优化哪家好微信推广平台哪里找
  • 如何制作收费网站百度收录个人网站是什么怎么做
  • VsCode + Wsl:终极开发环境搭建指南
  • 深度学习——Logistic回归中的梯度下降法
  • 中国住房和城乡建设网网站学习网站大全
  • 【Android】ViewPager2实现手/自动轮播图
  • 产品营销网站可以做英语翻译兼职的网站
  • jQuery Mobile 图标:全面解析与应用指南
  • Java(File)
  • AI 翻译入门指南:机器如何理解语言
  • 怎样上传网站程序网站数据库怎么配置
  • MySQL相关知识查询表中的内容(第三次作业)
  • h5游戏免费下载:过马路小游戏
  • 昆山建设局网站深圳企业有限公司
  • LangGraph 官方教程:聊天机器人之三