TorchInductor - Autotune
Triton Autotune
Triton Kernel支持tl.constexpr类型的参数,这些参数会使用Triton的Autotune机制来寻优。
Triton提供了多种装饰器进行自动寻优:
triton.autotune:提供一个triton.Config列表,Triton会
多次运行 Kernel 函数来评估不同配置的性能,选出最优Config:
@triton.autotune(configs=[triton.Config({'BLOCK_SIZE': 128}, num_warps=1),triton.Config({'BLOCK_SIZE': 256}, num_warps=2),triton.Config({'BLOCK_SIZE': 512}, num_warps=4),],key=['n_elements'], # 根据输入的元素数量来选择最优配置
)
@triton.jit
def add_kernel(x_ptr, # 输入 x 的指针y_ptr, # 输入 y 的指针output_ptr, # 输出的指针n_elements, # 元素数量BLOCK_SIZE: tl.constexpr, # 块大小,这是一个编译时常量
):
- triton.heuristics:提供计算参数的值的函数,通过动态启发式方法来生成Config,再多次运行 Kernel 函数来评估不同配置的性能,选出最优Config:
@triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])})@triton.jitdef kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
perf_model:提供一个评估性能的函数,
基于理论分析的方法,通过预测不同配置的性能来提前筛选出可能的最优配置:
def perf_model(config, args):# 简单的性能模型:块大小越大,性能越好return 1 / config.kwargs['BLOCK_SIZE']@triton.autotune(configs=[triton.Config({'BLOCK_SIZE': 128}, num_warps=1),triton.Config({'BLOCK_SIZE': 256}, num_warps=2),triton.Config({'BLOCK_SIZE': 512}, num_warps=4),],key=['n_elements'],prune_configs_by={'perf_model': perf_model, 'top_k': 2},
)
@triton.jit
def add_kernel(x_ptr, # 输入 x 的指针y_ptr, # 输入 y 的指针output_ptr, # 输出的指针n_elements, # 元素数量BLOCK_SIZE: tl.constexpr, # 块大小,这是一个编译时常量
):
Pytorch triton_heuristics
torch.compile在生成Triton Kernel时,也会生成tl.constexpr类型的参数,并且在torch.runtime中实现了一套轻量化的triton.heuristics实现,叫做triton_heuristics:
@triton_heuristics.pointwise(size_hints={'x': 2097152},filename=__file__,triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=76, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '36505BED74EF047D8A064498F4B963488377F2A195BA2DC67799189257CDE669', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_mul_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
本文主要分析triton_heuristics的实现,triton_heuristics的实现主要分为2部分:
- triton.Config生成
- triton.Config寻优
triton.Config生成
triton_heuristics.pointwise
对pointwise类型Triton Kernel,torch.runtime使用triton_heuristics.pointwise进行自动调优。
triton_heuristics.pointwise生成Configs:
- 基于autotune_hints:来生成备选Configs:根据size_hints,计算元素个数numel和block size:bs,获取device的信息:triton_meta["device"],根据size_hints,bs,triton_meta["device"]计算备选Config。
- 根据size_hints,计算元素个数numel和block size:bs,尝试不同的num_elements_per_warp(64和256),不同的bs(bs和bs//2)
def pointwise(size_hints,triton_meta,tile_hint=None,filename=None,min_elem_per_thread=0,inductor_meta=None,
):"""Construct @triton.heuristics() based on size_hints."""inductor_meta = {} if inductor_meta is None else inductor_metaassert not inductor_meta.get("no_x_dim")numel = functools.reduce(operator.mul, size_hints.values())bs = max(256, min(numel // 128, 1024))hinted_configs = autotune_hints_to_configs(inductor_meta.get("autotune_hints", set()),size_hints,bs,triton_meta["device"],)triton_config_with_settings = functools.partial(triton_config, min_elem_per_thread=min_elem_per_thread)configs = Noneif len(size_hints) == 1:if disable_pointwise_autotuning(inductor_meta) and not (inductor_meta.get("max_autotune")or inductor_meta.get("max_autotune_pointwise")):configs = [triton_config_with_settings(size_hints, bs)]else:configs = [triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),triton_config_with_settings(size_hints, bs // 2, num_elements_per_warp=64),*hinted_configs,]if len(size_hints) == 2:if (disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE) and not (inductor_meta.get("max_autotune")or inductor_meta.get("max_autotune_pointwise")):configs = [triton_config_with_settings(size_hints, 32, 32)]else:configs = [triton_config_with_settings(size_hints, 32, 32),triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16triton_config_with_settings(size_hints, 256, 16),triton_config_with_settings(size_hints, 16, 256),triton_config_with_settings(size_hints, bs, 1),triton_config_with_settings(size_hints, 1, bs),*hinted_configs,]if len(size_hints) == 3:if disable_pointwise_autotuning(inductor_meta):configs = [triton_config_with_settings(size_hints, 16, 16, 16)]else:configs = [triton_config_with_settings(size_hints, 16, 16, 16),triton_config_with_settings(size_hints, 64, 8, 8),triton_config_with_settings(size_hints, 8, 64, 8),triton_config_with_settings(size_hints, 8, 8, 64),triton_config_with_settings(size_hints, bs, 1, 1),triton_config_with_settings(size_hints, 1, bs, 1),triton_config_with_settings(size_hints, 1, 1, bs),*hinted_configs,]if not configs:raise NotImplementedError(f"size_hints: {size_hints}")return cached_autotune(size_hints,configs,triton_meta=triton_meta,inductor_meta=inductor_meta,heuristic_type=HeuristicType.POINTWISE,filename=filename,)
triton_heuristics.template
对Matmul类型Kernel,torch.runtime使用triton_heuristics.template进行自动调优:接将模板参数透传给Triton。
return cached_autotune(None,[triton.Config({}, num_stages=num_stages, num_warps=num_warps)],triton_meta=triton_meta,inductor_meta=inductor_meta,heuristic_type=HeuristicType.TEMPLATE,filename=filename,)
triton.Config寻优
triton_heuristics调用cached_autotune进行triton.Config寻优:
- 如果有cache里的best_config,直接使用。
- 否则wrap Triton Kernel函数,返回decorator,在decorator中通过DebugAutotuner或CachingAutotuner进行自动寻优。
def cached_autotune(size_hints: Optional[List[int]],configs: List[Config],triton_meta,heuristic_type,filename=None,inductor_meta=None,custom_kernel=False,
):configs = unique_configs(configs)inductor_meta = {} if inductor_meta is None else inductor_metadisabled = inductor_meta.get("force_disable_caches", False)autotune_cache = Noneif (not disabledand filename is not Noneand (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))and not os.environ.get("TRITON_INTERPRET", "0") == "1"):configs_hash = hash_configs(configs)autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)if autotune_cache:if best_config := autotune_cache.read_best(inductor_meta, configs):configs = [best_config]else:if disabled:log.debug("autotune caching is disabled by config.force_disable_caches")mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())optimize_mem = inductor_meta.pop("optimize_mem", True)if "restore_value" in triton_meta:mutated_arg_names += triton_meta.pop("restore_value")reset_to_zero_arg_names: List[str] = []if "reset_to_zero" in triton_meta:reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero"))def decorator(fn):# Remove XBLOCK from config if it's not a function argument.# This way, coordinate descent tuning will not try to tune it.## Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1.import inspectif "XBLOCK" not in inspect.signature(fn.fn).parameters:for tconfig in configs:if "XBLOCK" in tconfig.kwargs:assert tconfig.kwargs["XBLOCK"] == 1tconfig.kwargs.pop("XBLOCK")if inductor_meta.get("profile_bandwidth"):return DebugAutotuner(fn,triton_meta=triton_meta,inductor_meta=inductor_meta,regex_filter=inductor_meta["profile_bandwidth_regex"],with_profiler=inductor_meta["profile_bandwidth_with_do_bench_using_profiling"],configs=configs,save_cache_hook=autotune_cache and autotune_cache.save,mutated_arg_names=mutated_arg_names,reset_to_zero_arg_names=reset_to_zero_arg_names,optimize_mem=optimize_mem,heuristic_type=heuristic_type,size_hints=size_hints,custom_kernel=custom_kernel,filename=filename,with_bandwidth_info=True,)return CachingAutotuner(fn,triton_meta=triton_meta,inductor_meta=inductor_meta,configs=configs,save_cache_hook=autotune_cache and autotune_cache.save,mutated_arg_names=mutated_arg_names,reset_to_zero_arg_names=reset_to_zero_arg_names,optimize_mem=optimize_mem,heuristic_type=heuristic_type,size_hints=size_hints,custom_kernel=custom_kernel,filename=filename,)return decorator
Triton所有的Autotune都继承自KernelInterface,定义了Triton JIT Kernel的调用接口:
class KernelInterface(Generic[T]):run: Tdef __getitem__(self, grid) -> T:"""A JIT function is launched with: fn[grid](*args, **kwargs).Hence JITFunction.__getitem__ returns a callable proxy thatmemorizes the grid."""return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
这个接口提供了两种方式来执行JIT Kernel:
- fn[grid](*args, **kwargs)
- fn.run(*args, **kwargs, grid=grid)
而Autotune就是在这个接口的基础上通过decorator实现了自动寻优:
- CachingAutotuner:使用Local/Remote Cache,缓存最优Config,在第二次调用时,直接取用。Pytorch里实现的这个CachingAutotuner相比Triton原生的Autotuner,省略了invalidation key机制,既不会因为条件变化(如Kernel调用参数等)而触发重新自动寻优。
DebugAutotune:在CachingAutotuner的基础上,增加带宽信息的Profiling。