当前位置: 首页 > news >正文

【Python训练营打卡】day31 @浙大疏锦行

DAY 31 文件的规范拆分和写法

知识点回顾

1.  规范的文件命名

2.  规范的文件夹管理

3.  机器学习项目的拆分

4.  编码格式和类型注解

作业:尝试针对之前的心脏病项目,准备拆分的项目文件,思考下哪些部分可以未来复用。

 

机器学习项目流程

一个典型的机器学习项目通常包含以下阶段:

  • 数据加载:从文件、数据库、API 等获取原始数据。
    • 命名参考:load_data.py 、data_loader.py
  • 数据探索与可视化:了解数据特性,初期可用 Jupyter Notebook,成熟后固化绘图函数。
    • 命名参考:eda.py 、visualization_utils.py
  • 数据预处理:处理缺失值、异常值,进行标准化、归一化、编码等操作。
    • 命名参考:preprocess.py 、data_cleaning.py 、data_transformation.py
  • 特征工程:创建新特征,选择、优化现有特征。
    • 命名参考:feature_engineering.py
  • 模型训练:构建模型架构,设置超参数并训练,保存模型。
    • 命名参考:model.py 、train.py
  • 模型评估:用合适指标评估模型在测试集上的性能,生成报告。
    • 命名参考:evaluate.py
  • 模型预测:用训练好的模型对新数据预测。
    • 命名参考:predict.py 、inference.py

文件的组织

1. 项目核心代码组织

  • src/(source 的缩写):存放项目的核心源代码。按照机器学习项目阶段进一步细分:
    • src/data/:放置与数据相关的代码。
      • src/data/load_data.py:负责从各类数据源(如文件系统、数据库、API 等)读取原始数据。
      • src/data/preprocess.py:进行数据清洗(处理缺失值、异常值)、数据转换(标准化、归一化、编码等)操作。
      • src/data/feature_engineering.py:根据业务和数据特点,创建新特征或对现有特征进行选择、优化。
    • src/models/:关于模型的代码。
      • src/models/model.py:定义模型架构,比如神经网络结构、机器学习算法模型设定等。
      • src/models/train.py:设置模型超参数,并执行训练过程,保存训练好的模型。
      • src/models/evaluate.py:使用合适的评估指标(如准确率、召回率、均方误差等),在测试集上评估模型性能,生成评估报告。
      • src/models/predict.py 或 src/models/inference.py:利用训练好的模型对新数据进行预测。
    • src/utils/:存放通用辅助函数代码,可进一步细分:
      • src/utils/io_utils.py:包含文件读写相关帮助函数,比如读取特定格式文件、保存数据到文件等。
      • src/utils/logging_utils.py:实现日志记录功能,方便记录项目运行过程中的信息,便于调试和监控。
      • src/utils/math_utils.py:特定的数值计算函数,像自定义的矩阵运算、统计计算等。
      • src/utils/plotting_utils.py:绘图工具函数,用于生成数据可视化图表(如绘制损失函数变化曲线、特征分布直方图等 )。

2. 配置文件管理

  • config/ 目录:集中存放项目的配置文件,方便管理和切换不同环境(开发、测试、生产)的配置。
    • config/config.py 或 config/settings.py:以 Python 代码形式定义配置参数。
    • config/config.yaml 或 config/config.json:采用 YAML 或 JSON 格式,清晰列出文件路径、模型超参数、随机种子、API 密钥等可配置参数。
    • .env 文件:通常放在项目根目录,用于存储敏感信息(如数据库密码、API 密钥等),在代码中通过环境变量的方式读取,一般会被 .gitignore 忽略,防止敏感信息泄露。

3. 实验与探索代码

  • notebooks/ 或 experiments/ 目录:用于初期的数据探索、快速实验、模型原型验证。
    • notebooks/initial_eda.ipynb:在项目初期,使用 Jupyter Notebook 进行数据探索与可视化,了解数据特性,分析数据分布、相关性等。
    • experiments/model_experimentation.py:编写脚本对不同模型架构、超参数组合进行快速实验,对比实验结果,寻找最优模型设置。
      这部分往往是最开始的探索阶段,后面跑通了后拆分成了完整的项目,留作纪念用。

4. 项目产出物管理

  • data/ 目录:存放项目相关数据。
    • data/raw/:放置从外部获取的未经处理的原始数据,保持数据原始状态。
    • data/processed/:存放经过预处理(清洗、转换、特征工程等操作)后的数据,供模型训练和评估使用。
    • data/interim/:(可选)保存中间处理结果,比如数据清洗过程中生成的临时文件、特征工程中间步骤产生的数据等。
  • models/ 目录:专门存放训练好的模型文件,根据模型保存格式不同,可能是 .pkl(Python pickle 格式,常用于保存 sklearn 模型 )、.h5(常用于保存 Keras 模型 )、.joblib 等。
  • reports/ 或 output/ 目录:存储项目运行产生的各类报告和输出文件。
    • reports/evaluation_report.txt:记录模型评估的详细结果,包括各项评估指标数值、模型性能分析等。
    • reports/visualizations/:存放数据可视化图片,如损失函数收敛图、预测结果对比图等。
    • output/logs/:保存项目运行日志文件,记录项目从开始到结束过程中的关键信息,如训练开始时间、训练过程中的损失值变化、预测时间等。

通用的拆分起步思路:

  1. 首先,按照机器学习的主要工作流程(数据处理、训练、评估等)将代码分离到不同的 .py 文件中。 这是最基本也是最有价值的一步。
  2. 然后,创建一个 utils.py 来存放通用的辅助函数。
  3. 考虑将所有配置参数集中到一个 config.py 文件中。
  4. 为你的数据和模型产出物创建专门的顶层目录,如 data/ 和 models/,将它们与你的源代码(通常放在 src/ 目录)分开。

当遵循这些通用的拆分思路和原则时,项目结构自然会变得清晰。

注意事项

if __name__ == __main__

常常会看到 if __name__ == "__main__" 这个写法,实际上,每个文件都是一个对象,对象就会有属性和方法。

如果直接运行这个文件,则 __name__ 等于 __main__,若这个文件被其他模块导入,则 __name__ 不等于 __main__

这个写法有如下好处:

  1. 明确程序起点:一个 Python 项目往往由多个模块组成。if __name__ == "__main__" 可清晰界定程序执行的起始位置。比如一个包含数据处理模块 data_processing.py、模型训练模块 model_training.py 的机器学习项目,在 model_training.py 中用 if __name__ == "__main__" 包裹训练相关的主逻辑代码,运行该文件时就知道需要从这里开始执行(其他文件都是附属文件),让项目结构和执行流程更清晰。(大多时候如此)

  2. 避免执行不必要的代码:Python 遵从 “模块导入即执行” 机制,当使用 import xxx 导入一个模块时,Python 会执行该模块中的所有顶层代码(即不在任何函数或类内部的代码)。如果顶层代码中定义了全局变量或执行了某些操作(如读取文件、初始化数据库连接),这些操作会在导入时立即生效,并可能影响整个程序的状态。为了避免执行不必要的代码,我们可以使用 if __name__ == "__main__" 来避免在导入时执行相关代码。这样,只有当模块被直接运行时(即执行 python xxx.py),才会执行顶层代码,而导入时则不会执行。这能确保在导入模块时不触发非必要操作,从而提高程序的性能和可维护性。

  3. 合理的资源管理if __name__ == "__main__" 常与定义 main 函数结合使用,函数内的变量在函数执行完毕后会被释放,能及时回收内存资源,避免内存泄漏,保证程序高效运行。

编码格式

规范的 py 文件,首行会有:# -*- coding: utf-8 -*-

主要目的是显式声明文件的编码格式,确保 Python 解释器能正确读取和解析文件中的非 ASCII 字符(如中文、日文、特殊符号等)。也就是说这个是写给解释器看的。

在 Python 2.x 时代,默认编码是 ASCII,不支持直接在代码中写入非 ASCII 字符(如中文注释、字符串中的中文),否则会报错(SyntaxError: Non-UTF-8 code starting with...)。虽然 Python 3.x 默认为 UTF-8 编码,理论上可以省略编码声明,但实际开发中,为了兼容旧代码、明确文件编码规则,或在团队协作中避免因编辑器 / 环境配置不同导致的乱码问题,许多开发者仍会保留这一行声明。

注意:

  1. 编码声明必须出现在文件的前两行(通常是首行),否则会被忽略。
  2. 如果编码格式没问题,可能是 vscode 的编码格式不是 utf-8,可以尝试修改编码格式。
  3. 常见的编码报错是因为字符串编码问题,可以尝试显式转化,即读取的时候转化为 utf-8 编码。

类型注解

Python 的类型注解是在 Python 3.5+ 引入的特性,用于为变量、函数参数、返回值和类属性等添加类型信息。虽然 Python 仍是动态类型语言,但类型注解可以提高代码可读性、可维护性,并支持静态类型检查工具(如 mypy)。

其次在安装 Python 插件时,附带安装了 2 个插件:

  1. python debugger:用于断点调试,我们已经介绍过。
  2. pylance:用于代码提示和类型检查,该插件会根据代码中的类型注解给出相应提示和检查。例如,若定义一个函数,其参数类型为 int,当传入字符串时,它会提示参数类型不正确。

变量类型注解语法为 变量名: 类型

# 变量的类型注解
name: str = "Alice"
age: int = 30
height: float = 1.75
is_student: bool = False

函数类型注解为函数参数和返回值指定类型,语法为 def 函数名(参数: 类型) -> 返回类型。

def add(a: int, b: int) -> int:return a + bdef greet(name: str) -> None:print(f"Hello, {name}")

类属性与方法的类型注解:为类的属性和方法添加类型信息。

# 定义一个矩形类
class Rectangle:width: float      # 矩形宽度(浮点数),类属性的类型注解(不初始化值)height: float     # 矩形高度(浮点数)def __init__(self, width: float, height: float):self.width = widthself.height = heightdef area(self) -> float:# 计算面积(宽度 × 高度)return self.width * self.height

上述的“width: float  # 矩形宽度(浮点数)”这个写法由于没有对变量赋值,所以是一种类型注解写法

 \src\data\preprocessing.py

import pandas as pd
import numpy as np
from typing import Tuple, Dict
from sklearn.preprocessing import MinMaxScaler, StandardScalerdef load_data(file_path: str) -> pd.DataFrame:"""加载数据文件Args:file_path: 数据文件路径Returns:加载的数据框"""return pd.read_csv(file_path)def encode_categorical_features(data: pd.DataFrame) -> Tuple[pd.DataFrame, Dict]:"""对分类特征进行编码Args:data: 原始数据框Returns:编码后的数据框和编码映射字典"""# 使用字典映射进行标签编码feature_mappings = {'cp': {0: 0, 1: 1, 2: 2, 3: 3},'slope': {0: 0, 1: 1, 2: 2}}data_encoded = data.copy()for feature, mapping in feature_mappings.items():data_encoded[feature] = data[feature].map(mapping)# 独热编码thal_mapping = {1: 0, 2: 1, 3: 2}data_encoded['thal'] = data['thal'].map(thal_mapping)data_encoded = pd.get_dummies(data_encoded, columns=['thal'], prefix='thal', dtype=int)mappings = {'feature_mappings': feature_mappings,'thal_mapping': thal_mapping}return data_encoded, mappingsdef handle_missing_values(data: pd.DataFrame) -> pd.DataFrame:"""处理缺失值Args:data: 包含缺失值的数据框Returns:处理后的数据框"""data_clean = data.copy()discrete_features = ['sex', 'cp', 'fbs', 'restecg', 'exang', 'slope', 'ca', 'thal', 'target']continuous_features = ['age', 'trestbps', 'chol', 'thalach', 'oldpeak']# 离散特征用众数补全for feature in discrete_features:if feature in data.columns and data[feature].isnull().any():mode_value = data[feature].mode()[0]data_clean[feature].fillna(mode_value, inplace=True)# 连续特征用中位数补全for feature in continuous_features:if feature in data.columns and data[feature].isnull().any():median_value = data[feature].median()data_clean[feature].fillna(median_value, inplace=True)return data_cleanif __name__ == "__main__":# 测试代码data = load_data("data/raw/heart.csv")data_encoded, mappings = encode_categorical_features(data)data_clean = handle_missing_values(data_encoded)data_scaled = scale_features(data_clean)print("数据预处理完成!")

\src\models\train.py

# -*- coding: utf-8 -*-import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import time
import joblib # 用于保存模型
from typing import Tuple # 用于类型注解from data.preprocessing import load_data, encode_categorical_features, handle_missing_valuesdef prepare_data() -> Tuple:"""准备训练数据Returns:训练集和测试集的特征和标签"""# 加载和预处理数据data = load_data("data/raw/heart.csv")data_encoded, _ = encode_categorical_features(data)data_clean = handle_missing_values(data_encoded)# 分离特征和标签X = data_clean.drop(['target'], axis=1)y = data_clean['target']# 划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)return X_train, X_test, y_train, y_testdef train_model(X_train, y_train, model_params=None) -> RandomForestClassifier:"""训练随机森林模型Args:X_train: 训练特征y_train: 训练标签model_params: 模型参数字典Returns:训练好的模型"""if model_params is None:model_params = {'random_state': 42}model = RandomForestClassifier(**model_params)model.fit(X_train, y_train)return modeldef evaluate_model(model, X_test, y_test) -> None:"""评估模型性能Args:model: 训练好的模型X_test: 测试特征y_test: 测试标签"""y_pred = model.predict(X_test)print("\n分类报告:")print(classification_report(y_test, y_pred))print("\n混淆矩阵:")print(confusion_matrix(y_test, y_pred))def save_model(model, model_path: str) -> None:"""保存模型Args:model: 训练好的模型model_path: 模型保存路径"""os.makedirs(os.path.dirname(model_path), exist_ok=True)joblib.dump(model, model_path)print(f"\n模型已保存至: {model_path}")if __name__ == "__main__":# 准备数据X_train, X_test, y_train, y_test = prepare_data()# 记录开始时间start_time = time.time()# 训练模型model = train_model(X_train, y_train)# 记录结束时间end_time = time.time()print(f"\n训练耗时: {end_time - start_time:.4f} 秒")# 评估模型evaluate_model(model, X_test, y_test)# 保存模型save_model(model, "models/random_forest_model.joblib") 

\src\visualization\plots.py

import matplotlib.pyplot as plt
import seaborn as sns
import shap
import numpy as np
from typing import Any
from sklearn.metrics import confusion_matrix  def plot_feature_importance_shap(model: Any, X_test, save_path: str = None) -> None:"""绘制SHAP特征重要性图Args:model: 训练好的模型X_test: 测试数据save_path: 图片保存路径"""# 初始化SHAP解释器explainer = shap.TreeExplainer(model)shap_values = explainer.shap_values(X_test)# 绘制特征重要性条形图plt.figure(figsize=(12, 8))shap.summary_plot(shap_values[:, :, 0], X_test, plot_type="bar", show=False)plt.title("SHAP特征重要性")if save_path:plt.savefig(save_path)print(f"特征重要性图已保存至: {save_path}")plt.show()def plot_confusion_matrix(y_true, y_pred, save_path: str = None) -> None:"""绘制混淆矩阵热力图Args:y_true: 真实标签y_pred: 预测标签save_path: 图片保存路径"""plt.figure(figsize=(8, 6))cm = confusion_matrix(y_true, y_pred)sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')plt.title('混淆矩阵')plt.ylabel('真实标签')plt.xlabel('预测标签')if save_path:plt.savefig(save_path)print(f"混淆矩阵图已保存至: {save_path}")plt.show()def set_plot_style():"""设置绘图样式"""plt.style.use('seaborn')plt.rcParams['font.sans-serif'] = ['SimHei']plt.rcParams['axes.unicode_minus'] = Falseif __name__ == "__main__":# 设置绘图样式set_plot_style()# 这里可以添加测试代码print("可视化模块加载成功!") 

@浙大疏锦行

相关文章:

  • 第六天的尝试
  • 游戏开发实战(二):Python复刻「崩坏星穹铁道」嗷呜嗷呜事务所---源码级解析该小游戏背后的算法与设计模式【纯原创】
  • TripGenie:畅游济南旅行规划助手:个人工作纪实(十八)
  • 单端IO和差分IO标准
  • 飞致云旗下开源项目GitHub Star总数突破150,000个
  • 告别格式不兼容!画质无损 RainCrack 免费无广告转码软件
  • 解决Linux服务器MXNet安装与`npx`模块问题
  • SymPy | 获取表达式自由变量方法与因式分解
  • 模板引擎:FreeMarker
  • ES6核心特性与语法
  • 04 接口自动化-框架封装思想建立之httprunner框架(上)
  • 【图像大模型】Stable Diffusion 3 Medium:多模态扩散模型的技术突破与实践指南
  • 第9天-Python数据爬取实战:从入门到进阶完整指南
  • 学习日记-day11-5.20
  • IEEEtran中文献中的作者大于3个时,用et al.省略
  • 第十六届C++B组easyQuestions
  • 大模型会话窗口为什么对最新和最久记忆表现较好
  • 如何保存解析后的商品信息?
  • Cribl 对数据源进行过滤-01
  • Unity自定义shader打包SpriteAtlas图集问题
  • 做纱窗修水管的一个网站/新闻热点事件2021(最新)
  • 小程序建站平台哪个好/互联网营销师培训机构哪家好
  • bs网站做映射/专业seo公司
  • 服务好的网站开发/网站需要怎么优化比较好
  • 长沙找人做网站/青岛网站权重提升
  • 深圳哪个做网站好优化/东莞网站优化