当前位置: 首页 > news >正文

SHAP可视化代码详细讲解

SHAP可视化代码详细讲解

目录

  1. 代码概述
  2. 环境准备与导入
  3. 数据生成
  4. 模型训练
  5. SHAP值计算
  6. 可视化图表详解
  7. 交互特征分析
  8. 总结与应用建议

1. 代码概述

本代码是一个完整的SHAP (SHapley Additive exPlanations) 分析示例,展示了如何:

  • 生成模拟数据集
  • 训练机器学习模型(随机森林)
  • 计算SHAP值来解释模型预测
  • 生成19种不同类型的SHAP可视化图表

SHAP的核心作用:解释"黑盒"机器学习模型的预测结果,量化每个特征对预测的贡献程度。


2. 环境准备与导入

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
import shap

关键库说明

  • numpy/pandas:数据处理
  • matplotlib:图表绘制
  • sklearn:机器学习模型
  • shap:SHAP值计算和可视化(需要安装:pip install shap

3. 数据生成

3.1 创建特征

X = pd.DataFrame({'feature_1': np.random.randn(n_samples),          # 标准正态分布'feature_2': np.random.randn(n_samples),          # 标准正态分布'feature_3': np.random.randn(n_samples),          # 标准正态分布'feature_4': np.random.uniform(0, 10, n_samples), # 0-10均匀分布'feature_5': np.random.choice([0,1,2,3], n_samples), # 分类变量'N_dist': np.random.uniform(1.10, 1.30, n_samples),  # 距离特征'N_d': np.random.choice(range(11), n_samples)        # 离散特征
})

特征设计说明

  • 连续特征(feature_1-4):模拟实数值特征
  • 分类特征(feature_5):模拟类别数据
  • N_dist:模拟化学中的原子间距离(单位:埃Å)
  • N_d:模拟离散的配位数或计数特征

3.2 构建目标变量

y = (2 * X['feature_1'] +                    # 线性关系3 * X['feature_2']**2 +                 # 非线性关系X['feature_3'] * X['feature_4'] +       # 交互效应np.sin(X['N_dist'] * 10) * 5 +          # 周期性关系X['N_d'] * 0.5 +                        # 线性关系np.random.randn(n_samples) * 0.5)       # 随机噪声

关系类型

  1. 线性:feature_1, N_d
  2. 非线性:feature_2(平方关系)
  3. 交互:feature_3 × feature_4
  4. 周期性:N_dist(正弦函数)

这种复杂关系使得SHAP分析更有意义!


4. 模型训练

rf_model = RandomForestRegressor(n_estimators=100,    # 100棵决策树max_depth=10,        # 最大深度10random_state=42,     # 随机种子n_jobs=-1            # 使用所有CPU核心
)
rf_model.fit(X_train, y_train)

模型选择理由

  • 随机森林能捕捉非线性和交互关系
  • SHAP的TreeExplainer对树模型进行了优化,计算速度快
  • R²分数用于评估模型质量(接近1表示拟合良好)

5. SHAP值计算

explainer = shap.TreeExplainer(rf_model)
shap_values = explainer(X_test)

核心概念

  • explainer.expected_value:基准预测值(所有特征未知时的平均预测)
  • shap_values:每个样本每个特征的SHAP值
  • SHAP值解读
    • 正值:特征使预测值增大
    • 负值:特征使预测值减小
    • 绝对值大小:影响程度

数学原理

预测值 = 基准值 + SHAP值1 + SHAP值2 + ... + SHAP值n

6. 可视化图表详解

图1: Summary Plot (点图)

shap.summary_plot(shap_values, X_test, show=False)

功能:整体特征重要性可视化

  • Y轴:特征名称(按重要性降序排列)
  • X轴:SHAP值(影响大小)
  • 颜色:特征值高低(红色=高,蓝色=低)
  • 解读:可看出哪些特征最重要,以及特征值高低如何影响预测

在这里插入图片描述

图2: Summary Plot (柱状图)

shap.summary_plot(shap_values, X_test, plot_type="bar", show=False)

功能:平均绝对SHAP值排名

  • 显示每个特征的平均影响程度
  • 快速识别最重要的特征
  • 注意:只显示重要性大小,不显示正负方向

在这里插入图片描述

图3: Waterfall Plot (瀑布图)

shap.plots.waterfall(shap_values[0], show=False)

功能:单个样本的预测解释

  • 从基准值开始,逐步展示每个特征的贡献
  • 红色箭头:增加预测值
  • 蓝色箭头:减少预测值
  • 最终到达实际预测值

应用场景:解释某个具体预测结果给业务人员


在这里插入图片描述

图4: Force Plot (力图)

shap.plots.force(shap_values[0], matplotlib=True, show=False)

功能:紧凑型单样本解释

  • 红色区域:推高预测值的特征
  • 蓝色区域:降低预测值的特征
  • 区域宽度=影响大小

优势:适合快速浏览多个预测


在这里插入图片描述

图5: Decision Plot (决策图)

shap.decision_plot(explainer.expected_value,shap_values.values[:50],X_test.iloc[:50],show=False
)

功能:多个样本的决策路径可视化

  • 每条线代表一个样本
  • 从基准值(底部)到最终预测(顶部)
  • 可以看出不同样本的相似决策路径

适用:比较多个预测的差异


在这里插入图片描述

图6-12: Dependence Plot (依赖图)

for feature in X_test.columns:shap.plots.scatter(shap_values[:, feature], show=False)

功能:单特征与SHAP值的关系

  • X轴:特征值
  • Y轴:该特征的SHAP值
  • 颜色:自动选择交互最强的另一个特征

关键发现

  • 线性关系:点呈直线分布
  • 非线性关系:点呈曲线或其他模式
  • 交互效应:不同颜色的点有不同趋势

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

图13-14: 自定义散点图

# 图13: N_dist特征
ax.scatter(X_test['N_dist'].values, shap_values.values[:, feature_idx],c='blue', alpha=0.6)
ax.set_xlabel('N≡N (Å)')
ax.set_ylabel('SHAP value for N≡N')

目的

  • 自定义样式(颜色、标签、格式)
  • 适配特定领域需求(如化学、物理)
  • 添加参考线、背景色等辅助元素

图14:类似处理离散特征N_d,用红色表示


在这里插入图片描述
在这里插入图片描述

图15: Interaction Scatter Plot

shap.plots.scatter(shap_values[:, "feature_3"],color=shap_values[:, "feature_4"],show=False
)

功能:探索两个特征的交互

  • 查看feature_3的SHAP值如何受feature_4影响
  • 颜色深浅表示feature_4的SHAP值

应用:发现特征间的协同或拮抗效应


在这里插入图片描述

图16: Beeswarm Plot (蜂群图)

shap.plots.beeswarm(shap_values, show=False)

功能:改进版的summary plot

  • 更清晰的点分布
  • 避免点重叠
  • 保留summary plot的所有信息

推荐:替代传统summary plot,视觉效果更好


在这里插入图片描述

图17: Heatmap Plot (热力图)

shap.plots.heatmap(shap_values[:100], show=False)

功能:样本-特征SHAP值矩阵

  • :样本
  • :特征
  • 颜色:SHAP值大小

用途

  • 识别样本聚类
  • 发现特征模式
  • 适合大规模分析

在这里插入图片描述

7. 交互特征分析

7.1 计算交互值

shap_interaction_values = explainer.shap_interaction_values(X_test[:100])

注意

  • 计算量大,建议用小样本(这里用100个)
  • 返回3D数组:[样本, 特征i, 特征j]
  • [i, j]位置表示特征i和特征j的交互效应

图18: Interaction Summary

shap.summary_plot(shap_interaction_values[:, :, :],X_test.iloc[:100],show=False
)

功能:交互特征重要性总览

  • 显示哪些特征对之间交互最强
  • 帮助理解复杂的非线性关系

在这里插入图片描述

图19: Specific Interaction Dependence

shap.dependence_plot(("feature_3", "feature_4"),shap_interaction_values,X_test.iloc[:100],show=False
)

功能:查看特定特征对的交互

  • 专门研究feature_3和feature_4的协同效应
  • 验证数据生成时设计的交互关系

在这里插入图片描述

8. 总结与应用建议

8.1 核心要点

  1. SHAP值是可加的:所有特征SHAP值之和 + 基准值 = 预测值
  2. 公平性:基于博弈论的Shapley值,保证公平分配贡献
  3. 一致性:如果改变模型使某特征影响增大,其SHAP值绝对值也增大

8.2 图表选择指南

目的推荐图表
整体特征重要性Summary Plot / Beeswarm Plot
单个预测解释Waterfall Plot / Force Plot
特征-预测关系Dependence Plot
多样本比较Decision Plot / Heatmap
交互效应探索Interaction plots

8.3 实际应用场景

医疗诊断:解释为什么模型判断某患者高风险
金融风控:说明为什么拒绝某笔贷款申请
推荐系统:展示为什么推荐某商品
科学研究:发现物理/化学变量间的关系

8.4 注意事项

  1. 计算成本:交互值计算很慢,先用小样本测试
  2. 模型依赖:TreeExplainer适合树模型,其他模型用KernelExplainer
  3. 解读谨慎:SHAP显示相关性,不一定是因果关系
  4. 数据质量:垃圾进垃圾出,确保训练数据质量

8.5 扩展学习

  • SHAP官方文档:https://shap.readthedocs.io
  • 原始论文:Lundberg & Lee (2017) - “A Unified Approach to Interpreting Model Predictions”
  • 进阶:SHAP在深度学习中的应用(DeepExplainer)

代码运行要求

# 安装依赖
pip install numpy pandas matplotlib scikit-learn shap# 运行代码
python shap画图.py# 输出结果
# - 19张PNG图片
# - 控制台统计信息

预期输出时间:1-3分钟(取决于硬件配置)


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
import shap# 设置随机种子
np.random.seed(42)# ==================== 1. 生成模拟数据 ====================
n_samples = 1000# 创建特征
X = pd.DataFrame({'feature_1': np.random.randn(n_samples),'feature_2': np.random.randn(n_samples),'feature_3': np.random.randn(n_samples),'feature_4': np.random.uniform(0, 10, n_samples),'feature_5': np.random.choice([0, 1, 2, 3], n_samples),  # 分类变量'N_dist': np.random.uniform(1.10, 1.30, n_samples),  # 类似附件(c)图的距离特征'N_d': np.random.choice(range(11), n_samples)  # 类似附件(d)图的离散特征
})# 创建目标变量(带有非线性关系和交互效应)
y = (2 * X['feature_1'] +3 * X['feature_2']**2 +X['feature_3'] * X['feature_4'] +  # 交互效应np.sin(X['N_dist'] * 10) * 5 +X['N_d'] * 0.5 +np.random.randn(n_samples) * 0.5)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42
)# ==================== 2. 训练随机森林模型 ====================
print("训练随机森林模型...")
rf_model = RandomForestRegressor(n_estimators=100,max_depth=10,random_state=42,n_jobs=-1
)
rf_model.fit(X_train, y_train)print(f"训练集 R² Score: {rf_model.score(X_train, y_train):.4f}")
print(f"测试集 R² Score: {rf_model.score(X_test, y_test):.4f}")# ==================== 3. 计算SHAP值 ====================
print("\n计算SHAP值...")
# 使用TreeExplainer(专为树模型优化)
explainer = shap.TreeExplainer(rf_model)
shap_values = explainer(X_test)# ==================== 4. 绘制所有SHAP图 ====================
plt.style.use('default')# ------------------------ 图1: Summary Plot (点图) ------------------------
print("\n绘制图1: Summary Plot (特征重要性概览)...")
plt.figure(figsize=(10, 6))
shap.summary_plot(shap_values, X_test, show=False)
plt.title("SHAP Summary Plot - Feature Importance Overview", fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('shap_1_summary_plot.png', dpi=300, bbox_inches='tight')
plt.show()# ------------------------ 图2: Summary Plot (柱状图) ------------------------
print("绘制图2: Summary Plot (平均绝对SHAP值)...")
plt.figure(figsize=(10, 6))
shap.summary_plot(shap_values, X_test, plot_type="bar", show=False)
plt.title("SHAP Summary Plot - Mean Absolute SHAP Values", fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('shap_2_summary_bar.png', dpi=300, bbox_inches='tight')
plt.show()# ------------------------ 图3: Waterfall Plot (单样本) ------------------------
print("绘制图3: Waterfall Plot (单个预测解释)...")
plt.figure(figsize=(10, 6))
shap.plots.waterfall(shap_values[0], show=False)
plt.title("SHAP Waterfall Plot - Single Prediction Explanation", fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('shap_3_waterfall.png', dpi=300, bbox_inches='tight')
plt.show()# ------------------------ 图4: Force Plot (单样本) ------------------------
print("绘制图4: Force Plot (单个预测)...")
plt.figure(figsize=(14, 3))
shap.plots.force(shap_values[0], matplotlib=True, show=False)
plt.title("SHAP Force Plot - Single Prediction", fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('shap_4_force_plot.png', dpi=300, bbox_inches='tight')
plt.show()# ------------------------ 图5: Decision Plot ------------------------
print("绘制图5: Decision Plot (多个样本决策路径)...")
plt.figure(figsize=(10, 8))
shap.decision_plot(explainer.expected_value,shap_values.values[:50],X_test.iloc[:50],show=False
)
plt.title("SHAP Decision Plot - Multiple Predictions", fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('shap_5_decision_plot.png', dpi=300, bbox_inches='tight')
plt.show()# ------------------------ 图6-12: Dependence Plot (每个特征) ------------------------
print("绘制图6-12: Dependence Plots (单特征依赖图)...")
for i, feature in enumerate(X_test.columns):plt.figure(figsize=(10, 6))shap.plots.scatter(shap_values[:, feature], show=False)plt.title(f"SHAP Dependence Plot - {feature}", fontsize=14, pad=20)plt.tight_layout()plt.savefig(f'shap_6_{i}_dependence_{feature}.png', dpi=300, bbox_inches='tight')plt.show()# ------------------------ 图13: 自定义散点图 ------------------------
print("绘制图13: 自定义散点图")
fig, ax = plt.subplots(figsize=(8, 6))# 获取N_dist特征的索引
feature_idx = list(X_test.columns).index('N_dist')# 绘制散点图
scatter = ax.scatter(X_test['N_dist'].values,shap_values.values[:, feature_idx],c='blue',alpha=0.6,s=30,edgecolors='none'
)# 添加参考线
ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.7)# 设置标签和标题
ax.set_xlabel('N≡N (Å)', fontsize=12)
ax.set_ylabel('SHAP value for N≡N', fontsize=12)
ax.set_title('SHAP Dependence: N≡N Distance', fontsize=14, pad=20)
ax.grid(True, alpha=0.3)plt.tight_layout()
plt.savefig('shap_13_custom_N_dist.png', dpi=300, bbox_inches='tight')
plt.show()# ------------------------ 图14: 自定义散点图 ------------------------
print("绘制图14: 自定义散点图")
fig, ax = plt.subplots(figsize=(8, 6))# 获取N_d特征的索引
feature_idx = list(X_test.columns).index('N_d')# 绘制散点图
scatter = ax.scatter(X_test['N_d'].values,shap_values.values[:, feature_idx],c='red',alpha=0.6,s=30,edgecolors='none'
)# 添加参考线
ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.7)# 添加背景色渐变(可选)
ax.fill_between([X_test['N_d'].min(), X_test['N_d'].max()],-1, 1,alpha=0.1,color='red'
)# 设置标签和标题
ax.set_xlabel('$N_d$', fontsize=12)
ax.set_ylabel('SHAP value for $N_d$', fontsize=12)
ax.set_title('SHAP Dependence: Discrete Feature $N_d$', fontsize=14, pad=20)
ax.grid(True, alpha=0.3)plt.tight_layout()
plt.savefig('shap_14_custom_N_d.png', dpi=300, bbox_inches='tight')
plt.show()# ------------------------ 图15: Interaction Scatter Plot ------------------------
print("绘制图15: 交互特征散点图...")
# 选择两个可能有交互的特征
fig, ax = plt.subplots(figsize=(10, 6))
shap.plots.scatter(shap_values[:, "feature_3"],color=shap_values[:, "feature_4"],show=False
)
plt.title("SHAP Interaction: feature_3 colored by feature_4", fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('shap_15_interaction_scatter.png', dpi=300, bbox_inches='tight')
plt.show()# ------------------------ 图16: Beeswarm Plot ------------------------
print("绘制图16: Beeswarm Plot...")
plt.figure(figsize=(10, 6))
shap.plots.beeswarm(shap_values, show=False)
plt.title("SHAP Beeswarm Plot", fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('shap_16_beeswarm.png', dpi=300, bbox_inches='tight')
plt.show()# ------------------------ 图17: Heatmap Plot ------------------------
print("绘制图17: Heatmap Plot...")
plt.figure(figsize=(12, 8))
shap.plots.heatmap(shap_values[:100], show=False)
plt.title("SHAP Heatmap - Feature Contributions", fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('shap_17_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()# ==================== 5. 交互特征分析 ====================
print("\n计算SHAP交互值(这可能需要一些时间)...")
# 注意:交互值计算较慢,建议使用较小的样本集
shap_interaction_values = explainer.shap_interaction_values(X_test[:100])# ------------------------ 图18: Interaction Summary ------------------------
print("绘制图18: 交互特征重要性汇总...")
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_interaction_values[:, :, :],X_test.iloc[:100],show=False
)
plt.title("SHAP Interaction Summary", fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('shap_18_interaction_summary.png', dpi=300, bbox_inches='tight')
plt.show()# ------------------------ 图19: Specific Interaction Dependence ------------------------
print("绘制图19: 特定交互依赖图...")
# 查看feature_3和feature_4的交互
fig, ax = plt.subplots(figsize=(10, 6))
shap.dependence_plot(("feature_3", "feature_4"),shap_interaction_values,X_test.iloc[:100],show=False
)
plt.title("SHAP Interaction: feature_3 × feature_4", fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('shap_19_specific_interaction.png', dpi=300, bbox_inches='tight')
plt.show()# ==================== 6. 打印统计信息 ====================
print("\n" + "="*60)
print("SHAP分析统计信息")
print("="*60)
print(f"样本数量: {len(X_test)}")
print(f"特征数量: {X_test.shape[1]}")
print(f"基准值 (Expected Value): {explainer.expected_value:.4f}")
print("\n特征SHAP值统计:")
print("-"*60)for i, feature in enumerate(X_test.columns):mean_abs_shap = np.mean(np.abs(shap_values.values[:, i]))print(f"{feature:15s} - 平均绝对SHAP值: {mean_abs_shap:.4f}")print("\n所有SHAP图已生成完毕!")
print("="*60)
http://www.dtcms.com/a/467363.html

相关文章:

  • 网站建设小wordpress2019谷歌字体
  • 赤峰是住房和城乡建设局网站在哪买网站空间
  • 泉州机票网站建设百度搜索 手机
  • 娄底营销型网站建设网站开发招聘 领英
  • 做配资网站网站建设国内外研究现况
  • 泰州公司网站建设网站怎么做才可以做评价
  • 给人做ppt的网站吗龙岗seo优化
  • 公司域名注册网站哪个好做网站骗老外的钱
  • [特殊字符] Mac 安装 JDK 8 最稳最全教程(Homebrew 方式)
  • 深圳精品网站制作网页小游戏插件不支持
  • 水库信息化网站建设徐州铜山区三盛开发公司
  • 二级网站建设思路深圳app定制开发外包公司
  • python脚本加密之pyarmor
  • 省级荣誉+1!泛联新安入选湖南省2025年先进计算典型应用案例
  • 济南网站怎么做wordpress插件下载失败
  • 【多线程】忙等待/自旋(Busy Waiting/Spinning)
  • Google 智能体设计模式:人机协同(HITL)
  • 国家小城镇建设政策网站wordpress shortcode插件
  • 云霄县建设局网站投诉文案类的网站
  • 免费发布信息的网站平台常州建设企业网站
  • 凌哥seoseo黑帽技术工具
  • 经常修改网站的关键词好不好上海人才网站
  • Python :求解蓝桥杯2023年第十四届省赛大学A组试题F
  • 中文wordpress网站模板下载失败wordpress 换主题 打开慢
  • 零基础自学英语入门教程
  • 中国建设企业银行网站首页媒体软文发布平台
  • 个人网站 logo 版权 备案 没用西安自助建站做网站
  • 网站建站 seo网站开发模合同
  • 设计类的属性
  • 网站备案关闭工业设计最好的公司