day 32
官方文档阅读
绘制pdpbox库中的InteractTargetPlot实例
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier# 加载鸢尾花数据集
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target # 添加目标列(0-2类:山鸢尾、杂色鸢尾、维吉尼亚鸢尾)# 特征与目标变量
features = iris.feature_names # 4个特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度
target = 'target' # 目标列名# 划分训练集与测试集
X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=42
)# 训练模型
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
实例化要传入原始数据,包含多个特征的列表,这些特征的名称,标签,分桶方式等
import pdpbox
from pdpbox.info_plots import InteractTargetPlot # 导入TargetPlot类
# 选择待分析的特征(如:petal length (cm))
feature = ['petal length (cm)','sepal width (cm)']
feature_name = feature # 特征显示名称# 初始化TargetPlot对象(移除plot_type参数)
interact_target_plot = InteractTargetPlot(df=df, # 原始数据(需包含特征和目标列)features=feature, # 目标特征列feature_names=feature_name, # 特征名称(用于绘图标签)# target='target', # 多分类目标索引(鸢尾花3个类别)target='target', # 多分类目标索引(鸢尾花3个类别)grid_types='percentile', # 分桶方式:百分位num_grid_points=10 # 划分为10个桶
)interact_target_plot.plot()[0]
@浙大疏锦行