当前位置: 首页 > news >正文

模型解释与可解释AI实战

一、为什么需要模型解释?

模型解释技术帮助:

  1. 理解模型决策依据(特征重要性)
  2. 调试模型错误预测
  3. 满足监管合规要求(金融/医疗)
  4. 提升用户对AI的信任
    本章使用Captum实现CV/NLP模型的可视化解释

二、环境准备与工具安装

!pip install captum torchvision matplotlib
import torch
import numpy as np
from captum.attr import IntegratedGradients, LayerGradCam
import matplotlib.pyplot as plt

三、图像分类解释实战(CIFAR-10)

1. 加载预训练模型

from torchvision.models import resnet18

model = resnet18(pretrained=True)
model.eval()

2. 准备测试图像

from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 加载示例图像(类别:ship)
from PIL import Image
img = Image.open("test_ship.jpg").convert("RGB")
input_tensor = transform(img).unsqueeze(0)

3. 集成梯度解释

def visualize_attr(attr, title):
    attr = attr.squeeze().cpu().detach().numpy()
    plt.imshow(attr, cmap='hot')
    plt.colorbar()
    plt.title(title)
    plt.show()

# 计算特征重要性
integrated_grad = IntegratedGradients(model)
attr_ig = integrated_grad.attribute(input_tensor, target=8)  # ship类别ID为8
visualize_attr(attr_ig.mean(dim=1), "Integrated Gradients")

4. Grad-CAM可视化

# 选择目标卷积层
target_layer = model.layer4.conv2

# 计算Grad-CAM
layer_gradcam = LayerGradCam(model, target_layer)
attr_gc = layer_gradcam.attribute(input_tensor, target=8)

# 可视化叠加效果
heatmap = np.clip(attr_gc.squeeze().cpu().detach().numpy(), 0, None)
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())

orig_img = input_tensor.squeeze().permute(1,2,0).cpu().detach().numpy()
plt.imshow(orig_img * 0.5 + heatmap * 0.5)
plt.title("Grad-CAM Visualization")
plt.show()

四、文本分类解释实战(IMDB情感分析)

1. 加载情感分析模型

from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")

2. 构建解释器

from captum.attr import LayerIntegratedGradients

# 定义输入处理函数
def model_forward(input_ids, attention_mask=None):
    return model(input_ids, attention_mask).logits

# 初始化解释器
lig = LayerIntegratedGradients(
    model_forward,
    model.bert.embeddings
)

3. 计算词元重要性

text = "This movie is a complete disaster, full of terrible acting and pointless scenes."
inputs = tokenizer(text, return_tensors="pt")

# 计算基准值(空输入)
ref_input_ids = torch.tensor([tokenizer.cls_token_id] + [tokenizer.pad_token_id]*(inputs.input_ids.shape-2) + [tokenizer.sep_token_id], 
                            device='cpu').unsqueeze(0)

# 计算归因值
attributions, delta = lig.attribute(
    inputs=inputs.input_ids,
    baselines=ref_input_ids,
    additional_forward_args=(inputs.attention_mask,),
    return_convergence_delta=True,
    target=0  # 负面情感对应的类别
)

# 可视化结果
token_attributions = attributions.sum(dim=2).squeeze(0)
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids)

plt.figure(figsize=(12, 3))
plt.bar(range(len(tokens)), token_attributions.detach().numpy())
plt.xticks(range(len(tokens)), tokens, rotation=90)
plt.title("Token Importance Scores")
plt.show()

五、高级解释技巧

1. 对比解释(对比不同类别)

# 对比飞机(类别0)与鸟类(类别2)的解释差异
attr_plane = integrated_grad.attribute(input_tensor, target=0)
attr_bird = integrated_grad.attribute(input_tensor, target=2)

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(attr_plane.mean(1).squeeze().detach().cpu(), cmap='hot')
plt.title('Airplane Attribution')

plt.subplot(1,2,2)
plt.imshow(attr_bird.mean(1).squeeze().detach().cpu(), cmap='hot')
plt.title('Bird Attribution')
plt.show()

2. 层次相关性传播(LRP)

from captum.attr import LRP

lrp = LRP(model)
attr_lrp = lrp.attribute(input_tensor, target=8)

plt.imshow(attr_lrp.mean(1).squeeze().detach().cpu(), cmap='hot')
plt.title('Layer-wise Relevance Propagation')
plt.show()

六、常见问题解答

Q1:如何安装最新版Captum?

pip install git+https://github.com/pytorch/captum.git

Q2:归因结果全为0怎么办?

  • 检查输入是否经过正确的归一化
  • 尝试不同的基线值(Baseline)
  • 验证模型是否真的使用该特征
import matplotlib
matplotlib.use('Agg')  # 无GUI模式

plt.ioff()
fig = plt.figure()
# ...生成图像...
fig.savefig('explanation.png', bbox_inches='tight')
plt.close(fig)

相关文章:

  • 涨薪技术|k8s设计原理
  • Python高级——实现简单名片管理系统
  • 【sql靶场】过滤绕过第26-27a关保姆级教程
  • AVL(平衡二叉树)
  • 【前端】 el-form-item的label由于字数多自行换行调整
  • 常考计算机操作系统面试习题(二)(下)
  • Spring Boot深度解析:从核心原理到最佳实践
  • C语言字符函数,字符串函数以及内存函数
  • 腾讯云大模型知识引擎x deepseek:打造智能服装搭配新体验
  • Kubernetes 故障排查指南
  • Linux启动之__vet_atags
  • 23种设计模式-外观(Facade)设计模式
  • unix网络编程
  • annoy编译安装问题及解决
  • 嵌入式八股文学习笔记——C++学习笔记面向对象相关
  • Python第九章节——异常,模块与包
  • leetcode128.最长连续序列
  • Objects.equals() 和 Object.equals() 的区别:
  • 信号处理中的窗
  • 《Python实战进阶》第30集:Scikit-learn 入门:分类与回归模型
  • 南京大屠杀幸存者刘贵祥去世,享年95岁
  • 一代名伶程砚秋经典影像:一箱旧影,芳华满堂
  • 五一假期多地政府食堂对外开放:部分机关食堂饭菜“秒没”
  • 美“群聊泄密门”始作俑者沃尔兹将离职
  • 国铁集团:5月1日全国铁路预计发送旅客2250万人次
  • 49:49白热化,美参议院对新关税政策产生巨大分歧