Day32
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from pdpbox import pdp
import matplotlib.pyplot as plt
# 1. 加载数据并训练模型
iris = load_iris()
X = pd.DataFrame(iris.data, columns=iris.feature_names)
y = iris.target
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X, y)
# 2. 创建交互效应数据
interact_out = pdp.pdp_interact(
model=model,
dataset=X,
model_features=iris.feature_names,
features=['petal length (cm)', 'petal width (cm)']
)
# 3. 绘制四种类型的交互效应图
plt.figure(figsize=(15, 12))
# 3.1 等高线图
plt.subplot(2, 2, 1)
fig, axes = pdp.pdp_interact_plot(
pdp_interact_out=interact_out,
feature_names=['花瓣长度', '花瓣宽度'],
plot_type='contour',
x_quantile=True,
plot_pdp=True,
contour_kw={'cmap': 'Blues'},
pdp_line_kw={'color': 'red', 'linewidth': 2}
)
plt.title('等高线图', fontproperties='SimHei', fontsize=12)
# 3.2 网格图
plt.subplot(2, 2, 2)
fig, axes = pdp.pdp_interact_plot(
pdp_interact_out=interact_out,
feature_names=['花瓣长度', '花瓣宽度'],
plot_type='grid',
x_quantile=True,
plot_pdp=True,
cell_type='avg',
plot_params={'font_family': 'SimHei'}
)
plt.title('网格图', fontproperties='SimHei', fontsize=12)
# 3.3 3D图
plt.subplot(2, 2, 3, projection='3d')
fig, axes = pdp.pdp_interact_plot(
pdp_interact_out=interact_out,
feature_names=['花瓣长度', '花瓣宽度'],
plot_type='3d',
x_quantile=True
)
plt.title('3D图', fontproperties='SimHei', fontsize=12)
# 3.4 分类问题交互图 (类别1)
interact_class = pdp.pdp_interact(
model=model,
dataset=X,
model_features=iris.feature_names,
features=['petal length (cm)', 'petal width (cm)'],
which_class=1
)
plt.subplot(2, 2, 4)
fig, axes = pdp.pdp_interact_plot(
pdp_interact_out=interact_class,
feature_names=['花瓣长度', '花瓣宽度'],
plot_type='contour',
x_quantile=True,
plot_pdp=True,
center=True,
contour_kw={'cmap': 'RdYlBu'}
)
plt.title('类别1的交互效应', fontproperties='SimHei', fontsize=12)
plt.tight_layout()
plt.show()