GenerationMixin:generate
generate
以下是对您提供的 generate
方法的详细解释。这个方法用于大型语言模型(LLM)中的文本生成,尤其是具有语言模型头的模型。该方法包含了多个复杂的逻辑,支持多种生成模式,如贪心搜索、采样、束搜索等。
方法定义
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
...
参数说明
inputs
:可选,输入张量,形状可能因模态而异。用于作为生成的提示或编码器的输入。generation_config
:可选,GenerationConfig
对象,包含生成时的参数配置。logits_processor
:可选,LogitsProcessorList
对象,自定义的 logits 处理器列表,用于在生成过程中调整 logits。stopping_criteria
:可选,StoppingCriteriaList
对象,自定义的停止标准列表,用于在满足条件时终止生成。prefix_allowed_tokens_fn
:可选,函数,用于在每一步生成时限制允许的 token。synced_gpus
:可选,布尔值,指示是否在多 GPU 环境下同步运行以避免死锁。assistant_model
:可选,用于加速生成的辅助模型,必须具有相同的 tokenizer。streamer
:可选,用于流式输出生成序列的对象。negative_prompt_ids
:可选,torch.LongTensor
,形状为(batch_size, sequence_length)
,用于一些处理器(如 CFG)的负提示。negative_prompt_attention_mask
:可选,torch.LongTensor
,形状为(batch_size, sequence_length)
,对应negative_prompt_ids
的 attention mask。kwargs
:其他参数,可用于覆盖generation_config
中的设置,或传递给模型的forward
方法。
返回值
GenerateOutput
或torch.LongTensor
:根据参数设置,返回生成的序列或者包含生成过程详细信息的输出对象。
方法逻辑解析
总体流程
- 处理生成配置和参数验证:确保
generation_config
和kwargs
的正确性。 - 设置生成参数:根据传入的配置或默认值,设置生成所需的参数。
- 准备模型输入:处理输入张量,生成
input_ids
,并管理模型需要的其他关键字参数。 - 确定生成模式:根据配置,选择合适的生成模式,例如贪心搜索、采样、束搜索等。
- 生成序列:调用相应的生成函数,生成目标序列。
- 返回结果:根据参数设置,返回生成的序列或包含更多信息的对象。
详细步骤
1. 处理生成配置和参数验证
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None)
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None)
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
_validate_model_class
:检查模型的类型是否支持生成。_validate_model_class- 提取
tokenizer
:用于停止条件等,非必要参数。 - 准备
generation_config
和model_kwargs
:处理传入的配置和额外参数。_prepare_generation_config - 验证模型关键字参数:确保传入的参数与模型的预期一致。_validate_model_kwargs
- 验证辅助模型:如果使用了
assistant_model
,确保其与主模型兼容。
2. 设置生成参数
if synced_gpus is None:
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
synced_gpus
:在多 GPU 环境下,判断是否需要同步。logits_processor
和stopping_criteria
:初始化或使用默认的处理器和停止标准。- StoppingCriteria代码分析
- LogitsProcessor代码分析
- 注意力掩码相关变量:检查模型的
forward
方法是否接受attention_mask
,以及当前参数中是否提供。
3. 准备模型输入
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
_prepare_model_inputs
:处理输入张量,获取模型输入名称,并更新model_kwargs
。_prepare_model_inputsbatch_size
和device
:获取批次大小和设备信息。- 准备特殊 token:如
bos_token_id
、eos_token_id
等。_prepare_special_tokens
4. 检查并处理注意力掩码
# decoder-only models must use left-padding for batched generation.
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
generation_config._pad_token_tensor is not None
and batch_size > 1
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
这段代码是关于为仅解码器架构(decoder-only models)处理输入时的填充方式建议。它检查是否使用了右填充(right-padding),在这种情况下给出警告。
-
架构类型检查:
self.config.is_encoder_decoder
用于检查模型是否属于编码器-解码器架构。- 如果模型不是编码器-解码器架构(即它是仅解码器架构),并且不是在TorchDynamo编译模式下,代码继续进行。
-
输入条件检查:
- 确保批处理大小大于1,即有多个序列在一起进行处理。
- 检查输入张量
inputs_tensor
的形状满足条件:它是二维张量,一般形式为[batch_size, sequence_length]
。 - 检查序列中的最后一个标识符是否为
pad_token_id
,指定的generation_config._pad_token_tensor
不为None
。 torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
用于确定至少有一个序列的最后一个令牌是填充令牌。
-
警告日志:
- 如果满足以上条件,显示警告信息。
- 提示在仅解码器模型中检测到右填充,它建议使用左填充(
padding_side='left'
),以获取正确的生成结果。
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
generation_config.use_cache = True
这段代码针对仅解码器架构(decoder-only models)进行了一种配置调整,特别是在使用 inputs_embeds
作为输入的时候。以下是代码的简单说明:
-
架构类型检查:
not self.config.is_encoder_decoder
用于判断模型是否是仅解码器架构。- 如果模型是仅解码器架构,代码继续进行。
-
输入类型检查:
model_input_name == "inputs_embeds"
检查输入类型是否为嵌入层(embeddings)。inputs_embeds
通常表示已经经过词嵌入层的输入,这意味着模型接收的不是直接的令牌ID,而是对应的词向量。
-
缓存使用设置:
generation_config.use_cache = True
设置生成配置的use_cache
属性为True
。- 启用缓存对于跟踪生成的序列特别重要,尤其是在无法确定新生成的第一个令牌是否已生成时。
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config, model_kwargs
)
elif kwargs_has_attention_mask:
if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
raise ValueError("`attention_mask` passed to `generate` must be 2D.")
- 如果需要
attention_mask
:且未提供,则生成默认的attention_mask
。 - 验证提供的
attention_mask
:确保其形状正确。 - _prepare_attention_mask_for_generation
5. 为编码器-解码器模型准备输入
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config
)
- 对于编码器-解码器模型:如果未提供
encoder_outputs
,则进行编码器的前向计算,并更新model_kwargs
。计算编码器部分
6. 准备用于自回归生成的 input_ids
if self.config.is_encoder_decoder:
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config._decoder_start_token_tensor,
device=inputs_tensor.device,
)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
- 编码器-解码器模型:准备解码器的
input_ids
。 - _prepare_decoder_input_ids_for_generation
- _decoder_start_token_tensor
- 解码器模型:直接使用
inputs_tensor
作为input_ids
。
7. 处理特殊的生成配置
if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)
if streamer is not None:
streamer.put(input_ids.cpu())
token_healing
:如果启用了此项配置,修复输入中的 tokens。token_healingstreamer
:如果提供了流式处理器,传递当前的input_ids
。
8. 准备生成长度相关的参数
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
- 计算输入序列长度。
- 确定是否使用默认的最大和最小长度。
- 准备生成长度配置,可能会根据输入长度进行调整。_prepare_generated_length
9. 准备缓存和其他模型参数
if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
model_kwargs["logits_to_keep"] = 1
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
logits_to_keep
:如果模型支持,仅保留需要的 logits,减少内存占用。_supports_logits_to_keep- 验证生成长度:确保生成长度的合法性。_validate_generated_length
# 7. Prepare the cache.
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
# - different models have a different cache name expected by the model (default = "past_key_values")
# - `max_length`, prepared above, is used to determine the maximum cache length
max_cache_length = generation_config.max_length - 1
if (
inputs_tensor.shape[1] != input_ids_length
and model_input_name == "inputs_embeds"
and not self.config.is_encoder_decoder
):
max_cache_length += inputs_tensor.shape[1]
self._prepare_cache_for_generation(
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
)
- 准备缓存:为生成过程中的缓存(如注意力缓存)分配空间。
_prepare_cache_for_generation
10. 确定生成模式
generation_mode = generation_config.get_generation_mode(assistant_model)
- 根据生成配置和辅助模型,确定生成模式,例如:
- 辅助生成(Assisted Generation)
- DoLa 生成(DOLA Generation)
- 对比搜索(Contrastive Search)
- 采样或贪心搜索
- 束搜索(Beam Search)
- 组束搜索(Group Beam Search)
- 受限束搜索(Constrained Beam Search)
11. 准备 logits 处理器和停止标准
prepared_logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=inputs_tensor.device,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
)
model_kwargs["use_cache"] = generation_config.use_cache
- 获取 logits 处理器:整合默认和自定义的 logits 处理器,用于在生成过程中调整 logits。
- 获取停止标准:整合默认和自定义的停止标准,用于在满足条件时终止生成。
- 设置
use_cache
:根据配置,决定是否在生成过程中使用缓存。
_get_logits_processor
_get_stopping_criteria
12. 根据生成模式调用相应的生成函数
- 辅助生成
if generation_mode == GenerationMode.ASSISTED_GENERATION:
# 验证条件
# 获取候选生成器
# 执行辅助生成
- DoLa 生成
elif generation_mode == GenerationMode.DOLA_GENERATION:
# 执行 DoLa 解码
- 对比搜索
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
# 执行对比搜索
- 采样或贪心搜索
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 扩展 input_ids
# 执行采样或贪心搜索
- 束搜索
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 准备束搜索评分器
# 扩展 input_ids
# 执行束搜索
GenerationMixin:_sample方法(GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH)
- 组束搜索
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
# 准备组束搜索评分器
# 扩展 input_ids
# 执行组束搜索
- 受限束搜索
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
# 准备约束条件
# 准备受限束搜索评分器
# 扩展 input_ids
# 执行受限束搜索
13. 处理生成结果
# 如果需要,将缓存转换为传统格式
if (
generation_config.return_legacy_cache is True
and not is_torchdynamo_compiling()
and hasattr(result, "past_key_values")
and getattr(result.past_key_values, "to_legacy_cache") is not None
):
result.past_key_values = result.past_key_values.to_legacy_cache()
return result
- 转换缓存格式:如果配置需要,将生成过程中使用的缓存转换为传统格式。
- 返回结果:最终将生成的结果返回。
_validate_model_class
函数功能概述
这个函数名为_validate_model_class
,用于验证当前的模型类是否支持生成(generation)操作。如果不支持生成,则会抛出一个异常,提示用户使用合适的模型类。
-
条件判断
if not is_torchdynamo_compiling() and not self.can_generate():
- 这里有一个
if
条件,用于检查两个条件是否同时满足:not is_torchdynamo_compiling()
not self.can_generate()
is_torchdynamo_compiling()
:- 这是一个函数,检查当前是否处于 TorchDynamo 的编译环境中。
- TorchDynamo 是 PyTorch 的一个 JIT 编译器框架,可以对模型进行编译优化。
- 在编译过程中,某些检查可能需要被跳过,因此在编译时不进行此验证。
self.can_generate()
:- 这是模型类的一个方法,返回布尔值,表示模型是否支持生成操作。
- 如果模型具有语言模型头(language model head),则通常支持生成,即
self.can_generate()
返回True
。
因此,只有在不处于编译环境且模型不能生成的情况下,才会进入
if
语句内部,抛出异常。 - 这里有一个
-
定义支持生成的模型类名后缀列表
terminations_with_generation_support = [ "ForCausalLM", "ForConditionalGeneration", "ForSpeechSeq2Seq", "ForVision2Seq", ]
- 这是一个列表,包含了通常支持生成操作的模型类名称的后缀。
- 这些后缀包括:
"ForCausalLM"
:用于自回归语言模型,如 GPT-2、GPT-3 等。"ForConditionalGeneration"
:用于条件生成模型,如 BART、T5 等。"ForSpeechSeq2Seq"
:用于语音序列到序列模型。"ForVision2Seq"
:用于视觉到序列的模型,如图像描述生成。
-
抛出异常
raise TypeError( f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " "it doesn't have a language model head. Classes that support generation often end in one of these " f"names: {terminations_with_generation_support}." )
- 如果条件满足,说明当前模型不支持生成,则抛出一个
TypeError
异常。 - 异常信息包括:
- 当前模型类的名称:
{self.__class__.__name__}
。 - 说明模型不兼容
.generate()
方法,因为它没有语言模型头。 - 提示支持生成的模型类通常以哪些后缀结尾:
{terminations_with_generation_support}
。
- 当前模型类的名称:
- 如果条件满足,说明当前模型不支持生成,则抛出一个
_prepare_generation_config
函数功能概述
这个函数名为_prepare_generation_config
,它的作用是准备生成所需的配置对象generation_config
,并应用从kwargs
(关键字参数)中传入的任何生成配置选项。这个函数还处理了与模型配置文件(config
)的向后兼容性。
参数和返回值
-
参数:
self
:类的实例。generation_config
:可选的GenerationConfig
对象,表示生成配置。如果为None
,则需要从其他地方获取或生成。**kwargs
:其他关键字参数,可能包含生成配置的参数,将用于更新generation_config
。
-
返回值:
generation_config
:最终准备好的生成配置对象。model_kwargs
:用于模型的关键字参数字典。
处理逻辑详解
-
初始设置
# 设置一个标志,指示是否使用模型的默认生成配置 using_model_generation_config = False
-
处理
generation_config
为None
的情况if generation_config is None: ...
当用户没有提供
generation_config
时,需要从模型中获取。但在处理之前,先考虑到可能的向后兼容性问题。-
遗留(Legacy)行为的处理
# 遗留支持:用户可能修改了模型的配置来控制生成。要触发这种遗留行为,需要满足以下条件: # 1) generation_config 是从模型配置创建的(`_from_model_config`字段为True) # 2) generation_config 自创建以来没有被修改过(哈希值相同) # 3) 模型配置中有非默认的生成参数 # 4) 用户在模型配置中设置了新的生成参数 # 注意:`torch.compile`无法编译`hash`函数,因此在编译时,这种遗留支持被禁用 if ( not is_torchdynamo_compiling() and self.generation_config._from_model_config # 条件1 and self.generation_config._original_object_hash == hash(self.generation_config) # 条件2 and len(self.config._get_non_default_generation_parameters()) > 0 # 条件3 ): new_generation_config = GenerationConfig.from_model_config(self.config) if new_generation_config != self.generation_config: # 条件4 warnings.warn( "You have modified the pretrained model configuration to control generation. This is a" " deprecated strategy to control generation and will be removed in v5." " Please use and modify the model generation configuration (see" " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )", UserWarning, ) self.generation_config = new_generation_config
-
解释条件:
self.generation_config._from_model_config
:- 检查
generation_config
是否是从模型配置创建的。
- 检查
self.generation_config._original_object_hash == hash(self.generation_config)
:- 检查
generation_config
自创建以来是否没有被修改过。 - 由于
torch.compile
无法编译hash
函数,因此在编译时无法进行此检查。
- 检查
len(self.config._get_non_default_generation_parameters()) > 0
:- 检查模型配置中是否有非默认的生成参数。
new_generation_config != self.generation_config
:- 检查新的生成配置是否与当前的不同,表示用户在模型配置中设置了新的生成参数。
-
操作:
- 如果以上条件都满足,表示用户通过修改模型配置来控制生成。这是一种已被弃用的做法。
- 发出警告,提示此策略将在v5版本中移除,建议用户使用并修改生成配置对象
generation_config
。 - 更新
self.generation_config
为新的生成配置new_generation_config
。
-
-
设置
generation_config
generation_config = self.generation_config using_model_generation_config = True
- 如果没有提供
generation_config
,则使用模型的默认生成配置self.generation_config
。 - 设置标志
using_model_generation_config = True
,表示正在使用模型的默认生成配置。
- 如果没有提供
-
-
处理
torch.compile
相关的问题# `torch.compile`无法编译`copy.deepcopy`等函数,因此需要根据是否在编译中,决定如何处理 if not is_torchdynamo_compiling(): ... else: model_kwargs = kwargs
-
非编译环境下的处理
if not is_torchdynamo_compiling(): generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) ...
-
深拷贝
generation_config
:- 使用
copy.deepcopy
创建generation_config
的深拷贝,以避免修改原始对象。 - 由于
torch.compile
无法编译copy.deepcopy
,因此在编译环境下无法进行此操作。
- 使用
-
更新
generation_config
:- 调用
generation_config.update(**kwargs)
方法,用传入的kwargs
更新生成配置。 - 这个方法返回未被
generation_config
使用的参数,即那些不属于生成配置的参数,存储在model_kwargs
中。
- 调用
-
处理特殊的Token ID:
# 如果提供了`generation_config`,需要确保所有特殊的Token ID都有默认值 if not using_model_generation_config: if generation_config.bos_token_id is None: generation_config.bos_token_id = self.generation_config.bos_token_id if generation_config.eos_token_id is None: generation_config.eos_token_id = self.generation_config.eos_token_id if generation_config.pad_token_id is None: generation_config.pad_token_id = self.generation_config.pad_token_id if generation_config.decoder_start_token_id is None: generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
- 如果用户提供了自己的
generation_config
(即不使用模型的默认生成配置),需要确保特殊的Token ID(开始、结束、填充、解码器开始)有默认值。 - 如果这些ID在用户提供的
generation_config
中为None
,则使用模型默认的self.generation_config
中的值。
- 如果用户提供了自己的
-
-
编译环境下的处理
else: model_kwargs = kwargs
- 在编译环境下,由于无法使用
copy.deepcopy
和hash
,直接将传入的kwargs
赋值给model_kwargs
。 - 不进行深拷贝和更新操作。
- 在编译环境下,由于无法使用
-
-
返回结果
return generation_config, model_kwargs
- 函数返回准备好的
generation_config
和model_kwargs
。 model_kwargs
包含了模型需要的其他参数。
- 函数返回准备好的
_validate_model_kwargs
函数功能概述
这个函数名为_validate_model_kwargs
,用于在生成(generation)过程中对传入的模型关键字参数model_kwargs
进行验证。它的主要作用是:
- 验证传入的
model_kwargs
是否被模型的生成方法使用,以确保参数的正确性,防止参数拼写错误或传入不支持的参数。 - 给出明确的错误信息,帮助用户及时发现并修正问题。
1. 检查past_key_values
是否为Cache
实例,并验证模型是否支持
if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
raise ValueError(
f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
"check the model documentation for supported cache formats."
)
-
目的:如果
model_kwargs
中包含past_key_values
,并且它是一个Cache
实例,而模型不支持Cache
类型的缓存,则抛出错误。 -
解释:
model_kwargs.get("past_key_values", None)
:尝试获取past_key_values
的值,如果不存在则为None
。isinstance(..., Cache)
:检查past_key_values
是否是Cache
类的实例。not self._supports_cache_class
:检查当前模型是否不支持Cache
类型的缓存。
-
错误处理:如果以上条件满足,则抛出
ValueError
,提示用户当前模型不支持Cache
实例作为past_key_values
,并建议查看模型文档以了解支持的缓存格式。
2. 对于编码器-解码器模型,移除特定的参数
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)
-
目的:在生成过程中,不需要直接传入
decoder_input_ids
,因此从model_kwargs
中移除该参数,防止后续产生错误。 -
解释:
self.config.is_encoder_decoder
:检查模型是否为编码器-解码器结构。model_kwargs.pop(key, None)
:从model_kwargs
中移除指定的键,如果不存在则返回None
。
3. 初始化未使用的模型参数列表和获取模型方法参数
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
-
unused_model_args
:用于收集未被模型方法使用的参数名。 -
model_args
:获取prepare_inputs_for_generation
方法的参数名集合,表示模型在生成过程中可能使用的参数。 -
解释:
inspect.signature(...)
:获取函数的完整参数签名。.parameters
:获取参数名组成的有序字典。set(...)
:将参数名转换为集合,方便后续操作。
4. 扩展模型参数集合,包含forward
方法的参数
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
-
目的:如果
prepare_inputs_for_generation
方法接受kwargs
(可变关键字参数),则需要将self.forward
方法的参数也包含进来,因为这些参数可能会通过kwargs
传递。 -
解释:
- 检查
model_args
中是否包含"kwargs"
或"model_kwargs"
,表示接受可变关键字参数。 - 使用
set.union
(|=
)操作,将self.forward
方法的参数集合并入model_args
。
- 检查
5. 对于编码器-解码器模型,包含编码器和解码器的参数
if self.config.is_encoder_decoder:
base_model = getattr(self, self.base_model_prefix, None)
# 允许编码器的参数
encoder = getattr(self, "encoder", None)
# 特殊情况处理(如Musicgen模型)
if encoder is None and base_model is not None:
encoder = getattr(base_model, "encoder", None)
if encoder is not None:
encoder_model_args = set(inspect.signature(encoder.forward).parameters)
model_args |= encoder_model_args
# 允许解码器的参数
decoder = getattr(self, "decoder", None)
if decoder is None and base_model is not None:
decoder = getattr(base_model, "decoder", None)
if decoder is not None:
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}
-
目的:将编码器和解码器的
forward
方法的参数名添加到model_args
中,以便在验证model_kwargs
时考虑到这些参数。 -
详细解释:
-
base_model = getattr(self, self.base_model_prefix, None)
:获取基础模型实例,self.base_model_prefix
通常是模型的前缀名。 -
处理编码器:
encoder = getattr(self, "encoder", None)
:尝试直接获取self.encoder
属性。- 如果
encoder
为None
且base_model
存在,则尝试从base_model
中获取encoder
。 - 获取
encoder.forward
方法的参数名集合encoder_model_args
。 - 将
encoder_model_args
并入model_args
。
-
处理解码器:
- 类似地,获取
decoder
实例。 - 获取
decoder.forward
方法的参数名集合decoder_model_args
。 - 将
decoder_model_args
的参数名加上前缀"decoder_"
,以匹配生成过程中参数的命名方式。 - 将修改后的
decoder_model_args
并入model_args
。
- 类似地,获取
-
-
特殊情况说明:
- 对于某些模型(例如
MusicgenForConditionalGeneration
),编码器和解码器的命名和结构可能与标准情况不同,需要特别处理。
- 对于某些模型(例如
6. 检查model_kwargs
中的未使用参数
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
-
目的:找出
model_kwargs
中那些未被模型方法使用的参数。 -
解释:
- 遍历
model_kwargs
中的所有键值对。 - 如果参数值不为
None
且参数名不在model_args
集合中,说明模型并不会使用该参数。 - 将这些未使用的参数名添加到
unused_model_args
列表中。
- 遍历
7. 抛出错误提示未使用的参数
if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
-
目的:如果存在未使用的参数,抛出
ValueError
,提醒用户这些参数未被模型使用,可能存在拼写错误或传入了不支持的参数。 -
解释:
- 检查
unused_model_args
列表是否非空。 - 抛出错误,错误信息中包含未使用的参数列表,并提示用户可能是由于参数名的拼写错误导致的。
- 检查
_prepare_model_inputs
方法定义
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[torch.Tensor] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# 方法体...
参数说明:
inputs
:可选的torch.Tensor
,表示要作为模型输入的张量。可能是input_ids
、inputs_embeds
等形式,具体取决于模型的要求。bos_token_id
:可选的torch.Tensor
,表示序列开始的 token ID(BOS = Begin Of Sequence)。在生成任务中,如果未提供输入,可能需要使用该 token 进行初始化。model_kwargs
:可选的字典,包含传递给模型的其他关键字参数。
返回值:
inputs
:torch.Tensor
,准备好的模型输入张量。input_name
:str
,模型输入的名称,可能是input_ids
或inputs_embeds
。model_kwargs
:字典,更新后的模型关键字参数。
方法功能概述
该方法的主要目的是在生成过程中,准备和验证模型的输入,确保输入与模型的预期格式和要求一致。具体任务包括:
- 确定模型所需的主要输入名称(
input_name
),这可能取决于模型是编码器-解码器模型还是仅解码器模型。 - 处理传入的
inputs
和model_kwargs
,以避免重复传递相同的输入参数。 - 在需要时,使用
bos_token_id
初始化input_ids
,以开始生成过程。 - 对于支持
inputs_embeds
的模型,正确处理inputs_embeds
参数。
逐步详解
步骤 1:确定模型的主要输入名称
# 判断模型是否是编码器-解码器,并获取正确的输入名称
if (
self.config.is_encoder_decoder
and hasattr(self, "encoder")
and self.encoder.main_input_name != self.main_input_name
):
input_name = self.encoder.main_input_name
else:
input_name = self.main_input_name
解释:
- 目的:获取模型预期的主要输入参数名称,可能是
input_ids
、inputs_embeds
等。 - 逻辑:
- 如果模型是 编码器-解码器模型,并且编码器的
main_input_name
与模型的main_input_name
不同,则使用编码器的main_input_name
。- 这是因为编码器和解码器可能需要不同的输入名称。例如,一些模型的编码器可能使用
inputs_embeds
,而解码器使用input_ids
。
- 这是因为编码器和解码器可能需要不同的输入名称。例如,一些模型的编码器可能使用
- 否则,使用模型的
main_input_name
作为输入名称。
- 如果模型是 编码器-解码器模型,并且编码器的
示例:
- 如果
self.main_input_name
为'input_ids'
,而self.encoder.main_input_name
为'inputs_embeds'
,则对于编码器,需要使用'inputs_embeds'
作为输入名称。
步骤 2:从 model_kwargs
中提取非输入相关的参数
# 过滤掉模型输入的关键字参数,以及值为 None 的参数
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
解释:
- 目的:清理
model_kwargs
,只保留非模型输入的参数和值非None
的参数。 - 逻辑:
- 遍历
model_kwargs
中的所有键值对,保留以下情况的参数:- 值
v
不为None
。 - 或者键
k
不是模型的主要输入名称input_name
。
- 值
- 遍历
- 原因:避免模型输入参数(如
input_ids
)重复出现在inputs
和model_kwargs
中,以防止冲突。
示例:
- 如果
input_name
是'input_ids'
,并且model_kwargs
包含{'input_ids': None, 'attention_mask': tensor([...])}
,那么经过此步骤后,model_kwargs
变为{'attention_mask': tensor([...])}
。
步骤 3:检查模型输入是否作为关键字参数传入
# 从 model_kwargs 中获取输入参数
inputs_kwarg = model_kwargs.pop(input_name, None)
# 检查是否同时传入了 inputs 和 inputs_kwarg
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "
f"Make sure to either pass {inputs} or {input_name}=..."
)
elif inputs_kwarg is not None:
inputs = inputs_kwarg
解释:
- 目的:确保模型输入参数只通过一个途径传入,避免冲突。
- 逻辑:
- 从
model_kwargs
中取出键为input_name
的参数,赋值给inputs_kwarg
。 - 如果同时传入了
inputs
和inputs_kwarg
,抛出异常,提示用户只能通过一种方式传入模型输入。 - 如果
inputs
为None
,但inputs_kwarg
不为None
,则将inputs_kwarg
赋值给inputs
。
- 从
示例:
- 用户通过位置参数传入了
inputs
,同时在model_kwargs
中传入了input_ids
,这将导致异常,因为模型无法确定使用哪个输入。
步骤 4:处理 inputs_embeds
的情况
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
if not self.config.is_encoder_decoder:
# 检查模型是否支持 inputs_embeds
has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
)
if not has_inputs_embeds_forwarding:
raise ValueError(
f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
# 将 input_ids 初始化并加入 model_kwargs
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs=model_kwargs
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
# 更新 inputs 和 input_name
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
解释:
- 目的:处理用户通过
inputs_embeds
提供输入的情况,确保模型支持这种输入方式,并正确处理。 - 逻辑:
- 当模型的输入名称为
'input_ids'
,且model_kwargs
中存在'inputs_embeds'
键时,进入此逻辑。 - 对于非编码器-解码器模型:
- 检查模型的
prepare_inputs_for_generation
方法是否接受inputs_embeds
参数。 - 如果不支持,则抛出异常,提示模型不支持通过
inputs_embeds
进行生成。 - 如果支持,则需要初始化
input_ids
,以便在生成过程中处理诸如 attention mask 等依赖input_ids
的自动操作。 - 将初始化的
input_ids
添加到model_kwargs
中。
- 检查模型的
- 对于编码器-解码器模型:
- 如果同时传入了
inputs
和inputs_embeds
,抛出异常,提示只能选择一种输入方式。
- 如果同时传入了
- 最后,将
inputs
设置为inputs_embeds
,并更新input_name
为'inputs_embeds'
。
- 当模型的输入名称为
示例:
- 用户在生成时提供了
inputs_embeds
,如果模型支持,则将其用于生成,否则抛出异常。
步骤 5:如果 inputs
仍为 None
,尝试从 bos_token_id
创建 input_ids
# 如果 inputs 为 None,使用 bos_token_id 初始化 input_ids
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
解释:
- 目的:在用户未提供任何输入的情况下,使用开始标记
bos_token_id
初始化输入,以启动生成过程。 - 逻辑:
- 调用
_maybe_initialize_input_ids_for_generation
方法,如果inputs
为None
,则尝试使用bos_token_id
创建input_ids
。 - 如果
inputs
已存在,则保持不变。
- 调用
示例:
- 用户未提供
inputs
,模型使用bos_token_id
[101]
(假设为开始标记的 ID)初始化input_ids
。
步骤 6:返回处理后的结果
# 返回准备好的 inputs、input_name 和更新后的 model_kwargs
return inputs, input_name, model_kwargs
示例:
# 示例 1:用户只提供 inputs
inputs = torch.tensor([[101, 102, 103]])
inputs, input_name, model_kwargs = model._prepare_model_inputs(inputs=inputs)
# 示例 2:用户只提供 input_ids 作为关键字参数
model_kwargs = {'input_ids': torch.tensor([[101, 102, 103]])}
inputs, input_name, model_kwargs = model._prepare_model_inputs(model_kwargs=model_kwargs)
# 示例 3:用户提供 inputs_embeds,模型支持 inputs_embeds
model_kwargs = {'inputs_embeds': torch.tensor([...])}
inputs, input_name, model_kwargs = model._prepare_model_inputs(model_kwargs=model_kwargs)
# 示例 4:用户未提供任何输入,使用 bos_token_id 初始化
bos_token_id = torch.tensor([101])
inputs, input_name, model_kwargs = model._prepare_model_inputs(bos_token_id=bos_token_id)
_prepare_special_tokens
主要目的是准备生成所需的特殊标记(如开始标记 bos_token_id
、结束标记 eos_token_id
、填充标记 pad_token_id
和解码器起始标记 decoder_start_token_id
),并将这些标记转换为张量。具体步骤如下:
-
定义辅助函数
_tensor_or_none
:- 将特殊标记转换为张量,如果标记为
None
则返回None
。
- 将特殊标记转换为张量,如果标记为
-
将特殊标记转换为张量:
- 使用
_tensor_or_none
函数将bos_token_id
、eos_token_id
、pad_token_id
和decoder_start_token_id
转换为张量。
- 使用
-
处理编码器-解码器模型:
- 如果模型是编码器-解码器类型,并且
decoder_start_token_id
未设置,则使用bos_token_id
作为decoder_start_token_id
。
- 如果模型是编码器-解码器类型,并且
-
处理
eos_token_tensor
:- 如果
eos_token_tensor
是 0 维张量,则将其扩展为 1 维张量。
- 如果
-
设置
pad_token_tensor
:- 如果
pad_token_tensor
未设置且eos_token_tensor
存在,则将pad_token_tensor
设置为eos_token_tensor
的第一个元素,并发出警告。
- 如果
-
安全检查和警告:
- 检查编码器-解码器模型是否设置了
decoder_start_token_id
。 - 检查
eos_token_tensor
是否与pad_token_tensor
相同,并在未设置注意力掩码时发出警告。 - 检查
eos_token_tensor
是否包含负数或浮点数,并发出警告。
- 检查编码器-解码器模型是否设置了
-
更新生成配置:
- 将转换后的特殊标记张量存储在
generation_config
的新属性中,以启用端到端编译。
- 将转换后的特殊标记张量存储在
_prepare_attention_mask_for_generation
该方法用于在生成过程中为模型准备 attention_mask
,以确保模型在生成序列时正确地关注输入序列的非填充部分(非 pad_token
部分)。
方法定义
def _prepare_attention_mask_for_generation(
self,
inputs_tensor: torch.Tensor,
generation_config: GenerationConfig,
model_kwargs: Dict[str, Any],
) -> torch.LongTensor:
# 方法体...
参数说明:
inputs_tensor
:torch.Tensor
,输入张量,可能是模型的主要输入。generation_config
:GenerationConfig
,生成配置对象,包含生成过程中所需的配置参数。model_kwargs
:Dict[str, Any]
,包含模型其他关键字参数的字典。
返回值:
attention_mask
:torch.LongTensor
,返回生成过程中使用的attention_mask
,用于指示模型需要关注的输入序列位置。
方法功能概述
该方法的主要目的是根据输入张量 inputs_tensor
和生成配置 generation_config
,为生成过程准备合适的 attention_mask
。具体而言,它会根据以下情况生成 attention_mask
:
- 如果输入张量包含
pad_token_id
,则在attention_mask
中标记出非填充的位置(即非pad_token_id
的位置为 1,pad_token_id
的位置为 0)。 - 如果无法判断是否需要生成
attention_mask
,则返回默认的attention_mask
,即全 1 的张量,表示所有位置都需要关注。
逐步详解
步骤 1:获取 pad_token_id
和 eos_token_id
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
- 说明:从
generation_config
中获取用于填充的pad_token_id
和序列结束的eos_token_id
。
步骤 2:检查 model_kwargs
中是否有 input_ids
# `input_ids` 可能存在于 model_kwargs 中,而不是主要输入(例如多模态模型)
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
inputs_tensor = model_kwargs["input_ids"]
-
解释:
- 在某些情况下,
inputs_tensor
并不是模型的主要输入,如在多模态模型中,input_ids
可能存在于model_kwargs
中。 - 如果
model_kwargs
中有input_ids
,并且其形状的第二维度长度大于 0(即有数据),则用model_kwargs["input_ids"]
替换inputs_tensor
。
- 在某些情况下,
-
目的:确保
inputs_tensor
是正确的input_ids
,以便后续的attention_mask
计算。
步骤 3:创建默认的 attention_mask
# 如果无法推断 attention mask,则返回默认的 attention mask
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
-
解释:
- 创建一个形状与
inputs_tensor
前两个维度相同的全 1 张量,作为默认的attention_mask
。 - 该默认
attention_mask
表示输入序列的所有位置都需要关注。 - 切片的基本语法是 start: end:step,其中 start 是起始索引,end 是结束索引(不包括),step 是步长。写为 [:2] 是一种简写表示方式,没有显式指定起始和步长,默认从头开始且步长为1
- 创建一个形状与
步骤 4:如果 pad_token_id
为 None
,返回默认的 attention_mask
if pad_token_id is None:
return default_attention_mask
-
说明:
- 如果
pad_token_id
未定义,则无法根据填充标记生成attention_mask
,因此直接返回默认的attention_mask
。
- 如果
步骤 5:检查 inputs_tensor
是否为有效的 input_ids
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
if not is_input_ids:
return default_attention_mask
-
解释:
-
检查
inputs_tensor
是否满足以下条件:inputs_tensor
的形状为二维,即形状为(batch_size, sequence_length)
。inputs_tensor
的数据类型为整数类型torch.int
或torch.long
。
-
如果不满足上述条件,则认为无法根据
inputs_tensor
推断出attention_mask
,因此返回默认的attention_mask
。
-
步骤 6:检查 inputs_tensor
中是否包含 pad_token_id
is_pad_token_in_inputs = (pad_token_id is not None) and (
isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
)
-
解释:
-
使用
isin_mps_friendly
函数检查inputs_tensor
中是否包含pad_token_id
。isin_mps_friendly(elements, test_elements)
:检查elements
中的元素是否在test_elements
中,返回布尔张量。
-
isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
:- 检查
inputs_tensor
中是否有元素等于pad_token_id
,如果有,则返回True
。
- 检查
-
-
结果:
is_pad_token_in_inputs
为一个布尔值,表示inputs_tensor
中是否包含pad_token_id
。
步骤 7:检查 pad_token_id
是否不等于 eos_token_id
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
)
-
解释:
-
需要确保
pad_token_id
和eos_token_id
不相等,以避免将eos_token_id
误认为填充标记。 -
逻辑:
-
如果
eos_token_id
为None
,则认为pad_token_id
不等于eos_token_id
。 -
如果
eos_token_id
不为None
,则检查eos_token_id
是否等于pad_token_id
:-
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
:- 检查
eos_token_id
是否等于pad_token_id
。
- 检查
-
使用按位取反运算符
~
,将结果取反。
-
-
-
-
结果:
is_pad_token_not_equal_to_eos_token_id
为一个布尔值,表示pad_token_id
是否不等于eos_token_id
。
步骤 8:确定是否可以推断 attention_mask
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
-
解释:
- 通过将
is_pad_token_in_inputs
和is_pad_token_not_equal_to_eos_token_id
相乘(布尔值相乘相当于逻辑与),判断是否可以基于pad_token_id
推断出attention_mask
。
- 通过将
-
结果:
-
can_infer_attention_mask
为布尔值:-
如果
True
,表示可以根据pad_token_id
推断attention_mask
。 -
如果
False
,则无法推断,需要返回默认的attention_mask
。
-
-
步骤 9:基于 pad_token_id
生成 attention_mask
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
-
解释:
-
使用
inputs_tensor.ne(pad_token_id)
,得到一个布尔张量,标记出inputs_tensor
中不等于pad_token_id
的位置。 -
.long()
:将布尔张量转换为长整型,即True
转换为1
,False
转换为0
。
-
-
结果:
attention_mask_from_padding
为一个长整型张量,形状与inputs_tensor
相同,非pad_token_id
的位置为1
,pad_token_id
的位置为0
。
步骤 10:根据是否可以推断 attention_mask
,选择最终的 attention_mask
attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
-
解释:
-
如果可以推断
attention_mask
(can_infer_attention_mask
为True
):- 使用
attention_mask_from_padding
。
- 使用
-
如果不能推断:
- 使用
default_attention_mask
。
- 使用
-
计算方式:
-
attention_mask = attention_mask_from_padding * can_infer_attention_mask
:-
当
can_infer_attention_mask
为True
时,attention_mask
等于attention_mask_from_padding
。 -
当
can_infer_attention_mask
为False
时,乘积为0
。
-
-
default_attention_mask * ~can_infer_attention_mask
:-
~can_infer_attention_mask
对can_infer_attention_mask
取反。 -
当
can_infer_attention_mask
为False
时,~can_infer_attention_mask
为True
。 -
因此,当不能推断时,
attention_mask
等于default_attention_mask
。
-
-
最终,将两个结果相加,得到正确的
attention_mask
。
-
-
-
结果:
attention_mask
为最终的注意力掩码张量,表示模型在生成过程中需要关注的输入位置。
步骤 11:返回最终的 attention_mask
return attention_mask
-
解释:
- 返回生成的
attention_mask
张量,用于在生成过程中指导模型关注正确的输入序列位置。
- 返回生成的
_prepare_encoder_decoder_kwargs_for_generation
以下是对您提供的 _prepare_encoder_decoder_kwargs_for_generation
方法的详细解释。该方法用于在生成过程中为 编码器-解码器模型 准备必要的关键字参数,以便在生成序列时正确调用编码器。
方法定义
def _prepare_encoder_decoder_kwargs_for_generation(
self,
inputs_tensor: torch.Tensor,
model_kwargs,
model_input_name: Optional[str],
generation_config: GenerationConfig,
) -> Dict[str, Any]:
# 方法体...
参数说明:
inputs_tensor
:torch.Tensor
,输入张量,通常是input_ids
,表示输入序列的标记(tokens)。model_kwargs
:Dict[str, Any]
,包含传递给模型的其他关键字参数。model_input_name
:Optional[str]
,模型输入的名称,默认为None
。如果为None
,则使用self.main_input_name
。generation_config
:GenerationConfig
,生成配置对象,包含生成过程中所需的配置参数。
返回值:
model_kwargs
: 返回更新后的model_kwargs
,其中包括编码器的输出encoder_outputs
,以供生成器使用。
方法功能概述
该方法的主要目的是:
- 获取模型的 编码器 部分。
- 从
model_kwargs
和generation_config
中提取 编码器所需的参数,并准备传递给编码器的关键字参数encoder_kwargs
。 - 调用 编码器的
forward
方法,获取编码器的输出,并将其添加到model_kwargs
中,以供 解码器 在生成过程中使用。
逐步详解
步骤 1:获取编码器
# 1. get encoder
encoder = self.get_encoder()
-
说明:
- 调用模型的
get_encoder()
方法,获取编码器对象。 - 该编码器将用于处理输入的
inputs_tensor
,生成编码器的输出。
- 调用模型的
步骤 1.1:兼容性处理
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(self, "hf_device_map"):
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
else:
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
-
解释:
- 目的:确保在使用 Accelerate 库进行大型模型推理时,编码器的输出与输入位于 同一设备 上(如 GPU),避免跨设备的数据传输开销。
-
逻辑:
- 检查模型是否具有
hf_device_map
属性,如果存在,表示模型使用了 Accelerate 库进行设备映射。 - 检查编码器是否具有
_hf_hook
属性:- 如果有,设置其
io_same_device
属性为True
,表示编码器的输入和输出在同一设备上。 - 如果没有,使用
add_hook_to_module
函数,将AlignDevicesHook(io_same_device=True)
添加到编码器模块上。
- 如果有,设置其
- 检查模型是否具有
-
相关函数:
add_hook_to_module(module, hook)
: 将钩子函数添加到指定的模块上,控制模块的输入输出行为。
步骤 2:准备编码器的参数
# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix)
}
-
目的:
- 从
model_kwargs
中提取与编码器相关的参数,过滤掉与解码器或交叉注意力相关的参数。
- 从
-
逻辑:
- 定义一个列表
irrelevant_prefix
,包含了不相关的参数前缀,如"decoder_"
、"cross_attn"
、"use_cache"
。 - 使用字典推导式,从
model_kwargs
中过滤掉以这些前缀开头的参数。 - 结果是
encoder_kwargs
,其中包含了需要传递给编码器的参数。
- 定义一个列表
-
示例:
-
如果
model_kwargs
包含:model_kwargs = { "input_ids": tensor(...), "attention_mask": tensor(...), "decoder_input_ids": tensor(...), "use_cache": True, }
-
过滤后,
encoder_kwargs
为:encoder_kwargs = { "input_ids": tensor(...), "attention_mask": tensor(...), }
-
步骤 2.1:检查编码器的签名
encoder_signature = set(inspect.signature(encoder.forward).parameters)
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
if not encoder_accepts_wildcard:
encoder_kwargs = {
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
}
-
解释:
- 目的:确保传递给编码器的参数在其
forward
方法的参数列表中,即编码器能够接受这些参数。
- 目的:确保传递给编码器的参数在其
-
逻辑:
- 使用
inspect.signature(encoder.forward).parameters
获取编码器forward
方法的参数名称集合encoder_signature
。 - 检查编码器是否接受通配参数
**kwargs
或**model_kwargs
,如果接受,则无需进一步过滤参数。 - 如果编码器不接受通配参数,则过滤
encoder_kwargs
,仅保留在encoder_signature
中的参数。
- 使用
步骤 2.2:添加生成配置中的参数
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
-
说明:
- 从
generation_config
中提取output_attentions
和output_hidden_states
配置,添加到encoder_kwargs
中。 - 这些配置决定了编码器在前向计算时,是否返回注意力权重和隐藏状态。
- 从
步骤 3:调用编码器并更新 model_kwargs
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore
return model_kwargs
-
解释:
-
确保编码器返回
ModelOutput
对象:- 设置
encoder_kwargs["return_dict"] = True
,使编码器的输出为ModelOutput
格式,而不是元组。
- 设置
-
准备编码器的输入:
- 确定模型的输入名称
model_input_name
,如果传入了model_input_name
,则使用它,否则使用self.main_input_name
。 - 将
inputs_tensor
添加到encoder_kwargs
,键为model_input_name
。
- 确定模型的输入名称
-
调用编码器的
forward
方法:model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore
- 调用编码器,传入
encoder_kwargs
,得到编码器的输出encoder_outputs
。 - 将
encoder_outputs
添加到model_kwargs
中,键为"encoder_outputs"
。 - 使用类型注解
: ModelOutput
指定类型,这是为了让静态类型检查器知道encoder_outputs
的类型。
- 调用编码器,传入
-
返回更新后的
model_kwargs
:- 包含了
encoder_outputs
,供后续的解码器生成过程中使用。
- 包含了
-
_prepare_decoder_input_ids_for_generation
以下是对您提供的 _prepare_decoder_input_ids_for_generation
方法的详细解释。这个方法用于在生成过程中为 编码器-解码器模型 准备 decoder_input_ids
,确保解码器在生成时能够正确地开始。
方法定义
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
model_input_name: str,
model_kwargs: Dict[str, torch.Tensor],
decoder_start_token_id: torch.Tensor,
device: torch.device = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 方法体...
参数说明:
batch_size
:int
,表示批次大小,即一次输入中样本的数量。model_input_name
:str
,模型输入的名称,通常为"input_ids"
。model_kwargs
:Dict[str, torch.Tensor]
,包含传递给模型的其他关键字参数。decoder_start_token_id
:torch.Tensor
,解码器开始标记的 token ID。device
:torch.device
,可选,指定要在哪个设备上运行,如 GPU 或 CPU。如果未指定,则使用模型默认的设备。
返回值:
decoder_input_ids
:torch.LongTensor
,准备好的用于生成的解码器输入 IDs。model_kwargs
: 更新后的模型关键字参数字典。
方法功能概述
这个方法的主要作用是:
- 检查用户是否手动提供了
decoder_input_ids
,如果没有,则根据情况初始化它。 - 确保
decoder_start_token_id
的形状正确,并适应批次大小。 - 确保
decoder_input_ids
以特殊的开始标记(如 BOS token)开头,如果没有,则自动添加。 - 处理特定模型的例外情况,例如 “Donut” 和 “Whisper” 模型可能有不同的处理方式。
通过这些步骤,该方法确保了在生成过程中,解码器能够正确地开始生成序列。
逐步详解
步骤 1:检查用户是否提供了 decoder_input_ids
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
decoder_input_ids = model_kwargs.pop("input_ids")
else:
decoder_input_ids = None
解释:
-
目的: 确定是否需要初始化
decoder_input_ids
。 -
逻辑:
- 如果
model_kwargs
中存在decoder_input_ids
,则将其取出,并从model_kwargs
中删除,避免重复。 - 如果
model_kwargs
中存在input_ids
,且model_input_name
不是"input_ids"
,则将其视为decoder_input_ids
。这是为了方便某些模型的输入命名。 - 如果以上都不满足,则将
decoder_input_ids
设为None
,表示需要初始化。
- 如果
示例:
- 用户在调用生成方法时,可能通过
model_kwargs
提供了decoder_input_ids
,方法将直接使用它。 - 如果用户没有提供,方法将尝试从
input_ids
中获取(当编码器不使用input_ids
作为主要输入时)。 - 如果都没有提供,方法将在后续步骤中初始化
decoder_input_ids
。
步骤 2:调整 decoder_start_token_id
的形状
if device is None:
device = self.device
if decoder_start_token_id.ndim == 1:
if decoder_start_token_id.shape[0] != batch_size:
raise ValueError(
f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
)
decoder_start_token_id = decoder_start_token_id.view(-1, 1)
else:
decoder_start_token_id = (
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
)
解释:
-
目的: 确保
decoder_start_token_id
的形状为(batch_size, 1)
,即每个样本都有一个开始标记。 -
逻辑:
- 如果未指定
device
,则使用模型默认的设备self.device
。 - 检查
decoder_start_token_id
的维度:- 如果是一维张量(向量),即
decoder_start_token_id.ndim == 1
:- 检查其长度是否等于
batch_size
,否则抛出错误。 - 重塑为形状
(batch_size, 1)
,即每个样本一个开始标记。
- 检查其长度是否等于
- 如果不是一维张量,则认为是单个标量:
- 创建一个形状为
(batch_size, 1)
的张量,元素全为decoder_start_token_id
的值。
- 创建一个形状为
- 如果是一维张量(向量),即
- 如果未指定
示例:
- 如果
decoder_start_token_id
是标量101
(假设是开始标记的 ID):- 创建一个形状为
(batch_size, 1)
的张量,所有元素都是101
。
- 创建一个形状为
- 如果
decoder_start_token_id
是张量[101, 102]
,batch_size
为2
:- 将其重塑为:
[[101], [102]]
- 将其重塑为:
步骤 3:确保 decoder_input_ids
以开始标记开头
# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:
decoder_input_ids = decoder_start_token_id
解释:
- 情况 1: 如果
decoder_input_ids
为None
,即用户未提供任何解码器输入:- 直接使用
decoder_start_token_id
作为decoder_input_ids
,表示解码器将从开始标记开始生成。
- 直接使用
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token...
elif "donut" in self.__class__.__name__.lower() or (
self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
):
pass
elif self.config.model_type in ["whisper"]:
pass
解释:
- 情况 2: 针对特定模型的例外情况:
- Donut 模型:
- Donut 模型有特定的解码器开始标记,且不需要额外的 BOS(Begin of Sequence)标记。
- 如果模型名包含
"donut"
,则不对decoder_input_ids
进行处理。
- Whisper 模型:
- Whisper 模型也有自己的处理逻辑,不需要添加开始标记。
- Donut 模型:
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item():
decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
decoder_attention_mask = torch.cat(
(torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
dim=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
解释:
- 情况 3: 用户提供了
decoder_input_ids
,但其首个 token 并非decoder_start_token_id
:- 检查首个 token 是否等于
decoder_start_token_id
:(decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
:- 比较每个样本的
decoder_input_ids
首个 token 是否不等于对应的decoder_start_token_id
。 - 如果对于所有样本都不相等,则返回
True
。
- 比较每个样本的
- 如果首个 token 不匹配,则在
decoder_input_ids
前添加decoder_start_token_id
:- 使用
torch.cat
在维度-1
(序列长度维度)上拼接decoder_start_token_id
和decoder_input_ids
。
- 使用
- 调整
decoder_attention_mask
(如果提供):- 如果
model_kwargs
中存在decoder_attention_mask
:- 在其前面添加一个值为
1
的位置,表示新添加的开始标记需要被注意。 - 更新
model_kwargs["decoder_attention_mask"]
。
- 在其前面添加一个值为
- 如果
- 检查首个 token 是否等于
示例:
-
如果原始
decoder_input_ids
为:[[5, 6, 7], [8, 9, 10]]
-
decoder_start_token_id
为:[[101], [101]]
-
拼接后得到:
[[101, 5, 6, 7], [101, 8, 9, 10]]
-
同时调整
decoder_attention_mask
,在前面添加一个1
。
步骤 4:返回处理后的结果
return decoder_input_ids, model_kwargs
- 返回更新后的
decoder_input_ids
和model_kwargs
。
整体流程总结
-
输入处理:
- 检查用户是否提供了
decoder_input_ids
,如果没有,则需要初始化。 - 通过
model_kwargs
获取decoder_input_ids
或input_ids
,如果适用。
- 检查用户是否提供了
-
确保解码器开始标记的形状正确:
- 将
decoder_start_token_id
调整为形状(batch_size, 1)
,确保每个样本都有对应的开始标记。
- 将
-
确保
decoder_input_ids
以开始标记开头:- 如果
decoder_input_ids
为None
,直接使用decoder_start_token_id
。 - 对于特定模型(如 Donut 和 Whisper),保留用户提供的
decoder_input_ids
,不做修改。 - 如果用户提供的
decoder_input_ids
不以decoder_start_token_id
开头,自动在其前添加。 - 同时,调整
decoder_attention_mask
,确保新添加的开始标记在注意力掩码中被考虑。
- 如果
-
返回处理后的
decoder_input_ids
和model_kwargs
,供后续生成过程使用。
示例代码
示例 1:用户未提供 decoder_input_ids
# 假设 batch_size = 2
decoder_start_token_id = torch.tensor([101], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=2,
model_input_name="input_ids",
model_kwargs={},
decoder_start_token_id=decoder_start_token_id,
device=device,
)
# 结果:decoder_input_ids = [[101], [101]]
示例 2:用户提供了 decoder_input_ids
,但未以开始标记开头
decoder_input_ids = torch.tensor([[5, 6, 7], [8, 9, 10]], device=device)
model_kwargs = {"decoder_input_ids": decoder_input_ids}
decoder_start_token_id = torch.tensor([101, 102], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=2,
model_input_name="input_ids",
model_kwargs=model_kwargs,
decoder_start_token_id=decoder_start_token_id,
device=device,
)
# 结果:
# decoder_input_ids = [[101, 5, 6, 7], [102, 8, 9, 10]]
示例 3:针对 Donut 模型
# 设模型名称包含 "Donut",用户提供了 decoder_input_ids
self.__class__.__name__ = "DonutModel"
decoder_input_ids = torch.tensor([[5, 6, 7], [8, 9, 10]], device=device)
model_kwargs = {"decoder_input_ids": decoder_input_ids}
decoder_start_token_id = torch.tensor([101], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=2,
model_input_name="input_ids",
model_kwargs=model_kwargs,
decoder_start_token_id=decoder_start_token_id,
device=device,
)
# 结果:
# decoder_input_ids 保持不变,不添加 decoder_start_token_id
注意事项
-
模型特定的例外情况: 对于某些模型,如 Donut 和 Whisper,需要特殊处理,不能盲目添加开始标记。
-
一致性检查: 确保
batch_size
与decoder_start_token_id
的长度一致,否则抛出异常。 -
处理注意力掩码: 如果修改了
decoder_input_ids
,且提供了decoder_attention_mask
,需要相应地调整掩码。 -
设备一致性: 所有张量都应该在同一设备上(CPU 或 GPU),方法中确保了这一点。
总结
该方法的主要功能是为生成过程准备合适的 decoder_input_ids
:
- 初始值设置: 如果用户未提供,则使用
decoder_start_token_id
初始化。 - 形状调整: 确保
decoder_start_token_id
的形状与批次大小匹配。 - 开头标记: 确保
decoder_input_ids
以特定的开始标记开头,以符合模型的要求。 - 特例处理: 对于特定模型,按其特殊需求处理。
通过这些步骤,保证了编码器-解码器模型在生成过程中能够正确地开始解码器的序列生成。
希望以上解释能够帮助您理解 _prepare_decoder_input_ids_for_generation
方法的功能和每个步骤的具体作用。如果您还有其他问题,欢迎继续提问!
heal_tokens
这个方法用于在生成过程中,对输入的 input_ids
进行修复(healing),以增强模型的输出质量。具体来说,它会根据输入序列的最后一个 token,寻找可能的扩展,并替换原始的尾部 token。这在某些情况下可以纠正模型生成中的不一致或错误。
方法定义
def heal_tokens(
self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
) -> torch.LongTensor:
r"""
Generates sequences of token ids for models with a language modeling head.
Parameters:
input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation.
tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids.
Return:
`torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension.
"""
# 方法体...
参数说明:
input_ids
:torch.LongTensor
,形状为(batch_size, sequence_length)
,表示输入的 token 序列,用于生成的提示。tokenizer
:PreTrainedTokenizerBase
,可选参数,模型对应的 tokenizer,用于解码和编码 token IDs。
返回值:
torch.LongTensor
,形状与input_ids
相同,其中每个序列的尾部 token 被替换为了适当的扩展 token。
方法功能概述
该方法的主要目的是:
- 修复输入序列的尾部 token:对于每个输入序列,检查其最后一个 token,寻找可能的扩展 token,并替换之。
- 改进模型生成的连贯性:通过纠正输入序列的尾部,使得生成的序列在语义和形式上更加连贯。
- 处理空序列和特殊情况:在方法中包含了对空序列和特殊情况的处理,确保方法的稳健性。
逐步详解
步骤 1:验证 tokenizer
是否提供
if tokenizer is None:
raise ValueError(
" When generating with token healing, you must pass the model's tokenizer to the `tokenizer` "
"argument of `generate`."
)
- 解释:
- 方法要求必须提供
tokenizer
参数,否则无法进行解码和编码操作。 - 如果未提供
tokenizer
,则抛出ValueError
。
- 方法要求必须提供
步骤 2:获取特殊 token IDs 和构建词汇前缀树
bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id
vocab_trie = ExtensionsTrie(tokenizer.get_vocab())
generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id)
- 解释:
- 获取特殊 token IDs:
bos_token_id
: 序列开始的 token ID(BOS = Begin Of Sequence)。pad_token_id
: 填充 token 的 ID。
- 构建词汇前缀树:
tokenizer.get_vocab()
: 获取 tokenizer 的词汇表,返回一个字典{token: token_id}
。ExtensionsTrie
: 自定义的前缀树类,用于快速查找以某个前缀开始的所有 token。
- 创建生成配置:
GenerationConfig
: 配置生成过程的参数。max_new_tokens=1
: 生成的最大新 token 数设置为 1,因为我们只需要替换尾部 token。pad_token_id=pad_token_id
: 设置填充 token ID。
- 获取特殊 token IDs:
步骤 3:处理提示文本
# assumption: leading/trailing whitespace is not meaningful, so the prompts are
# stripped before re-tokenizing to desensitize generation to whitespace artefacts
prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
- 解释:
- 假设:首尾的空白字符(whitespace)对生成过程没有实质影响,因此在重新编码前去除。
- 步骤:
tokenizer.batch_decode
: 将input_ids
解码为字符串列表,跳过特殊 tokens。- 使用列表推导式,对每个字符串进行
strip
操作,去除首尾空白字符。 - 结果存储在
prompts
列表中。
步骤 4:重新编码输入序列
input_ids = tokenizer(
prompts,
return_tensors="pt",
padding=True,
).input_ids.to(input_ids.device)
- 解释:
- 重新编码:
- 将处理后的
prompts
列表再次通过tokenizer
编码为input_ids
。 return_tensors="pt"
: 返回 PyTorch 张量。padding=True
: 对序列进行填充,以使它们的长度一致。
- 将处理后的
- 调整设备:
- 使用
.to(input_ids.device)
将新生成的input_ids
移动到与原始input_ids
相同的设备上(如 GPU 或 CPU)。
- 使用
- 重新编码:
步骤 5:替换序列中的 bos_token_id
为 pad_token_id
# replace bos with pad to not condition healing on it
input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
- 解释:
- 目的:
- 在后续的修复过程中,不希望序列开始的特殊标记(BOS)影响结果,因此将其替换为填充标记。
- 操作:
- 使用
torch.where
函数,将input_ids
中等于bos_token_id
的位置替换为pad_token_id
。
- 使用
- 目的:
步骤 6:检查 input_ids
是否为空
if input_ids.numel() == 0:
return input_ids
- 解释:
- 目的:
- 如果
input_ids
为空(即没有元素),则无需进行后续处理,直接返回。
- 如果
- 检查:
input_ids.numel()
: 返回张量中元素的总数。- 如果为
0
,则返回input_ids
。
- 目的:
步骤 7:获取每个序列的尾部 token ID 和对应的 token
tail_ids = input_ids[:, -1].tolist()
space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
# tail tokens are used for a prefix search, thus, whitespaces are replaced with
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
- 解释:
- 获取尾部 token ID:
input_ids[:, -1]
: 获取每个序列的最后一个 token ID。.tolist()
: 转换为 Python 列表,方便后续处理。
- 处理空格 token:
tokenizer.convert_tokens_to_ids(" ")
: 将空格字符" "
转换为对应的 token ID。tokenizer.convert_ids_to_tokens
: 再将 token ID 转换为 token 字符串。[0]
: 获取结果中的第一个元素(列表中可能包含多个 token)。- 结果
space_tok
为空格对应的 token,例如在 BPE(Byte-Pair Encoding)中,空格可能对应特殊字符,例如'Ġ'
。
- 获取尾部 token 的字符串表示:
- 使用生成器表达式
(tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
:- 对每个
tail_id
:tokenizer.decode(t)
: 解码为字符串。.replace(" ", space_tok)
: 将字符串中的空格替换为space_tok
,以便于后续前缀搜索。
- 对每个
- 使用生成器表达式
- 获取尾部 token ID:
步骤 8:遍历每个序列,尝试修复尾部 token
for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):
batch_ids = input_ids[batch_idx]
if torch.all(batch_ids == pad_token_id).item():
continue # skip empty sequences (all pad ids)
- 解释:
- 遍历序列:
- 使用
enumerate
对tail_ids
和tail_toks
进行遍历。 batch_idx
: 当前序列的索引。tail_id
: 当前序列的尾部 token ID。tail_tok
: 当前序列的尾部 token 字符串(经过特殊处理的)。
- 使用
- 获取当前序列的
input_ids
:batch_ids = input_ids[batch_idx]
: 获取当前序列的input_ids
。
- 检查序列是否为空:
torch.all(batch_ids == pad_token_id).item()
: 检查序列中的所有位置是否都是pad_token_id
。- 如果是,表示序列为空,跳过后续处理。
- 遍历序列:
步骤 9:构建可能的替代 token 的偏置字典
# apply bias for alternatives (extensions) to the tail token
"""
seq_bias key has to be tuple with int so have to use
tokenizer function to convert str to int
"""
seq_bias = {
(tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
}
if len(seq_bias) == 1:
continue # skip if there are no token alternatives to heal with
# slightly favor original token to limit aggressive healing e.g. 'http' -> 'https'
seq_bias[(tail_id,)] += 1.0
generation_config.update(sequence_bias=seq_bias)
- 解释:
- 寻找可能的扩展 tokens:
vocab_trie.extensions(prefix=tail_tok)
: 在词汇前缀树中,找到以tail_tok
为前缀的所有 token。
- 构建偏置字典
seq_bias
:- 键:以单个 token ID 组成的元组
(token_id,)
,因为后续需要的键是元组形式。 - 值:偏置值
10.0
,用于在生成过程中提升这些 tokens 的概率。 - 需要使用
tokenizer.convert_tokens_to_ids(alt_tok)
将 token 字符串转换为 token ID。
- 键:以单个 token ID 组成的元组
- 检查是否有可替代的 tokens:
- 如果
seq_bias
的长度为1
,表示只有原始的tail_id
,没有其他可替代的 tokens,跳过处理。
- 如果
- 稍微提升原始 token 的概率:
seq_bias[(tail_id,)] += 1.0
: 对原始的tail_id
,增加一个较小的偏置1.0
,以防止过度修复(例如,将'http'
修复为'https'
)。
- 更新生成配置:
generation_config.update(sequence_bias=seq_bias)
: 将偏置字典添加到生成配置中,供生成过程使用。
- 寻找可能的扩展 tokens:
步骤 10:准备生成输入,去除尾部 token
trimmed_ids = batch_ids[:-1]
"""
the latter code assumes trimmed_ids is not empty
so have to check its element count
"""
if trimmed_ids.numel() == 0:
continue
# if the prompt is a single (non-pad) token, regenerate from bos
if len(batch_ids[batch_ids != pad_token_id]) == 1:
trimmed_ids[-1] = bos_token_id
- 解释:
- 去除尾部 token:
trimmed_ids = batch_ids[:-1]
: 获取当前序列,去除最后一个 token,即tail_id
。
- 检查
trimmed_ids
是否为空:- 如果
trimmed_ids.numel() == 0
,表示序列长度为 0,无需处理,继续下一个序列。
- 如果
- 特殊情况处理:
- 如果序列只有一个非填充 token(除去
pad_token_id
后长度为 1):- 需要将
trimmed_ids
的最后一个位置替换为bos_token_id
,以重新从开始标记生成。
- 需要将
- 如果序列只有一个非填充 token(除去
- 去除尾部 token:
步骤 11:生成新的 token 并替换尾部 token
input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config)
- 解释:
- 调用生成方法:
self.generate
: 调用模型的generate
方法,生成新序列。trimmed_ids.unsqueeze(0)
: 为了适应批量维度,将trimmed_ids
添加一个维度,形状从(sequence_length - 1,)
变为(1, sequence_length - 1)
。generation_config=generation_config
: 使用之前配置的生成参数,包括偏置字典seq_bias
。
- 更新
input_ids
:- 将生成的新序列替换到
input_ids
中对应的位置。 - 注意,这里生成的序列长度为
sequence_length
,因为max_new_tokens=1
,所以会在trimmed_ids
的基础上生成一个新 token。
- 将生成的新序列替换到
- 调用生成方法:
步骤 12:返回修复后的 input_ids
return input_ids
- 解释:
- 返回处理后的
input_ids
,其中每个序列的尾部 token 已根据可能的扩展进行了修复。
- 返回处理后的
整体流程总结
-
准备工作:
- 验证
tokenizer
参数。 - 获取特殊 token IDs。
- 构建词汇前缀树
vocab_trie
,用于快速查找可能的扩展 token。 - 配置生成参数
generation_config
。
- 验证
-
处理输入序列:
- 将
input_ids
解码为字符串列表prompts
,去除首尾空白。 - 重新编码
prompts
为input_ids
,确保一致性。 - 将
bos_token_id
替换为pad_token_id
,避免对序列开始标记的影响。
- 将
-
遍历每个序列,尝试修复尾部 token:
- 获取每个序列的尾部 token ID 和对应的 token 字符串。
- 查找以尾部 token 为前缀的可能扩展 tokens。
- 构建偏置字典
seq_bias
,提升这些扩展 token 在生成过程中的概率。 - 去除序列的尾部 token,准备生成新的尾部 token。
- 调用
self.generate
方法,生成新的序列,并替换到input_ids
中。
-
返回处理后的
input_ids
。
_prepare_generated_length
这个方法用于在生成过程中,根据用户提供的生成配置和模型输入,准备和调整生成的最大长度 (max_length
) 和最小长度 (min_length
),以避免类似属性之间的冲突。
方法定义
def _prepare_generated_length(
self,
generation_config,
has_default_max_length,
has_default_min_length,
model_input_name,
input_ids_length,
inputs_tensor,
):
"""Prepared max and min length in generation configs to avoid clashes between similar attributes"""
# 方法体...
参数说明:
-
generation_config
:GenerationConfig
对象,包含了生成过程中需要的各种配置参数,如max_length
、min_length
、max_new_tokens
、min_new_tokens
等。 -
has_default_max_length
:bool
类型,指示max_length
是否使用了默认值。如果为True
,表示用户未显式设置max_length
。 -
has_default_min_length
:bool
类型,指示min_length
是否使用了默认值。 -
model_input_name
:str
类型,模型输入的名称,通常为"input_ids"
或"inputs_embeds"
。 -
input_ids_length
:int
类型,输入序列input_ids
的长度,即输入的 token 数量。 -
inputs_tensor
:torch.Tensor
对象,模型的输入张量。
方法功能概述
该方法的主要作用是:
-
调整
max_length
和min_length
:根据用户提供的max_new_tokens
、min_new_tokens
、max_length
和min_length
,以及输入序列的长度,计算并设置最终的生成长度参数,确保生成过程按照预期进行。 -
避免冲突:如果用户同时设置了类似的属性(例如同时设置了
max_length
和max_new_tokens
),该方法会明确优先级,并在必要时发出警告,提示用户可能存在的冲突。 -
处理特殊情况:针对一些特殊的输入情况,例如使用了
inputs_embeds
、模型是编码器-解码器模型等,做出相应的调整。
逐步详解
1. 处理 max_length
if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
解释:
-
情况 1:用户设置了
max_new_tokens
-
逻辑:
-
如果
generation_config.max_new_tokens
不为None
,表示用户希望通过max_new_tokens
来指定要生成的新 token 数量。 -
冲突处理:
-
如果用户同时设置了
max_length
(并且不是默认值),则发出警告,提示两者可能存在冲突,但以max_new_tokens
为准。 -
not has_default_max_length
:表示max_length
被用户显式设置了。
-
-
计算
max_length
:-
将
max_length
设置为input_ids_length + max_new_tokens
。-
input_ids_length
:输入序列的长度。 -
max_new_tokens
:用户希望生成的新 token 数量。
-
-
-
这样,
max_length
就表示生成的序列(包括输入和生成的部分)总长度。
-
-
示例:
- 如果
input_ids_length = 10
,max_new_tokens = 20
,则max_length = 10 + 20 = 30
。
# if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
# otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`
elif (
model_input_name == "inputs_embeds"
and input_ids_length != inputs_tensor.shape[1]
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]
解释:
-
情况 2:模型输入是
inputs_embeds
,且存在输入长度不匹配-
逻辑:
-
条件判断:
-
model_input_name == "inputs_embeds"
:表示模型的输入是嵌入表示。 -
input_ids_length != inputs_tensor.shape[1]
:输入的input_ids
长度与inputs_tensor
的长度(序列维度大小)不一致。 -
not self.config.is_encoder_decoder
:模型不是编码器-解码器模型。
-
-
处理:
- 需要调整
max_length
,减去inputs_tensor.shape[1]
,即输入的序列长度。
- 需要调整
-
原因:
- 当用户提供了
inputs_embeds
而非input_ids
,且两者长度不一致,为了确保生成的总长度不超过用户预期,需要调整max_length
。
- 当用户提供了
-
-
elif has_default_max_length: # by default let's always generate 20 new tokens
if generation_config.max_length == GenerationConfig().max_length:
generation_config.max_length = generation_config.max_length + input_ids_length
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
if max_position_embeddings is not None:
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
解释:
-
情况 3:用户未设置
max_length
,使用默认值-
逻辑:
-
has_default_max_length
为True
,即用户未显式设置max_length
。 -
如果
generation_config.max_length
等于默认的max_length
,则执行以下操作:-
计算新的
max_length
:-
将
max_length
设置为原来的max_length
加上input_ids_length
。- 这样,默认情况下,会在输入的基础上生成
20
个新 tokens(假设默认max_length
为20
)。
- 这样,默认情况下,会在输入的基础上生成
-
-
考虑模型的最大位置嵌入长度:
-
获取模型配置中的
max_position_embeddings
,它表示模型能处理的最大序列长度。 -
如果存在,就将
generation_config.max_length
限制在max_position_embeddings
之内,防止生成长度超过模型的能力。
-
-
-
-
2. 处理 min_length
if generation_config.min_new_tokens is not None:
if not has_default_min_length:
logger.warning(
f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(="
f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.min_length = generation_config.min_new_tokens + input_ids_length
解释:
-
情况 1:用户设置了
min_new_tokens
-
逻辑:
-
如果
generation_config.min_new_tokens
不为None
,表示用户希望指定要生成的最小新 token 数量。 -
冲突处理:
- 如果用户同时设置了
min_length
(并且不是默认值),则发出警告,提示两者存在冲突,以min_new_tokens
为准。
- 如果用户同时设置了
-
计算
min_length
:- 将
min_length
设置为min_new_tokens + input_ids_length
。
- 将
-
-
elif (
model_input_name == "inputs_embeds"
and input_ids_length != inputs_tensor.shape[1]
and not self.config.is_encoder_decoder
):
generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
解释:
-
情况 2:模型输入是
inputs_embeds
,且存在输入长度不匹配-
逻辑:
-
与处理
max_length
时类似,但这里调整的是min_length
。 -
将
generation_config.min_length
减去inputs_tensor.shape[1]
,并取最大值0
(防止出现负值)。 -
这样可以确保生成的最小长度不会超过用户预期。
-
-
3. 返回更新后的 generation_config
return generation_config
- 返回调整过后的
generation_config
,包含更新的max_length
和min_length
,以供生成过程使用。
_validate_generated_length
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
"""Performs validation related to the resulting generated length"""
# 函数主体从这里开始
功能说明
-
目的:该函数用于对生成的长度进行验证,确保
generation_config
中的参数设置合理,避免在生成过程中出现不可预期的行为或错误。 -
主要任务:
- 检查
max_length
和max_new_tokens
之间的关系,给出适当的警告或错误。 - 检查输入序列长度
input_ids_length
与max_length
之间的关系。 - 检查
min_length
和max_length
之间的关系,给出警告。 - 确保
min_new_tokens
加上input_ids_length
不超过max_length
,并给出警告。
- 检查
参数说明
self
:类的实例,典型的 Python 类方法的第一个参数。generation_config
:生成配置对象,包含了生成过程中使用的各项参数设置,例如max_length
、min_length
、max_new_tokens
等。input_ids_length
:整数,输入序列input_ids
的长度,即序列的长度。has_default_max_length
:布尔值,指示是否使用了默认的max_length
设置(即用户没有在调用时显式指定max_length
)。
代码详细解释
1. 编译时不进行警告或异常抛出
# Can't throw warnings/exceptions during compilation
if is_torchdynamo_compiling():
return
解释:
-
目的:在使用 TorchDynamo(PyTorch 编译器)进行编译时,不要抛出警告或异常。
-
逻辑:
is_torchdynamo_compiling()
:检查当前是否在使用 TorchDynamo 进行编译。if is_torchdynamo_compiling(): return
:如果正在编译,直接返回,不进行后续的验证。
-
原因:在编译过程中,抛出异常或者发出警告可能会导致编译失败或行为异常,因此在编译时跳过验证。
2. 第一部分:与参数设置相关的 max_length
警告
# 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 20 is the default max_length of the generation config
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
"generation.",
UserWarning,
)
解释:
-
目的:当用户未指定
max_length
且未设置max_new_tokens
时,发出警告提示。 -
逻辑:
-
条件判断:
has_default_max_length
:用户未显式指定max_length
,使用了默认值。generation_config.max_new_tokens is None
:max_new_tokens
未设置。generation_config.max_length == 20
:max_length
等于20
,这是GenerationConfig
的默认值。
-
处理:
- 如果以上条件都满足,使用
warnings.warn()
发出警告,提示用户正在使用模型无关的默认max_length
来控制生成长度,建议用户设置max_new_tokens
来控制生成的最大长度。
- 如果以上条件都满足,使用
-
-
原因:
- 如果用户没有明确设置生成长度的参数,可能会导致生成的序列长度与预期不符。
- 建议用户使用
max_new_tokens
,因为它更直观地控制生成的新 tokens 数量。
3. 检查输入序列长度是否超过或等于 max_length
if input_ids_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
raise ValueError(
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_length` or, better yet, setting `max_new_tokens`."
)
解释:
-
目的:如果输入序列的长度大于或等于
max_length
,抛出ValueError
,因为这样会导致生成过程中的问题。 -
逻辑:
-
条件判断:
if input_ids_length >= generation_config.max_length
:如果输入序列的长度input_ids_length
大于或等于max_length
。
-
处理:
-
根据模型类型确定输入 IDs 的名称:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
- 如果是编码器-解码器模型,使用
'decoder_input_ids'
。 - 否则,使用
'input_ids'
。
- 如果是编码器-解码器模型,使用
-
抛出
ValueError
,包含详细的错误信息:- 提示输入序列的长度与
max_length
的值,以及可能导致意外行为。 - 建议用户增加
max_length
或者更好的,设置max_new_tokens
。
- 提示输入序列的长度与
-
-
-
原因:
- 如果输入序列的长度已经达到或超过
max_length
,模型将无法生成新的 tokens,可能导致生成过程立即结束或出现错误。 - 为了避免这种情况,必须确保
max_length
大于输入序列的长度。
- 如果输入序列的长度已经达到或超过
4. 第二部分:由于不可行的参数组合导致的 min_length
警告
# 2. Min length warnings due to unfeasible parameter combinations
min_length_error_suffix = (
" Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
"increase the maximum length."
)
if has_default_max_length:
min_length_error_suffix += (
f" Note that `max_length` is set to {generation_config.max_length}, its default value."
)
解释:
-
目的:准备
min_length
相关的错误信息,供后续警告使用。 -
逻辑:
-
定义错误信息的后缀:
min_length_error_suffix
:提示用户需要降低最小长度和/或增加最大长度。
-
如果使用了默认的
max_length
:- 如果
has_default_max_length
为True
,则在min_length_error_suffix
后面追加一句,说明max_length
被设置为其默认值。
- 如果
-
5. 检查 min_length
是否大于 max_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
warnings.warn(
f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
UserWarning,
)
解释:
-
目的:如果
min_length
大于max_length
,发出警告,因为这是不可行的约束。 -
逻辑:
-
条件判断:
generation_config.min_length is not None
:min_length
已被设置。generation_config.min_length > generation_config.max_length
:min_length
大于max_length
。
-
处理:
- 使用
warnings.warn()
发出警告,提示min_length
大于max_length
,并附加之前准备的min_length_error_suffix
。
- 使用
-
-
原因:
- 如果最小生成长度大于最大生成长度,模型无法满足这样的约束,可能导致生成过程在达到最大长度时停止,而未达到最小长度。
- 提醒用户调整参数,使其合理。
6. 检查 min_new_tokens
加上 input_ids_length
是否超过 max_length
if generation_config.min_new_tokens is not None:
min_length = generation_config.min_new_tokens + input_ids_length
if min_length > generation_config.max_length:
warnings.warn(
f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "
f"added to the prompt length ({input_ids_length}), is larger than"
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
UserWarning,
)
解释:
-
目的:如果
min_new_tokens
加上输入序列长度超过max_length
,发出警告。 -
逻辑:
-
条件判断:
if generation_config.min_new_tokens is not None
:min_new_tokens
已被设置。
-
计算最小长度:
min_length = generation_config.min_new_tokens + input_ids_length
:计算最小生成长度,即min_new_tokens
加上输入序列长度。
-
检查是否超过
max_length
:if min_length > generation_config.max_length
:如果计算得到的min_length
大于max_length
。
-
处理:
- 使用
warnings.warn()
发出警告,提示min_new_tokens
加上输入长度超过了可能的最大长度,并附加min_length_error_suffix
。
- 使用
-
-
原因:
- 如果
min_new_tokens
与输入序列长度之和超过max_length
,模型无法生成满足最小新 tokens 数量的序列。 - 提醒用户调整参数,降低
min_new_tokens
、增加max_length
,或减少输入序列的长度。
- 如果
_supports_logits_to_keep
方法定义
def _supports_logits_to_keep(self) -> bool:
"""
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
to save memory. Checking it in this way allows to avoid using a new model attribute.
"""
return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
方法功能概述
- 目的:确定当前模型的
forward
方法是否支持logits_to_keep
参数。 - 返回值:布尔值
True
:如果forward
方法的参数中包含logits_to_keep
。False
:如果forward
方法的参数中不包含logits_to_keep
。
逐步详解
-
使用
inspect
模块分析forward
方法的签名inspect.signature(self.forward)
- 作用:获取
self.forward
方法的签名信息,包括参数列表和参数默认值等。 inspect
模块:Python 内置模块,提供了检查和获取对象(如函数、类、模块等)信息的功能。
- 作用:获取
-
获取
forward
方法的参数字典inspect.signature(self.forward).parameters
- 返回值:一个有序字典(
OrderedDict
),键为参数名称,值为参数对应的Parameter
对象。
- 返回值:一个有序字典(
-
获取参数名称列表并转换为集合
set(inspect.signature(self.forward).parameters.keys())
parameters.keys()
:返回参数名称的可迭代对象。set(...)
:将参数名称转换为集合,方便后续进行快速查找(in
操作)。
-
检查是否包含
logits_to_keep
参数"logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
- 作用:判断字符串
"logits_to_keep"
是否在参数名称集合中。 - 返回值:布尔值。
- 作用:判断字符串
-
返回判断结果
- 如果包含:返回
True
。 - 如果不包含:返回
False
。
- 如果包含:返回
方法用途和背景
-
节省内存:
在生成任务中,模型可能会生成大量的 logits(每个时间步长预测下一个 token 的概率分布,通常是一个包含了整个词汇表大小的张量)。如果能够限制保留的 logits 数量(例如只保留 top-k 个 logits),可以大大节省内存。
-
动态检查模型功能:
不同的模型可能实现了不同的功能。通过这种动态检查的方法,可以在不修改模型代码的情况下,了解模型是否支持某个特定的参数或功能。这有助于编写通用的代码,适用于多种模型。
-
避免使用额外的模型属性:
通过检查方法签名,而不是增加一个模型属性,可以减少模型类的复杂性和维护成本。
举例说明
假设我们有一个模型,其 forward
方法定义如下:
def forward(self, input_ids, attention_mask=None, logits_to_keep=None):
# 模型的前向计算逻辑
logits = self.compute_logits(input_ids, attention_mask)
if logits_to_keep is not None:
# 只保留指定数量的 logits
logits = logits[:, -1, :logits_to_keep]
return logits
-
模型支持
logits_to_keep
参数:在这种情况下,_supports_logits_to_keep
方法会返回True
,因为forward
方法的参数中包含logits_to_keep
。 -
使用示例:
if self._supports_logits_to_keep(): outputs = self.forward(input_ids, attention_mask=attention_mask, logits_to_keep=10) else: outputs = self.forward(input_ids, attention_mask=attention_mask)
- 解释:代码首先检查模型是否支持
logits_to_keep
参数,如果支持,则在调用forward
方法时传入该参数,以只保留 top-10 的 logits,从而节省内存。
- 解释:代码首先检查模型是否支持
_prepare_cache_for_generation
def _prepare_cache_for_generation(
self,
generation_config: GenerationConfig,
model_kwargs: Dict,
assistant_model: "PreTrainedModel",
batch_size: int,
max_cache_length: int,
device: torch.device,
) -> bool:
"""
Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
instantiated, writes it to `model_kwargs`, under the name expected by the model.
"""
# 函数主体从这里开始
功能说明
-
目的:准备生成过程中使用的缓存(cache),根据给定的
generation_config
和其他参数,初始化或调整缓存。如果缓存被实例化,它将被写入到model_kwargs
中,使用模型期望的缓存名称。 -
背景:在文本生成任务中,使用缓存可以加速生成过程,特别是在自回归模型中,缓存先前的计算结果可以避免重复计算。在不同的模型或配置下,缓存的实现方式可能不同,因此需要根据情况准备合适的缓存。
参数说明
-
self
:当前类的实例,典型的 Python 类方法的第一个参数。 -
generation_config
:GenerationConfig
类型,表示生成配置,其中包含生成过程中的各种参数设置,如是否使用缓存、缓存的实现方式等。 -
model_kwargs
:Dict
类型,包含传递给模型的关键字参数。在函数中,可能会对其进行修改,添加缓存相关的参数。 -
assistant_model
:PreTrainedModel
类型,可选的辅助模型,用于加速生成或其他目的。 -
batch_size
:整数,表示批次大小。 -
max_cache_length
:整数,表示缓存的最大长度,即缓存可以存储的最大序列长度。 -
device
:torch.device
类型,表示在何种设备(CPU 或 GPU)上运行。
1. 确定缓存名称
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
-
解释:
-
这行代码根据当前模型类的名称,确定缓存在
model_kwargs
中的键名称。 -
如果类名中不包含
"mamba"
,则缓存名称为"past_key_values"
;否则,缓存名称为"cache_params"
。
-
-
原因:
- 不同的模型可能期望的缓存名称不同。模型需要从
model_kwargs
中获取缓存,如果名称不一致,可能导致缓存无法正确工作。
- 不同的模型可能期望的缓存名称不同。模型需要从
2. 确定是否需要跨注意力缓存(cross-attention cache)
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
-
解释:
-
requires_cross_attention_cache
是一个布尔值,表示是否需要准备跨注意力缓存。 -
条件:
-
self.config.is_encoder_decoder
:如果模型是编码器-解码器架构,则需要跨注意力缓存。 -
model_kwargs.get("encoder_outputs") is not None
:如果在model_kwargs
中提供了编码器的输出,则也需要跨注意力缓存。
-
-
-
原因:
- 在编码器-解码器模型(如 BART、T5)中,解码器需要访问编码器的输出,因此需要跨注意力缓存。
3. 快速退出路径 1:用户已在 model_kwargs
中指定了缓存
# 快速退出路径 1:如果用户指定了缓存,我们只需要:
# a) 检查是否有冲突的 `generate` 参数
# b) 如果用户传递了旧的缓存格式,并且模型支持,将其转换为新的缓存格式
user_defined_cache = model_kwargs.get(cache_name)
if user_defined_cache is not None:
if generation_config.cache_implementation is not None:
raise ValueError(
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
"Cache object) is unsupported. Please use only one of the two."
)
if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache():
model_kwargs[cache_name] = (
DynamicCache.from_legacy_cache(user_defined_cache)
if not requires_cross_attention_cache
else EncoderDecoderCache.from_legacy_cache(user_defined_cache)
)
return
-
解释:
-
获取用户定义的缓存:
user_defined_cache = model_kwargs.get(cache_name)
:从model_kwargs
中获取用户可能提供的缓存。
-
检查用户是否同时指定了
cache_implementation
:- 如果用户既在
model_kwargs
中提供了缓存,又在generation_config
中指定了cache_implementation
,这是冲突的,会引发错误。
- 如果用户既在
-
处理旧的缓存格式:
-
如果
user_defined_cache
是一个元组(旧的缓存格式),并且模型支持默认的动态缓存(self._supports_default_dynamic_cache()
返回True
),则将旧的缓存转换为新的缓存格式。 -
根据是否需要跨注意力缓存,使用不同的缓存类:
-
如果不需要跨注意力缓存,使用
DynamicCache.from_legacy_cache(user_defined_cache)
。 -
如果需要跨注意力缓存,使用
EncoderDecoderCache.from_legacy_cache(user_defined_cache)
。
-
-
-
返回:
- 在处理完用户提供的缓存后,直接返回,不再进行后续的缓存准备。
-
4. 快速退出路径 2:用户指定不使用缓存
# 快速退出路径 2:如果用户指定不使用缓存。(冲突的参数已在 `generation_config.validate()` 中处理)
if generation_config.use_cache is False:
return
-
解释:
-
如果在
generation_config
配置中,用户设置了use_cache=False
,表示不使用缓存。 -
直接返回,不需要准备缓存。
-
5. 快速退出路径 3:模型仅支持旧的缓存格式
# 快速退出路径 3:模型仅支持旧的缓存格式,无需准备
if not self._supports_default_dynamic_cache():
if generation_config.cache_implementation is not None:
warnings.warn(
"This model does not support `Cache` instances, it only supports the legacy cache format (tuple "
f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be "
"ignored.",
UserWarning,
)
return
-
解释:
-
如果模型不支持默认的动态缓存(
self._supports_default_dynamic_cache()
返回False
),则无法使用新的缓存实现。 -
如果用户在
generation_config
中指定了cache_implementation
,则发出警告,指出模型仅支持旧的缓存格式,cache_implementation
将被忽略。 -
直接返回,不需要进一步准备缓存。
-
6. 需要准备缓存,根据 generation_config.cache_implementation
# 否则,我们需要根据 `generation_config.cache_implementation` 准备缓存
# TODO(joao): 在辅助生成中支持静态缓存。辅助生成需要回滚缓存,目前只有动态缓存支持
if assistant_model is not None and generation_config.cache_implementation is not None:
logger.warning_once(
"An assistant model is provided, using a dynamic cache instead of a cache of type="
f"'{generation_config.cache_implementation}'."
)
generation_config.cache_implementation = None
-
解释:
-
如果上述快速退出条件都不满足,且需要准备缓存,则需要根据
generation_config.cache_implementation
的值来准备缓存。 -
特殊情况:辅助模型和缓存实现的冲突:
-
如果提供了
assistant_model
,并且指定了cache_implementation
,则发出警告,指出由于提供了辅助模型,将使用动态缓存,而不是指定类型的缓存。 -
将
generation_config.cache_implementation
设置为None
,以确保使用动态缓存。
-
-
-
原因:
- 在辅助生成过程中,需要回滚缓存,目前只有动态缓存支持回滚。因此,即使用户指定了其他缓存实现,也需要使用动态缓存。
7. 根据缓存实现方式准备缓存
if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs[cache_name] = self._get_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
max_cache_len=max_cache_length,
device=device,
model_kwargs=model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
raise ValueError(
"This model does not support the quantized cache. If you want your model to support quantized "
"cache, please open an issue and tag @zucchini-nlp."
)
cache_config = (
generation_config.cache_config
if generation_config.cache_config is not None
else QuantizedCacheConfig()
)
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
if cache_config.backend == "quanto" and not is_optimum_quanto_available():
raise ImportError(
"You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. "
"Please install it via with `pip install optimum-quanto`"
)
elif cache_config.backend == "HQQ" and not is_hqq_available():
raise ImportError(
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
"Please install it via with `pip install hqq`"
)
model_kwargs[cache_name] = cache_class(cache_config)
elif generation_config.cache_implementation == "offloaded":
model_kwargs[cache_name] = OffloadedCache()
-
解释:
-
检查缓存实现方式是否需要特别的准备:
-
NEED_SETUP_CACHE_CLASSES_MAPPING
:一个映射,包含需要特殊设置的缓存类。 -
如果
generation_config.cache_implementation
在NEED_SETUP_CACHE_CLASSES_MAPPING
中,则需要调用_get_cache
方法来获取缓存实例。
-
-
处理静态缓存:
-
如果
cache_implementation
是"static"
,并且模型不支持静态缓存(not self._supports_static_cache
),则抛出ValueError
,提示模型不支持静态缓存。 -
提供一个 GitHub Issue 链接,供用户了解更多信息。
-
-
获取缓存实例:
-
调用
self._get_cache()
方法,传入缓存实现方式、批次大小、最大缓存长度、设备等参数,获取缓存实例。 -
批次大小计算:
-
max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size
:计算缓存所需的实际批次大小。- 在束搜索或其他情况下,批次大小可能需要乘以
num_beams
或num_return_sequences
。
- 在束搜索或其他情况下,批次大小可能需要乘以
-
-
-
处理量化缓存(quantized cache):
-
如果
cache_implementation
是"quantized"
,需要特殊处理。 -
检查模型是否支持量化缓存(
self._supports_quantized_cache
)。 -
如果不支持,抛出
ValueError
。 -
获取缓存配置
cache_config
,如果用户未提供generation_config.cache_config
,则使用默认的QuantizedCacheConfig()
。 -
根据缓存配置的后端,获取对应的缓存类
cache_class
。 -
检查所需的包是否已安装:
-
对于
"quanto"
后端,检查is_optimum_quanto_available()
。 -
对于
"HQQ"
后端,检查is_hqq_available()
。
-
-
如果未安装,抛出
ImportError
,提示用户安装相应的包。 -
实例化缓存类,并将其存储在
model_kwargs[cache_name]
中。
-
-
处理离线缓存(offloaded cache):
- 如果
cache_implementation
是"offloaded"
,则实例化OffloadedCache()
,并存储在model_kwargs[cache_name]
中。
- 如果
-
8. 默认情况下,使用动态缓存
# 默认情况下,使用 DynamicCache() 实例。这将避免在旧格式之间来回转换,从而避免复制缓存,节省内存
else:
model_kwargs[cache_name] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
-
解释:
-
如果
generation_config.cache_implementation
为None
,即用户未指定特定的缓存实现方式,并且上述条件都不满足,则默认使用动态缓存。 -
根据是否需要跨注意力缓存,实例化不同的缓存类:
-
如果不需要跨注意力缓存,使用
DynamicCache()
。 -
如果需要跨注意力缓存,使用
EncoderDecoderCache(DynamicCache(), DynamicCache())
,即分别为编码器和解码器缓存初始化动态缓存。
-
-
-
原因:
-
动态缓存可以在生成过程中动态扩展,并支持回滚等特性。
-
使用默认的动态缓存,可以避免在旧缓存格式与新缓存格式之间来回转换,减少内存使用和数据复制。
-
_get_logits_processor
def _get_logits_processor(
self,
generation_config: GenerationConfig,
input_ids_seq_length: int,
encoder_input_ids: torch.LongTensor,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
logits_processor: Optional[LogitsProcessorList],
device: str = None,
model_kwargs: Optional[Dict[str, Any]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
instances used to modify the scores of the language model head.
"""
# 函数主体从这里开始
功能说明
-
目的:该函数用于构建一个
LogitsProcessorList
对象,其中包含了一系列LogitsProcessor
实例,这些实例用于在生成过程中修改语言模型头(language model head)的logits(模型输出的分数),以实现各种生成控制策略。 -
背景:在生成任务中,通过调整模型输出的logits,可以引入各种生成策略,如重复惩罚、温度调节、Top-K采样、禁用某些词汇等,从而控制生成文本的风格、内容和长度。
参数说明
-
self
:类的实例,典型的Python类方法的第一个参数。 -
generation_config
:GenerationConfig
对象,包含了生成过程中的各种配置参数,如重复惩罚系数、最小生成长度、温度等。 -
input_ids_seq_length
:整数,表示输入序列的长度,即input_ids
的长度。 -
encoder_input_ids
:torch.LongTensor
,编码器的输入IDs。如果模型是编码器-解码器架构,这对应于编码器的输入。 -
prefix_allowed_tokens_fn
:可选的函数,类型为Callable[[int, torch.Tensor], List[int]]
。用于在生成过程中限制每个位置上允许生成的tokens,通常用于受限生成任务。 -
logits_processor
:可选的LogitsProcessorList
对象,用户自定义的LogitsProcessor列表,可用于补充或覆盖默认的processor。 -
device
:字符串,可选参数,指定设备(如'cpu'
或'cuda'
)。如果未提供,默认为None
。 -
model_kwargs
:可选的字典,包含了传递给模型的其他关键字参数。 -
negative_prompt_ids
:可选的torch.Tensor
,用于一些生成策略(如Classifier-Free Guidance)中的负面提示IDs。 -
negative_prompt_attention_mask
:可选的torch.Tensor
,对应negative_prompt_ids
的注意力掩码。
1. 初始化LogitsProcessorList
# instantiate processors list
processors = LogitsProcessorList()
- 解释:创建一个空的
LogitsProcessorList
对象processors
,用于存储将要应用的所有LogitsProcessor
实例。
2. 处理Classifier-Free Guidance(CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
processors.append(
UnbatchedClassifierFreeGuidanceLogitsProcessor(
generation_config.guidance_scale,
self,
unconditional_ids=negative_prompt_ids,
unconditional_attention_mask=negative_prompt_attention_mask,
use_cache=generation_config.use_cache,
)
)
-
解释:
-
条件:如果
guidance_scale
不为None
且不等于1,则说明需要应用Classifier-Free Guidance(CFG)。guidance_scale
是CFG的缩放因子,通常大于1,用于调整生成的多样性与准确性。
-
操作:向
processors
中添加一个UnbatchedClassifierFreeGuidanceLogitsProcessor
实例。-
参数说明:
-
generation_config.guidance_scale
:CFG的缩放因子。 -
self
:模型实例,用于在LogitsProcessor中调用模型的其他方法。 -
unconditional_ids
:负面提示的IDs,即negative_prompt_ids
。 -
unconditional_attention_mask
:负面提示的注意力掩码,即negative_prompt_attention_mask
。 -
use_cache
:是否使用缓存,来自generation_config
。
-
-
-
3. 处理序列偏置(Sequence Bias)
if generation_config.sequence_bias is not None:
processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
-
解释:
-
条件:如果
sequence_bias
不为None
,则需要应用序列偏置。sequence_bias
是一种机制,可对特定的token序列施加偏置,提高或降低它们在生成中的概率。
-
操作:向
processors
中添加一个SequenceBiasLogitsProcessor
实例,传入sequence_bias
参数。
-
4. 处理多样性惩罚(Diversity Penalty)
if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
processors.append(
HammingDiversityLogitsProcessor(
diversity_penalty=generation_config.diversity_penalty,
num_beams=generation_config.num_beams,
num_beam_groups=generation_config.num_beam_groups,
)
)
-
解释:
-
条件:如果
diversity_penalty
不为None
且大于0,则需要应用多样性惩罚。- 多样性惩罚用于在束搜索中鼓励生成更多样化的序列。
-
操作:向
processors
中添加一个HammingDiversityLogitsProcessor
实例。-
参数说明:
-
diversity_penalty
:多样性惩罚系数。 -
num_beams
:束搜索的束宽,即同时考虑的序列数量。 -
num_beam_groups
:束搜索的组数,用于分组束搜索。
-
-
-
5. 处理编码器重复惩罚(Encoder Repetition Penalty)
if (
generation_config.encoder_repetition_penalty is not None
and generation_config.encoder_repetition_penalty != 1.0
):
if len(encoder_input_ids.shape) == 2:
processors.append(
EncoderRepetitionPenaltyLogitsProcessor(
penalty=generation_config.encoder_repetition_penalty,
encoder_input_ids=encoder_input_ids,
)
)
else:
warnings.warn(
"Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to "
"`generate`, ignoring the argument.",
UserWarning,
)
-
解释:
-
条件:如果
encoder_repetition_penalty
不为None
且不等于1.0,则需要应用编码器重复惩罚。- 编码器重复惩罚用于减少模型在生成时重复输入内容的可能性。
-
检查:如果
encoder_input_ids
的形状为二维(即存在有效的编码器输入),则应用惩罚。 -
操作:向
processors
中添加一个EncoderRepetitionPenaltyLogitsProcessor
实例。-
参数说明:
-
penalty
:重复惩罚系数。 -
encoder_input_ids
:编码器的输入IDs。
-
-
-
否则:发出警告,提示需要提供
input_ids
以应用该惩罚,忽略该参数。
-
6. 处理重复惩罚(Repetition Penalty)
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
-
解释:
-
条件:如果
repetition_penalty
不为None
且不等于1.0,则需要应用重复惩罚。- 重复惩罚用于减少模型在生成时重复之前生成内容的可能性。
-
操作:向
processors
中添加一个RepetitionPenaltyLogitsProcessor
实例,传入penalty
参数。
-
7. 处理禁止重复的n-gram(No Repeat N-Gram)
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
-
解释:
-
条件:如果
no_repeat_ngram_size
不为None
且大于0,则需要禁止重复的n-gram。- 这用于防止模型在生成时重复生成相同的n-gram,提高生成的多样性。
-
操作:向
processors
中添加一个NoRepeatNGramLogitsProcessor
实例,传入no_repeat_ngram_size
参数。
-
8. 处理编码器禁止重复的n-gram(Encoder No Repeat N-Gram)
if (
generation_config.encoder_no_repeat_ngram_size is not None
and generation_config.encoder_no_repeat_ngram_size > 0
):
if len(encoder_input_ids.shape) == 2:
processors.append(
EncoderNoRepeatNGramLogitsProcessor(
generation_config.encoder_no_repeat_ngram_size,
encoder_input_ids,
)
)
else:
warnings.warn(
"Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to "
"`generate`, ignoring the argument.",
UserWarning,
)
-
解释:
-
条件:如果
encoder_no_repeat_ngram_size
不为None
且大于0,则需要在生成时避免重复输入中的n-gram。- 这用于防止模型在生成时重复输入序列中的n-gram。
-
检查:如果
encoder_input_ids
的形状为二维(存在有效的编码器输入),则应用该处理器。 -
操作:向
processors
中添加一个EncoderNoRepeatNGramLogitsProcessor
实例。-
参数说明:
-
encoder_no_repeat_ngram_size
:禁止重复的n-gram大小。 -
encoder_input_ids
:编码器的输入IDs。
-
-
-
否则:发出警告,提示需要提供
input_ids
以应用该处理器,忽略该参数。
-
9. 处理坏词(Bad Words)
if generation_config.bad_words_ids is not None:
processors.append(
NoBadWordsLogitsProcessor(
generation_config.bad_words_ids,
generation_config._eos_token_tensor,
)
)
-
解释:
-
条件:如果
bad_words_ids
不为None
,则需要在生成过程中禁止某些词。bad_words_ids
是一个列表,包含需要禁止的词的token IDs。
-
操作:向
processors
中添加一个NoBadWordsLogitsProcessor
实例。-
参数说明:
-
bad_words_ids
:需要禁止的词的token IDs。 -
_eos_token_tensor
:结束标记的token张量,用于在必要时停止生成。
-
-
-
10. 处理最小长度(Minimum Length)
if (
generation_config.min_length is not None
and generation_config._eos_token_tensor is not None
and generation_config.min_length > 0
):
processors.append(
MinLengthLogitsProcessor(
generation_config.min_length,
generation_config._eos_token_tensor,
device=device,
)
)
-
解释:
-
条件:如果
min_length
不为None
,_eos_token_tensor
不为None
,且min_length
大于0,则需要在生成达到最小长度之前禁止生成结束标记。 -
操作:向
processors
中添加一个MinLengthLogitsProcessor
实例。-
参数说明:
-
min_length
:最小生成长度。 -
_eos_token_tensor
:结束标记的token张量。 -
device
:设备信息。
-
-
-
11. 处理最小新tokens的长度(Minimum New Tokens Length)
if (
generation_config.min_new_tokens is not None
and generation_config._eos_token_tensor is not None
and generation_config.min_new_tokens > 0
):
processors.append(
MinNewTokensLengthLogitsProcessor(
input_ids_seq_length,
generation_config.min_new_tokens,
generation_config._eos_token_tensor,
device=device,
)
)
-
解释:
-
条件:如果
min_new_tokens
不为None
,_eos_token_tensor
不为None
,且min_new_tokens
大于0,则需要在生成新tokens达到最小数量之前禁止生成结束标记。 -
操作:向
processors
中添加一个MinNewTokensLengthLogitsProcessor
实例。-
参数说明:
-
input_ids_seq_length
:输入序列的长度。 -
min_new_tokens
:最小新生成的tokens数量。 -
_eos_token_tensor
:结束标记的token张量。 -
device
:设备信息。
-
-
-
12. 处理前缀限制(Prefix Allowed Tokens Function)
if prefix_allowed_tokens_fn is not None:
processors.append(
PrefixConstrainedLogitsProcessor(
prefix_allowed_tokens_fn,
generation_config.num_beams // generation_config.num_beam_groups,
)
)
-
解释:
-
条件:如果
prefix_allowed_tokens_fn
不为None
,则需要在生成过程中限制每个位置上允许生成的tokens。- 这通常用于受限生成任务,例如自动补全或基于前缀的约束生成。
-
操作:向
processors
中添加一个PrefixConstrainedLogitsProcessor
实例。-
参数说明:
-
prefix_allowed_tokens_fn
:用于限制每个位置上允许生成的tokens的函数。 -
generation_config.num_beams // generation_config.num_beam_groups
:计算每个组中的束宽。
-
-
-
13. 处理强制起始token(Forced BOS Token)
if generation_config.forced_bos_token_id is not None:
processors.append(
ForcedBOSTokenLogitsProcessor(
generation_config.forced_bos_token_id,
)
)
-
解释:
-
条件:如果
forced_bos_token_id
不为None
,则需要在生成的第一个位置强制生成指定的起始token。- 这用于确保生成的序列以特定的token开始,例如在某些任务中需要强制生成特定的起始标记。
-
操作:向
processors
中添加一个ForcedBOSTokenLogitsProcessor
实例,传入forced_bos_token_id
。
-
14. 处理强制结束token(Forced EOS Token)
if generation_config.forced_eos_token_id is not None:
processors.append(
ForcedEOSTokenLogitsProcessor(
generation_config.max_length,
generation_config.forced_eos_token_id,
device=device,
)
)
-
解释:
-
条件:如果
forced_eos_token_id
不为None
,则需要在生成达到最大长度时强制生成指定的结束token。- 这用于确保生成的序列以特定的token结束。
-
操作:向
processors
中添加一个ForcedEOSTokenLogitsProcessor
实例。-
参数说明:
-
generation_config.max_length
:最大生成长度。 -
forced_eos_token_id
:强制的结束token ID。 -
device
:设备信息。
-
-
-
15. 处理无效值移除(Remove Invalid Values)
if generation_config.remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor())
-
解释:
-
条件:如果
remove_invalid_values
为True
,则需要在生成过程中移除inf
和nan
等无效值。- 这用于确保生成过程的稳定性,防止由于无效值导致的错误。
-
操作:向
processors
中添加一个InfNanRemoveLogitsProcessor
实例。
-
16. 处理指数衰减长度惩罚(Exponential Decay Length Penalty)
if generation_config.exponential_decay_length_penalty is not None:
processors.append(
ExponentialDecayLengthPenalty(
generation_config.exponential_decay_length_penalty,
generation_config._eos_token_tensor,
input_ids_seq_length,
)
)
-
解释:
-
条件:如果
exponential_decay_length_penalty
不为None
,则需要应用指数衰减的长度惩罚。- 这用于在生成过程中,对句子长度施加惩罚,鼓励模型生成特定长度的句子。
-
操作:向
processors
中添加一个ExponentialDecayLengthPenalty
实例。-
参数说明:
-
exponential_decay_length_penalty
:指数衰减长度惩罚的参数。 -
_eos_token_tensor
:结束标记的token张量。 -
input_ids_seq_length
:输入序列的长度。
-
-
-
17. 处理抑制特定tokens(Suppress Tokens)
if generation_config.suppress_tokens is not None:
processors.append(
SuppressTokensLogitsProcessor(
generation_config.suppress_tokens,
device=device,
)
)
-
解释:
-
条件:如果
suppress_tokens
不为None
,则需要在生成过程中抑制特定的tokens,不让它们生成。suppress_tokens
是需要抑制的token IDs列表。
-
操作:向
processors
中添加一个SuppressTokensLogitsProcessor
实例。-
参数说明:
-
suppress_tokens
:需要抑制的token IDs。 -
device
:设备信息。
-
-
-
18. 处理在开头抑制特定tokens(Suppress Tokens at Begin)
if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
begin_index = (
begin_index
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
else begin_index + 1
)
processors.append(
SuppressTokensAtBeginLogitsProcessor(
generation_config.begin_suppress_tokens,
begin_index,
device=device,
)
)
-
解释:
-
条件:如果
begin_suppress_tokens
不为None
,则需要在生成的开头位置抑制特定的tokens。- 这用于避免模型在一开始生成某些不期望的tokens。
-
计算起始索引:
-
begin_index
初始值为input_ids_seq_length
,表示当前生成的位置。 -
如果
input_ids_seq_length <= 1
且forced_bos_token_id
不为None
,则begin_index += 1
。
-
-
操作:向
processors
中添加一个SuppressTokensAtBeginLogitsProcessor
实例。-
参数说明:
-
begin_suppress_tokens
:需要抑制的token IDs。 -
begin_index
:开始抑制的位置索引。 -
device
:设备信息。
-
-
-
19. 处理强制解码器IDs(Forced Decoder IDs)
if generation_config.forced_decoder_ids is not None:
# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT
raise ValueError(
"You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "
"in favour of `input_ids` or `decoder_input_ids` respectively.",
)
-
解释:
-
条件:如果
forced_decoder_ids
不为None
,则抛出异常。 -
原因:当前不支持
forced_decoder_ids
,建议用户使用input_ids
或decoder_input_ids
来替代。 -
备注:注释中提到,当TensorFlow和FLAX版本与PyTorch版本对齐后,可以将此异常移动到
GenerationConfig.validate()
中。
-
20. 合并用户自定义的logits_processor
# TODO (joao): find a strategy to specify the order of the processors
processors = self._merge_criteria_processor_list(processors, logits_processor)
-
解释:
-
操作:调用
self._merge_criteria_processor_list()
方法,将之前构建的processors
列表与用户自定义的logits_processor
进行合并。 -
备注:注释中提到需要找到一种策略来指定处理器的顺序。
-
21. 处理采样策略下的LogitsWarper
# 只有在使用采样策略时,才应用之前被称为`LogitsWarpers`的处理器
if generation_config.do_sample:
# 在beam方法中,我们需要至少保留一个非eos token,以探索可能具有更好得分的连续序列
if generation_config.num_beams > 1:
if isinstance(generation_config._eos_token_tensor, list):
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1
# 以下思想主要来自PR:https://github.com/huggingface/transformers/pull/5420/files
# 所有的sampler都在`generation_utils_samplers.py`中
if generation_config.temperature is not None and generation_config.temperature != 1.0:
processors.append(TemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
processors.append(
TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.top_p is not None and generation_config.top_p < 1.0:
processors.append(
TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.min_p is not None:
# 在温度缩放之后应用(见:https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
processors.append(
MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
processors.append(
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
processors.append(
EpsilonLogitsWarper(
epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
processors.append(
EtaLogitsWarper(
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
)
)
-
解释:
-
条件:只有在
do_sample
为True
时,才应用这些处理器,因为它们与采样策略相关。 -
计算
min_tokens_to_keep
:-
在束搜索等方法中,需要保留至少一个非结束token,以确保能够探索可能更好的序列。
-
根据结束token的类型和数量,计算需要保留的最小token数量。
-
-
添加采样相关的
LogitsWarper
:-
根据
generation_config
中的配置,向processors
中添加相应的LogitsWarper
,包括:-
TemperatureLogitsWarper
:调整温度参数,控制生成的随机性。 -
TopKLogitsWarper
:仅保留概率最高的top_k
个tokens。 -
TopPLogitsWarper
:仅保留累计概率达到top_p
的tokens。 -
MinPLogitsWarper
:应用最小概率阈值。 -
TypicalLogitsWarper
:使用Typical采样策略。 -
EpsilonLogitsWarper
和EtaLogitsWarper
:应用epsilon或eta截断。
-
-
-
22. 处理水印(Watermarking)
# Watermarking应该在所有logits处理完成后再应用(参见#34630)
if generation_config.watermarking_config is not None:
processors.append(
generation_config.watermarking_config.construct_processor(self.config.vocab_size, device)
)
-
解释:
-
条件:如果
watermarking_config
不为None
,则需要在生成过程中应用水印策略。- 水印策略通常用于在生成的文本中嵌入隐式标记,以便于后续识别生成内容。
-
操作:向
processors
中添加一个由watermarking_config
构建的处理器。- 使用模型的词汇表大小
vocab_size
和设备信息device
。
- 使用模型的词汇表大小
-
备注:水印处理器应当在所有logits处理完成后再应用。
-
23. 处理Logits重归一化(Logit Renormalization)
# `LogitNormalization`应该始终是最后一个logit处理器(如果存在的话)
if generation_config.renormalize_logits is True:
processors.append(LogitNormalization())
-
解释:
-
条件:如果
renormalize_logits
为True
,则需要在处理完所有logits后重新归一化。- 这用于确保logits在所有处理后仍然形成有效的概率分布。
-
操作:向
processors
中添加一个LogitNormalization
实例。- 备注:
LogitNormalization
应当始终是最后应用的logits处理器。
- 备注:
-
24. 返回处理器列表
return processors
- 解释:返回构建好的
LogitsProcessorList
对象processors
,供生成过程使用。
_get_stopping_criteria
def _get_stopping_criteria(
self,
generation_config: GenerationConfig,
stopping_criteria: Optional[StoppingCriteriaList],
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
**kwargs,
) -> StoppingCriteriaList:
# 方法体...
参数说明
generation_config
:GenerationConfig
对象,包含生成过程中所需的各种配置参数,如max_length
、max_time
、stop_strings
等。stopping_criteria
:可选的StoppingCriteriaList
对象,用户可以传入自定义的停止条件列表,与默认的停止条件合并。tokenizer
:可选的PreTrainedTokenizerBase
对象,模型对应的 tokenizer,用于处理stop_strings
等需要 tokenizer 的功能。**kwargs
:其他可能的关键字参数,供扩展使用。
返回值
StoppingCriteriaList
:该方法返回一个StoppingCriteriaList
对象,包含生成过程中需要检查的停止条件。
该方法的主要作用是:
-
创建默认的停止条件列表:根据
generation_config
中的配置,生成对应的停止条件(如最大长度、最大时间、特殊字符串、结束标记等)。 -
合并用户提供的停止条件:如果用户在
stopping_criteria
中提供了自定义的停止条件,方法会将其与默认的停止条件合并。 -
返回完整的停止条件列表:生成过程会根据这个列表,在满足任何一个停止条件时结束生成。
步骤 1:初始化停止条件列表
criteria = StoppingCriteriaList()
- 说明:创建一个空的
StoppingCriteriaList
对象,用于存储后续添加的停止条件。
步骤 2:处理 max_length
停止条件
if generation_config.max_length is not None:
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
criteria.append(
MaxLengthCriteria(
max_length=generation_config.max_length,
max_position_embeddings=max_position_embeddings,
)
)
-
解释:
-
获取
max_length
:从generation_config
中获取max_length
,即生成的最大长度。 -
获取模型的最大位置嵌入长度:
-
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
-
从模型配置中获取
max_position_embeddings
,表示模型能处理的最大序列长度。
-
-
添加
MaxLengthCriteria
:-
创建一个
MaxLengthCriteria
对象,传入max_length
和max_position_embeddings
。 -
将其添加到
criteria
列表中。
-
-
-
作用:当生成的序列长度达到
max_length
或模型的最大位置嵌入长度时,停止生成。
步骤 3:处理 max_time
停止条件
if generation_config.max_time is not None:
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
-
解释:
-
获取
max_time
:从generation_config
中获取max_time
,即生成的最长时间(以秒为单位)。 -
添加
MaxTimeCriteria
:-
创建一个
MaxTimeCriteria
对象,传入max_time
。 -
将其添加到
criteria
列表中。
-
-
-
作用:当生成过程运行时间超过
max_time
秒时,停止生成。
步骤 4:处理 stop_strings
停止条件
if generation_config.stop_strings is not None:
if tokenizer is None:
raise ValueError(
"There are one or more stop strings, either in the arguments to `generate` or in the "
"model's generation config, but we could not locate a tokenizer. When generating with "
"stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
)
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
-
解释:
-
获取
stop_strings
:从generation_config
中获取stop_strings
,这是一个字符串列表,包含用于停止生成的特殊字符串。 -
检查
tokenizer
:因为处理字符串需要使用 tokenizer,如果tokenizer
为None
,抛出ValueError
。 -
添加
StopStringCriteria
:-
创建一个
StopStringCriteria
对象,传入stop_strings
和tokenizer
。 -
将其添加到
criteria
列表中。
-
-
-
作用:当生成的文本中出现指定的字符串时,停止生成。
步骤 5:处理 eos_token_id
(结束标记)停止条件
if generation_config._eos_token_tensor is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
-
解释:
-
获取
_eos_token_tensor
:从generation_config
中获取_eos_token_tensor
,即结束标记的 token ID。 -
添加
EosTokenCriteria
:-
创建一个
EosTokenCriteria
对象,传入eos_token_id
。 -
将其添加到
criteria
列表中。
-
-
-
作用:当生成的序列中出现结束标记时,停止生成。
步骤 6:处理辅助模型的置信度停止条件(可选)
if (
generation_config.is_assistant
and generation_config.assistant_confidence_threshold is not None
and generation_config.assistant_confidence_threshold > 0
):
criteria.append(
ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold)
)
-
解释:
-
条件检查:
-
generation_config.is_assistant
:判断是否使用了辅助模型。 -
generation_config.assistant_confidence_threshold
不为空且大于 0。
-
-
添加
ConfidenceCriteria
:-
创建一个
ConfidenceCriteria
对象,传入assistant_confidence_threshold
。 -
将其添加到
criteria
列表中。
-
-
-
作用:当辅助模型的置信度达到一定阈值时,停止生成。
步骤 7:合并用户提供的停止条件
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
-
解释:
-
合并默认和用户提供的停止条件:
-
调用
_merge_criteria_processor_list
方法,将默认的criteria
与用户提供的stopping_criteria
合并。 -
这个方法通常会去除重复的条件,或者根据某些规则进行合并。
-
-
-
作用:确保生成过程中考虑所有相关的停止条件。
步骤 8:返回最终的停止条件列表
return criteria
- 解释:返回包含所有停止条件的
StoppingCriteriaList
,供生成过程使用。
整体流程总结
-
目的:为生成过程准备一个完整的停止条件列表,当满足任何一个条件时,生成过程将停止。
-
处理逻辑:
-
初始化:创建一个空的停止条件列表。
-
添加默认停止条件:根据
generation_config
中的配置,添加对应的停止条件,包括最大长度、最大时间、结束标记、特殊字符串等。 -
辅助模型条件:如果使用了辅助模型,并且设置了置信度阈值,添加对应的停止条件。
-
合并用户停止条件:将用户提供的
stopping_criteria
合并到默认条件列表中。 -
返回:输出完整的停止条件列表。
-