当前位置: 首页 > news >正文

YOLOv11-ultralytics-8.3.67部分代码阅读笔记-tuner.py

tuner.py

ultralytics\utils\tuner.py

目录

tuner.py

1.所需的库和模块

2.def run_ray_tune(model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args,): 


1.所需的库和模块

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks

2.def run_ray_tune(model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args,): 

# 这段代码定义了一个名为 run_ray_tune 的函数,用于使用 Ray Tune 框架进行超参数优化。它支持对模型的训练参数进行自动调优,并提供了与 WandB 集成的日志记录功能。
# 定义了函数 run_ray_tune ,接受以下参数 :
# 1.model :要优化的模型对象。
# 2.space :超参数搜索空间,一个字典,定义了需要优化的参数及其范围。默认为 None 。
# 3.grace_period :ASHA 调度器的宽限期,表示在多少个 epoch 内不会停止试验。默认为 10 。
# 4.gpu_per_trial :每个试验分配的 GPU 数量。默认为 None 。
# 5.max_samples :最大试验次数。默认为 10 。
# 6.**train_args :其他传递给训练函数的参数,以关键字参数的形式传递。
def run_ray_tune(
    model,
    space: dict = None,
    grace_period: int = 10,
    gpu_per_trial: int = None,
    max_samples: int = 10,
    **train_args,
):
    # 使用 Ray Tune 运行超参数调整。
    # 示例:
    # ```python
    # from ultralytics import YOLO
    # # Load a YOLO11n model
    # model = YOLO("yolo11n.pt")
    # # 开始调整YOLO11n 在 COCO8 数据集上训练的超参数
    # result_grid = model.tune(data="coco8.yaml", use_ray=True)
    """
    Runs hyperparameter tuning using Ray Tune.

    Args:
        model (YOLO): Model to run the tuner on.
        space (dict, optional): The hyperparameter search space. Defaults to None.
        grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
        gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
        max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
        train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.

    Returns:
        (dict): A dictionary containing the results of the hyperparameter search.

    Example:
        ```python
        from ultralytics import YOLO

        # Load a YOLO11n model
        model = YOLO("yolo11n.pt")

        # Start tuning hyperparameters for YOLO11n training on the COCO8 dataset
        result_grid = model.tune(data="coco8.yaml", use_ray=True)
        ```
    """
    # 在日志中输出一条提示信息,引导用户了解 Ray Tune 的相关文档。
    LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")    # 💡 了解 RayTune,请访问 https://docs.ultralytics.com/integrations/ray-tune 。
    # 检查 train_args 是否为空。
    if train_args is None:
        # 如果为空,则初始化为空字典。
        train_args = {}

    # 这段代码的作用是检查和导入运行 Ray Tune 所需的依赖项,并验证 wandb 是否可用。
    # 开始一个 try 块,用于捕获导入过程中可能出现的异常。
    try:
        # 调用 checks.check_requirements 函数,检查是否安装了 ray[tune] 。如果未安装,该函数会抛出异常。
        # def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
        # -> 用于检查和安装Python项目的依赖项。返回 False ,表示自动安装失败。如果未启用自动安装功能( install 为 False 或 AUTOINSTALL 为 False ),直接返回 False ,表示未安装缺失的依赖项。如果 pkgs 列表为空(即没有缺失的依赖项),返回 True ,表示所有依赖项都已满足。
        # -> return False / return False / return True
        checks.check_requirements("ray[tune]")

        # 导入 ray 模块,这是 Ray Tune 的核心依赖项。
        import ray
        # 从 ray 模块中导入 tune 子模块,用于 超参数优化 。
        from ray import tune
        # 从 ray.air 模块中导入 RunConfig ,用于 配置 Ray Tune 的运行参数 。
        from ray.air import RunConfig
        # 从 ray.air.integrations.wandb 模块中导入 WandbLoggerCallback ,用于 将训练日志同步到 Weights & Biases (WandB) 。
        from ray.air.integrations.wandb import WandbLoggerCallback
        # 从 ray.tune.schedulers 模块中导入 ASHAScheduler ,用于 实现超参数优化的调度策略 。
        from ray.tune.schedulers import ASHAScheduler
    # 捕获导入过程中可能出现的 ImportError 异常。
    except ImportError:
        # 如果捕获到 ImportError ,抛出一个 ModuleNotFoundError ,提示用户安装 ray[tune] 。
        raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"')    # 需要 Ray Tune,但未找到。要安装,请运行:pip install "ray[tune]"。

    # 开始第二个 try 块,用于检查 wandb 是否可用。
    try:
        # 尝试导入 wandb 模块。
        import wandb

        # 验证导入的 wandb 模块是否包含 __version__ 属性,以确保其正确导入。如果 wandb 未正确安装或导入失败,会抛出 AssertionError 。
        assert hasattr(wandb, "__version__")
    # 捕获 ImportError 或 AssertionError 异常。
    except (ImportError, AssertionError):
        # 如果捕获到异常,将 wandb 设置为 False ,表示 wandb 不可用。
        wandb = False
    # 这段代码的作用是。检查和导入 Ray Tune 所需的依赖项:如果未安装 ray[tune] ,抛出异常并提示用户安装。导入 Ray Tune 的核心模块和工具。验证 wandb 是否可用:如果 wandb 未安装或导入失败,将 wandb 设置为 False ,避免后续代码中使用 wandb 功能。这种设计确保了代码在运行前具备必要的依赖项,并且能够灵活处理依赖项缺失的情况,避免因未安装的模块导致程序崩溃。

    # 这段代码的作用是检查 Ray 的版本是否满足要求,并定义了一个默认的超参数搜索空间 default_space ,最后将模型对象存储到 Ray 的对象存储中。
    # 调用 checks.check_version 函数,检查 Ray 的版本是否满足 >=2.0.0 。 如果版本不符合要求,会抛出异常并提示用户升级 Ray。
    checks.check_version(ray.__version__, ">=2.0.0", "ray")
    # 定义了一个默认的超参数搜索空间 default_space ,包含多个超参数及其范围。
    default_space = {
        # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
        # lr0 :初始学习率,范围为 [1e-5, 1e-1] 。
        "lr0": tune.uniform(1e-5, 1e-1),
        # lrf :最终学习率因子( lr0 * lrf ),范围为 [0.01, 1.0] 。
        "lrf": tune.uniform(0.01, 1.0),  # final OneCycleLR learning rate (lr0 * lrf)
        # momentum :动量(SGD)或 Adam 的 beta1,范围为 [0.6, 0.98] 。
        "momentum": tune.uniform(0.6, 0.98),  # SGD momentum/Adam beta1
        # weight_decay :优化器的权重衰减,范围为 [0.0, 0.001] 。
        "weight_decay": tune.uniform(0.0, 0.001),  # optimizer weight decay 5e-4
        # warmup_epochs :预热阶段的轮数(可以是分数),范围为 [0.0, 5.0] 。
        "warmup_epochs": tune.uniform(0.0, 5.0),  # warmup epochs (fractions ok)
        # warmup_momentum :预热阶段的初始动量,范围为 [0.0, 0.95] 。
        "warmup_momentum": tune.uniform(0.0, 0.95),  # warmup initial momentum
        # box :边界框损失权重,范围为 [0.02, 0.2] 。
        "box": tune.uniform(0.02, 0.2),  # box loss gain
        # cls :分类损失权重,范围为 [0.2, 4.0] 。
        "cls": tune.uniform(0.2, 4.0),  # cls loss gain (scale with pixels)
        # hsv_h :HSV 色调增强(分数),范围为 [0.0, 0.1] 。
        "hsv_h": tune.uniform(0.0, 0.1),  # image HSV-Hue augmentation (fraction)
        # hsv_s :HSV 饱和度增强(分数),范围为 [0.0, 0.9] 。
        "hsv_s": tune.uniform(0.0, 0.9),  # image HSV-Saturation augmentation (fraction)
        # hsv_v :HSV 明度增强(分数),范围为 [0.0, 0.9] 。
        "hsv_v": tune.uniform(0.0, 0.9),  # image HSV-Value augmentation (fraction)
        # degrees :图像旋转角度(正负),范围为 [0.0, 45.0] 。
        "degrees": tune.uniform(0.0, 45.0),  # image rotation (+/- deg)
        # translate :图像平移比例(正负),范围为 [0.0, 0.9] 。
        "translate": tune.uniform(0.0, 0.9),  # image translation (+/- fraction)
        # scale :图像缩放比例(正负),范围为 [0.0, 0.9] 。
        "scale": tune.uniform(0.0, 0.9),  # image scale (+/- gain)
        # shear :图像剪切角度(正负),范围为 [0.0, 10.0] 。
        "shear": tune.uniform(0.0, 10.0),  # image shear (+/- deg)
        # perspective :图像透视增强(分数),范围为 [0.0, 0.001] 。
        "perspective": tune.uniform(0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001
        # flipud :图像上下翻转概率,范围为 [0.0, 1.0] 。
        "flipud": tune.uniform(0.0, 1.0),  # image flip up-down (probability)
        # fliplr :图像左右翻转概率,范围为 [0.0, 1.0] 。
        "fliplr": tune.uniform(0.0, 1.0),  # image flip left-right (probability)
        # bgr :图像通道顺序为 BGR 的概率,范围为 [0.0, 1.0] 。
        "bgr": tune.uniform(0.0, 1.0),  # image channel BGR (probability)
        # mosaic :马赛克数据增强概率,范围为 [0.0, 1.0] 。
        "mosaic": tune.uniform(0.0, 1.0),  # image mixup (probability)
        # mixup :混合数据增强概率,范围为 [0.0, 1.0] 。
        "mixup": tune.uniform(0.0, 1.0),  # image mixup (probability)
        # copy_paste :复制粘贴增强概率,范围为 [0.0, 1.0] 。
        "copy_paste": tune.uniform(0.0, 1.0),  # segment copy-paste (probability)
    }

    # Put the model in ray store
    # 从模型对象中获取任务类型(例如 task 可能是 detect 、 segment 等)。
    task = model.task
    # 使用 ray.put 将模型对象存储到 Ray 的对象存储中。 这样可以确保模型对象在多个试验中共享,避免重复加载。
    model_in_store = ray.put(model)
    # 这段代码的作用是。版本检查:确保 Ray 的版本满足 >=2.0.0 。定义超参数搜索空间:提供了一个默认的超参数搜索空间 default_space ,包含多个与训练相关的超参数及其范围。模型共享:将模型对象存储到 Ray 的对象存储中,以便在多个试验中共享,提高效率。这种设计使得超参数优化过程更加自动化和高效,同时提供了灵活的超参数配置选项。

    # 这段代码定义了一个名为 _tune 的内部函数,用于执行单次超参数优化试验的训练过程。它是 Ray Tune 框架中用于训练模型的核心函数。
    # 定义了一个名为 _tune 的函数,接受一个参数。
    # 1.config :一个字典,包含当前试验的超参数配置。
    def _tune(config):
        # 使用指定的超参数和其他参数训练 YOLO 模型。
        """
        Trains the YOLO model with the specified hyperparameters and additional arguments.

        Args:
            config (dict): A dictionary of hyperparameters to use for training.

        Returns:
            None
        """
        # 使用 ray.get 从 Ray 的对象存储中获取模型对象 model_in_store 。 model_in_store 是在外部通过 ray.put 存储的模型对象,用于在多个试验中共享同一个模型实例。
        model_to_train = ray.get(model_in_store)  # get the model from ray store for tuning
        # 调用模型的 reset_callbacks 方法,重置模型的回调函数。 这一步是为了确保每次试验时,模型的回调函数(如日志记录、早停等)都是干净的,避免回调函数在多次试验中重复使用。
        model_to_train.reset_callbacks()
        # 将外部传入的 train_args 更新到当前试验的配置字典 config 中。 train_args 是一个字典,包含用户指定的训练参数(如数据路径、训练轮数等)。这一步确保了用户指定的参数能够覆盖默认的超参数配置。
        config.update(train_args)
        # 使用更新后的配置字典 config 调用模型的 train 方法进行训练。 **config 将字典中的键值对作为关键字参数传递给 train 方法。 results 是训练过程的返回值,通常包含训练结果和评估指标。
        results = model_to_train.train(**config)
        # 返回训练结果的字典 results.results_dict 。 这个字典通常包含训练过程中的关键指标(如损失值、准确率等),用于 Ray Tune 框架评估当前试验的性能。
        return results.results_dict
    # 这段代码定义了一个名为 _tune 的函数,用于执行单次超参数优化试验的训练过程。它的主要功能包括。模型共享:从 Ray 的对象存储中获取模型对象,确保多个试验共享同一个模型实例。回调重置:重置模型的回调函数,避免回调函数在多次试验中重复使用。配置更新:将用户指定的训练参数更新到当前试验的配置中,确保用户设置的参数能够生效。模型训练:使用更新后的配置调用模型的 train 方法进行训练,并返回训练结果。结果返回:返回训练结果的字典,供 Ray Tune 框架评估当前试验的性能。这种设计使得超参数优化过程更加高效和灵活,同时确保了每次试验的独立性和用户配置的优先级。

    # 这段代码的作用是获取超参数搜索空间和数据集路径,并确保它们被正确设置。
    # Get search space    注释说明接下来的代码块用于获取超参数搜索空间。
    # 检查是否提供了用户自定义的超参数搜索空间 space 。如果 space 为 None 或为空字典,则使用默认的搜索空间。
    if not space:
        # 如果用户没有提供超参数搜索空间,则将 space 设置为默认的搜索空间 default_space 。
        space = default_space
        # 如果使用了默认的搜索空间,通过日志记录器 LOGGER 输出警告信息,提示用户未提供搜索空间,正在使用默认值。
        LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.")    # 警告⚠️未提供搜索空间,使用默认搜索空间。

    # Get dataset    注释说明接下来的代码块用于获取数据集路径。
    # 从 train_args 中获取数据集路径 data 。 如果 train_args 中提供了 data 键,则使用其值。 如果未提供,则从 TASK2DATA 字典中根据任务类型 task 获取默认数据集路径。
    data = train_args.get("data", TASK2DATA[task])
    # 将数据集路径 data 添加到 超参数搜索空间 space 中。这一步确保在超参数优化过程中,每个试验都使用相同的训练数据集。
    space["data"] = data
    # 检查 train_args 中是否提供了 data 键。如果未提供,则输出警告信息。
    if "data" not in train_args:
        # 如果未提供数据集路径,通过日志记录器 LOGGER 输出警告信息,提示用户未提供数据集路径,正在使用默认值。
        LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')    # 警告⚠️未提供数据,使用默认的“data={data}”。
    # 这段代码的作用是。获取超参数搜索空间:如果用户未提供超参数搜索空间,则使用默认的搜索空间。如果使用了默认搜索空间,输出警告信息提示用户。获取数据集路径:从 train_args 中获取数据集路径,如果未提供,则使用默认值。将数据集路径添加到超参数搜索空间中,确保每个试验使用相同的训练数据集。如果未提供数据集路径,输出警告信息提示用户。这种设计确保了超参数优化过程中的灵活性和健壮性:用户可以自定义超参数搜索空间,也可以使用默认值。数据集路径可以通过用户指定或使用默认值,避免因未提供数据集路径而导致错误。

    # 这段代码的作用是定义了 Ray Tune 的训练函数、调度器和回调,用于执行超参数优化。
    # Define the trainable function with allocated resources    注释说明接下来的代码块用于定义训练函数,并为其分配资源。
    # 使用 tune.with_resources 函数将 _tune 函数与资源分配绑定。
    # 每个试验将分配以下资源 :
    # CPU : NUM_THREADS 个线程。
    # GPU : gpu_per_trial 个 GPU,如果未指定,则默认为 0 。
    # 这一步确保了每个试验在运行时能够获得足够的计算资源。
    trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})

    # Define the ASHA scheduler for hyperparameter search    注释说明接下来的代码块用于定义 ASHA 调度器。

    # ASHAScheduler(time_attr='training_iteration', metric='metric', mode='max', max_t=100, grace_period=1, reduction_factor=4, brackets=1)
    # ASHAScheduler 是 Ray Tune 中的一个调度器,它实现了异步连续减半(Async Successive Halving)算法。这种算法可以提前终止性能较差的试验,节省计算资源,并将资源分配给表现较好的试验。
    # 参数说明 :
    # time_attr :(默认为 'training_iteration')用于比较时间的属性,可以是任何单调递增的属性,例如训练轮次。
    # metric :要优化的指标名称,即在训练结果字典中用于衡量优化目标的值。
    # mode :(默认为 'max')优化模式,可以是 'min' 或 'max',表示是最小化还是最大化指标。
    # max_t :(默认为 100)试验可以运行的最大时间单位,单位由 time_attr 决定。
    # grace_period :(默认为 1)在至少这个时间之后才考虑停止试验,单位与 time_attr 相同。
    # reduction_factor :(默认为 4)用于设置减半率和数量的无量纲标量。
    # brackets :(默认为 1)分组的数量,每个分组有不同的减半率,由 reduction_factor 指定。
    # ASHAScheduler 与其他调度器相比,提供了更好的并行性,并且在淘汰过程中避免了落后问题。因此,Ray Tune 推荐使用 ASHAScheduler 而不是标准的 HyperBandScheduler 。

    # 创建了一个 ASHA 调度器实例,用于超参数优化的调度策略。
    asha_scheduler = ASHAScheduler(
        # 使用训练的 epoch 数作为时间属性。
        time_attr="epoch",
        # 指定优化的目标指标(如准确率、mAP 等),根据任务类型从 TASK2METRIC 字典中获取。
        metric=TASK2METRIC[task],
        # 优化目标是最大化指定的指标。
        mode="max",
        # 最大训练轮数,优先使用 train_args 中的 epochs ,如果没有则使用默认配置或默认值 100 。
        max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
        # 宽限期,表示在多少个 epoch 内不会停止表现不佳的试验。
        grace_period=grace_period,
        # 每次减少试验数量的因子。
        reduction_factor=3,
    )

    # Define the callbacks for the hyperparameter search    注释说明接下来的代码块用于定义超参数优化的回调。
    # 定义了一个回调列表 tuner_callbacks ,用于记录训练过程的日志。
    # 如果 wandb 可用,则添加 WandbLoggerCallback ,将日志同步到 WandB 项目 YOLOv8-tune 。
    # 如果 wandb 不可用,则回调列表为空。
    tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
    # 这段代码的作用是。定义训练函数与资源分配:将 _tune 函数与资源分配绑定,确保每个试验能够获得足够的 CPU 和 GPU 资源。定义 ASHA 调度器:使用 ASHA 调度器动态调整试验,优化目标是最大化指定的指标。调度器的参数包括最大训练轮数、宽限期和减少因子,这些参数确保了优化过程的高效性和灵活性。定义回调:如果 wandb 可用,添加 WandbLoggerCallback ,将训练日志同步到 WandB。如果 wandb 不可用,则不使用任何回调。这种设计使得超参数优化过程更加高效和灵活,同时提供了详细的日志记录功能,便于用户监控和分析训练过程。

    # 这段代码的作用是创建一个 Ray Tune 超参数优化调优器( tuner ),并配置相关的保存目录、参数空间和运行配置。
    # Create the Ray Tune hyperparameter search tuner    注释说明接下来的代码块用于创建 Ray Tune 的超参数优化调优器。
    # 调用 get_save_dir 函数生成保存目录路径。
    # 使用 get_cfg(DEFAULT_CFG, train_args) 获取配置对象。
    # 如果 train_args 中提供了 name ,则使用其值作为保存目录的名称;否则,默认为 "tune" 。
    # 调用 .resolve() 确保生成的路径是绝对路径。
    # tune_dir 是 保存超参数优化结果的目录 。
    tune_dir = get_save_dir(
        get_cfg(DEFAULT_CFG, train_args), name=train_args.pop("name", "tune")
    ).resolve()  # must be absolute dir
    # 使用 mkdir 方法创建保存目录。 parents=True 如果父目录不存在,则一并创建。 exist_ok=True 如果目录已存在,则不会抛出异常。
    tune_dir.mkdir(parents=True, exist_ok=True)
    # 创建一个 Ray Tune 调优器实例 tuner ,用于执行超参数优化。
    tuner = tune.Tuner(
        # 绑定资源的训练函数(之前定义的 _tune 函数)。
        trainable_with_resources,
        # 超参数搜索空间,定义了需要优化的参数及其范围。
        param_space=space,
        # 配置超参数优化的调度器和试验数量。
        # scheduler=asha_scheduler :使用 ASHA 调度器动态调整试验。
        # num_samples=max_samples :最大试验次数。
        tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
        # 配置运行参数。
        # callbacks=tuner_callbacks :回调列表,用于日志记录(如 WandB)。
        # storage_path=tune_dir :保存超参数优化结果的目录路径。
        run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir),
    )
    # 这段代码的作用是。生成保存目录:使用 get_save_dir 函数生成保存超参数优化结果的目录。确保目录路径是绝对路径,并创建目录(如果不存在)。创建 Ray Tune 调优器:配置训练函数,并为其分配资源(CPU 和 GPU)。定义超参数搜索空间 space ,指定需要优化的参数及其范围。使用 ASHA 调度器动态调整试验,优化目标是最大化指定的指标。配置运行参数,包括回调(如 WandB)和保存路径。这种设计使得超参数优化过程更加高效和灵活,同时提供了详细的日志记录功能,便于用户监控和分析训练过程。

    # 这段代码的作用是执行超参数优化过程,获取优化结果,并清理 Ray 的资源。
    # Run the hyperparameter search    注释说明接下来的代码块用于执行超参数优化过程。
    # 调用 tuner.fit() 方法启动超参数优化过程。
    # 这一步会根据之前定义的配置(如调度器、搜索空间、资源分配等)运行多个试验,并自动调整超参数以寻找最优解。
    # 在优化过程中,Ray Tune 会根据调度器的策略(如 ASHA)动态停止表现不佳的试验,以节省资源。
    tuner.fit()

    # Get the results of the hyperparameter search    注释说明接下来的代码块用于获取超参数优化的结果。
    # 调用 tuner.get_results() 方法获取超参数优化的结果。
    # results 是一个对象,包含了所有试验的结果,包括每个试验的超参数配置、训练指标、最佳试验等信息。
    # 这些结果可以用于分析和评估超参数优化的效果,例如查看最佳试验的超参数配置和性能指标。
    results = tuner.get_results()

    # Shut down Ray to clean up workers    注释说明接下来的代码块用于关闭 Ray,清理资源。
    # 调用 ray.shutdown() 方法关闭 Ray。
    # 这一步会清理所有与 Ray 相关的资源,包括关闭后台进程和释放 GPU 资源。
    # 这是一个重要的步骤,尤其是在脚本运行结束后,以避免资源泄漏。
    ray.shutdown()

    # 返回 超参数优化的结果 results 。 这个结果对象可以被进一步分析,例如提取最佳试验的超参数配置,或者用于后续的模型训练和评估。
    return results
    # 这段代码的作用是。执行超参数优化:使用 tuner.fit() 启动超参数优化过程,根据定义的调度器和搜索空间运行多个试验。优化过程会动态调整超参数,寻找最优解。获取优化结果:使用 tuner.get_results() 获取超参数优化的结果,包括所有试验的超参数配置和性能指标。清理资源:使用 ray.shutdown() 关闭 Ray,清理所有与 Ray 相关的资源,确保资源不会泄漏。这种设计使得超参数优化过程更加高效和灵活,同时提供了完整的生命周期管理,从启动优化到清理资源,确保了整个过程的健壮性。
# 这段代码定义了一个函数 run_ray_tune ,用于使用 Ray Tune 框架进行超参数优化。它支持以下功能。超参数搜索空间:用户可以自定义超参数搜索空间,或者使用默认的搜索空间。资源分配:支持为每个试验分配 CPU 和 GPU 资源。调度策略:使用 ASHA 调度器动态调整试验,提高优化效率。日志记录:支持与 WandB 集成,记录训练过程和结果。保存目录管理:动态生成保存目录,避免路径冲突。灵活性:支持用户通过 train_args 传递额外的训练参数。这种设计使得超参数优化过程更加自动化、高效,并且易于扩展和集成。
# def run_ray_tune(model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args,):
# -> 用于使用 Ray Tune 框架进行超参数优化。它支持对模型的训练参数进行自动调优,并提供了与 WandB 集成的日志记录功能。返回 超参数优化的结果 results 。 这个结果对象可以被进一步分析,例如提取最佳试验的超参数配置,或者用于后续的模型训练和评估。
# -> return results

http://www.dtcms.com/a/20033.html

相关文章:

  • CAS单点登录(第7版)8.委托和代理
  • (PC+WAP) PbootCMS中小学教育培训机构网站模板 – 绿色小学学校网站源码下载
  • 1219:马走日
  • android studio下载安装汉化-Flutter安装
  • Shader示例 6: 卡渲基础 - 描边 + 着色
  • VisualStudio 2012 fatal error C1083: 无法打开包括文件:“stdio.h 找不到 sdkddkver.h
  • 【算法与数据结构】并查集详解+题目
  • CF91B Queue
  • 数组_有序数组的平方
  • 基于 ollama 在linux 私有化部署DeepSeek-R1以及使用RESTful API的方式使用模型
  • 机器学习:k均值
  • x-restormer——restormer+SSA
  • 【算法】【区间和】acwing算法基础 802. 区间和 【有点复杂,但思路简单】
  • 本地部署MindSearch(开源 AI 搜索引擎框架),然后上传到 hugging face的Spaces——L2G6
  • E卷-特殊的加密算法-(200分)
  • SIP中常见的服务器类型
  • vue项目 Axios创建拦截器
  • Win11 远程 连接 Ubuntu20.04(局域网)
  • lvs的DR模式
  • 易支付精美设计的支付收银台模板源码
  • typecho快速发布文章
  • Oracle序列(基础操作)
  • Spring Bean的生命周期
  • wordpress主题插件开发中高频使用的38个函数
  • 了解rtc_time64_to_tm()和rtc_tm_to_time64()
  • WEB安全--SQL注入--二次注入
  • TCP/UDP 简介,三次握手与四次挥手
  • 使用瑞芯微RK3588的NPU进行模型转换和推理
  • Python使用OpenCV图片去水印多种方案实现
  • Redis问题排查常用命令