多脚本大批量训练
多脚本大批量训练
- 0、提问
- 1、示例1
- 📂 目录结构
- 📄 示例配置文件 `config/config1.json`
- 📄 示例训练脚本 `train/train1.py`
- 📄 示例工具脚本 `util/train1_util.py`
- 📄 主控脚本 `main.py`
- ✅ 特点
- 修改后的 `main.py`(顺序执行 + 配置文件名命名 log + logger)
- ✅ 特点
- 2、示例2
- 1\. 设计思路 (v3: 配置驱动)
- 2\. 项目结构
- 3\. 代码实现 (v3)
- a. `config/experiment_A.json` (新的配置结构)
- b. `config/experiment_B.json`
- c. `train/train1.py` 和 `train/train2.py`
- d. `main.py` (全新重构)
- 4\. 如何运行
- 5\. 预期输出
- 6\. 最终版框架优势
- 3、完整示例(ok✅)
- 📂 目录结构
- 📄 main.py
- 📄 train/train1.py
- 📄 train/train2.py
- 📄 util/train1\_util.py
- 📄 util/train2\_util.py
- 📄 config/config1.json
- 📄 config/config2.json
- ✅ 特点
- 修改后的 `main.py`(支持从指定配置开始训练)
- ✅ 新增功能
- 4、包内部动态导包(OK✅改成util)
- 📂 目录结构
- 📄 main.py
- 📄 train/train1.py
- 📄 train/train2.py
- 📄 util/train1\_util.py
- 📄 util/train2\_util.py
- 📄 config/config1.json
- 📄 config/config2.json
- ✅ 效果
- 4.2 配置流程
- 5、包内部动态导包(❌,不是tmp,是util)
- 📂 目录结构
- 🧠 main.py
- 🏋️ train/train1.py
- 🏋️ train/train2.py
- 🛠 util/train1\_util.py
- 🛠 util/train2\_util.py
- 🧩 tmp/tmp1.py
- 🧩 tmp/tmp2.py
- ⚙️ config/config1.json
- ⚙️ config/config2.json
- ✅ 使用说明小结
- 6、问题:(函数初始化参数不能被覆盖)
- 原因
- 测试代码
- 解决方法 ✅
- 解决方法 ✅(推荐)
- 7、属性操作:(`hasattr`、`getattr`、`setattr`、`delattr` )
- 1. `hasattr(obj, name)`
- 2. `getattr(obj, name, default=None)`
- 3. `setattr(obj, name, value)`
- 4. `delattr(obj, name)`
- 🔑 对比总结
- 8、logging
- 示例:每个目录一个 Logger
- 🔑 拓展技巧
0、提问
设计一个脚本实现:多脚本大批量训练,每个训练脚本包含train、model、config文件,每个train和model文件的超参数(常量)都由config配置
1 .
设计一个脚本实现:多脚本大批量训练
例如:
-- train-- train1.py-- train2.py...
-- util-- train1_util.py-- train2_util.py
-- config-- config1.json-- config2.json
main()#控制脚本,协调多个训练,记录训练过程和记录异常-- 控制导包,及配置包中的超参数(覆盖原有超参数),train和util中设置一个方法覆盖-- 其中所有train包用as相同名字-- 其中所有util包用as相同名字其中train文件包含一下超参:
SEED = 42
NUM_DAYS = 2000 # 生成更多数据以适应更长的序列
CONTEXT_LEN = 128 # 输入序列长度 (lookback window),对于金字塔模型,建议为2的N次方
HORIZON = 5 # 预测未来N天
BATCH_SIZE = 64 # 批处理大小
EPOCHS = 1000 # 训练轮次 (为快速演示设为较小值)eval文件包含一下超参:
D_MODEL = 64 # 模型主维度 (嵌入维度)
DATE_INPUT_DIM = 3 # 日期特征维度 (年,月,日)
2 .✅
设计一个脚本实现:多脚本大批量训练
例如:
-- train-- train1.py-- train2.py...
-- util-- train1_util.py-- train2_util.py
-- config-- config1.json-- config2.json
main()#控制脚本,协调多个训练,记录训练过程和记录异常-- 控制导包,及配置包中的超参数(覆盖原有超参数),train和util中设置一个方法覆盖-- 其中所有train包用as相同名字,导哪个包由配置来控制(import trian_x as train#trian_x在配置文件中配置 )-- 其中所有util包用as相同名字,同上其中train文件包含一下超参:
SEED = 42
NUM_DAYS = 2000 # 生成更多数据以适应更长的序列
CONTEXT_LEN = 128 # 输入序列长度 (lookback window),对于金字塔模型,建议为2的N次方
HORIZON = 5 # 预测未来N天
BATCH_SIZE = 64 # 批处理大小
EPOCHS = 1000 # 训练轮次 (为快速演示设为较小值)eval文件包含一下超参:
D_MODEL = 64 # 模型主维度 (嵌入维度)
DATE_INPUT_DIM = 3 # 日期特征维度 (年,月,日)
3 .(❌,不是tmp,是util)
设计一个脚本实现:多脚本大批量训练
例如:
-- train-- train1.py-- train2.py...
-- util-- train1_util.py-- train2_util.py
-- config-- config1.json-- config2.json
-- tmp-- tmp1.py-- tmp2.py...
main()#控制脚本,协调多个训练,记录训练过程和记录异常-- 控制导包,及配置包中的超参数(覆盖原有超参数),train和util中设置一个方法覆盖-- 其中所有train包用as相同名字,导哪个包由配置来控制(import trian_x as train#trian_x在配置文件中配置 )-- 其中所有util包用as相同名字,同上-- 同控制train内部动态导入tmp,(import tmp_x as tmp #tmp_x可在train内部动态更换 )其中train文件包含一下超参:
SEED = 42
NUM_DAYS = 2000 # 生成更多数据以适应更长的序列
CONTEXT_LEN = 128 # 输入序列长度 (lookback window),对于金字塔模型,建议为2的N次方
HORIZON = 5 # 预测未来N天
BATCH_SIZE = 64 # 批处理大小
EPOCHS = 1000 # 训练轮次 (为快速演示设为较小值)eval文件包含一下超参:
D_MODEL = 64 # 模型主维度 (嵌入维度)
DATE_INPUT_DIM = 3 # 日期特征维度 (年,月,日)
4 .(是util)
设计一个脚本实现:多脚本大批量训练
例如:
-- train-- train1.py-- train2.py...
-- util-- train1_util.py-- train2_util.py...
-- config-- config1.json-- config2.jsonmain()#控制脚本,协调多个训练,记录训练过程和记录异常-- 控制导包,及配置包中的超参数(覆盖原有超参数),train和util中设置一个方法覆盖-- 其中所有train包用as相同名字,导哪个包由配置来控制(import trian_x as train#trian_x在配置文件中配置 )-- 其中所有util包用as相同名字,同上-- 同控制train内部动态导入util,(import tutil_x as util #tutil_x可在train内部动态更换 )其中train文件包含一下超参:
SEED = 42
NUM_DAYS = 2000 # 生成更多数据以适应更长的序列
CONTEXT_LEN = 128 # 输入序列长度 (lookback window),对于金字塔模型,建议为2的N次方
HORIZON = 5 # 预测未来N天
BATCH_SIZE = 64 # 批处理大小
EPOCHS = 1000 # 训练轮次 (为快速演示设为较小值)eval文件包含一下超参:
D_MODEL = 64 # 模型主维度 (嵌入维度)
DATE_INPUT_DIM = 3 # 日期特征维度 (年,月,日)
1、示例1
👌我理解了,你想要的是一个可扩展的多脚本批量训练框架,核心要求是:
-
train/*.py
和util/*.py
文件里定义默认超参; -
config/*.json
配置文件中除了覆盖超参外,还指定使用哪个 train_x 和 util_x; -
在
main.py
里:- 动态导入指定的 train/util 模块,统一用
as train
和as util
; - 调用模块内部的
override_config(config)
方法覆盖默认超参; - 执行训练,并捕获异常、记录日志。
- 动态导入指定的 train/util 模块,统一用
📂 目录结构
project/
│── main.py
│
├── train/
│ ├── train1.py
│ ├── train2.py
│ └── ...
│
├── util/
│ ├── train1_util.py
│ ├── train2_util.py
│ └── ...
│
└── config/├── config1.json├── config2.json└── ...
📄 示例配置文件 config/config1.json
{"train_module": "train.train1","util_module": "util.train1_util","SEED": 123,"NUM_DAYS": 2500,"BATCH_SIZE": 32,"EPOCHS": 300,"D_MODEL": 128,"DATE_INPUT_DIM": 3
}
📄 示例训练脚本 train/train1.py
import random
import numpy as np
import torch
from util import train1_util as util# ==== 默认超参 ====
SEED = 42
NUM_DAYS = 2000
CONTEXT_LEN = 128
HORIZON = 5
BATCH_SIZE = 64
EPOCHS = 1000
D_MODEL = 64
DATE_INPUT_DIM = 3def override_config(cfg: dict):"""覆盖默认超参"""globals_dict = globals()for k, v in cfg.items():if k in globals_dict:globals_dict[k] = vdef set_seed(seed: int):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)def run(log_path: str = None):"""执行训练"""set_seed(SEED)for epoch in range(EPOCHS):loss = util.train_one_epoch(epoch, BATCH_SIZE, D_MODEL)if epoch % 100 == 0:msg = f"[train1] Epoch {epoch}, Loss: {loss:.4f}"print(msg)if log_path:with open(log_path, "a") as f:f.write(msg + "\n")print("[train1] Training complete.")
📄 示例工具脚本 util/train1_util.py
import random# 默认 util 内部参数
D_MODEL = 64def override_config(cfg: dict):"""覆盖 util 内部超参"""globals_dict = globals()for k, v in cfg.items():if k in globals_dict:globals_dict[k] = vdef train_one_epoch(epoch, batch_size, d_model):"""模拟训练过程,返回一个loss值"""base = 1.0 / (epoch + 1)noise = random.random() * 0.01return base * (d_model / 64) + noise
📄 主控脚本 main.py
import importlib
import json
import os
import traceback
from datetime import datetime# === 加载配置 ===
def load_config(path: str) -> dict:with open(path, "r") as f:return json.load(f)def run_task(config_path: str, log_dir="logs"):config = load_config(config_path)train_module_name = config["train_module"]util_module_name = config["util_module"]try:print(f"\n[MAIN] Running {train_module_name} with {config_path}")# === 动态导入模块 ===train = importlib.import_module(train_module_name) # as trainutil = importlib.import_module(util_module_name) # as util# === 覆盖超参 ===if hasattr(train, "override_config"):train.override_config(config)if hasattr(util, "override_config"):util.override_config(config)# === 日志文件 ===os.makedirs(log_dir, exist_ok=True)log_file = os.path.join(log_dir,f"{train_module_name.replace('.', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")# === 执行训练 ===train.run(log_file)except Exception as e:err_msg = f"[ERROR] {train_module_name} failed: {e}\n{traceback.format_exc()}"print(err_msg)with open(os.path.join(log_dir, "errors.log"), "a") as f:f.write(err_msg + "\n")def main():# 可以在这里批量列出需要运行的配置文件config_files = ["config/config1.json","config/config2.json"]for cfg in config_files:run_task(cfg)if __name__ == "__main__":main()
✅ 特点
config.json
里显式指定要跑的 train_x 和 util_x。main.py
动态导入模块,并且都用相同别名train
/util
。- 覆盖超参是通过
override_config(config)
方法完成的,既能覆盖 train,也能覆盖 util。 - 日志独立保存,异常集中记录。
要不要我再给你扩展一个版本:支持 并行跑多个配置文件(多进程/多GPU)?
明白了,你希望 日志文件以配置文件名命名,并且 每处理完一个配置文件就更新 logger 来记录任务进度。我帮你改造 main.py
:
修改后的 main.py
(顺序执行 + 配置文件名命名 log + logger)
import importlib
import json
import os
import traceback
import logging
from datetime import datetime# === 设置 logger ===
def setup_logger(log_file_path):"""创建一个logger,用于单个配置文件的训练日志"""logger = logging.getLogger(log_file_path)logger.setLevel(logging.INFO)# 防止重复添加Handlerif logger.hasHandlers():logger.handlers.clear()# 文件处理器fh = logging.FileHandler(log_file_path, mode="a", encoding="utf-8")fh.setLevel(logging.INFO)formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')fh.setFormatter(formatter)logger.addHandler(fh)# 控制台处理器ch = logging.StreamHandler()ch.setLevel(logging.INFO)ch.setFormatter(formatter)logger.addHandler(ch)return logger# === 加载配置文件 ===
def load_config(path: str) -> dict:with open(path, "r") as f:return json.load(f)def run_task(config_path: str, task_index: int, total_tasks: int, log_dir="logs"):# 获取配置文件名作为 log 文件名cfg_name = os.path.splitext(os.path.basename(config_path))[0]os.makedirs(log_dir, exist_ok=True)log_file_path = os.path.join(log_dir, f"{cfg_name}.log")logger = setup_logger(log_file_path)config = load_config(config_path)train_module_name = config["train_module"]util_module_name = config["util_module"]try:logger.info(f"({task_index}/{total_tasks}) Start {train_module_name} using {config_path}")# 动态导入 train 和 util 模块train = importlib.import_module(train_module_name)util = importlib.import_module(util_module_name)# 覆盖超参if hasattr(train, "override_config"):train.override_config(config)if hasattr(util, "override_config"):util.override_config(config)# 执行训练train.run(log_file_path)logger.info(f"({task_index}/{total_tasks}) Finished {train_module_name}")except Exception as e:logger.error(f"({task_index}/{total_tasks}) {train_module_name} failed: {e}\n{traceback.format_exc()}")def main():# 配置文件按顺序排列config_files = ["config/config1.json","config/config2.json"]total_tasks = len(config_files)for idx, cfg in enumerate(config_files, start=1):run_task(cfg, idx, total_tasks)if __name__ == "__main__":main()
✅ 特点
-
日志文件以配置文件名命名
config1.json
→logs/config1.log
config2.json
→logs/config2.log
-
顺序执行
- 配置文件列表按顺序排列,任务按顺序执行。
-
Logger 记录处理进度
- 每处理完一个配置文件就更新日志(
Start
→Finished
)。 - 同时控制台打印进度。
- 每处理完一个配置文件就更新日志(
-
异常捕获
- 如果训练脚本报错,异常会写入同一 log 文件。
如果你希望,我可以再加一个 整体进度 logger(单独一个 logs/overall_progress.log
记录所有任务的完成进度),同时每个训练任务也写自己的 log 文件。
你希望我加这个吗?
2、示例2
好的,我们再次对脚本进行重构,以满足您最新、更具体的要求。这次的核心变化是:将“要运行哪个脚本”的控制权完全交给配置文件。
这是一个更高级、更灵活的设计。之前的版本依赖于文件名约定(例如 ID ‘1’ 对应 train1.py
),而新版本则通过配置文件来明确声明要使用哪个训练模块和工具模块。这使得组合不同的模块进行实验变得非常容易。
1. 设计思路 (v3: 配置驱动)
- 配置即实验 (Configuration as Experiment):每个
.json
配置文件定义了一个完整的、独立的实验。这包括实验名称、要使用的训练脚本、要使用的工具脚本,以及需要覆盖的超参数。 - 主控脚本的角色:
main.py
的角色转变为一个“配置执行器”。它不再关心实验ID,而是简单地扫描config
目录,并为找到的每一个配置文件执行一个训练任务。 - 完全解耦:训练脚本 (
trainX.py
) 和工具脚本 (utilX.py
) 之间不再有命名的硬性关联。你可以自由组合,例如让一个实验使用train1.py
和util2.py
。
2. 项目结构
目录结构保持不变,依然清晰。
/your_project_root
├── main.py # 核心控制脚本 (已更新)
|
├── train/
│ ├── __init__.py
│ ├── train1.py
│ └── train2.py
|
├── util/
│ ├── __init__.py
│ ├── util1.py
│ └── util2.py
|
└── config/├── experiment_A.json # 配置文件 (结构已更新)└── experiment_B.json # 另一个配置文件
注意:我们把配置文件名改得更有描述性,比如 experiment_A.json
,但这并非强制要求。
3. 代码实现 (v3)
a. config/experiment_A.json
(新的配置结构)
这是本次设计的核心变化。配置文件现在包含了所有实验定义信息。
{"experiment_name": "Standard_Model_with_High_Epochs","description": "Run train script 1 with util script 1, overriding epochs and batch size.","module_settings": {"train_module": "train.train1","util_module": "util.util1"},"hyperparameters": {"EPOCHS": 50,"BATCH_SIZE": 128,"D_MODEL": 128}
}
b. config/experiment_B.json
这个配置演示了如何混合搭配模块,并为 train2.py
指定了不同的超参数。
{"experiment_name": "Advanced_Model_Deliberate_Fail","description": "Run train script 2 with util script 1. This experiment is designed to fail.","module_settings": {"train_module": "train.train2","util_module": "util.util1"},"hyperparameters": {"SEED": 2025,"LEARNING_RATE": 0.007,"HORIZON": 15}
}
c. train/train1.py
和 train/train2.py
这两个脚本无需任何修改,可以直接复用我们上一版的设计(带有 update_params
方法)。这证明了上次的设计具有良好的封装性。为了完整性,这里再次展示 train1.py
的结构。
# train/train1.py (无需修改)
# ... (默认超参数定义) ...
SEED = 42
EPOCHS = 1000
# ... etc ...def update_params(config: dict):"""使用配置字典更新模块的全局超参数。"""print(f"[{__name__}] Updating hyperparameters...")g = globals()for key, value in config.items():if key in g and key.isupper():g[key] = valueprint(f" -> '{key}' updated to: {value}")def run_training(util_module):"""训练任务的入口函数。"""print(f"[{__name__}] >>> Starting training task...")# ... (打印超参数和模拟训练的代码) ...print(f"[{__name__}] <<< Training task finished successfully.")return {"status": "success"}
d. main.py
(全新重构)
主脚本的逻辑完全重写,以实现配置驱动的训练流程。
# main.py
import os
import sys
import json
import glob
import logging
import importlib
import traceback
from datetime import datetime# --- 配置区 ---
CONFIG_DIR = "config"
LOG_FILE = f"training_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"# --- 日志设置 ---
def setup_logging():"""配置日志记录器"""logging.basicConfig(level=logging.INFO,format='%(asctime)s - [%(levelname)s] - %(message)s',handlers=[logging.FileHandler(LOG_FILE, encoding='utf-8'),logging.StreamHandler(sys.stdout)])# --- 主逻辑 ---
def main():"""配置驱动的主协调函数"""setup_logging()logging.info("=============================================")logging.info("=== Configuration-Driven Batch Training ===")logging.info("=============================================")# 1. 扫描配置目录下的所有.json文件config_files = glob.glob(os.path.join(CONFIG_DIR, '*.json'))if not config_files:logging.warning(f"No configuration files found in '{CONFIG_DIR}' directory. Exiting.")returnlogging.info(f"Found {len(config_files)} experiments to run: {config_files}")summary = []for config_path in config_files:try:# 2. 加载和解析配置文件with open(config_path, 'r') as f:config = json.load(f)exp_name = config.get("experiment_name", os.path.basename(config_path))logging.info(f"--- Running Experiment: {exp_name} (from {config_path}) ---")# 从配置中获取模块名train_module_name = config["module_settings"]["train_module"]util_module_name = config["module_settings"]["util_module"]hyperparams_override = config.get("hyperparameters", {})# 3. 动态导入指定的模块 (实现 as 的效果)# 这里的变量 train 和 util 就是别名 (alias)train = importlib.import_module(train_module_name)util = importlib.import_module(util_module_name)logging.info(f"Imported train module: '{train_module_name}' as 'train'")logging.info(f"Imported util module: '{util_module_name}' as 'util'")# 4. 调用模块自身的方法覆盖超参数if hyperparams_override:if hasattr(train, 'update_params'):train.update_params(hyperparams_override)else:logging.warning(f"Module '{train_module_name}' has no 'update_params' method. Skipping override.")# 5. 执行训练result = train.run_training(util)logging.info(f"Experiment '{exp_name}' completed successfully. Result: {result}")summary.append({"experiment": exp_name, "status": "SUCCESS", "details": result})except KeyError as e:error_msg = f"Configuration file '{config_path}' is missing a required key: {e}"logging.error(error_msg)summary.append({"experiment": exp_name, "status": "CONFIG_ERROR", "details": error_msg})except Exception as e:# 捕获所有其他异常error_details = traceback.format_exc()logging.error(f"An unexpected error occurred in experiment '{exp_name}':\n{error_details}")summary.append({"experiment": exp_name, "status": "RUNTIME_ERROR", "details": str(e)})finally:logging.info(f"--- Finished Experiment: {exp_name} ---\n")# 打印最终总结logging.info("=============================================")logging.info("=== Batch Training Process Finished ===")logging.info("=============================================")logging.info("Final Summary:")for report in summary:logging.info(f" - Experiment: {report['experiment']}, Status: {report['status']}")if __name__ == "__main__":main()
4. 如何运行
- 创建好所有文件和目录。
- 在终端中导航到项目根目录。
- 执行脚本:
python main.py
脚本将自动查找 config
文件夹中的所有 .json
文件,并依次执行它们定义的实验。
5. 预期输出
输出将清晰地反映是哪个配置文件驱动了哪个实验的运行。
2025-09-09 11:30:00,100 - [INFO] - =============================================
2025-09-09 11:30:00,100 - [INFO] - === Configuration-Driven Batch Training ===
2025-09-09 11:30:00,100 - [INFO] - =============================================
2025-09-09 11:30:00,101 - [INFO] - Found 2 experiments to run: ['config/experiment_A.json', 'config/experiment_B.json']
2025-09-09 11:30:00,101 - [INFO] - --- Running Experiment: Standard_Model_with_High_Epochs (from config/experiment_A.json) ---
2025-09-09 11:30:00,115 - [INFO] - Imported train module: 'train.train1' as 'train'
2025-09-09 11:30:00,118 - [INFO] - Imported util module: 'util.util1' as 'util'
[train.train1] Updating hyperparameters...-> 'EPOCHS' updated to: 50-> 'BATCH_SIZE' updated to: 128-> 'D_MODEL' updated to: 128
[train.train1] >>> Starting training task...
... (训练过程)
[train.train1] <<< Training task finished successfully.
2025-09-09 11:30:02,500 - [INFO] - Experiment 'Standard_Model_with_High_Epochs' completed successfully. Result: {'status': 'success'}
2025-09-09 11:30:02,500 - [INFO] - --- Finished Experiment: Standard_Model_with_High_Epochs ---2025-09-09 11:30:02,501 - [INFO] - --- Running Experiment: Advanced_Model_Deliberate_Fail (from config/experiment_B.json) ---
2025-09-09 11:30:02,505 - [INFO] - Imported train module: 'train.train2' as 'train'
2025-09-09 11:30:02,507 - [INFO] - Imported util module: 'util.util1' as 'util'
[train.train2] Updating hyperparameters...-> 'SEED' updated to: 2025-> 'LEARNING_RATE' updated to: 0.007-> 'HORIZON' updated to: 15
[train.train2] >>> Starting training task...
...
!!! Injecting a deliberate error for demonstration !!!
2025-09-09 11:30:02,750 - [ERROR] - An unexpected error occurred in experiment 'Advanced_Model_Deliberate_Fail':
Traceback (most recent call last):File "main.py", line 68, in mainresult = train.run_training(util)File "your_project_root/train/train2.py", line 47, in run_trainingresult = 1 / 0
ZeroDivisionError: division by zero
2025-09-09 11:30:02,750 - [INFO] - --- Finished Experiment: Advanced_Model_Deliberate_Fail ---... (最终总结) ...
2025-09-09 11:30:02,751 - [INFO] - Final Summary:
2025-09-09 11:30:02,751 - [INFO] - - Experiment: Standard_Model_with_High_Epochs, Status: SUCCESS
2025-09-09 11:30:02,751 - [INFO] - - Experiment: Advanced_Model_Deliberate_Fail, Status: RUNTIME_ERROR
6. 最终版框架优势
这个版本是迄今为止最强大和最灵活的:
- 完全配置化:新增或修改实验只需添加或编辑
json
文件,无需触碰任何 Python 代码。 - 高度可组合:可以任意组合不同的训练脚本和工具脚本,只需在配置文件中指定即可。例如,你可以创建一个
experiment_C.json
来测试train1
和util2
的组合。 - 自解释性:配置文件本身就清晰地描述了每个实验的目的和设置,可读性极强。
- 易于扩展:当你有新的
train3.py
或util3.py
时,只需将它们放入相应目录,然后在新的配置文件中引用它们即可,主程序main.py
无需任何改动。
3、完整示例(ok✅)
好的,我给你整理一个完整的 多脚本批量训练框架示例,包括目录结构、训练脚本、工具脚本、配置文件和 main.py
,支持:
- 自动读取
config/
目录下的配置文件并按文件名排序 - 每处理完一个配置文件就更新日志
- 日志文件以配置文件名命名
- 动态导入
train
和util
模块 - 覆盖默认超参
📂 目录结构
project/
│── main.py
│
├── train/
│ ├── train1.py
│ └── train2.py
│
├── util/
│ ├── train1_util.py
│ └── train2_util.py
│
└── config/├── config1.json└── config2.json
📄 main.py
import importlib
import json
import os
import traceback
import loggingCONFIG_DIR = "config"
LOG_DIR = "logs"# === 设置 logger ===
def setup_logger(log_file_path):logger = logging.getLogger(log_file_path)logger.setLevel(logging.INFO)if logger.hasHandlers():logger.handlers.clear()# 文件 handlerfh = logging.FileHandler(log_file_path, mode="a", encoding="utf-8")fh.setLevel(logging.INFO)formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')fh.setFormatter(formatter)logger.addHandler(fh)# 控制台 handlerch = logging.StreamHandler()ch.setLevel(logging.INFO)ch.setFormatter(formatter)logger.addHandler(ch)return logger# === 加载配置文件 ===
def load_config(path: str) -> dict:with open(path, "r", encoding="utf-8") as f:return json.load(f)def run_task(config_path: str, task_index: int, total_tasks: int, log_dir=LOG_DIR):cfg_name = os.path.splitext(os.path.basename(config_path))[0]os.makedirs(log_dir, exist_ok=True)log_file_path = os.path.join(log_dir, f"{cfg_name}.log")logger = setup_logger(log_file_path)config = load_config(config_path)train_module_name = config["train_module"]util_module_name = config["util_module"]try:logger.info(f"({task_index}/{total_tasks}) Start {train_module_name} using {config_path}")# 动态导入 train 和 utiltrain = importlib.import_module(train_module_name)util = importlib.import_module(util_module_name)# 覆盖超参if hasattr(train, "override_config"):train.override_config(config)if hasattr(util, "override_config"):util.override_config(config)# 执行训练train.run(log_file_path)logger.info(f"({task_index}/{total_tasks}) Finished {train_module_name}")except Exception as e:logger.error(f"({task_index}/{total_tasks}) {train_module_name} failed: {e}\n{traceback.format_exc()}")def main():config_files = sorted([os.path.join(CONFIG_DIR, f)for f in os.listdir(CONFIG_DIR)if f.endswith(".json")])total_tasks = len(config_files)for idx, cfg in enumerate(config_files, start=1):run_task(cfg, idx, total_tasks)if __name__ == "__main__":main()
📄 train/train1.py
import random
import numpy as np
import torch
from util import train1_util as util# 默认超参
SEED = 42
NUM_DAYS = 2000
CONTEXT_LEN = 128
HORIZON = 5
BATCH_SIZE = 64
EPOCHS = 1000
D_MODEL = 64
DATE_INPUT_DIM = 3def override_config(cfg: dict):for k, v in cfg.items():if k in globals():globals()[k] = vdef set_seed(seed: int):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)def run(log_path: str = None):set_seed(SEED)for epoch in range(EPOCHS):loss = util.train_one_epoch(epoch, BATCH_SIZE, D_MODEL)if epoch % 100 == 0:msg = f"[train1] Epoch {epoch}, Loss: {loss:.4f}"print(msg)if log_path:with open(log_path, "a") as f:f.write(msg + "\n")print("[train1] Training complete.")
📄 train/train2.py
import random
import numpy as np
import torch
from util import train2_util as utilSEED = 42
NUM_DAYS = 2000
CONTEXT_LEN = 128
HORIZON = 5
BATCH_SIZE = 64
EPOCHS = 1000
D_MODEL = 64
DATE_INPUT_DIM = 3def override_config(cfg: dict):for k, v in cfg.items():if k in globals():globals()[k] = vdef set_seed(seed: int):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)def run(log_path: str = None):set_seed(SEED)for epoch in range(EPOCHS):loss = util.train_one_epoch(epoch, BATCH_SIZE, D_MODEL)if epoch % 100 == 0:msg = f"[train2] Epoch {epoch}, Loss: {loss:.4f}"print(msg)if log_path:with open(log_path, "a") as f:f.write(msg + "\n")print("[train2] Training complete.")
📄 util/train1_util.py
import randomD_MODEL = 64def override_config(cfg: dict):for k, v in cfg.items():if k in globals():globals()[k] = vdef train_one_epoch(epoch, batch_size, d_model):base = 1.0 / (epoch + 1)noise = random.random() * 0.01return base * (d_model / 64) + noise
📄 util/train2_util.py
import randomD_MODEL = 64def override_config(cfg: dict):for k, v in cfg.items():if k in globals():globals()[k] = vdef train_one_epoch(epoch, batch_size, d_model):base = 1.0 / (epoch + 1)noise = random.random() * 0.02return base * (d_model / 64) + noise
📄 config/config1.json
{"train_module": "train.train1","util_module": "util.train1_util","SEED": 123,"NUM_DAYS": 2500,"BATCH_SIZE": 32,"EPOCHS": 300,"D_MODEL": 128,"DATE_INPUT_DIM": 3
}
📄 config/config2.json
{"train_module": "train.train2","util_module": "util.train2_util","SEED": 456,"NUM_DAYS": 2000,"BATCH_SIZE": 64,"EPOCHS": 200,"D_MODEL": 64,"DATE_INPUT_DIM": 3
}
✅ 特点
- 自动扫描
config/
下所有.json
配置文件并按文件名排序执行。 - 每个配置文件生成独立日志:
logs/config1.log
、logs/config2.log
,实时记录进度。 - 动态导入训练脚本和工具脚本,并调用
override_config
覆盖默认超参。 - 异常也会写入对应日志文件,控制台同时输出。
这个框架可以直接跑批量训练任务,非常适合多配置文件、大批量训练场景。
如果你希望,我可以帮你加一个 整体进度日志 和 并行执行多任务版本,可以同时跑多个配置文件,节省训练时间。
你希望我加吗?
明白了,你希望在批量训练框架中增加 从指定配置文件索引开始训练 的功能,以便在之前任务失败时可以继续训练,而不必从头跑。下面我帮你改造 main.py
,并保持日志、按配置文件排序、以配置文件名命名 log 的特性。
修改后的 main.py
(支持从指定配置开始训练)
import importlib
import json
import os
import traceback
import loggingCONFIG_DIR = "config"
LOG_DIR = "logs"# === 设置 logger ===
def setup_logger(log_file_path):logger = logging.getLogger(log_file_path)logger.setLevel(logging.INFO)if logger.hasHandlers():logger.handlers.clear()# 文件 handlerfh = logging.FileHandler(log_file_path, mode="a", encoding="utf-8")fh.setLevel(logging.INFO)formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')fh.setFormatter(formatter)logger.addHandler(fh)# 控制台 handlerch = logging.StreamHandler()ch.setLevel(logging.INFO)ch.setFormatter(formatter)logger.addHandler(ch)return logger# === 加载配置文件 ===
def load_config(path: str) -> dict:with open(path, "r", encoding="utf-8") as f:return json.load(f)def run_task(config_path: str, task_index: int, total_tasks: int, log_dir=LOG_DIR):cfg_name = os.path.splitext(os.path.basename(config_path))[0]os.makedirs(log_dir, exist_ok=True)log_file_path = os.path.join(log_dir, f"{cfg_name}.log")logger = setup_logger(log_file_path)config = load_config(config_path)train_module_name = config["train_module"]util_module_name = config["util_module"]try:logger.info(f"({task_index}/{total_tasks}) Start {train_module_name} using {config_path}")# 动态导入 train 和 utiltrain = importlib.import_module(train_module_name)util = importlib.import_module(util_module_name)# 覆盖超参if hasattr(train, "override_config"):train.override_config(config)if hasattr(util, "override_config"):util.override_config(config)# 执行训练train.run(log_file_path)logger.info(f"({task_index}/{total_tasks}) Finished {train_module_name}")except Exception as e:logger.error(f"({task_index}/{total_tasks}) {train_module_name} failed: {e}\n{traceback.format_exc()}")def main(start_index: int = 0):"""start_index: 从第几个配置文件开始训练(0 表示第一个文件)"""# 获取 config 目录下所有 json 配置文件,按文件名排序config_files = sorted([os.path.join(CONFIG_DIR, f)for f in os.listdir(CONFIG_DIR)if f.endswith(".json")])total_tasks = len(config_files)# 只从 start_index 开始执行for idx, cfg in enumerate(config_files[start_index:], start=start_index + 1):run_task(cfg, idx, total_tasks)if __name__ == "__main__":# 可以在这里设置从哪个配置文件开始训练,例如 start_index=1 从第二个开始main(start_index=0)
✅ 新增功能
-
支持从指定配置文件开始训练
- 通过
main(start_index=...)
设置开始索引(从 0 开始计数)。 - 如果前面配置文件训练失败,可以从失败位置继续执行。
- 通过
-
其他特性保持不变
- 自动按文件名排序读取
config/
下配置文件 - 日志文件以配置文件名命名
- 每处理完一个配置文件就写日志并更新控制台
- 异常写入对应日志
- 自动按文件名排序读取
例如:
if __name__ == "__main__":main(start_index=1) # 从 config2.json 开始训练
我可以帮你再加一个 自动检测失败配置文件并继续训练 的功能,这样无需手动指定 start_index
,框架会自己跳过已完成的配置文件,只跑失败或未完成的。
你希望我加吗?
4、包内部动态导包(OK✅改成util)
👌 我来帮你把完整示例改好,加上控制台和日志对比打印,方便你确认每一步执行效果。
📂 目录结构
project/
│── main.py
│
├── train/
│ ├── train1.py
│ └── train2.py
│
├── util/
│ ├── train1_util.py
│ └── train2_util.py
│
└── config/├── config1.json└── config2.json
📄 main.py
import importlib
import json
import os
import traceback
import loggingCONFIG_DIR = "config"
LOG_DIR = "logs"# === 设置 logger ===
def setup_logger(log_file_path):logger = logging.getLogger(log_file_path)logger.setLevel(logging.INFO)if logger.hasHandlers():logger.handlers.clear()# 文件 handlerfh = logging.FileHandler(log_file_path, mode="a", encoding="utf-8")fh.setLevel(logging.INFO)formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')fh.setFormatter(formatter)logger.addHandler(fh)# 控制台 handlerch = logging.StreamHandler()ch.setLevel(logging.INFO)ch.setFormatter(formatter)logger.addHandler(ch)return logger# === 加载配置文件 ===
def load_config(path: str) -> dict:with open(path, "r", encoding="utf-8") as f:return json.load(f)def run_task(config_path: str, task_index: int, total_tasks: int, log_dir=LOG_DIR):cfg_name = os.path.splitext(os.path.basename(config_path))[0]os.makedirs(log_dir, exist_ok=True)log_file_path = os.path.join(log_dir, f"{cfg_name}.log")logger = setup_logger(log_file_path)config = load_config(config_path)train_module_name = config["train_module"]util_module_name = config["util_module"]try:logger.info(f"({task_index}/{total_tasks}) Start {train_module_name} using {config_path}")# 动态导入 train 和 utiltrain = importlib.import_module(train_module_name)util = importlib.import_module(util_module_name)# 覆盖超参if hasattr(train, "override_config"):train.override_config(config)if hasattr(util, "override_config"):util.override_config(config)# 让 train 知道 utilif hasattr(train, "set_util"):train.set_util(util)# 执行训练train.run(logger)logger.info(f"({task_index}/{total_tasks}) Finished {train_module_name}")except Exception as e:logger.error(f"({task_index}/{total_tasks}) {train_module_name} failed: {e}\n{traceback.format_exc()}")def main(start_index=1):config_files = sorted([os.path.join(CONFIG_DIR, f)for f in os.listdir(CONFIG_DIR)if f.endswith(".json")])total_tasks = len(config_files)for idx, cfg in enumerate(config_files, start=1):if idx < start_index:continuerun_task(cfg, idx, total_tasks)if __name__ == "__main__":# 如果需要从中途恢复,可以修改 start_indexmain(start_index=1)
📄 train/train1.py
import random
import numpy as np
import torch# 默认超参
SEED = 42
NUM_DAYS = 2000
CONTEXT_LEN = 128
HORIZON = 5
BATCH_SIZE = 64
EPOCHS = 1000
D_MODEL = 64
DATE_INPUT_DIM = 3_util = None # 动态 util 模块def set_util(util_module):global _util_util = util_moduledef override_config(cfg: dict):for k, v in cfg.items():if k in globals():globals()[k] = vdef set_seed(seed: int):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)def run(logger=None):set_seed(SEED)logger.info(f"[train1] Config: SEED={SEED}, NUM_DAYS={NUM_DAYS}, BATCH_SIZE={BATCH_SIZE}, "f"EPOCHS={EPOCHS}, D_MODEL={D_MODEL}")for epoch in range(EPOCHS):loss = _util.train_one_epoch(epoch, BATCH_SIZE, D_MODEL)if epoch % 100 == 0:msg = f"[train1] Epoch {epoch}, Loss={loss:.4f}"logger.info(msg)logger.info("[train1] Training complete.")
📄 train/train2.py
import random
import numpy as np
import torchSEED = 42
NUM_DAYS = 2000
CONTEXT_LEN = 128
HORIZON = 5
BATCH_SIZE = 64
EPOCHS = 1000
D_MODEL = 64
DATE_INPUT_DIM = 3_util = Nonedef set_util(util_module):global _util_util = util_moduledef override_config(cfg: dict):for k, v in cfg.items():if k in globals():globals()[k] = vdef set_seed(seed: int):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)def run(logger=None):set_seed(SEED)logger.info(f"[train2] Config: SEED={SEED}, NUM_DAYS={NUM_DAYS}, BATCH_SIZE={BATCH_SIZE}, "f"EPOCHS={EPOCHS}, D_MODEL={D_MODEL}")for epoch in range(EPOCHS):loss = _util.train_one_epoch(epoch, BATCH_SIZE, D_MODEL)if epoch % 100 == 0:msg = f"[train2] Epoch {epoch}, Loss={loss:.4f}"logger.info(msg)logger.info("[train2] Training complete.")
📄 util/train1_util.py
import randomD_MODEL = 64def override_config(cfg: dict):for k, v in cfg.items():if k in globals():globals()[k] = vdef train_one_epoch(epoch, batch_size, d_model):base = 1.0 / (epoch + 1)noise = random.random() * 0.01return base * (d_model / 64) + noise
📄 util/train2_util.py
import randomD_MODEL = 64def override_config(cfg: dict):for k, v in cfg.items():if k in globals():globals()[k] = vdef train_one_epoch(epoch, batch_size, d_model):base = 1.0 / (epoch + 1)noise = random.random() * 0.02return base * (d_model / 64) + noise
📄 config/config1.json
{"train_module": "train.train1","util_module": "util.train1_util","SEED": 123,"NUM_DAYS": 2500,"BATCH_SIZE": 32,"EPOCHS": 300,"D_MODEL": 128,"DATE_INPUT_DIM": 3
}
📄 config/config2.json
{"train_module": "train.train2","util_module": "util.train2_util","SEED": 456,"NUM_DAYS": 2000,"BATCH_SIZE": 64,"EPOCHS": 200,"D_MODEL": 64,"DATE_INPUT_DIM": 3
}
✅ 效果
控制台(同时写入 log 文件)会看到:
2025-09-09 12:00:00 - INFO - (1/2) Start train.train1 using config/config1.json
2025-09-09 12:00:00 - INFO - [train1] Config: SEED=123, NUM_DAYS=2500, BATCH_SIZE=32, EPOCHS=300, D_MODEL=128
2025-09-09 12:00:00 - INFO - [train1] Epoch 0, Loss=2.0123
...
2025-09-09 12:00:01 - INFO - [train1] Training complete.
2025-09-09 12:00:01 - INFO - (1/2) Finished train.train1
2025-09-09 12:00:01 - INFO - (2/2) Start train.train2 using config/config2.json
2025-09-09 12:00:01 - INFO - [train2] Config: SEED=456, NUM_DAYS=2000, BATCH_SIZE=64, EPOCHS=200, D_MODEL=64
2025-09-09 12:00:01 - INFO - [train2] Epoch 0, Loss=1.0456
...
2025-09-09 12:00:02 - INFO - [train2] Training complete.
2025-09-09 12:00:02 - INFO - (2/2) Finished train.train2
要不要我再帮你加一个功能:训练失败时自动记录失败配置,下一次运行时可以跳过已完成的配置,从失败的地方继续?
4.2 配置流程
整个框架:一个main加载多个config配置文件和对应模块然后注入train,一个train对应多个modul或model,train内部动态导包。先搭建好整个框架,然后适配不同模块,模块适配如下:
-
修改配置文件
注意,配置文件中模块的路径,及模块内部类的名字要统一
修改描述信息,对应具体参数 -
修改module模块文件
加入相关的配置超参数
覆盖函数初始化时的超参数 -
修改main要测试的文件
只针对修改的文件验证config_files=["./SB3_11tunning3/多股/config/config7_4.py"] # SB3_11tunning3/多股/config/config7_1.pyprint('配置文件:',config_files)total_tasks = len(config_files)
-
验证修改(train中打印模型结构)
print(PPO_model.policy) # 打印策略网络结构print("="* 50)print("模型摘要:")from torchinfo import summaryvec_env = PPO_model.get_env()obs = vec_env.reset()print('初始obs:',obs,obs.shape)# 方法 1:直接使用真实 obs# summary(model.policy, input_data=(torch.as_tensor(obs, dtype=torch.float32), ))# 方法 2:直接指定输入维度,不用真实 obs# summary(model.policy, input_size=(1, 4))summary(PPO_model.policy, input_size=obs.shape)
函数初始化参数不能被覆盖
TRANSFORMER_LAYERS = 2def test(value_layers=22):# 直接覆盖value_layers = TRANSFORMER_LAYERS# TRANSFORMER_LAYERS=TRANSFORMER_LAYERS 注意:❌这样错误,不能同名print(value_layers)
main示例:
import importlib
import json
import os
import traceback
import loggingfrom utils_v2.log_util import EvalLogger # 导入日志记录器# CONFIG_DIR = "config"
# LOG_DIR = "logs"CONFIG_DIR = "./SB3_11tunning3/多股/config"
LOG_DIR = "./logs_train_test"# # 0.设置保存路径
# from datetime import datetime
# timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
# base_dir = f'./logs_ppo_gate/{timestamp}/'
# os.makedirs(f'{base_dir}/charts', exist_ok=True) #./logs/{timestamp}/charts,创建多级目录# # 设置日志记录器
# # from utils_v2.log_util import EvalLogger # 导入日志记录器
# log_path = os.path.join(base_dir, f"ppo_log_{timestamp}.txt")
# logger = EvalLogger(log_path=log_path) # 初始化日志记录器
# logger.info(f"Eval logger 初始化完成 ✅,时间:{timestamp}")# === 设置 logger ===
# def setup_logger(log_file_path):
# logger = logging.getLogger(log_file_path)
# logger.setLevel(logging.INFO)
# if logger.hasHandlers():
# logger.handlers.clear()# # 文件 handler
# fh = logging.FileHandler(log_file_path, mode="a", encoding="utf-8")
# fh.setLevel(logging.INFO)
# formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
# fh.setFormatter(formatter)
# logger.addHandler(fh)# # 控制台 handler
# ch = logging.StreamHandler()
# ch.setLevel(logging.INFO)
# ch.setFormatter(formatter)
# logger.addHandler(ch)# return logger# === 加载配置文件 ===
# def load_config(path: str) -> dict:
# with open(path, "r", encoding="utf-8") as f:
# return json.load(f)# import json
# import osimport importlib.util
from utils_v2.log_util import EvalLogger # 导入日志记录器def load_config(path: str) -> dict:# 如果传入的是 .py 文件,就当作纯文本 JSON 来解析if path.endswith(".py"):try:spec = importlib.util.spec_from_file_location("new_module", path)new_module = importlib.util.module_from_spec(spec)spec.loader.exec_module(new_module)print(new_module.config)return new_module.configexcept Exception as e:raise ValueError(f"无法加载配置文件 {path}:{e}")else:with open(path, "r", encoding="utf-8") as f:return json.load(f)# 0.设置保存路径
from datetime import datetime
# timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
def run_task(config_path: str, task_index: int, total_tasks: int, logger=None,log_dir=LOG_DIR):timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')cfg_name = os.path.splitext(os.path.basename(config_path))[0]base_dir = f'{log_dir}/{cfg_name}_{timestamp}' # train中的base_dir,覆盖超参os.makedirs(base_dir, exist_ok=True)# os.makedirs(f'{log_dir}/{cfg_name}_timestamp', exist_ok=True)# log_file_path = os.path.join(log_dir, f"{cfg_name}.log")# log_file_path = os.path.join(base_dir, f"{cfg_name}_log_{timestamp}.log")# logger = setup_logger(log_file_path)config = load_config(config_path)train_module_name = config["train_module"]model_mudule_name = config["model_module"]util_modules_name = config["utils_modules"]print("模型model_mudule_name:",model_mudule_name)config["base_dir"]=base_dir # 覆盖超参base_dirprint("工具模块util_modules_name:",util_modules_name)try:logger.info(f"*************任务:{task_index}*****************")logger.info(f"({task_index}/{total_tasks}) Start {train_module_name} using {config_path}")logger.info(f"开始时间:{timestamp}.\n功能描述:\n {config['describe']}")# 动态导入 train 和 utiltrain = importlib.import_module(train_module_name)model = importlib.import_module(model_mudule_name)# utils = importlib.import_module(util_modules_name)utils = {name: importlib.import_module(alias) for name, alias in util_modules_name.items()}# 覆盖超参# if hasattr(train, "override_config"):# train.override_config(config)# if hasattr(model, "override_config"):# model.override_config(config)assert hasattr(train, "override_config"), "train 必须实现 override_config 方法"train.override_config(config)assert hasattr(model, "override_config"), "model 必须实现 override_config 方法"model.override_config(config)# 让 train 知道 util# if hasattr(train, "set_module"):# train.set_module(utils)assert hasattr(train, "set_module"), "train 必须实现 set_module 方法"train.set_module(utils)# 执行训练train.run()logger.info(f"({task_index}/{total_tasks}) Finished {train_module_name}\n\n\n")except Exception as e:logger.error(f"({task_index}/{total_tasks}) {train_module_name} failed: {e}\n{traceback.format_exc()}.\n\n\n")def main(start_index:int=1):# 初始化全局logger,global_logger = EvalLogger(log_path=log_path)timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')os.makedirs(LOG_DIR, exist_ok=True)log_file_path = os.path.join(LOG_DIR, f"global_log_{timestamp}.log")# logger = setup_logger(log_file_path)logger = EvalLogger(log_path=log_file_path) # 初始化日志记录器logger.info(f"Global logger 初始化完成 ✅,时间:{timestamp}\n\n")# config_files = sorted([# os.path.join(CONFIG_DIR, f)# for f in os.listdir(CONFIG_DIR)# if f.endswith(".json")# ])config_files = sorted([os.path.join(CONFIG_DIR, f)for f in os.listdir(CONFIG_DIR)if f.endswith(".py")])# 验证测试时,指定要测试的文件即可# config_files=["./SB3_11tunning3/多股/config/config0.py"] # SB3_11tunning3/多股/config/config7_1.py# config_files=["./SB3_11tunning3/多股/config/config7_4.py"] # SB3_11tunning3/多股/config/config7_1.pyprint('配置文件:',config_files)total_tasks = len(config_files)# 只从 start_index 开始执行# for idx, cfg in enumerate(config_files, start=1):for idx, cfg in enumerate(config_files[start_index:], start=start_index + 1):if idx < start_index:continuerun_task(cfg, idx, total_tasks, logger,LOG_DIR)if __name__ == "__main__":# 如果需要从中途恢复,可以修改 start_indexmain(start_index=0)
配置:
config = {"describe": "sb3_mlp 配置文件config0:\n""训练文件, sbt_dou2_7_pag\n""模型文件, sb3_mlp\n""说明: sb3_mlp ,SB3基础MLP网络\n""config主要修改:","save_prefix": "config0", "train_module": "sbt_dou2_7_pag","model_module": "police_model_pag.sb3_mlp", "utils_modules": {# "metrics_plot": "utils_v2.metrics_plot",# "EvalLogger": "utils_v2.log_util.EvalLogger",# "date_util": "utils_v2.date_util","StockTradingEnv": "env_dou.Double2_2_finrl", # "PPO_FeatureExtractor": "police_model_pag.sb3_mlp", # MambaTransformerExtractor"PPO_Policy": "police_model_pag.sb3_mlp", #},# -------------------------# 1、PPO超参数# -------------------------"Features_Dim": 128, # 特征提取器维度,用不上"BATCH_SIZE": 64, # 批次大小"Total_Timesteps": 1000, # 训练轮数,采样总次数"LR": 2e-4, # 学习率"Ent_coef": 0.01, # 熵系数"N_steps": 64, # 一次训练的步数,一次采样的步数"N_epochs": 3, # 采样中的更新次数# -------------------------# 2、评估超参数# -------------------------"WINDOW_SIZE": 1, # 时间窗口大小"Threshold": 0.8, # 奖励阈值,用于保存模型"N_tests": 30, # 测试次数
}
5、包内部动态导包(❌,不是tmp,是util)
太棒了,这里给你一套可直接落地的完整示例,满足你全部要求:
- 自动读取
config/
目录下的配置文件(按文件名排序批量执行) - 每个配置文件单独日志(以配置文件名命名),处理完一个就更新进度
main.py
动态导入:train
和util
(统一别名),并调用override_config
覆盖默认超参train
内部再动态导入tmp
(统一别名tmp
),支持在训练过程中按计划更换 tmp 模块- 提供
start_index
以便失败后从指定配置继续
📂 目录结构
project/
│── main.py
│
├── train/
│ ├── train1.py
│ └── train2.py
│
├── util/
│ ├── train1_util.py
│ └── train2_util.py
│
├── tmp/
│ ├── tmp1.py
│ └── tmp2.py
│
└── config/├── config1.json└── config2.json
🧠 main.py
- 自动扫描
config/*.json
,按文件名排序 - 以配置文件名命名日志
logs/{cfg}.log
- 统一导入
train
/util
,并调用它们的override_config
- 支持
start_index
(从第几个配置开始)
import importlib
import json
import os
import traceback
import loggingCONFIG_DIR = "config"
LOG_DIR = "logs"def setup_logger(log_file_path: str):logger = logging.getLogger(log_file_path)logger.setLevel(logging.INFO)if logger.hasHandlers():logger.handlers.clear()fmt = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')fh = logging.FileHandler(log_file_path, mode="a", encoding="utf-8")fh.setLevel(logging.INFO)fh.setFormatter(fmt)logger.addHandler(fh)ch = logging.StreamHandler()ch.setLevel(logging.INFO)ch.setFormatter(fmt)logger.addHandler(ch)return loggerdef load_config(path: str) -> dict:with open(path, "r", encoding="utf-8") as f:return json.load(f)def run_task(config_path: str, task_index: int, total_tasks: int, log_dir=LOG_DIR):cfg_name = os.path.splitext(os.path.basename(config_path))[0]os.makedirs(log_dir, exist_ok=True)log_file_path = os.path.join(log_dir, f"{cfg_name}.log")logger = setup_logger(log_file_path)try:config = load_config(config_path)train_module_name = config["train_module"]util_module_name = config["util_module"]logger.info(f"({task_index}/{total_tasks}) START {train_module_name} using {config_path}")# 动态导入 train / util,并统一别名train = importlib.import_module(train_module_name) # as trainutil = importlib.import_module(util_module_name) # as util# 覆盖超参(train & util 各自维护默认超参)if hasattr(train, "override_config"):train.override_config(config)if hasattr(util, "override_config"):util.override_config(config)# 执行训练(内部会动态导入 tmp 并可按计划更换)train.run(log_path=log_file_path, ext_logger=logger)logger.info(f"({task_index}/{total_tasks}) FINISH {train_module_name}")except Exception as e:logger.error(f"({task_index}/{total_tasks}) FAILED: {e}\n{traceback.format_exc()}")def main(start_index: int = 0):# 扫描 config 目录config_files = sorted([os.path.join(CONFIG_DIR, f) for f in os.listdir(CONFIG_DIR) if f.endswith(".json")])total = len(config_files)if total == 0:print("No config files found in ./config")returnfor idx, cfg in enumerate(config_files[start_index:], start=start_index + 1):run_task(cfg, idx, total)if __name__ == "__main__":# 失败续跑:把 start_index 改成失败时的索引(0 表示第一个)main(start_index=0)
🏋️ train/train1.py
- 默认超参(可被
override_config
覆盖) - 支持从 config 中读取
TMP_MODULE
或TMP_SCHEDULE
(按 epoch 动态更换 tmp) - 统一用
importlib
动态导入 tmp 为别名tmp
run()
支持外部传入 logger(也会写入独立 log 文件)
import random
import numpy as np
import torch
import importlibfrom util import train1_util as util# ==== 默认超参(可被 config 覆盖) ====
SEED = 42
NUM_DAYS = 2000
CONTEXT_LEN = 128
HORIZON = 5
BATCH_SIZE = 64
EPOCHS = 1000D_MODEL = 64
DATE_INPUT_DIM = 3# tmp 动态导入控制
TMP_MODULE = "tmp.tmp1" # 初始 tmp 模块
TMP_SCHEDULE = [] # 例如: [{"epoch": 200, "module": "tmp.tmp2"}]def override_config(cfg: dict):"""覆盖默认超参(包含 TMP_MODULE/TMP_SCHEDULE)"""g = globals()for k, v in cfg.items():if k in g:g[k] = vdef set_seed(seed: int):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)def _import_tmp(module_name: str):"""按名字动态导入 tmp 模块,并统一别名为 tmp"""return importlib.import_module(module_name)def _maybe_switch_tmp(epoch: int, current_tmp_mod, logger=None):"""根据 TMP_SCHEDULE 在特定 epoch 更换 tmp 模块"""for item in TMP_SCHEDULE:if epoch == int(item.get("epoch", -1)):new_mod_name = item.get("module")if new_mod_name:if logger: logger.info(f"[train1] Switch tmp -> {new_mod_name} at epoch {epoch}")return _import_tmp(new_mod_name)return current_tmp_moddef run(log_path: str = None, ext_logger=None):"""log_path: 独立日志文件路径(main 已写 console 和文件)ext_logger: 由 main 传入的 logger,优先使用它记录"""logger = ext_loggerdef log(msg):if logger:logger.info(msg)if log_path:with open(log_path, "a") as f:f.write(msg + "\n")else:print(msg)set_seed(SEED)# 初次导入 tmptmp = _import_tmp(TMP_MODULE)log(f"[train1] Using tmp: {TMP_MODULE}")for epoch in range(EPOCHS):# 按计划更换 tmptmp = _maybe_switch_tmp(epoch, tmp, logger=logger)# 使用 util + tmp 执行一次 epoch 训练loss = util.train_one_epoch(epoch, BATCH_SIZE, D_MODEL, tmp=tmp)# 可选择让 tmp 参与度量或正则(示例:tmp.regularize(loss))# loss = tmp.regularize(loss)if epoch % 100 == 0:log(f"[train1] Epoch {epoch}, Loss: {loss:.6f}")log("[train1] Training complete.")
🏋️ train/train2.py
与 train1 类似,展示多脚本共存。你可以按需改动:
import random
import numpy as np
import torch
import importlibfrom util import train2_util as utilSEED = 42
NUM_DAYS = 2000
CONTEXT_LEN = 128
HORIZON = 5
BATCH_SIZE = 64
EPOCHS = 1000D_MODEL = 64
DATE_INPUT_DIM = 3TMP_MODULE = "tmp.tmp2"
TMP_SCHEDULE = [{"epoch": 150, "module": "tmp.tmp1"}]def override_config(cfg: dict):g = globals()for k, v in cfg.items():if k in g:g[k] = vdef set_seed(seed: int):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)def _import_tmp(name: str):return importlib.import_module(name)def _maybe_switch_tmp(epoch: int, current_tmp_mod, logger=None):for item in TMP_SCHEDULE:if epoch == int(item.get("epoch", -1)):new_mod = item.get("module")if new_mod:if logger: logger.info(f"[train2] Switch tmp -> {new_mod} at epoch {epoch}")return _import_tmp(new_mod)return current_tmp_moddef run(log_path: str = None, ext_logger=None):logger = ext_loggerdef log(msg):if logger:logger.info(msg)if log_path:with open(log_path, "a") as f:f.write(msg + "\n")else:print(msg)set_seed(SEED)tmp = _import_tmp(TMP_MODULE)log(f"[train2] Using tmp: {TMP_MODULE}")for epoch in range(EPOCHS):tmp = _maybe_switch_tmp(epoch, tmp, logger=logger)loss = util.train_one_epoch(epoch, BATCH_SIZE, D_MODEL, tmp=tmp)if epoch % 100 == 0:log(f"[train2] Epoch {epoch}, Loss: {loss:.6f}")log("[train2] Training complete.")
🛠 util/train1_util.py
- 示例里让
tmp
参与训练(例如数据扰动、损失修饰等)
import random# 工具层默认参数(可被覆盖)
D_MODEL = 64def override_config(cfg: dict):g = globals()for k, v in cfg.items():if k in g:g[k] = vdef train_one_epoch(epoch: int, batch_size: int, d_model: int, tmp):"""tmp: 动态导入的 tmp 模块,需至少提供 tmp.transform(x) 或 tmp.loss_scale(...)这里用一个简单的 loss 衰减 + tmp 的 scale 演示"""base = 1.0 / (epoch + 1)noise = random.random() * 0.01loss = base * (d_model / 64) + noisescale = tmp.loss_scale(epoch) if hasattr(tmp, "loss_scale") else 1.0return loss * scale
🛠 util/train2_util.py
import randomD_MODEL = 64def override_config(cfg: dict):g = globals()for k, v in cfg.items():if k in g:g[k] = vdef train_one_epoch(epoch: int, batch_size: int, d_model: int, tmp):base = 1.0 / (epoch + 1)noise = random.random() * 0.02loss = base * (d_model / 64) + noise# 如果 tmp 有 transform,就用它“扰动”一下损失if hasattr(tmp, "transform"):loss = tmp.transform(loss, epoch=epoch)return loss
🧩 tmp/tmp1.py
def loss_scale(epoch: int) -> float:"""举例:前期更强正则(scale < 1),后期逐步放开(→ 1.0)"""if epoch < 100:return 0.9elif epoch < 300:return 0.95else:return 1.0def transform(x: float, epoch: int = 0) -> float:# 演示:轻微缩放return x * loss_scale(epoch)
🧩 tmp/tmp2.py
def loss_scale(epoch: int) -> float:"""另一种策略:中期更激进"""if epoch < 100:return 1.0elif epoch < 200:return 0.85elif epoch < 400:return 0.92else:return 1.0def transform(x: float, epoch: int = 0) -> float:return x * loss_scale(epoch)
⚙️ config/config1.json
- 指定
train
、util
、初始TMP_MODULE
,并提供一个切换计划TMP_SCHEDULE
{"train_module": "train.train1","util_module": "util.train1_util","SEED": 123,"NUM_DAYS": 2500,"CONTEXT_LEN": 256,"HORIZON": 5,"BATCH_SIZE": 32,"EPOCHS": 450,"D_MODEL": 128,"DATE_INPUT_DIM": 3,"TMP_MODULE": "tmp.tmp1","TMP_SCHEDULE": [{ "epoch": 200, "module": "tmp.tmp2" },{ "epoch": 400, "module": "tmp.tmp1" }]
}
⚙️ config/config2.json
{"train_module": "train.train2","util_module": "util.train2_util","SEED": 456,"NUM_DAYS": 2000,"CONTEXT_LEN": 128,"HORIZON": 5,"BATCH_SIZE": 64,"EPOCHS": 300,"D_MODEL": 64,"DATE_INPUT_DIM": 3,"TMP_MODULE": "tmp.tmp2","TMP_SCHEDULE": [{ "epoch": 150, "module": "tmp.tmp1" }]
}
✅ 使用说明小结
- 把以上文件按目录放好即可运行:
python main.py
- 批量任务按
config/*.json
名称排序依次执行 - 日志写入
logs/{配置文件名}.log
,每完成一个配置就写“FINISH”进度 - 失败续跑:在
main(start_index=...)
里设置从第几个配置继续(0 起始) - 在配置中可切换
train/util/tmp
的具体实现;训练中还可按TMP_SCHEDULE
动态切换tmp
策略
如果你想把并行/多进程、多 GPU 绑定、或**断点续训(记录完成到某个 epoch)**一起加上,我可以在这个框架上继续扩展。
6、问题:(函数初始化参数不能被覆盖)
你这个问题非常关键 ⚡:
如果你在 Python 里这样写:
TRANSFORMER_LAYERS = 2def test(value_layers: int = TRANSFORMER_LAYERS):print(value_layers)
然后在运行时修改:
TRANSFORMER_LAYERS = 8
test()
输出依然是 2。
原因
函数定义时,默认参数的值就已经绑定好了,不会随着全局变量更新而变化。
所以 def test(...=TRANSFORMER_LAYERS)
在定义那一刻就把 2
作为默认值记住了。
测试代码
你可以跑下面的:
# train.py
TRANSFORMER_LAYERS = 2def test(value_layers: int = TRANSFORMER_LAYERS):print("value_layers =", value_layers)def override_config(cfg: dict):globals().update(cfg)
然后 main 脚本:
import traintrain.test() # 输出 2# 覆盖
train.override_config({"TRANSFORMER_LAYERS": 8})train.test() # 仍然输出 2 !!!
train.test(train.TRANSFORMER_LAYERS) # 输出 8 ✅
输出:
value_layers = 2
value_layers = 2
value_layers = 8
解决方法 ✅
不要在函数签名里用全局变量做默认值,改成运行时获取:
def test(value_layers: int = None):if value_layers is None:value_layers = globals()["TRANSFORMER_LAYERS"]print("value_layers =", value_layers)
TRANSFORMER_LAYERS = 2def test(value_layers=None):if value_layers is None:value_layers = TRANSFORMER_LAYERSprint(value_layers)
这样覆盖后就能正确读取最新值。
要不要我帮你写一个 标准模板,把所有超参都做成这种「可覆盖、动态读取」的安全写法?
解决方法 ✅(推荐)
TRANSFORMER_LAYERS = 2def test(value_layers=22):# 直接覆盖value_layers = TRANSFORMER_LAYERSprint(value_layers)
提问:
TRANSFORMER_LAYERS = 2 PI_LAYERS = 2 def test(value_layers=TRANSFORMER_LAYERS,pi_layers=22): #用全局变量重新赋值 value_layers = TRANSFORMER_LAYERS pi_layers = PI_LAYERS
下面代码中的函数参数初始化是全局参数或者数字的,改成上面形式,标注要修改的地方,中文回答:
注意⚠️:下面情况直接报错
TRANSFORMER_LAYERS = 2
def test(value_layers=22):# 直接覆盖TRANSFORMER_LAYERS = TRANSFORMER_LAYERSprint('直接覆盖:',TRANSFORMER_LAYERS)
test() # 报错:UnboundLocalError: local variable 'TRANSFORMER_LAYERS' referenced before assignment
7、属性操作:(hasattr
、getattr
、setattr
、delattr
)
👌 我帮你整理一下 hasattr
和 getattr
的区别,并顺便对比一下常见的几个内置函数:
1. hasattr(obj, name)
作用:检查对象是否有某个属性
参数:
obj
:对象name
:属性名(字符串)
返回:布尔值(True
/False
)
例子:
class A:x = 10a = A()
print(hasattr(a, "x")) # True
print(hasattr(a, "y")) # False
2. getattr(obj, name, default=None)
作用:获取对象的属性值
参数:
obj
:对象name
:属性名(字符串)default
:可选,属性不存在时返回的默认值
返回:属性值(或default
)
例子:
class A:x = 10a = A()
print(getattr(a, "x")) # 10
print(getattr(a, "y", "N/A")) # "N/A"
3. setattr(obj, name, value)
作用:给对象设置属性
setattr(a, "y", 20)
print(a.y) # 20
4. delattr(obj, name)
作用:删除对象的属性
delattr(a, "y")
print(hasattr(a, "y")) # False
🔑 对比总结
函数 | 作用 | 常用场景 |
---|---|---|
hasattr | 判断对象是否有某个属性 | 避免 AttributeError |
getattr | 获取属性值(可设默认值) | 动态访问属性 |
setattr | 动态设置属性 | 运行时修改对象 |
delattr | 删除属性 | 运行时移除对象属性 |
要不要我帮你写个综合示例,展示 hasattr
、getattr
、setattr
、delattr
在同一个类里的用法?
8、logging
明白了,你想实现 在多个目录中创建不同的日志文件,并且每个目录对应一个日志对象。可以用 logging
配合多个 Logger
或多个 FileHandler
来实现。下面给你一个完整示例:
示例:每个目录一个 Logger
import logging
import osdef setup_logger(name, log_dir, log_file="log.txt"):os.makedirs(log_dir, exist_ok=True) # 确保目录存在logger = logging.getLogger(name)logger.setLevel(logging.INFO)# 避免重复添加 handlerif not logger.handlers:file_path = os.path.join(log_dir, log_file)fh = logging.FileHandler(file_path)formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')fh.setFormatter(formatter)logger.addHandler(fh)return logger# 创建多个目录对应的日志对象
logger_a = setup_logger("logger_a", "logs/dir_a")
logger_b = setup_logger("logger_b", "logs/dir_b")
logger_c = setup_logger("logger_c", "logs/dir_c")# 写入日志
logger_a.info("This is a log in dir_a")
logger_b.info("This is a log in dir_b")
logger_c.info("This is a log in dir_c")
✅ 运行后,会生成三个目录:
logs/dir_a/log.txtdir_b/log.txtdir_c/log.txt
每个目录都有自己独立的日志文件。
🔑 拓展技巧
- 动态批量创建:如果有大量目录,可以用循环动态创建 logger:
log_dirs = ["logs/a", "logs/b", "logs/c"]
loggers = {d: setup_logger(d, d) for d in log_dirs}loggers["logs/a"].info("Message A")
loggers["logs/b"].info("Message B")
- 每个目录多个文件:可以在
setup_logger
中加入时间戳或序号,让每个 logger 每次运行生成新的文件:
from datetime import datetime
log_file = f"log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
如果你需要,我可以帮你写一个函数封装版本:
- 输入多个目录
- 自动创建 logger
- 每次写日志生成独立文件
- 支持批量调用,非常适合多任务训练/多环境日志收集
你希望我直接写吗?