Python打卡DAY31
今日的示例代码包含2个部分
- notebook文件夹内的ipynb文件,介绍下今天的思路
- 项目文件夹中其他部分:拆分后的信贷项目,学习下如何拆分的,未来你看到的很多大项目都是类似的拆分方法
知识点回顾
- 规范的文件命名
- 规范的文件夹管理
- 机器学习项目的拆分
- 编码格式和类型注解
作业:尝试针对之前的心脏病项目ipynb,将他按照今天的示例项目整理成规范的形式,思考下哪些部分可以未来复用。
src/data/data_loader.py
import pandas as pd
from sklearn.model_selection import train_test_splitdef load_and_split_data(file_path, target_column, test_size=0.2, random_state=42):"""加载数据并划分训练集和测试集"""data = pd.read_csv(file_path)X = data.drop(target_column, axis=1)y = data[target_column]X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)return X_train, X_test, y_train, y_test
src/models/random_forest.py
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
import timedef train_random_forest(X_train, y_train, X_test, y_test, random_state=42):"""训练随机森林模型并评估性能"""start_time = time.time()model = RandomForestClassifier(random_state=random_state)model.fit(X_train, y_train)y_pred = model.predict(X_test)end_time = time.time()print(f"训练与预测耗时: {end_time - start_time:.4f} 秒")print("\n默认随机森林 在测试集上的分类报告:")print(classification_report(y_test, y_pred))print("默认随机森林 在测试集上的混淆矩阵:")print(confusion_matrix(y_test, y_pred))return model
src/utils/visualization.py
import shap
import matplotlib.pyplot as pltdef plot_shap_values(model, X_test):"""绘制SHAP值的条形图、蜂巢图和依赖图"""explainer = shap.TreeExplainer(model)shap_values = explainer.shap_values(X_test)print("shap_values[0] shape:", shap_values[0].shape)print("X_test shape:", X_test.shape)# SHAP特征重要性条形图print("--- 1. SHAP 特征重要性条形图 ---")shap.summary_plot(shap_values[0], X_test, plot_type="bar", show=False)plt.title("SHAP Feature Importance (Bar Plot)")plt.show()# SHAP特征重要性蜂巢图print("--- 2. SHAP 特征重要性蜂巢图 ---")shap.summary_plot(shap_values[0], X_test, plot_type="violin", show=False, max_display=10)plt.title("SHAP Feature Importance (Violin Plot)")plt.show()# SHAP特征重要性依赖图print("--- 3. SHAP 特征重要性依赖图 ---")shap.dependence_plot('Years in current job', shap_values[0], X_test, show=False)plt.title("SHAP Feature Importance (dependence plot)")plt.show()
src/main.py
from src.data.data_loader import load_and_split_data
from src.models.random_forest import train_random_forest
from src.utils.visualization import plot_shap_valuesif __name__ == "__main__":# 数据加载与划分file_path = "data/raw/heart.csv"target_column = "target"X_train, X_test, y_train, y_test = load_and_split_data(file_path, target_column)# 模型训练与评估model = train_random_forest(X_train, y_train, X_test, y_test)# SHAP值可视化plot_shap_values(model, X_test)
@浙大疏锦行