当前位置: 首页 > news >正文

Python打卡训练营day31-文件拆分

知识点回顾
  1. 规范的文件命名
  2. 规范的文件夹管理
  3. 机器学习项目的拆分
  4. 编码格式和类型注解

作业:尝试针对之前的心脏病项目ipynb,将他按照今天的示例项目整理成规范的形式,思考下哪些部分可以未来复用。

在有多级目录时,相对导入仅在同一包内有效,尤其在下级文件导入上级文件夹中的文件

# src/config.pyCONFIG = {"data_path": PROJECT_ROOT / "data/raw/heart.csv","test_size": 0.2,"random_state": 42,"models": {"random_forest": {"n_estimators": 100,"max_depth": 5},"xgboost": {"learning_rate": 0.1,"max_depth": 3,"n_estimators": 200}}
}

# src/data/loader.py
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split
from src.config import CONFIGdef load_data() -> tuple:"""加载并拆分数据集"""df = pd.read_csv(CONFIG["data_path"])# 假设最后一列是目标变量X = df.iloc[:, :-1]y = df.iloc[:, -1]return train_test_split(X, y,test_size=CONFIG["test_size"],random_state=CONFIG["random_state"])

# src/models/base_model.py
from abc import ABC, abstractmethod
import pandas as pdclass BaseModel(ABC):"""所有模型的统一接口"""@abstractmethoddef train(self, X_train: pd.DataFrame, y_train: pd.Series):pass@abstractmethoddef predict(self, X_test: pd.DataFrame) -> pd.Series:pass@abstractmethoddef save(self, path: str):pass

# src/models/random_forest.py
from sklearn.ensemble import RandomForestClassifier
from .base_model import BaseModel
from src.config import CONFIGclass RandomForestModel(BaseModel):def __init__(self):self.model = RandomForestClassifier(n_estimators=CONFIG["models"]["random_forest"]["n_estimators"],max_depth=CONFIG["models"]["random_forest"]["max_depth"],random_state=CONFIG["random_state"])def train(self, X_train, y_train):self.model.fit(X_train, y_train)def predict(self, X_test):return self.model.predict(X_test)def save(self, path):joblib.dump(self.model, path)

# src/models/train.py
from .random_forest import RandomForestModel
from .xgboost_model import XGBoostModel
from src.data import loader
from src.evaluation import metrics
from src.utils import save_resultsdef train_all_models():X_train, X_test, y_train, y_test = loader.load_data()models = {"RandomForest": RandomForestModel(),"XGBoost": XGBoostModel()}results = {}for name, model in models.items():model.train(X_train, y_train)preds = model.predict(X_test)results[name] = metrics.calculate_all_metrics(y_test, preds)model.save(f"models/{name}_model.pkl")save_results(results)

# src/evaluation/metrics.py
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_scoredef calculate_all_metrics(y_true, y_pred) -> dict:return {"accuracy": accuracy_score(y_true, y_pred),"precision": precision_score(y_true, y_pred),"recall": recall_score(y_true, y_pred),"f1": f1_score(y_true, y_pred)}

# scripts/train_model.py
from src.models import trainif __name__ == "__main__":train.train_all_models()

相关文章:

  • 网络视频教学怎么做seo的实现方式
  • 微信小程序是免费的吗广州网站快速排名优化
  • 武汉外包seo公司包头整站优化
  • wordpress建政府网站网络营销的三大基础
  • 黑群晖wordpress建站aso优化哪家好
  • 南京响应式网站设计百度网页排名怎么提升
  • 【AS32X601驱动系列教程】PLIC_中断应用详解
  • ELK服务搭建-0-1搭建记录
  • [ACTF新生赛2020]easyre
  • 分词算法BPE详解和CLIP的应用
  • Springboot怎么解决循环依赖
  • 针对vue项目的webpack优化攻略
  • 字节跳动2025年校招笔试手撕真题教程(二)
  • 神经网络学习-Day35
  • C语言指针进阶:通过地址,直接修改变量的值
  • 基于Python的全卷积网络(FCN)实现路径损耗预测
  • 为什么hash函数能减少哈希冲突
  • 内存管理 : 03多级页表和快表
  • 简单血条于小怪攻击模板
  • 开源项目跨平台桌宠 BongoCat,为桌面增添乐趣!
  • Java文件操作:从“Hello World”到“Hello File”
  • 打卡第28天:装饰器
  • 数据结构第2章绪论 (竟成)
  • CVE-2017-5645源码分析与漏洞复现(反序列化)
  • P1104 生日
  • go1.24 通过汇编深入学习map引入swiss table后的源码