做网站用什么编程语言好百度官方网站下载安装
generate
以下是对您提供的 generate
方法的详细解释。这个方法用于大型语言模型(LLM)中的文本生成,尤其是具有语言模型头的模型。该方法包含了多个复杂的逻辑,支持多种生成模式,如贪心搜索、采样、束搜索等。
方法定义
@torch.no_grad()
def generate(self,inputs: Optional[torch.Tensor] = None,generation_config: Optional[GenerationConfig] = None,logits_processor: Optional[LogitsProcessorList] = None,stopping_criteria: Optional[StoppingCriteriaList] = None,prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,synced_gpus: Optional[bool] = None,assistant_model: Optional["PreTrainedModel"] = None,streamer: Optional["BaseStreamer"] = None,negative_prompt_ids: Optional[torch.Tensor] = None,negative_prompt_attention_mask: Optional[torch.Tensor] = None,**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:...
参数说明
inputs
:可选,输入张量,形状可能因模态而异。用于作为生成的提示或编码器的输入。generation_config
:可选,GenerationConfig
对象,包含生成时的参数配置。logits_processor
:可选,LogitsProcessorList
对象,自定义的 logits 处理器列表,用于在生成过程中调整 logits。stopping_criteria
:可选,StoppingCriteriaList
对象,自定义的停止标准列表,用于在满足条件时终止生成。prefix_allowed_tokens_fn
:可选,函数,用于在每一步生成时限制允许的 token。synced_gpus
:可选,布尔值,指示是否在多 GPU 环境下同步运行以避免死锁。assistant_model
:可选,用于加速生成的辅助模型,必须具有相同的 tokenizer。streamer
:可选,用于流式输出生成序列的对象。negative_prompt_ids
:可选,torch.LongTensor
,形状为(batch_size, sequence_length)
,用于一些处理器(如 CFG)的负提示。negative_prompt_attention_mask
:可选,torch.LongTensor
,形状为(batch_size, sequence_length)
,对应negative_prompt_ids
的 attention mask。kwargs
:其他参数,可用于覆盖generation_config
中的设置,或传递给模型的forward
方法。
返回值
GenerateOutput
或torch.LongTensor
:根据参数设置,返回生成的序列或者包含生成过程详细信息的输出对象。
方法逻辑解析
总体流程
- 处理生成配置和参数验证:确保
generation_config
和kwargs
的正确性。 - 设置生成参数:根据传入的配置或默认值,设置生成所需的参数。
- 准备模型输入:处理输入张量,生成
input_ids
,并管理模型需要的其他关键字参数。 - 确定生成模式:根据配置,选择合适的生成模式,例如贪心搜索、采样、束搜索等。
- 生成序列:调用相应的生成函数,生成目标序列。
- 返回结果:根据参数设置,返回生成的序列或包含更多信息的对象。
详细步骤
1. 处理生成配置和参数验证
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None)
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None)
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
_validate_model_class
:检查模型的类型是否支持生成。_validate_model_class- 提取
tokenizer
:用于停止条件等,非必要参数。 - 准备
generation_config
和model_kwargs
:处理传入的配置和额外参数。_prepare_generation_config - 验证模型关键字参数:确保传入的参数与模型的预期一致。_validate_model_kwargs
- 验证辅助模型:如果使用了
assistant_model
,确保其与主模型兼容。
2. 设置生成参数
if synced_gpus is None:synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
synced_gpus
:在多 GPU 环境下,判断是否需要同步。logits_processor
和stopping_criteria
:初始化或使用默认的处理器和停止标准。- StoppingCriteria代码分析
- LogitsProcessor代码分析
- 注意力掩码相关变量:检查模型的
forward
方法是否接受attention_mask
,以及当前参数中是否提供。
3. 准备模型输入
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
_prepare_model_inputs
:处理输入张量,获取模型输入名称,并更新model_kwargs
。_prepare_model_inputsbatch_size
和device
:获取批次大小和设备信息。- 准备特殊 token:如
bos_token_id
、eos_token_id
等。_prepare_special_tokens
4. 检查并处理注意力掩码
# decoder-only models must use left-padding for batched generation.if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.if (generation_config._pad_token_tensor is not Noneand batch_size > 1and len(inputs_tensor.shape) == 2and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0):logger.warning("A decoder-only architecture is being used, but right-padding was detected! For correct ""generation results, please set `padding_side='left'` when initializing the tokenizer.")
这段代码是关于为仅解码器架构(decoder-only models)处理输入时的填充方式建议。它检查是否使用了右填充(right-padding),在这种情况下给出警告。
-
架构类型检查:
self.config.is_encoder_decoder
用于检查模型是否属于编码器-解码器架构。- 如果模型不是编码器-解码器架构(即它是仅解码器架构),并且不是在TorchDynamo编译模式下,代码继续进行。
-
输入条件检查:
- 确保批处理大小大于1,即有多个序列在一起进行处理。
- 检查输入张量
inputs_tensor
的形状满足条件:它是二维张量,一般形式为[batch_size, sequence_length]
。 - 检查序列中的最后一个标识符是否为
pad_token_id
,指定的generation_config._pad_token_tensor
不为None
。 torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
用于确定至少有一个序列的最后一个令牌是填充令牌。
-
警告日志:
- 如果满足以上条件,显示警告信息。
- 提示在仅解码器模型中检测到右填充,它建议使用左填充(
padding_side='left'
),以获取正确的生成结果。
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are# generating the first new token or not, and we only want to use the embeddings for the first new token)if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":generation_config.use_cache = True
这段代码针对仅解码器架构(decoder-only models)进行了一种配置调整,特别是在使用 inputs_embeds
作为输入的时候。以下是代码的简单说明:
-
架构类型检查:
not self.config.is_encoder_decoder
用于判断模型是否是仅解码器架构。- 如果模型是仅解码器架构,代码继续进行。
-
输入类型检查:
model_input_name == "inputs_embeds"
检查输入类型是否为嵌入层(embeddings)。inputs_embeds
通常表示已经经过词嵌入层的输入,这意味着模型接收的不是直接的令牌ID,而是对应的词向量。
-
缓存使用设置:
generation_config.use_cache = True
设置生成配置的use_cache
属性为True
。- 启用缓存对于跟踪生成的序列特别重要,尤其是在无法确定新生成的第一个令牌是否已生成时。
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(inputs_tensor, generation_config, model_kwargs)
elif kwargs_has_attention_mask:if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:raise ValueError("`attention_mask` passed to `generate` must be 2D.")
- 如果需要
attention_mask
:且未提供,则生成默认的attention_mask
。 - 验证提供的
attention_mask
:确保其形状正确。 - _prepare_attention_mask_for_generation
5. 为编码器-解码器模型准备输入
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs_tensor, model_kwargs, model_input_name, generation_config)
- 对于编码器-解码器模型:如果未提供
encoder_outputs
,则进行编码器的前向计算,并更新model_kwargs
。计算编码器部分
6. 准备用于自回归生成的 input_ids
if self.config.is_encoder_decoder:input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(batch_size=batch_size,model_input_name=model_input_name,model_kwargs=model_kwargs,decoder_start_token_id=generation_config._decoder_start_token_tensor,device=inputs_tensor.device,)
else:input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
- 编码器-解码器模型:准备解码器的
input_ids
。 - _prepare_decoder_input_ids_for_generation
- _decoder_start_token_tensor
- 解码器模型:直接使用
inputs_tensor
作为input_ids
。
7. 处理特殊的生成配置
if generation_config.token_healing:input_ids = self.heal_tokens(input_ids, tokenizer)
if streamer is not None:streamer.put(input_ids.cpu())
token_healing
:如果启用了此项配置,修复输入中的 tokens。token_healingstreamer
:如果提供了流式处理器,传递当前的input_ids
。
8. 准备生成长度相关的参数
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(generation_config=generation_config,has_default_max_length=has_default_max_length,has_default_min_length=has_default_min_length,model_input_name=model_input_name,inputs_tensor=inputs_tensor,input_ids_length=input_ids_length,
)
- 计算输入序列长度。
- 确定是否使用默认的最大和最小长度。
- 准备生成长度配置,可能会根据输入长度进行调整。_prepare_generated_length
9. 准备缓存和其他模型参数
if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:model_kwargs["logits_to_keep"] = 1
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
logits_to_keep
:如果模型支持,仅保留需要的 logits,减少内存占用。_supports_logits_to_keep- 验证生成长度:确保生成长度的合法性。_validate_generated_length
# 7. Prepare the cache.# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.# - different models have a different cache name expected by the model (default = "past_key_values")# - `max_length`, prepared above, is used to determine the maximum cache lengthmax_cache_length = generation_config.max_length - 1if (inputs_tensor.shape[1] != input_ids_lengthand model_input_name == "inputs_embeds"and not self.config.is_encoder_decoder):max_cache_length += inputs_tensor.shape[1]self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device)
- 准备缓存:为生成过程中的缓存(如注意力缓存)分配空间。
_prepare_cache_for_generation
10. 确定生成模式
generation_mode = generation_config.get_generation_mode(assistant_model)
- 根据生成配置和辅助模型,确定生成模式,例如:
- 辅助生成(Assisted Generation)
- DoLa 生成(DOLA Generation)
- 对比搜索(Contrastive Search)
- 采样或贪心搜索
- 束搜索(Beam Search)
- 组束搜索(Group Beam Search)
- 受限束搜索(Constrained Beam Search)
11. 准备 logits 处理器和停止标准
prepared_logits_processor = self._get_logits_processor(generation_config=generation_config,input_ids_seq_length=input_ids_length,encoder_input_ids=inputs_tensor,prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,logits_processor=logits_processor,device=inputs_tensor.device,model_kwargs=model_kwargs,negative_prompt_ids=negative_prompt_ids,negative_prompt_attention_mask=negative_prompt_attention_mask,
)
prepared_stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
)
model_kwargs["use_cache"] = generation_config.use_cache
- 获取 logits 处理器:整合默认和自定义的 logits 处理器,用于在生成过程中调整 logits。
- 获取停止标准:整合默认和自定义的停止标准,用于在满足条件时终止生成。
- 设置
use_cache
:根据配置,决定是否在生成过程中使用缓存。
_get_logits_processor
_get_stopping_criteria
12. 根据生成模式调用相应的生成函数
- 辅助生成
if generation_mode == GenerationMode.ASSISTED_GENERATION:# 验证条件# 获取候选生成器# 执行辅助生成
- DoLa 生成
elif generation_mode == GenerationMode.DOLA_GENERATION:# 执行 DoLa 解码
- 对比搜索
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:# 执行对比搜索
- 采样或贪心搜索
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):# 扩展 input_ids# 执行采样或贪心搜索
- 束搜索
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):# 准备束搜索评分器# 扩展 input_ids# 执行束搜索
GenerationMixin:_sample方法(GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH)
- 组束搜索
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:# 准备组束搜索评分器# 扩展 input_ids# 执行组束搜索
- 受限束搜索
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:# 准备约束条件# 准备受限束搜索评分器# 扩展 input_ids# 执行受限束搜索
13. 处理生成结果
# 如果需要,将缓存转换为传统格式
if (generation_config.return_legacy_cache is Trueand not is_torchdynamo_compiling()and hasattr(result, "past_key_values")and getattr(result.past_key_values, "to_legacy_cache") is not None
):result.past_key_values = result.past_key_values.to_legacy_cache()
return result
- 转换缓存格式:如果配置需要,将生成过程中使用的缓存转换为传统格式。
- 返回结果:最终将生成的结果返回。
_validate_model_class
函数功能概述
这个函数名为_validate_model_class
,用于验证当前的模型类是否支持生成(generation)操作。如果不支持生成,则会抛出一个异常,提示用户使用合适的模型类。
-
条件判断
if not is_torchdynamo_compiling() and not self.can_generate():
- 这里有一个
if
条件,用于检查两个条件是否同时满足:not is_torchdynamo_compiling()
not self.can_generate()
is_torchdynamo_compiling()
:- 这是一个函数,检查当前是否处于 TorchDynamo 的编译环境中。
- TorchDynamo 是 PyTorch 的一个 JIT 编译器框架,可以对模型进行编译优化。
- 在编译过程中,某些检查可能需要被跳过,因此在编译时不进行此验证。
self.can_generate()
:- 这是模型类的一个方法,返回布尔值,表示模型是否支持生成操作。
- 如果模型具有语言模型头(language model head),则通常支持生成,即
self.can_generate()
返回True
。
因此,只有在不处于编译环境且模型不能生成的情况下,才会进入
if
语句内部,抛出异常。 - 这里有一个
-
定义支持生成的模型类名后缀列表
terminations_with_generation_support = ["ForCausalLM","ForConditionalGeneration","ForSpeechSeq2Seq","ForVision2Seq", ]
- 这是一个列表,包含了通常支持生成操作的模型类名称的后缀。
- 这些后缀包括:
"ForCausalLM"
:用于自回归语言模型,如 GPT-2、GPT-3 等。"ForConditionalGeneration"
:用于条件生成模型,如 BART、T5 等。"ForSpeechSeq2Seq"
:用于语音序列到序列模型。"ForVision2Seq"
:用于视觉到序列的模型,如图像描述生成。
-
抛出异常
raise TypeError(f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as ""it doesn't have a language model head. Classes that support generation often end in one of these "f"names: {terminations_with_generation_support}." )
- 如果条件满足,说明当前模型不支持生成,则抛出一个
TypeError
异常。 - 异常信息包括:
- 当前模型类的名称:
{self.__class__.__name__}
。 - 说明模型不兼容
.generate()
方法,因为它没有语言模型头。 - 提示支持生成的模型类通常以哪些后缀结尾:
{terminations_with_generation_support}
。
- 当前模型类的名称:
- 如果条件满足,说明当前模型不支持生成,则抛出一个
_prepare_generation_config
函数功能概述
这个函数名为_prepare_generation_config
,它的作用是准备生成所需的配置对象generation_config
,并应用从kwargs
(关键字参数)中传入的任何生成配置选项。这个函数还处理了与模型配置文件(config
)的向后兼容性。
参数和返回值
-
参数:
self
:类的实例。generation_config
:可选的GenerationConfig
对象,表示生成配置。如果为None
,则需要从其他地方获取或生成。**kwargs
:其他关键字参数,可能包含生成配置的参数,将用于更新generation_config
。
-
返回值:
generation_config
:最终准备好的生成配置对象。model_kwargs
:用于模型的关键字参数字典。
处理逻辑详解
-
初始设置
# 设置一个标志,指示是否使用模型的默认生成配置 using_model_generation_config = False
-
处理
generation_config
为None
的情况if generation_config is None:...
当用户没有提供
generation_config
时,需要从模型中获取。但在处理之前,先考虑到可能的向后兼容性问题。-
遗留(Legacy)行为的处理
# 遗留支持:用户可能修改了模型的配置来控制生成。要触发这种遗留行为,需要满足以下条件: # 1) generation_config 是从模型配置创建的(`_from_model_config`字段为True) # 2) generation_config 自创建以来没有被修改过(哈希值相同) # 3) 模型配置中有非默认的生成参数 # 4) 用户在模型配置中设置了新的生成参数 # 注意:`torch.compile`无法编译`hash`函数,因此在编译时,这种遗留支持被禁用 if (not is_torchdynamo_compiling()and self.generation_config._from_model_config # 条件1and self.generation_config._original_object_hash == hash(self.generation_config) # 条件2and len(self.config._get_non_default_generation_parameters()) > 0 # 条件3 ):new_generation_config = GenerationConfig.from_model_config(self.config)if new_generation_config != self.generation_config: # 条件4warnings.warn("You have modified the pretrained model configuration to control generation. This is a"" deprecated strategy to control generation and will be removed in v5."" Please use and modify the model generation configuration (see"" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",UserWarning,)self.generation_config = new_generation_config
-
解释条件:
self.generation_config._from_model_config
:- 检查
generation_config
是否是从模型配置创建的。
- 检查
self.generation_config._original_object_hash == hash(self.generation_config)
:- 检查
generation_config
自创建以来是否没有被修改过。 - 由于
torch.compile
无法编译hash
函数,因此在编译时无法进行此检查。
- 检查
len(self.config._get_non_default_generation_parameters()) > 0
:- 检查模型配置中是否有非默认的生成参数。
new_generation_config != self.generation_config
:- 检查新的生成配置是否与当前的不同,表示用户在模型配置中设置了新的生成参数。
-
操作:
- 如果以上条件都满足,表示用户通过修改模型配置来控制生成。这是一种已被弃用的做法。
- 发出警告,提示此策略将在v5版本中移除,建议用户使用并修改生成配置对象
generation_config
。 - 更新
self.generation_config
为新的生成配置new_generation_config
。
-
-
设置
generation_config
generation_config = self.generation_config using_model_generation_config = True
- 如果没有提供
generation_config
,则使用模型的默认生成配置self.generation_config
。 - 设置标志
using_model_generation_config = True
,表示正在使用模型的默认生成配置。
- 如果没有提供
-
-
处理
torch.compile
相关的问题# `torch.compile`无法编译`copy.deepcopy`等函数,因此需要根据是否在编译中,决定如何处理 if not is_torchdynamo_compiling():... else:model_kwargs = kwargs
-
非编译环境下的处理
if not is_torchdynamo_compiling():generation_config = copy.deepcopy(generation_config)model_kwargs = generation_config.update(**kwargs)...
-
深拷贝
generation_config
:- 使用
copy.deepcopy
创建generation_config
的深拷贝,以避免修改原始对象。 - 由于
torch.compile
无法编译copy.deepcopy
,因此在编译环境下无法进行此操作。
- 使用
-
更新
generation_config
:- 调用
generation_config.update(**kwargs)
方法,用传入的kwargs
更新生成配置。 - 这个方法返回未被
generation_config
使用的参数,即那些不属于生成配置的参数,存储在model_kwargs
中。
- 调用
-
处理特殊的Token ID:
# 如果提供了`generation_config`,需要确保所有特殊的Token ID都有默认值 if not using_model_generation_config:if generation_config.bos_token_id is None:generation_config.bos_token_id = self.generation_config.bos_token_idif generation_config.eos_token_id is None:generation_config.eos_token_id = self.generation_config.eos_token_idif generation_config.pad_token_id is None:generation_config.pad_token_id = self.generation_config.pad_token_idif generation_config.decoder_start_token_id is None:generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
- 如果用户提供了自己的
generation_config
(即不使用模型的默认生成配置),需要确保特殊的Token ID(开始、结束、填充、解码器开始)有默认值。 - 如果这些ID在用户提供的
generation_config
中为None
,则使用模型默认的self.generation_config
中的值。
- 如果用户提供了自己的
-
-
编译环境下的处理
else:model_kwargs = kwargs
- 在编译环境下,由于无法使用
copy.deepcopy
和hash
,直接将传入的kwargs
赋值给model_kwargs
。 - 不进行深拷贝和更新操作。
- 在编译环境下,由于无法使用
-
-
返回结果
return generation_config, model_kwargs
- 函数返回准备好的
generation_config
和model_kwargs
。 model_kwargs
包含了模型需要的其他参数。
- 函数返回准备好的
_validate_model_kwargs
函数功能概述
这个函数名为_validate_model_kwargs
,用于在生成(generation)过程中对传入的模型关键字参数model_kwargs
进行验证。它的主要作用是:
- 验证传入的
model_kwargs
是否被模型的生成方法使用,以确保参数的正确性,防止参数拼写错误或传入不支持的参数。 - 给出明确的错误信息,帮助用户及时发现并修正问题。
1. 检查past_key_values
是否为Cache
实例,并验证模型是否支持
if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:raise ValueError(f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please ""check the model documentation for supported cache formats.")
-
目的:如果
model_kwargs
中包含past_key_values
,并且它是一个Cache
实例,而模型不支持Cache
类型的缓存,则抛出错误。 -
解释:
model_kwargs.get("past_key_values", None)
:尝试获取past_key_values
的值,如果不存在则为None
。isinstance(..., Cache)
:检查past_key_values
是否是Cache
类的实例。not self._supports_cache_class
:检查当前模型是否不支持Cache
类型的缓存。
-
错误处理:如果以上条件满足,则抛出
ValueError
,提示用户当前模型不支持Cache
实例作为past_key_values
,并建议查看模型文档以了解支持的缓存格式。
2. 对于编码器-解码器模型,移除特定的参数
if self.config.is_encoder_decoder:for key in ["decoder_input_ids"]:model_kwargs.pop(key, None)
-
目的:在生成过程中,不需要直接传入
decoder_input_ids
,因此从model_kwargs
中移除该参数,防止后续产生错误。 -
解释:
self.config.is_encoder_decoder
:检查模型是否为编码器-解码器结构。model_kwargs.pop(key, None)
:从model_kwargs
中移除指定的键,如果不存在则返回None
。
3. 初始化未使用的模型参数列表和获取模型方法参数
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
-
unused_model_args
:用于收集未被模型方法使用的参数名。 -
model_args
:获取prepare_inputs_for_generation
方法的参数名集合,表示模型在生成过程中可能使用的参数。 -
解释:
inspect.signature(...)
:获取函数的完整参数签名。.parameters
:获取参数名组成的有序字典。set(...)
:将参数名转换为集合,方便后续操作。
4. 扩展模型参数集合,包含forward
方法的参数
if "kwargs" in model_args or "model_kwargs" in model_args:model_args |= set(inspect.signature(self.forward).parameters)
-
目的:如果
prepare_inputs_for_generation
方法接受kwargs
(可变关键字参数),则需要将self.forward
方法的参数也包含进来,因为这些参数可能会通过kwargs
传递。 -
解释:
- 检查
model_args
中是否包含"kwargs"
或"model_kwargs"
,表示接受可变关键字参数。 - 使用
set.union
(|=
)操作,将self.forward
方法的参数集合并入model_args
。
- 检查
5. 对于编码器-解码器模型,包含编码器和解码器的参数
if self.config.is_encoder_decoder:base_model = getattr(self, self.base_model_prefix, None)# 允许编码器的参数encoder = getattr(self, "encoder", None)# 特殊情况处理(如Musicgen模型)if encoder is None and base_model is not None:encoder = getattr(base_model, "encoder", None)if encoder is not None:encoder_model_args = set(inspect.signature(encoder.forward).parameters)model_args |= encoder_model_args# 允许解码器的参数decoder = getattr(self, "decoder", None)if decoder is None and base_model is not None:decoder = getattr(base_model, "decoder", None)if decoder is not None:decoder_model_args = set(inspect.signature(decoder.forward).parameters)model_args |= {f"decoder_{x}" for x in decoder_model_args}
-
目的:将编码器和解码器的
forward
方法的参数名添加到model_args
中,以便在验证model_kwargs
时考虑到这些参数。 -
详细解释:
-
base_model = getattr(self, self.base_model_prefix, None)
:获取基础模型实例,self.base_model_prefix
通常是模型的前缀名。 -
处理编码器:
encoder = getattr(self, "encoder", None)
:尝试直接获取self.encoder
属性。- 如果
encoder
为None
且base_model
存在,则尝试从base_model
中获取encoder
。 - 获取
encoder.forward
方法的参数名集合encoder_model_args
。 - 将
encoder_model_args
并入model_args
。
-
处理解码器:
- 类似地,获取
decoder
实例。 - 获取
decoder.forward
方法的参数名集合decoder_model_args
。 - 将
decoder_model_args
的参数名加上前缀"decoder_"
,以匹配生成过程中参数的命名方式。 - 将修改后的
decoder_model_args
并入model_args
。
- 类似地,获取
-
-
特殊情况说明:
- 对于某些模型(例如
MusicgenForConditionalGeneration
),编码器和解码器的命名和结构可能与标准情况不同,需要特别处理。
- 对于某些模型(例如
6. 检查model_kwargs
中的未使用参数
for key, value in model_kwargs.items():if value is not None and key not in model_args:unused_model_args.append(key)
-
目的:找出
model_kwargs
中那些未被模型方法使用的参数。 -
解释:
- 遍历
model_kwargs
中的所有键值对。 - 如果参数值不为
None
且参数名不在model_args
集合中,说明模型并不会使用该参数。 - 将这些未使用的参数名添加到
unused_model_args
列表中。
- 遍历
7. 抛出错误提示未使用的参数
if unused_model_args:raise ValueError(f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"" generate arguments will also show up in this list)")
-
目的:如果存在未使用的参数,抛出
ValueError
,提醒用户这些参数未被模型使用,可能存在拼写错误或传入了不支持的参数。 -
解释:
- 检查
unused_model_args
列表是否非空。 - 抛出错误,错误信息中包含未使用的参数列表,并提示用户可能是由于参数名的拼写错误导致的。
- 检查
_prepare_model_inputs
方法定义
def _prepare_model_inputs(self,inputs: Optional[torch.Tensor] = None,bos_token_id: Optional[torch.Tensor] = None,model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:"""This function extracts the model-specific `inputs` for generation."""# 方法体...
参数说明:
inputs
:可选的torch.Tensor
,表示要作为模型输入的张量。可能是input_ids
、inputs_embeds
等形式,具体取决于模型的要求。bos_token_id
:可选的torch.Tensor
,表示序列开始的 token ID(BOS = Begin Of Sequence)。在生成任务中,如果未提供输入,可能需要使用该 token 进行初始化。model_kwargs
:可选的字典,包含传递给模型的其他关键字参数。
返回值:
inputs
:torch.Tensor
,准备好的模型输入张量。input_name
:str
,模型输入的名称,可能是input_ids
或inputs_embeds
。model_kwargs
:字典,更新后的模型关键字参数。
方法功能概述
该方法的主要目的是在生成过程中,准备和验证模型的输入,确保输入与模型的预期格式和要求一致。具体任务包括:
- 确定模型所需的主要输入名称(
input_name
),这可能取决于模型是编码器-解码器模型还是仅解码器模型。 - 处理传入的
inputs
和model_kwargs
,以避免重复传递相同的输入参数。 - 在需要时,使用
bos_token_id
初始化input_ids
,以开始生成过程。 - 对于支持
inputs_embeds
的模型,正确处理inputs_embeds
参数。
逐步详解
步骤 1:确定模型的主要输入名称
# 判断模型是否是编码器-解码器,并获取正确的输入名称
if (self.config.is_encoder_decoderand hasattr(self, "encoder")and self.encoder.main_input_name != self.main_input_name
):input_name = self.encoder.main_input_name
else:input_name = self.main_input_name
解释:
- 目的:获取模型预期的主要输入参数名称,可能是
input_ids
、inputs_embeds
等。 - 逻辑:
- 如果模型是 编码器-解码器模型,并且编码器的
main_input_name
与模型的main_input_name
不同,则使用编码器的main_input_name
。- 这是因为编码器和解码器可能需要不同的输入名称。例如,一些模型的编码器可能使用
inputs_embeds
,而解码器使用input_ids
。
- 这是因为编码器和解码器可能需要不同的输入名称。例如,一些模型的编码器可能使用
- 否则,使用模型的
main_input_name
作为输入名称。
- 如果模型是 编码器-解码器模型,并且编码器的
示例:
- 如果
self.main_input_name
为'input_ids'
,而self.encoder.main_input_name
为'inputs_embeds'
,则对于编码器,需要使用'inputs_embeds'
作为输入名称。
步骤 2:从 model_kwargs
中提取非输入相关的参数
# 过滤掉模型输入的关键字参数,以及值为 None 的参数
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
解释:
- 目的:清理
model_kwargs
,只保留非模型输入的参数和值非None
的参数。 - 逻辑:
- 遍历
model_kwargs
中的所有键值对,保留以下情况的参数:- 值
v
不为None
。 - 或者键
k
不是模型的主要输入名称input_name
。
- 值
- 遍历
- 原因:避免模型输入参数(如
input_ids
)重复出现在inputs
和model_kwargs
中,以防止冲突。
示例:
- 如果
input_name
是'input_ids'
,并且model_kwargs
包含{'input_ids': None, 'attention_mask': tensor([...])}
,那么经过此步骤后,model_kwargs
变为{'attention_mask': tensor([...])}
。
步骤 3:检查模型输入是否作为关键字参数传入
# 从 model_kwargs 中获取输入参数
inputs_kwarg = model_kwargs.pop(input_name, None)
# 检查是否同时传入了 inputs 和 inputs_kwarg
if inputs_kwarg is not None and inputs is not None:raise ValueError(f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "f"Make sure to either pass {inputs} or {input_name}=...")
elif inputs_kwarg is not None:inputs = inputs_kwarg
解释:
- 目的:确保模型输入参数只通过一个途径传入,避免冲突。
- 逻辑:
- 从
model_kwargs
中取出键为input_name
的参数,赋值给inputs_kwarg
。 - 如果同时传入了
inputs
和inputs_kwarg
,抛出异常,提示用户只能通过一种方式传入模型输入。 - 如果
inputs
为None
,但inputs_kwarg
不为None
,则将inputs_kwarg
赋值给inputs
。
- 从
示例:
- 用户通过位置参数传入了
inputs
,同时在model_kwargs
中传入了input_ids
,这将导致异常,因为模型无法确定使用哪个输入。
步骤 4:处理 inputs_embeds
的情况
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:if not self.config.is_encoder_decoder:# 检查模型是否支持 inputs_embedshas_inputs_embeds_forwarding = "inputs_embeds" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())if not has_inputs_embeds_forwarding:raise ValueError(f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} ""doesn't have its forwarding implemented. See the GPT2 implementation for an example ""(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!")# 将 input_ids 初始化并加入 model_kwargsmodel_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs=model_kwargs)else:if inputs is not None:raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")# 更新 inputs 和 input_nameinputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
解释:
- 目的:处理用户通过
inputs_embeds
提供输入的情况,确保模型支持这种输入方式,并正确处理。 - 逻辑:
- 当模型的输入名称为
'input_ids'
,且model_kwargs
中存在'inputs_embeds'
键时,进入此逻辑。 - 对于非编码器-解码器模型:
- 检查模型的
prepare_inputs_for_generation
方法是否接受inputs_embeds
参数。 - 如果不支持,则抛出异常,提示模型不支持通过
inputs_embeds
进行生成。 - 如果支持,则需要初始化
input_ids
,以便在生成过程中处理诸如 attention mask 等依赖input_ids
的自动操作。 - 将初始化的
input_ids
添加到model_kwargs
中。
- 检查模型的
- 对于编码器-解码器模型:
- 如果同时传入了
inputs
和inputs_embeds
,抛出异常,提示只能选择一种输入方式。
- 如果同时传入了
- 最后,将
inputs
设置为inputs_embeds
,并更新input_name
为'inputs_embeds'
。
- 当模型的输入名称为
示例:
- 用户在生成时提供了
inputs_embeds
,如果模型支持,则将其用于生成,否则抛出异常。
步骤 5:如果 inputs
仍为 None
,尝试从 bos_token_id
创建 input_ids
# 如果 inputs 为 None,使用 bos_token_id 初始化 input_ids
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
解释:
- 目的:在用户未提供任何输入的情况下,使用开始标记
bos_token_id
初始化输入,以启动生成过程。 - 逻辑:
- 调用
_maybe_initialize_input_ids_for_generation
方法,如果inputs
为None
,则尝试使用bos_token_id
创建input_ids
。 - 如果
inputs
已存在,则保持不变。
- 调用
示例:
- 用户未提供
inputs
,模型使用bos_token_id
[101]
(假设为开始标记的 ID)初始化input_ids
。
步骤 6:返回处理后的结果
# 返回准备好的 inputs、input_name 和更新后的 model_kwargs
return inputs, input_name, model_kwargs
示例:
# 示例 1:用户只提供 inputs
inputs = torch.tensor([[101, 102, 103]])
inputs, input_name, model_kwargs = model._prepare_model_inputs(inputs=inputs)# 示例 2:用户只提供 input_ids 作为关键字参数
model_kwargs = {'input_ids': torch.tensor([[101, 102, 103]])}
inputs, input_name, model_kwargs = model._prepare_model_inputs(model_kwargs=model_kwargs)# 示例 3:用户提供 inputs_embeds,模型支持 inputs_embeds
model_kwargs = {'inputs_embeds': torch.tensor([...])}
inputs, input_name, model_kwargs = model._prepare_model_inputs(model_kwargs=model_kwargs)# 示例 4:用户未提供任何输入,使用 bos_token_id 初始化
bos_token_id = torch.tensor([101])
inputs, input_name, model_kwargs = model._prepare_model_inputs(bos_token_id=bos_token_id)
_prepare_special_tokens
主要目的是准备生成所需的特殊标记(如开始标记 bos_token_id
、结束标记 eos_token_id
、填充标记 pad_token_id
和解码器起始标记 decoder_start_token_id
),并将这些标记转换为张量。具体步骤如下:
-
定义辅助函数
_tensor_or_none
:- 将特殊标记转换为张量,如果标记为
None
则返回None
。
- 将特殊标记转换为张量,如果标记为
-
将特殊标记转换为张量:
- 使用
_tensor_or_none
函数将bos_token_id
、eos_token_id
、pad_token_id
和decoder_start_token_id
转换为张量。
- 使用
-
处理编码器-解码器模型:
- 如果模型是编码器-解码器类型,并且
decoder_start_token_id
未设置,则使用bos_token_id
作为decoder_start_token_id
。
- 如果模型是编码器-解码器类型,并且
-
处理
eos_token_tensor
:- 如果
eos_token_tensor
是 0 维张量,则将其扩展为 1 维张量。
- 如果
-
设置
pad_token_tensor
:- 如果
pad_token_tensor
未设置且eos_token_tensor
存在,则将pad_token_tensor
设置为eos_token_tensor
的第一个元素,并发出警告。
- 如果
-
安全检查和警告:
- 检查编码器-解码器模型是否设置了
decoder_start_token_id
。 - 检查
eos_token_tensor
是否与pad_token_tensor
相同,并在未设置注意力掩码时发出警告。 - 检查
eos_token_tensor
是否包含负数或浮点数,并发出警告。
- 检查编码器-解码器模型是否设置了
-
更新生成配置:
- 将转换后的特殊标记张量存储在
generation_config
的新属性中,以启用端到端编译。
- 将转换后的特殊标记张量存储在
_prepare_attention_mask_for_generation
该方法用于在生成过程中为模型准备 attention_mask
,以确保模型在生成序列时正确地关注输入序列的非填充部分(非 pad_token
部分)。
方法定义
def _prepare_attention_mask_for_generation(self,inputs_tensor: torch.Tensor,generation_config: GenerationConfig,model_kwargs: Dict[str, Any],
) -> torch.LongTensor:# 方法体...
参数说明:
inputs_tensor
:torch.Tensor
,输入张量,可能是模型的主要输入。generation_config
:GenerationConfig
,生成配置对象,包含生成过程中所需的配置参数。model_kwargs
:Dict[str, Any]
,包含模型其他关键字参数的字典。
返回值:
attention_mask
:torch.LongTensor
,返回生成过程中使用的attention_mask
,用于指示模型需要关注的输入序列位置。
方法功能概述
该方法的主要目的是根据输入张量 inputs_tensor
和生成配置 generation_config
,为生成过程准备合适的 attention_mask
。具体而言,它会根据以下情况生成 attention_mask
:
- 如果输入张量包含
pad_token_id
,则在attention_mask
中标记出非填充的位置(即非pad_token_id
的位置为 1,pad_token_id
的位置为 0)。 - 如果无法判断是否需要生成
attention_mask
,则返回默认的attention_mask
,即全 1 的张量,表示所有位置都需要关注。
逐步详解
步骤 1:获取 pad_token_id
和 eos_token_id
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
- 说明:从
generation_config
中获取用于填充的pad_token_id
和序列结束的eos_token_id
。
步骤 2:检查 model_kwargs
中是否有 input_ids
# `input_ids` 可能存在于 model_kwargs 中,而不是主要输入(例如多模态模型)
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:inputs_tensor = model_kwargs["input_ids"]
-
解释:
- 在某些情况下,
inputs_tensor
并不是模型的主要输入,如在多模态模型中,input_ids
可能存在于model_kwargs
中。 - 如果
model_kwargs
中有input_ids
,并且其形状的第二维度长度大于 0(即有数据),则用model_kwargs["input_ids"]
替换inputs_tensor
。
- 在某些情况下,
-
目的:确保
inputs_tensor
是正确的input_ids
,以便后续的attention_mask
计算。
步骤 3:创建默认的 attention_mask
# 如果无法推断 attention mask,则返回默认的 attention mask
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
-
解释:
- 创建一个形状与
inputs_tensor
前两个维度相同的全 1 张量,作为默认的attention_mask
。 - 该默认
attention_mask
表示输入序列的所有位置都需要关注。 - 切片的基本语法是 start: end:step,其中 start 是起始索引,end 是结束索引(不包括),step 是步长。写为 [:2] 是一种简写表示方式,没有显式指定起始和步长,默认从头开始且步长为1
- 创建一个形状与
步骤 4:如果 pad_token_id
为 None
,返回默认的 attention_mask
if pad_token_id is None:return default_attention_mask
-
说明:
- 如果
pad_token_id
未定义,则无法根据填充标记生成attention_mask
,因此直接返回默认的attention_mask
。
- 如果
步骤 5:检查 inputs_tensor
是否为有效的 input_ids
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
if not is_input_ids:return default_attention_mask
-
解释:
-
检查
inputs_tensor
是否满足以下条件:inputs_tensor
的形状为二维,即形状为(batch_size, sequence_length)
。inputs_tensor
的数据类型为整数类型torch.int
或torch.long
。
-
如果不满足上述条件,则认为无法根据
inputs_tensor
推断出attention_mask
,因此返回默认的attention_mask
。
-
步骤 6:检查 inputs_tensor
中是否包含 pad_token_id
is_pad_token_in_inputs = (pad_token_id is not None) and (isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
)
-
解释:
-
使用
isin_mps_friendly
函数检查inputs_tensor
中是否包含pad_token_id
。isin_mps_friendly(elements, test_elements)
:检查elements
中的元素是否在test_elements
中,返回布尔张量。
-
isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
:- 检查
inputs_tensor
中是否有元素等于pad_token_id
,如果有,则返回True
。
- 检查
-
-
结果:
is_pad_token_in_inputs
为一个布尔值,表示inputs_tensor
中是否包含pad_token_id
。
步骤 7:检查 pad_token_id
是否不等于 eos_token_id
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
)
-
解释:
-
需要确保
pad_token_id
和eos_token_id
不相等,以避免将eos_token_id
误认为填充标记。 -
逻辑:
-
如果
eos_token_id
为None
,则认为pad_token_id
不等于eos_token_id
。 -
如果
eos_token_id
不为None
,则检查eos_token_id
是否等于pad_token_id
:-
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
:- 检查
eos_token_id
是否等于pad_token_id
。
- 检查
-
使用按位取反运算符
~
,将结果取反。
-
-
-
-
结果:
is_pad_token_not_equal_to_eos_token_id
为一个布尔值,表示pad_token_id
是否不等于eos_token_id
。
步骤 8:确定是否可以推断 attention_mask
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
-
解释:
- 通过将
is_pad_token_in_inputs
和is_pad_token_not_equal_to_eos_token_id
相乘(布尔值相乘相当于逻辑与),判断是否可以基于pad_token_id
推断出attention_mask
。
- 通过将
-
结果:
-
can_infer_attention_mask
为布尔值:-
如果
True
,表示可以根据pad_token_id
推断attention_mask
。 -
如果
False
,则无法推断,需要返回默认的attention_mask
。
-
-
步骤 9:基于 pad_token_id
生成 attention_mask
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
-
解释:
-
使用
inputs_tensor.ne(pad_token_id)
,得到一个布尔张量,标记出inputs_tensor
中不等于pad_token_id
的位置。 -
.long()
:将布尔张量转换为长整型,即True
转换为1
,False
转换为0
。
-
-
结果:
attention_mask_from_padding
为一个长整型张量,形状与inputs_tensor
相同,非pad_token_id
的位置为1
,pad_token_id
的位置为0
。
步骤 10:根据是否可以推断 attention_mask
,选择最终的 attention_mask
attention_mask = (attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
-
解释:
-
如果可以推断
attention_mask
(can_infer_attention_mask
为True
):- 使用
attention_mask_from_padding
。
- 使用
-
如果不能推断:
- 使用
default_attention_mask
。
- 使用
-
计算方式:
-
attention_mask = attention_mask_from_padding * can_infer_attention_mask
:-
当
can_infer_attention_mask
为True
时,attention_mask
等于attention_mask_from_padding
。 -
当
can_infer_attention_mask
为False
时,乘积为0
。
-
-
default_attention_mask * ~can_infer_attention_mask
:-
~can_infer_attention_mask
对can_infer_attention_mask
取反。 -
当
can_infer_attention_mask
为False
时,~can_infer_attention_mask
为True
。 -
因此,当不能推断时,
attention_mask
等于default_attention_mask
。
-
-
最终,将两个结果相加,得到正确的
attention_mask
。
-
-
-
结果:
attention_mask
为最终的注意力掩码张量,表示模型在生成过程中需要关注的输入位置。
步骤 11:返回最终的 attention_mask
return attention_mask
-
解释:
- 返回生成的
attention_mask
张量,用于在生成过程中指导模型关注正确的输入序列位置。
- 返回生成的
_prepare_encoder_decoder_kwargs_for_generation
以下是对您提供的 _prepare_encoder_decoder_kwargs_for_generation
方法的详细解释。该方法用于在生成过程中为 编码器-解码器模型 准备必要的关键字参数,以便在生成序列时正确调用编码器。
方法定义
def _prepare_encoder_decoder_kwargs_for_generation(self,inputs_tensor: torch.Tensor,model_kwargs,model_input_name: Optional[str],generation_config: GenerationConfig,
) -> Dict[str, Any]:# 方法体...
参数说明:
inputs_tensor
:torch.Tensor
,输入张量,通常是input_ids
,表示输入序列的标记(tokens)。model_kwargs
:Dict[str, Any]
,包含传递给模型的其他关键字参数。model_input_name
:Optional[str]
,模型输入的名称,默认为None
。如果为None
,则使用self.main_input_name
。generation_config
:GenerationConfig
,生成配置对象,包含生成过程中所需的配置参数。
返回值:
model_kwargs
: 返回更新后的model_kwargs
,其中包括编码器的输出encoder_outputs
,以供生成器使用。
方法功能概述
该方法的主要目的是:
- 获取模型的 编码器 部分。
- 从
model_kwargs
和generation_config
中提取 编码器所需的参数,并准备传递给编码器的关键字参数encoder_kwargs
。 - 调用 编码器的
forward
方法,获取编码器的输出,并将其添加到model_kwargs
中,以供 解码器 在生成过程中使用。
逐步详解
步骤 1:获取编码器
# 1. get encoder
encoder = self.get_encoder()
-
说明:
- 调用模型的
get_encoder()
方法,获取编码器对象。 - 该编码器将用于处理输入的
inputs_tensor
,生成编码器的输出。
- 调用模型的
步骤 1.1:兼容性处理
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(self, "hf_device_map"):if hasattr(encoder, "_hf_hook"):encoder._hf_hook.io_same_device = Trueelse:add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
-
解释:
- 目的:确保在使用 Accelerate 库进行大型模型推理时,编码器的输出与输入位于 同一设备 上(如 GPU),避免跨设备的数据传输开销。
-
逻辑:
- 检查模型是否具有
hf_device_map
属性,如果存在,表示模型使用了 Accelerate 库进行设备映射。 - 检查编码器是否具有
_hf_hook
属性:- 如果有,设置其
io_same_device
属性为True
,表示编码器的输入和输出在同一设备上。 - 如果没有,使用
add_hook_to_module
函数,将AlignDevicesHook(io_same_device=True)
添加到编码器模块上。
- 如果有,设置其
- 检查模型是否具有
-
相关函数:
add_hook_to_module(module, hook)
: 将钩子函数添加到指定的模块上,控制模块的输入输出行为。
步骤 2:准备编码器的参数
# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {argument: valuefor argument, value in model_kwargs.items()if not any(argument.startswith(p) for p in irrelevant_prefix)
}
-
目的:
- 从
model_kwargs
中提取与编码器相关的参数,过滤掉与解码器或交叉注意力相关的参数。
- 从
-
逻辑:
- 定义一个列表
irrelevant_prefix
,包含了不相关的参数前缀,如"decoder_"
、"cross_attn"
、"use_cache"
。 - 使用字典推导式,从
model_kwargs
中过滤掉以这些前缀开头的参数。 - 结果是
encoder_kwargs
,其中包含了需要传递给编码器的参数。
- 定义一个列表
-
示例:
-
如果
model_kwargs
包含:model_kwargs = {"input_ids": tensor(...),"attention_mask": tensor(...),"decoder_input_ids": tensor(...),"use_cache": True, }
-
过滤后,
encoder_kwargs
为:encoder_kwargs = {"input_ids": tensor(...),"attention_mask": tensor(...), }
-
步骤 2.1:检查编码器的签名
encoder_signature = set(inspect.signature(encoder.forward).parameters)
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
if not encoder_accepts_wildcard:encoder_kwargs = {argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature}
-
解释:
- 目的:确保传递给编码器的参数在其
forward
方法的参数列表中,即编码器能够接受这些参数。
- 目的:确保传递给编码器的参数在其
-
逻辑:
- 使用
inspect.signature(encoder.forward).parameters
获取编码器forward
方法的参数名称集合encoder_signature
。 - 检查编码器是否接受通配参数
**kwargs
或**model_kwargs
,如果接受,则无需进一步过滤参数。 - 如果编码器不接受通配参数,则过滤
encoder_kwargs
,仅保留在encoder_signature
中的参数。
- 使用
步骤 2.2:添加生成配置中的参数
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
-
说明:
- 从
generation_config
中提取output_attentions
和output_hidden_states
配置,添加到encoder_kwargs
中。 - 这些配置决定了编码器在前向计算时,是否返回注意力权重和隐藏状态。
- 从
步骤 3:调用编码器并更新 model_kwargs
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore
return model_kwargs
-
解释:
-
确保编码器返回
ModelOutput
对象:- 设置
encoder_kwargs["return_dict"] = True
,使编码器的输出为ModelOutput
格式,而不是元组。
- 设置
-
准备编码器的输入:
- 确定模型的输入名称
model_input_name
,如果传入了model_input_name
,则使用它,否则使用self.main_input_name
。 - 将
inputs_tensor
添加到encoder_kwargs
,键为model_input_name
。
- 确定模型的输入名称
-
调用编码器的
forward
方法:model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore
- 调用编码器,传入
encoder_kwargs
,得到编码器的输出encoder_outputs
。 - 将
encoder_outputs
添加到model_kwargs
中,键为"encoder_outputs"
。 - 使用类型注解
: ModelOutput
指定类型,这是为了让静态类型检查器知道encoder_outputs
的类型。
- 调用编码器,传入
-
返回更新后的
model_kwargs
:- 包含了
encoder_outputs
,供后续的解码器生成过程中使用。
- 包含了
-
_prepare_decoder_input_ids_for_generation
以下是对您提供的 _prepare_decoder_input_ids_for_generation
方法的详细解释。这个方法用于在生成过程中为 编码器-解码器模型 准备 decoder_input_ids
,确保解码器在生成时能够正确地开始。
方法定义
def _prepare_decoder_input_ids_for_generation(self,batch_size: int,model_input_name: str,model_kwargs: Dict[str, torch.Tensor],decoder_start_token_id: torch.Tensor,device: torch.device = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""# 方法体...
参数说明:
batch_size
:int
,表示批次大小,即一次输入中样本的数量。model_input_name
:str
,模型输入的名称,通常为"input_ids"
。model_kwargs
:Dict[str, torch.Tensor]
,包含传递给模型的其他关键字参数。decoder_start_token_id
:torch.Tensor
,解码器开始标记的 token ID。device
:torch.device
,可选,指定要在哪个设备上运行,如 GPU 或 CPU。如果未指定,则使用模型默认的设备。
返回值:
decoder_input_ids
:torch.LongTensor
,准备好的用于生成的解码器输入 IDs。model_kwargs
: 更新后的模型关键字参数字典。
方法功能概述
这个方法的主要作用是:
- 检查用户是否手动提供了
decoder_input_ids
,如果没有,则根据情况初始化它。 - 确保
decoder_start_token_id
的形状正确,并适应批次大小。 - 确保
decoder_input_ids
以特殊的开始标记(如 BOS token)开头,如果没有,则自动添加。 - 处理特定模型的例外情况,例如 “Donut” 和 “Whisper” 模型可能有不同的处理方式。
通过这些步骤,该方法确保了在生成过程中,解码器能够正确地开始生成序列。
逐步详解
步骤 1:检查用户是否提供了 decoder_input_ids
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:decoder_input_ids = model_kwargs.pop("decoder_input_ids")
elif "input_ids" in model_kwargs and model_input_name != "input_ids":decoder_input_ids = model_kwargs.pop("input_ids")
else:decoder_input_ids = None
解释:
-
目的: 确定是否需要初始化
decoder_input_ids
。 -
逻辑:
- 如果
model_kwargs
中存在decoder_input_ids
,则将其取出,并从model_kwargs
中删除,避免重复。 - 如果
model_kwargs
中存在input_ids
,且model_input_name
不是"input_ids"
,则将其视为decoder_input_ids
。这是为了方便某些模型的输入命名。 - 如果以上都不满足,则将
decoder_input_ids
设为None
,表示需要初始化。
- 如果
示例:
- 用户在调用生成方法时,可能通过
model_kwargs
提供了decoder_input_ids
,方法将直接使用它。 - 如果用户没有提供,方法将尝试从
input_ids
中获取(当编码器不使用input_ids
作为主要输入时)。 - 如果都没有提供,方法将在后续步骤中初始化
decoder_input_ids
。
步骤 2:调整 decoder_start_token_id
的形状
if device is None:device = self.device
if decoder_start_token_id.ndim == 1:if decoder_start_token_id.shape[0] != batch_size:raise ValueError(f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}")decoder_start_token_id = decoder_start_token_id.view(-1, 1)
else:decoder_start_token_id = (torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id)
解释:
-
目的: 确保
decoder_start_token_id
的形状为(batch_size, 1)
,即每个样本都有一个开始标记。 -
逻辑:
- 如果未指定
device
,则使用模型默认的设备self.device
。 - 检查
decoder_start_token_id
的维度:- 如果是一维张量(向量),即
decoder_start_token_id.ndim == 1
:- 检查其长度是否等于
batch_size
,否则抛出错误。 - 重塑为形状
(batch_size, 1)
,即每个样本一个开始标记。
- 检查其长度是否等于
- 如果不是一维张量,则认为是单个标量:
- 创建一个形状为
(batch_size, 1)
的张量,元素全为decoder_start_token_id
的值。
- 创建一个形状为
- 如果是一维张量(向量),即
- 如果未指定
示例:
- 如果
decoder_start_token_id
是标量101
(假设是开始标记的 ID):- 创建一个形状为
(batch_size, 1)
的张量,所有元素都是101
。
- 创建一个形状为
- 如果
decoder_start_token_id
是张量[101, 102]
,batch_size
为2
:- 将其重塑为:
[[101],[102]]
- 将其重塑为:
步骤 3:确保 decoder_input_ids
以开始标记开头
# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:decoder_input_ids = decoder_start_token_id
解释:
- 情况 1: 如果
decoder_input_ids
为None
,即用户未提供任何解码器输入:- 直接使用
decoder_start_token_id
作为decoder_input_ids
,表示解码器将从开始标记开始生成。
- 直接使用
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token...
elif "donut" in self.__class__.__name__.lower() or (self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
):pass
elif self.config.model_type in ["whisper"]:pass
解释:
- 情况 2: 针对特定模型的例外情况:
- Donut 模型:
- Donut 模型有特定的解码器开始标记,且不需要额外的 BOS(Begin of Sequence)标记。
- 如果模型名包含
"donut"
,则不对decoder_input_ids
进行处理。
- Whisper 模型:
- Whisper 模型也有自己的处理逻辑,不需要添加开始标记。
- Donut 模型:
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item():decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)if "decoder_attention_mask" in model_kwargs:decoder_attention_mask = model_kwargs["decoder_attention_mask"]decoder_attention_mask = torch.cat((torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),dim=-1,)model_kwargs["decoder_attention_mask"] = decoder_attention_mask
解释:
- 情况 3: 用户提供了
decoder_input_ids
,但其首个 token 并非decoder_start_token_id
:- 检查首个 token 是否等于
decoder_start_token_id
:(decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
:- 比较每个样本的
decoder_input_ids
首个 token 是否不等于对应的decoder_start_token_id
。 - 如果对于所有样本都不相等,则返回
True
。
- 比较每个样本的
- 如果首个 token 不匹配,则在
decoder_input_ids
前添加decoder_start_token_id
:- 使用
torch.cat
在维度-1
(序列长度维度)上拼接decoder_start_token_id
和decoder_input_ids
。
- 使用
- 调整
decoder_attention_mask
(如果提供):- 如果
model_kwargs
中存在decoder_attention_mask
:- 在其前面添加一个值为
1
的位置,表示新添加的开始标记需要被注意。 - 更新
model_kwargs["decoder_attention_mask"]
。
- 在其前面添加一个值为
- 如果
- 检查首个 token 是否等于
示例:
-
如果原始
decoder_input_ids
为:[[5, 6, 7],[8, 9, 10]]
-
decoder_start_token_id
为:[[101],[101]]
-
拼接后得到:
[[101, 5, 6, 7],[101, 8, 9, 10]]
-
同时调整
decoder_attention_mask
,在前面添加一个1
。
步骤 4:返回处理后的结果
return decoder_input_ids, model_kwargs
- 返回更新后的
decoder_input_ids
和model_kwargs
。
整体流程总结
-
输入处理:
- 检查用户是否提供了
decoder_input_ids
,如果没有,则需要初始化。 - 通过
model_kwargs
获取decoder_input_ids
或input_ids
,如果适用。
- 检查用户是否提供了
-
确保解码器开始标记的形状正确:
- 将
decoder_start_token_id
调整为形状(batch_size, 1)
,确保每个样本都有对应的开始标记。
- 将
-
确保
decoder_input_ids
以开始标记开头:- 如果
decoder_input_ids
为None
,直接使用decoder_start_token_id
。 - 对于特定模型(如 Donut 和 Whisper),保留用户提供的
decoder_input_ids
,不做修改。 - 如果用户提供的
decoder_input_ids
不以decoder_start_token_id
开头,自动在其前添加。 - 同时,调整
decoder_attention_mask
,确保新添加的开始标记在注意力掩码中被考虑。
- 如果
-
返回处理后的
decoder_input_ids
和model_kwargs
,供后续生成过程使用。
示例代码
示例 1:用户未提供 decoder_input_ids
# 假设 batch_size = 2
decoder_start_token_id = torch.tensor([101], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(batch_size=2,model_input_name="input_ids",model_kwargs={},decoder_start_token_id=decoder_start_token_id,device=device,
)
# 结果:decoder_input_ids = [[101], [101]]
示例 2:用户提供了 decoder_input_ids
,但未以开始标记开头
decoder_input_ids = torch.tensor([[5, 6, 7], [8, 9, 10]], device=device)
model_kwargs = {"decoder_input_ids": decoder_input_ids}
decoder_start_token_id = torch.tensor([101, 102], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(batch_size=2,model_input_name="input_ids",model_kwargs=model_kwargs,decoder_start_token_id=decoder_start_token_id,device=device,
)
# 结果:
# decoder_input_ids = [[101, 5, 6, 7], [102, 8, 9, 10]]
示例 3:针对 Donut 模型
# 设模型名称包含 "Donut",用户提供了 decoder_input_ids
self.__class__.__name__ = "DonutModel"
decoder_input_ids = torch.tensor([[5, 6, 7], [8, 9, 10]], device=device)
model_kwargs = {"decoder_input_ids": decoder_input_ids}
decoder_start_token_id = torch.tensor([101], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(batch_size=2,model_input_name="input_ids",model_kwargs=model_kwargs,decoder_start_token_id=decoder_start_token_id,device=device,
)
# 结果:
# decoder_input_ids 保持不变,不添加 decoder_start_token_id
注意事项
-
模型特定的例外情况: 对于某些模型,如 Donut 和 Whisper,需要特殊处理,不能盲目添加开始标记。
-
一致性检查: 确保
batch_size
与decoder_start_token_id
的长度一致,否则抛出异常。 -
处理注意力掩码: 如果修改了
decoder_input_ids
,且提供了decoder_attention_mask
,需要相应地调整掩码。 -
设备一致性: 所有张量都应该在同一设备上(CPU 或 GPU),方法中确保了这一点。
总结
该方法的主要功能是为生成过程准备合适的 decoder_input_ids
:
- 初始值设置: 如果用户未提供,则使用
decoder_start_token_id
初始化。 - 形状调整: 确保
decoder_start_token_id
的形状与批次大小匹配。 - 开头标记: 确保
decoder_input_ids
以特定的开始标记开头,以符合模型的要求。 - 特例处理: 对于特定模型,按其特殊需求处理。
通过这些步骤,保证了编码器-解码器模型在生成过程中能够正确地开始解码器的序列生成。
希望以上解释能够帮助您理解 _prepare_decoder_input_ids_for_generation
方法的功能和每个步骤的具体作用。如果您还有其他问题,欢迎继续提问!
heal_tokens
这个方法用于在生成过程中,对输入的 input_ids
进行修复(healing),以增强模型的输出质量。具体来说,它会根据输入序列的最后一个 token,寻找可能的扩展,并替换原始的尾部 token。这在某些情况下可以纠正模型生成中的不一致或错误。
方法定义
def heal_tokens(self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
) -> torch.LongTensor:r"""Generates sequences of token ids for models with a language modeling head.Parameters:input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation.tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids.Return:`torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension."""# 方法体...
参数说明:
input_ids
:torch.LongTensor
,形状为(batch_size, sequence_length)
,表示输入的 token 序列,用于生成的提示。tokenizer
:PreTrainedTokenizerBase
,可选参数,模型对应的 tokenizer,用于解码和编码 token IDs。
返回值:
torch.LongTensor
,形状与input_ids
相同,其中每个序列的尾部 token 被替换为了适当的扩展 token。
方法功能概述
该方法的主要目的是:
- 修复输入序列的尾部 token:对于每个输入序列,检查其最后一个 token,寻找可能的扩展 token,并替换之。
- 改进模型生成的连贯性:通过纠正输入序列的尾部,使得生成的序列在语义和形式上更加连贯。
- 处理空序列和特殊情况:在方法中包含了对空序列和特殊情况的处理,确保方法的稳健性。
逐步详解
步骤 1:验证 tokenizer
是否提供
if tokenizer is None:raise ValueError(" When generating with token healing, you must pass the model's tokenizer to the `tokenizer` ""argument of `generate`.")
- 解释:
- 方法要求必须提供
tokenizer
参数,否则无法进行解码和编码操作。 - 如果未提供
tokenizer
,则抛出ValueError
。
- 方法要求必须提供
步骤 2:获取特殊 token IDs 和构建词汇前缀树
bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id
vocab_trie = ExtensionsTrie(tokenizer.get_vocab())
generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id)
- 解释:
- 获取特殊 token IDs:
bos_token_id
: 序列开始的 token ID(BOS = Begin Of Sequence)。pad_token_id
: 填充 token 的 ID。
- 构建词汇前缀树:
tokenizer.get_vocab()
: 获取 tokenizer 的词汇表,返回一个字典{token: token_id}
。ExtensionsTrie
: 自定义的前缀树类,用于快速查找以某个前缀开始的所有 token。
- 创建生成配置:
GenerationConfig
: 配置生成过程的参数。max_new_tokens=1
: 生成的最大新 token 数设置为 1,因为我们只需要替换尾部 token。pad_token_id=pad_token_id
: 设置填充 token ID。
- 获取特殊 token IDs:
步骤 3:处理提示文本
# assumption: leading/trailing whitespace is not meaningful, so the prompts are
# stripped before re-tokenizing to desensitize generation to whitespace artefacts
prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
- 解释:
- 假设:首尾的空白字符(whitespace)对生成过程没有实质影响,因此在重新编码前去除。
- 步骤:
tokenizer.batch_decode
: 将input_ids
解码为字符串列表,跳过特殊 tokens。- 使用列表推导式,对每个字符串进行
strip
操作,去除首尾空白字符。 - 结果存储在
prompts
列表中。
步骤 4:重新编码输入序列
input_ids = tokenizer(prompts,return_tensors="pt",padding=True,
).input_ids.to(input_ids.device)
- 解释:
- 重新编码:
- 将处理后的
prompts
列表再次通过tokenizer
编码为input_ids
。 return_tensors="pt"
: 返回 PyTorch 张量。padding=True
: 对序列进行填充,以使它们的长度一致。
- 将处理后的
- 调整设备:
- 使用
.to(input_ids.device)
将新生成的input_ids
移动到与原始input_ids
相同的设备上(如 GPU 或 CPU)。
- 使用
- 重新编码:
步骤 5:替换序列中的 bos_token_id
为 pad_token_id
# replace bos with pad to not condition healing on it
input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
- 解释:
- 目的:
- 在后续的修复过程中,不希望序列开始的特殊标记(BOS)影响结果,因此将其替换为填充标记。
- 操作:
- 使用
torch.where
函数,将input_ids
中等于bos_token_id
的位置替换为pad_token_id
。
- 使用
- 目的:
步骤 6:检查 input_ids
是否为空
if input_ids.numel() == 0:return input_ids
- 解释:
- 目的:
- 如果
input_ids
为空(即没有元素),则无需进行后续处理,直接返回。
- 如果
- 检查:
input_ids.numel()
: 返回张量中元素的总数。- 如果为
0
,则返回input_ids
。
- 目的:
步骤 7:获取每个序列的尾部 token ID 和对应的 token
tail_ids = input_ids[:, -1].tolist()
space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
# tail tokens are used for a prefix search, thus, whitespaces are replaced with
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
- 解释:
- 获取尾部 token ID:
input_ids[:, -1]
: 获取每个序列的最后一个 token ID。.tolist()
: 转换为 Python 列表,方便后续处理。
- 处理空格 token:
tokenizer.convert_tokens_to_ids(" ")
: 将空格字符" "
转换为对应的 token ID。tokenizer.convert_ids_to_tokens
: 再将 token ID 转换为 token 字符串。[0]
: 获取结果中的第一个元素(列表中可能包含多个 token)。- 结果
space_tok
为空格对应的 token,例如在 BPE(Byte-Pair Encoding)中,空格可能对应特殊字符,例如'Ġ'
。
- 获取尾部 token 的字符串表示:
- 使用生成器表达式
(tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
:- 对每个
tail_id
:tokenizer.decode(t)
: 解码为字符串。.replace(" ", space_tok)
: 将字符串中的空格替换为space_tok
,以便于后续前缀搜索。
- 对每个
- 使用生成器表达式
- 获取尾部 token ID:
步骤 8:遍历每个序列,尝试修复尾部 token
for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):batch_ids = input_ids[batch_idx]if torch.all(batch_ids == pad_token_id).item():continue # skip empty sequences (all pad ids)
- 解释:
- 遍历序列:
- 使用
enumerate
对tail_ids
和tail_toks
进行遍历。 batch_idx
: 当前序列的索引。tail_id
: 当前序列的尾部 token ID。tail_tok
: 当前序列的尾部 token 字符串(经过特殊处理的)。
- 使用
- 获取当前序列的
input_ids
:batch_ids = input_ids[batch_idx]
: 获取当前序列的input_ids
。
- 检查序列是否为空:
torch.all(batch_ids == pad_token_id).item()
: 检查序列中的所有位置是否都是pad_token_id
。- 如果是,表示序列为空,跳过后续处理。
- 遍历序列:
步骤 9:构建可能的替代 token 的偏置字典
# apply bias for alternatives (extensions) to the tail token
"""
seq_bias key has to be tuple with int so have to use
tokenizer function to convert str to int
"""
seq_bias = {(tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
}
if len(seq_bias) == 1:continue # skip if there are no token alternatives to heal with
# slightly favor original token to limit aggressive healing e.g. 'http' -> 'https'
seq_bias[(tail_id,)] += 1.0
generation_config.update(sequence_bias=seq_bias)
- 解释:
- 寻找可能的扩展 tokens:
vocab_trie.extensions(prefix=tail_tok)
: 在词汇前缀树中,找到以tail_tok
为前缀的所有 token。
- 构建偏置字典
seq_bias
:- 键:以单个 token ID 组成的元组
(token_id,)
,因为后续需要的键是元组形式。 - 值:偏置值
10.0
,用于在生成过程中提升这些 tokens 的概率。 - 需要使用
tokenizer.convert_tokens_to_ids(alt_tok)
将 token 字符串转换为 token ID。
- 键:以单个 token ID 组成的元组
- 检查是否有可替代的 tokens:
- 如果
seq_bias
的长度为1
,表示只有原始的tail_id
,没有其他可替代的 tokens,跳过处理。
- 如果
- 稍微提升原始 token 的概率:
seq_bias[(tail_id,)] += 1.0
: 对原始的tail_id
,增加一个较小的偏置1.0
,以防止过度修复(例如,将'http'
修复为'https'
)。
- 更新生成配置:
generation_config.update(sequence_bias=seq_bias)
: 将偏置字典添加到生成配置中,供生成过程使用。
- 寻找可能的扩展 tokens:
步骤 10:准备生成输入,去除尾部 token
trimmed_ids = batch_ids[:-1]
"""
the latter code assumes trimmed_ids is not empty
so have to check its element count
"""
if trimmed_ids.numel() == 0:continue
# if the prompt is a single (non-pad) token, regenerate from bos
if len(batch_ids[batch_ids != pad_token_id]) == 1:trimmed_ids[-1] = bos_token_id
- 解释:
- 去除尾部 token:
trimmed_ids = batch_ids[:-1]
: 获取当前序列,去除最后一个 token,即tail_id
。
- 检查
trimmed_ids
是否为空:- 如果
trimmed_ids.numel() == 0
,表示序列长度为 0,无需处理,继续下一个序列。
- 如果
- 特殊情况处理:
- 如果序列只有一个非填充 token(除去
pad_token_id
后长度为 1):- 需要将
trimmed_ids
的最后一个位置替换为bos_token_id
,以重新从开始标记生成。
- 需要将
- 如果序列只有一个非填充 token(除去
- 去除尾部 token:
步骤 11:生成新的 token 并替换尾部 token
input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config)
- 解释:
- 调用生成方法:
self.generate
: 调用模型的generate
方法,生成新序列。trimmed_ids.unsqueeze(0)
: 为了适应批量维度,将trimmed_ids
添加一个维度,形状从(sequence_length - 1,)
变为(1, sequence_length - 1)
。generation_config=generation_config
: 使用之前配置的生成参数,包括偏置字典seq_bias
。
- 更新
input_ids
:- 将生成的新序列替换到
input_ids
中对应的位置。 - 注意,这里生成的序列长度为
sequence_length
,因为max_new_tokens=1
,所以会在trimmed_ids
的基础上生成一个新 token。
- 将生成的新序列替换到
- 调用生成方法:
步骤 12:返回修复后的 input_ids
return input_ids
- 解释:
- 返回处理后的
input_ids
,其中每个序列的尾部 token 已根据可能的扩展进行了修复。
- 返回处理后的
整体流程总结
-
准备工作:
- 验证
tokenizer
参数。 - 获取特殊 token IDs。
- 构建词汇前缀树
vocab_trie
,用于快速查找可能的扩展 token。 - 配置生成参数
generation_config
。
- 验证
-
处理输入序列:
- 将
input_ids
解码为字符串列表prompts
,去除首尾空白。 - 重新编码
prompts
为input_ids
,确保一致性。 - 将
bos_token_id
替换为pad_token_id
,避免对序列开始标记的影响。
- 将
-
遍历每个序列,尝试修复尾部 token:
- 获取每个序列的尾部 token ID 和对应的 token 字符串。
- 查找以尾部 token 为前缀的可能扩展 tokens。
- 构建偏置字典
seq_bias
,提升这些扩展 token 在生成过程中的概率。 - 去除序列的尾部 token,准备生成新的尾部 token。
- 调用
self.generate
方法,生成新的序列,并替换到input_ids
中。
-
返回处理后的
input_ids
。
_prepare_generated_length
这个方法用于在生成过程中,根据用户提供的生成配置和模型输入,准备和调整生成的最大长度 (max_length
) 和最小长度 (min_length
),以避免类似属性之间的冲突。
方法定义
def _prepare_generated_length(self,generation_config,has_default_max_length,has_default_min_length,model_input_name,input_ids_length,inputs_tensor,
):"""Prepared max and min length in generation configs to avoid clashes between similar attributes"""# 方法体...
参数说明:
-
generation_config
:GenerationConfig
对象,包含了生成过程中需要的各种配置参数,如max_length
、min_length
、max_new_tokens
、min_new_tokens
等。 -
has_default_max_length
:bool
类型,指示max_length
是否使用了默认值。如果为True
,表示用户未显式设置max_length
。 -
has_default_min_length
:bool
类型,指示min_length
是否使用了默认值。 -
model_input_name
:str
类型,模型输入的名称,通常为"input_ids"
或"inputs_embeds"
。 -
input_ids_length
:int
类型,输入序列input_ids
的长度,即输入的 token 数量。 -
inputs_tensor
:torch.Tensor
对象,模型的输入张量。
方法功能概述
该方法的主要作用是:
-
调整
max_length
和min_length
:根据用户提供的max_new_tokens
、min_new_tokens
、max_length
和min_length
,以及输入序列的长度,计算并设置最终的生成长度参数,确保生成过程按照预期进行。 -
避免冲突:如果用户同时设置了类似的属性(例如同时设置了
max_length
和max_new_tokens
),该方法会明确优先级,并在必要时发出警告,提示用户可能存在的冲突。 -
处理特殊情况:针对一些特殊的输入情况,例如使用了
inputs_embeds
、模型是编码器-解码器模型等,做出相应的调整。
逐步详解
1. 处理 max_length
if generation_config.max_new_tokens is not None:if not has_default_max_length and generation_config.max_length is not None:logger.warning(f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. ""Please refer to the documentation for more information. ""(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)")generation_config.max_length = generation_config.max_new_tokens + input_ids_length
解释:
-
情况 1:用户设置了
max_new_tokens
-
逻辑:
-
如果
generation_config.max_new_tokens
不为None
,表示用户希望通过max_new_tokens
来指定要生成的新 token 数量。 -
冲突处理:
-
如果用户同时设置了
max_length
(并且不是默认值),则发出警告,提示两者可能存在冲突,但以max_new_tokens
为准。 -
not has_default_max_length
:表示max_length
被用户显式设置了。
-
-
计算
max_length
:-
将
max_length
设置为input_ids_length + max_new_tokens
。-
input_ids_length
:输入序列的长度。 -
max_new_tokens
:用户希望生成的新 token 数量。
-
-
-
这样,
max_length
就表示生成的序列(包括输入和生成的部分)总长度。
-
-
示例:
- 如果
input_ids_length = 10
,max_new_tokens = 20
,则max_length = 10 + 20 = 30
。
# if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
# otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`
elif (model_input_name == "inputs_embeds"and input_ids_length != inputs_tensor.shape[1]and not self.config.is_encoder_decoder
):generation_config.max_length -= inputs_tensor.shape[1]
解释:
-
情况 2:模型输入是
inputs_embeds
,且存在输入长度不匹配-
逻辑:
-
条件判断:
-
model_input_name == "inputs_embeds"
:表示模型的输入是嵌入表示。 -
input_ids_length != inputs_tensor.shape[1]
:输入的input_ids
长度与inputs_tensor
的长度(序列维度大小)不一致。 -
not self.config.is_encoder_decoder
:模型不是编码器-解码器模型。
-
-
处理:
- 需要调整
max_length
,减去inputs_tensor.shape[1]
,即输入的序列长度。
- 需要调整
-
原因:
- 当用户提供了
inputs_embeds
而非input_ids
,且两者长度不一致,为了确保生成的总长度不超过用户预期,需要调整max_length
。
- 当用户提供了
-
-
elif has_default_max_length: # by default let's always generate 20 new tokensif generation_config.max_length == GenerationConfig().max_length:generation_config.max_length = generation_config.max_length + input_ids_lengthmax_position_embeddings = getattr(self.config, "max_position_embeddings", None)if max_position_embeddings is not None:generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
解释:
-
情况 3:用户未设置
max_length
,使用默认值-
逻辑:
-
has_default_max_length
为True
,即用户未显式设置max_length
。 -
如果
generation_config.max_length
等于默认的max_length
,则执行以下操作:-
计算新的
max_length
:-
将
max_length
设置为原来的max_length
加上input_ids_length
。- 这样,默认情况下,会在输入的基础上生成
20
个新 tokens(假设默认max_length
为20
)。
- 这样,默认情况下,会在输入的基础上生成
-
-
考虑模型的最大位置嵌入长度:
-
获取模型配置中的
max_position_embeddings
,它表示模型能处理的最大序列长度。 -
如果存在,就将
generation_config.max_length
限制在max_position_embeddings
之内,防止生成长度超过模型的能力。
-
-
-
-
2. 处理 min_length
if generation_config.min_new_tokens is not None:if not has_default_min_length:logger.warning(f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(="f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. ""Please refer to the documentation for more information. ""(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)")generation_config.min_length = generation_config.min_new_tokens + input_ids_length
解释:
-
情况 1:用户设置了
min_new_tokens
-
逻辑:
-
如果
generation_config.min_new_tokens
不为None
,表示用户希望指定要生成的最小新 token 数量。 -
冲突处理:
- 如果用户同时设置了
min_length
(并且不是默认值),则发出警告,提示两者存在冲突,以min_new_tokens
为准。
- 如果用户同时设置了
-
计算
min_length
:- 将
min_length
设置为min_new_tokens + input_ids_length
。
- 将
-
-
elif (model_input_name == "inputs_embeds"and input_ids_length != inputs_tensor.shape[1]and not self.config.is_encoder_decoder
):generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
解释:
-
情况 2:模型输入是
inputs_embeds
,且存在输入长度不匹配-
逻辑:
-
与处理
max_length
时类似,但这里调整的是min_length
。 -
将
generation_config.min_length
减去inputs_tensor.shape[1]
,并取最大值0
(防止出现负值)。 -
这样可以确保生成的最小长度不会超过用户预期。
-
-
3. 返回更新后的 generation_config
return generation_config
- 返回调整过后的
generation_config
,包含更新的max_length
和min_length
,以供生成过程使用。
_validate_generated_length
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):"""Performs validation related to the resulting generated length"""# 函数主体从这里开始
功能说明
-
目的:该函数用于对生成的长度进行验证,确保
generation_config
中的参数设置合理,避免在生成过程中出现不可预期的行为或错误。 -
主要任务:
- 检查
max_length
和max_new_tokens
之间的关系,给出适当的警告或错误。 - 检查输入序列长度
input_ids_length
与max_length
之间的关系。 - 检查
min_length
和max_length
之间的关系,给出警告。 - 确保
min_new_tokens
加上input_ids_length
不超过max_length
,并给出警告。
- 检查
参数说明
self
:类的实例,典型的 Python 类方法的第一个参数。generation_config
:生成配置对象,包含了生成过程中使用的各项参数设置,例如max_length
、min_length
、max_new_tokens
等。input_ids_length
:整数,输入序列input_ids
的长度,即序列的长度。has_default_max_length
:布尔值,指示是否使用了默认的max_length
设置(即用户没有在调用时显式指定max_length
)。
代码详细解释
1. 编译时不进行警告或异常抛出
# Can't throw warnings/exceptions during compilation
if is_torchdynamo_compiling():return
解释:
-
目的:在使用 TorchDynamo(PyTorch 编译器)进行编译时,不要抛出警告或异常。
-
逻辑:
is_torchdynamo_compiling()
:检查当前是否在使用 TorchDynamo 进行编译。if is_torchdynamo_compiling(): return
:如果正在编译,直接返回,不进行后续的验证。
-
原因:在编译过程中,抛出异常或者发出警告可能会导致编译失败或行为异常,因此在编译时跳过验证。
2. 第一部分:与参数设置相关的 max_length
警告
# 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:# 20 is the default max_length of the generation configwarnings.warn(f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the ""generation length. We recommend setting `max_new_tokens` to control the maximum length of the ""generation.",UserWarning,)
解释:
-
目的:当用户未指定
max_length
且未设置max_new_tokens
时,发出警告提示。 -
逻辑:
-
条件判断:
has_default_max_length
:用户未显式指定max_length
,使用了默认值。generation_config.max_new_tokens is None
:max_new_tokens
未设置。generation_config.max_length == 20
:max_length
等于20
,这是GenerationConfig
的默认值。
-
处理:
- 如果以上条件都满足,使用
warnings.warn()
发出警告,提示用户正在使用模型无关的默认max_length
来控制生成长度,建议用户设置max_new_tokens
来控制生成的最大长度。
- 如果以上条件都满足,使用
-
-
原因:
- 如果用户没有明确设置生成长度的参数,可能会导致生成的序列长度与预期不符。
- 建议用户使用
max_new_tokens
,因为它更直观地控制生成的新 tokens 数量。
3. 检查输入序列长度是否超过或等于 max_length
if input_ids_length >= generation_config.max_length:input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"raise ValueError(f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"" increasing `max_length` or, better yet, setting `max_new_tokens`.")
解释:
-
目的:如果输入序列的长度大于或等于
max_length
,抛出ValueError
,因为这样会导致生成过程中的问题。 -
逻辑:
-
条件判断:
if input_ids_length >= generation_config.max_length
:如果输入序列的长度input_ids_length
大于或等于max_length
。
-
处理:
-
根据模型类型确定输入 IDs 的名称:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
- 如果是编码器-解码器模型,使用
'decoder_input_ids'
。 - 否则,使用
'input_ids'
。
- 如果是编码器-解码器模型,使用
-
抛出
ValueError
,包含详细的错误信息:- 提示输入序列的长度与
max_length
的值,以及可能导致意外行为。 - 建议用户增加
max_length
或者更好的,设置max_new_tokens
。
- 提示输入序列的长度与
-
-
-
原因:
- 如果输入序列的长度已经达到或超过
max_length
,模型将无法生成新的 tokens,可能导致生成过程立即结束或出现错误。 - 为了避免这种情况,必须确保
max_length
大于输入序列的长度。
- 如果输入序列的长度已经达到或超过
4. 第二部分:由于不可行的参数组合导致的 min_length
警告
# 2. Min length warnings due to unfeasible parameter combinations
min_length_error_suffix = (" Generation will stop at the defined maximum length. You should decrease the minimum length and/or ""increase the maximum length."
)
if has_default_max_length:min_length_error_suffix += (f" Note that `max_length` is set to {generation_config.max_length}, its default value.")
解释:
-
目的:准备
min_length
相关的错误信息,供后续警告使用。 -
逻辑:
-
定义错误信息的后缀:
min_length_error_suffix
:提示用户需要降低最小长度和/或增加最大长度。
-
如果使用了默认的
max_length
:- 如果
has_default_max_length
为True
,则在min_length_error_suffix
后面追加一句,说明max_length
被设置为其默认值。
- 如果
-
5. 检查 min_length
是否大于 max_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:warnings.warn(f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,UserWarning,)
解释:
-
目的:如果
min_length
大于max_length
,发出警告,因为这是不可行的约束。 -
逻辑:
-
条件判断:
generation_config.min_length is not None
:min_length
已被设置。generation_config.min_length > generation_config.max_length
:min_length
大于max_length
。
-
处理:
- 使用
warnings.warn()
发出警告,提示min_length
大于max_length
,并附加之前准备的min_length_error_suffix
。
- 使用
-
-
原因:
- 如果最小生成长度大于最大生成长度,模型无法满足这样的约束,可能导致生成过程在达到最大长度时停止,而未达到最小长度。
- 提醒用户调整参数,使其合理。
6. 检查 min_new_tokens
加上 input_ids_length
是否超过 max_length
if generation_config.min_new_tokens is not None:min_length = generation_config.min_new_tokens + input_ids_lengthif min_length > generation_config.max_length:warnings.warn(f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "f"added to the prompt length ({input_ids_length}), is larger than"f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,UserWarning,)
解释:
-
目的:如果
min_new_tokens
加上输入序列长度超过max_length
,发出警告。 -
逻辑:
-
条件判断:
if generation_config.min_new_tokens is not None
:min_new_tokens
已被设置。
-
计算最小长度:
min_length = generation_config.min_new_tokens + input_ids_length
:计算最小生成长度,即min_new_tokens
加上输入序列长度。
-
检查是否超过
max_length
:if min_length > generation_config.max_length
:如果计算得到的min_length
大于max_length
。
-
处理:
- 使用
warnings.warn()
发出警告,提示min_new_tokens
加上输入长度超过了可能的最大长度,并附加min_length_error_suffix
。
- 使用
-
-
原因:
- 如果
min_new_tokens
与输入序列长度之和超过max_length
,模型无法生成满足最小新 tokens 数量的序列。 - 提醒用户调整参数,降低
min_new_tokens
、增加max_length
,或减少输入序列的长度。
- 如果
_supports_logits_to_keep
方法定义
def _supports_logits_to_keep(self) -> bool:"""Return True if the current model supports the keyword argument `logits_to_keep` in forward()to save memory. Checking it in this way allows to avoid using a new model attribute."""return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
方法功能概述
- 目的:确定当前模型的
forward
方法是否支持logits_to_keep
参数。 - 返回值:布尔值
True
:如果forward
方法的参数中包含logits_to_keep
。False
:如果forward
方法的参数中不包含logits_to_keep
。
逐步详解
-
使用
inspect
模块分析forward
方法的签名inspect.signature(self.forward)
- 作用:获取
self.forward
方法的签名信息,包括参数列表和参数默认值等。 inspect
模块:Python 内置模块,提供了检查和获取对象(如函数、类、模块等)信息的功能。
- 作用:获取
-
获取
forward
方法的参数字典inspect.signature(self.forward).parameters
- 返回值:一个有序字典(
OrderedDict
),键为参数名称,值为参数对应的Parameter
对象。
- 返回值:一个有序字典(
-
获取参数名称列表并转换为集合
set(inspect.signature(self.forward).parameters.keys())
parameters.keys()
:返回参数名称的可迭代对象。set(...)
:将参数名称转换为集合,方便后续进行快速查找(in
操作)。
-
检查是否包含
logits_to_keep
参数"logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
- 作用:判断字符串
"logits_to_keep"
是否在参数名称集合中。 - 返回值:布尔值。
- 作用:判断字符串
-
返回判断结果
- 如果包含:返回
True
。 - 如果不包含:返回
False
。
- 如果包含:返回
方法用途和背景
-
节省内存:
在生成任务中,模型可能会生成大量的 logits(每个时间步长预测下一个 token 的概率分布,通常是一个包含了整个词汇表大小的张量)。如果能够限制保留的 logits 数量(例如只保留 top-k 个 logits),可以大大节省内存。
-
动态检查模型功能:
不同的模型可能实现了不同的功能。通过这种动态检查的方法,可以在不修改模型代码的情况下,了解模型是否支持某个特定的参数或功能。这有助于编写通用的代码,适用于多种模型。
-
避免使用额外的模型属性:
通过检查方法签名,而不是增加一个模型属性,可以减少模型类的复杂性和维护成本。
举例说明
假设我们有一个模型,其 forward
方法定义如下:
def forward(self, input_ids, attention_mask=None, logits_to_keep=None):# 模型的前向计算逻辑logits = self.compute_logits(input_ids, attention_mask)if logits_to_keep is not None:# 只保留指定数量的 logitslogits = logits[:, -1, :logits_to_keep]return logits
-
模型支持
logits_to_keep
参数:在这种情况下,_supports_logits_to_keep
方法会返回True
,因为forward
方法的参数中包含logits_to_keep
。 -
使用示例:
if self._supports_logits_to_keep():outputs = self.forward(input_ids, attention_mask=attention_mask, logits_to_keep=10) else:outputs = self.forward(input_ids, attention_mask=attention_mask)
- 解释:代码首先检查模型是否支持
logits_to_keep
参数,如果支持,则在调用forward
方法时传入该参数,以只保留 top-10 的 logits,从而节省内存。
- 解释:代码首先检查模型是否支持
_prepare_cache_for_generation
def _prepare_cache_for_generation(self,generation_config: GenerationConfig,model_kwargs: Dict,assistant_model: "PreTrainedModel",batch_size: int,max_cache_length: int,device: torch.device,
) -> bool:"""Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache isinstantiated, writes it to `model_kwargs`, under the name expected by the model."""# 函数主体从这里开始
功能说明
-
目的:准备生成过程中使用的缓存(cache),根据给定的
generation_config
和其他参数,初始化或调整缓存。如果缓存被实例化,它将被写入到model_kwargs
中,使用模型期望的缓存名称。 -
背景:在文本生成任务中,使用缓存可以加速生成过程,特别是在自回归模型中,缓存先前的计算结果可以避免重复计算。在不同的模型或配置下,缓存的实现方式可能不同,因此需要根据情况准备合适的缓存。
参数说明
-
self
:当前类的实例,典型的 Python 类方法的第一个参数。 -
generation_config
:GenerationConfig
类型,表示生成配置,其中包含生成过程中的各种参数设置,如是否使用缓存、缓存的实现方式等。 -
model_kwargs
:Dict
类型,包含传递给模型的关键字参数。在函数中,可能会对其进行修改,添加缓存相关的参数。 -
assistant_model
:PreTrainedModel
类型,可选的辅助模型,用于加速生成或其他目的。 -
batch_size
:整数,表示批次大小。 -
max_cache_length
:整数,表示缓存的最大长度,即缓存可以存储的最大序列长度。 -
device
:torch.device
类型,表示在何种设备(CPU 或 GPU)上运行。
1. 确定缓存名称
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
-
解释:
-
这行代码根据当前模型类的名称,确定缓存在
model_kwargs
中的键名称。 -
如果类名中不包含
"mamba"
,则缓存名称为"past_key_values"
;否则,缓存名称为"cache_params"
。
-
-
原因:
- 不同的模型可能期望的缓存名称不同。模型需要从
model_kwargs
中获取缓存,如果名称不一致,可能导致缓存无法正确工作。
- 不同的模型可能期望的缓存名称不同。模型需要从
2. 确定是否需要跨注意力缓存(cross-attention cache)
requires_cross_attention_cache = (self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
-
解释:
-
requires_cross_attention_cache
是一个布尔值,表示是否需要准备跨注意力缓存。 -
条件:
-
self.config.is_encoder_decoder
:如果模型是编码器-解码器架构,则需要跨注意力缓存。 -
model_kwargs.get("encoder_outputs") is not None
:如果在model_kwargs
中提供了编码器的输出,则也需要跨注意力缓存。
-
-
-
原因:
- 在编码器-解码器模型(如 BART、T5)中,解码器需要访问编码器的输出,因此需要跨注意力缓存。
3. 快速退出路径 1:用户已在 model_kwargs
中指定了缓存
# 快速退出路径 1:如果用户指定了缓存,我们只需要:
# a) 检查是否有冲突的 `generate` 参数
# b) 如果用户传递了旧的缓存格式,并且模型支持,将其转换为新的缓存格式
user_defined_cache = model_kwargs.get(cache_name)
if user_defined_cache is not None:if generation_config.cache_implementation is not None:raise ValueError(f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a ""Cache object) is unsupported. Please use only one of the two.")if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache():model_kwargs[cache_name] = (DynamicCache.from_legacy_cache(user_defined_cache)if not requires_cross_attention_cacheelse EncoderDecoderCache.from_legacy_cache(user_defined_cache))return
-
解释:
-
获取用户定义的缓存:
user_defined_cache = model_kwargs.get(cache_name)
:从model_kwargs
中获取用户可能提供的缓存。
-
检查用户是否同时指定了
cache_implementation
:- 如果用户既在
model_kwargs
中提供了缓存,又在generation_config
中指定了cache_implementation
,这是冲突的,会引发错误。
- 如果用户既在
-
处理旧的缓存格式:
-
如果
user_defined_cache
是一个元组(旧的缓存格式),并且模型支持默认的动态缓存(self._supports_default_dynamic_cache()
返回True
),则将旧的缓存转换为新的缓存格式。 -
根据是否需要跨注意力缓存,使用不同的缓存类:
-
如果不需要跨注意力缓存,使用
DynamicCache.from_legacy_cache(user_defined_cache)
。 -
如果需要跨注意力缓存,使用
EncoderDecoderCache.from_legacy_cache(user_defined_cache)
。
-
-
-
返回:
- 在处理完用户提供的缓存后,直接返回,不再进行后续的缓存准备。
-
4. 快速退出路径 2:用户指定不使用缓存
# 快速退出路径 2:如果用户指定不使用缓存。(冲突的参数已在 `generation_config.validate()` 中处理)
if generation_config.use_cache is False:return
-
解释:
-
如果在
generation_config
配置中,用户设置了use_cache=False
,表示不使用缓存。 -
直接返回,不需要准备缓存。
-
5. 快速退出路径 3:模型仅支持旧的缓存格式
# 快速退出路径 3:模型仅支持旧的缓存格式,无需准备
if not self._supports_default_dynamic_cache():if generation_config.cache_implementation is not None:warnings.warn("This model does not support `Cache` instances, it only supports the legacy cache format (tuple "f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be ""ignored.",UserWarning,)return
-
解释:
-
如果模型不支持默认的动态缓存(
self._supports_default_dynamic_cache()
返回False
),则无法使用新的缓存实现。 -
如果用户在
generation_config
中指定了cache_implementation
,则发出警告,指出模型仅支持旧的缓存格式,cache_implementation
将被忽略。 -
直接返回,不需要进一步准备缓存。
-
6. 需要准备缓存,根据 generation_config.cache_implementation
# 否则,我们需要根据 `generation_config.cache_implementation` 准备缓存
# TODO(joao): 在辅助生成中支持静态缓存。辅助生成需要回滚缓存,目前只有动态缓存支持
if assistant_model is not None and generation_config.cache_implementation is not None:logger.warning_once("An assistant model is provided, using a dynamic cache instead of a cache of type="f"'{generation_config.cache_implementation}'.")generation_config.cache_implementation = None
-
解释:
-
如果上述快速退出条件都不满足,且需要准备缓存,则需要根据
generation_config.cache_implementation
的值来准备缓存。 -
特殊情况:辅助模型和缓存实现的冲突:
-
如果提供了
assistant_model
,并且指定了cache_implementation
,则发出警告,指出由于提供了辅助模型,将使用动态缓存,而不是指定类型的缓存。 -
将
generation_config.cache_implementation
设置为None
,以确保使用动态缓存。
-
-
-
原因:
- 在辅助生成过程中,需要回滚缓存,目前只有动态缓存支持回滚。因此,即使用户指定了其他缓存实现,也需要使用动态缓存。
7. 根据缓存实现方式准备缓存
if generation_config.cache_implementation is not None:if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:if generation_config.cache_implementation == "static" and not self._supports_static_cache:raise ValueError("This model does not support `cache_implementation='static'`. Please check the following ""issue: https://github.com/huggingface/transformers/issues/28981")model_kwargs[cache_name] = self._get_cache(cache_implementation=generation_config.cache_implementation,batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,max_cache_len=max_cache_length,device=device,model_kwargs=model_kwargs,)elif generation_config.cache_implementation == "quantized":if not self._supports_quantized_cache:raise ValueError("This model does not support the quantized cache. If you want your model to support quantized ""cache, please open an issue and tag @zucchini-nlp.")cache_config = (generation_config.cache_configif generation_config.cache_config is not Noneelse QuantizedCacheConfig())cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]if cache_config.backend == "quanto" and not is_optimum_quanto_available():raise ImportError("You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. ""Please install it via with `pip install optimum-quanto`")elif cache_config.backend == "HQQ" and not is_hqq_available():raise ImportError("You need to install `HQQ` in order to use KV cache quantization with HQQ backend. ""Please install it via with `pip install hqq`")model_kwargs[cache_name] = cache_class(cache_config)elif generation_config.cache_implementation == "offloaded":model_kwargs[cache_name] = OffloadedCache()
-
解释:
-
检查缓存实现方式是否需要特别的准备:
-
NEED_SETUP_CACHE_CLASSES_MAPPING
:一个映射,包含需要特殊设置的缓存类。 -
如果
generation_config.cache_implementation
在NEED_SETUP_CACHE_CLASSES_MAPPING
中,则需要调用_get_cache
方法来获取缓存实例。
-
-
处理静态缓存:
-
如果
cache_implementation
是"static"
,并且模型不支持静态缓存(not self._supports_static_cache
),则抛出ValueError
,提示模型不支持静态缓存。 -
提供一个 GitHub Issue 链接,供用户了解更多信息。
-
-
获取缓存实例:
-
调用
self._get_cache()
方法,传入缓存实现方式、批次大小、最大缓存长度、设备等参数,获取缓存实例。 -
批次大小计算:
-
max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size
:计算缓存所需的实际批次大小。- 在束搜索或其他情况下,批次大小可能需要乘以
num_beams
或num_return_sequences
。
- 在束搜索或其他情况下,批次大小可能需要乘以
-
-
-
处理量化缓存(quantized cache):
-
如果
cache_implementation
是"quantized"
,需要特殊处理。 -
检查模型是否支持量化缓存(
self._supports_quantized_cache
)。 -
如果不支持,抛出
ValueError
。 -
获取缓存配置
cache_config
,如果用户未提供generation_config.cache_config
,则使用默认的QuantizedCacheConfig()
。 -
根据缓存配置的后端,获取对应的缓存类
cache_class
。 -
检查所需的包是否已安装:
-
对于
"quanto"
后端,检查is_optimum_quanto_available()
。 -
对于
"HQQ"
后端,检查is_hqq_available()
。
-
-
如果未安装,抛出
ImportError
,提示用户安装相应的包。 -
实例化缓存类,并将其存储在
model_kwargs[cache_name]
中。
-
-
处理离线缓存(offloaded cache):
- 如果
cache_implementation
是"offloaded"
,则实例化OffloadedCache()
,并存储在model_kwargs[cache_name]
中。
- 如果
-
8. 默认情况下,使用动态缓存
# 默认情况下,使用 DynamicCache() 实例。这将避免在旧格式之间来回转换,从而避免复制缓存,节省内存
else:model_kwargs[cache_name] = (DynamicCache()if not requires_cross_attention_cacheelse EncoderDecoderCache(DynamicCache(), DynamicCache()))
-
解释:
-
如果
generation_config.cache_implementation
为None
,即用户未指定特定的缓存实现方式,并且上述条件都不满足,则默认使用动态缓存。 -
根据是否需要跨注意力缓存,实例化不同的缓存类:
-
如果不需要跨注意力缓存,使用
DynamicCache()
。 -
如果需要跨注意力缓存,使用
EncoderDecoderCache(DynamicCache(), DynamicCache())
,即分别为编码器和解码器缓存初始化动态缓存。
-
-
-
原因:
-
动态缓存可以在生成过程中动态扩展,并支持回滚等特性。
-
使用默认的动态缓存,可以避免在旧缓存格式与新缓存格式之间来回转换,减少内存使用和数据复制。
-
_get_logits_processor
def _get_logits_processor(self,generation_config: GenerationConfig,input_ids_seq_length: int,encoder_input_ids: torch.LongTensor,prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],logits_processor: Optional[LogitsProcessorList],device: str = None,model_kwargs: Optional[Dict[str, Any]] = None,negative_prompt_ids: Optional[torch.Tensor] = None,negative_prompt_attention_mask: Optional[torch.Tensor] = None,
) -> LogitsProcessorList:"""This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]instances used to modify the scores of the language model head."""# 函数主体从这里开始
功能说明
-
目的:该函数用于构建一个
LogitsProcessorList
对象,其中包含了一系列LogitsProcessor
实例,这些实例用于在生成过程中修改语言模型头(language model head)的logits(模型输出的分数),以实现各种生成控制策略。 -
背景:在生成任务中,通过调整模型输出的logits,可以引入各种生成策略,如重复惩罚、温度调节、Top-K采样、禁用某些词汇等,从而控制生成文本的风格、内容和长度。
参数说明
-
self
:类的实例,典型的Python类方法的第一个参数。 -
generation_config
:GenerationConfig
对象,包含了生成过程中的各种配置参数,如重复惩罚系数、最小生成长度、温度等。 -
input_ids_seq_length
:整数,表示输入序列的长度,即input_ids
的长度。 -
encoder_input_ids
:torch.LongTensor
,编码器的输入IDs。如果模型是编码器-解码器架构,这对应于编码器的输入。 -
prefix_allowed_tokens_fn
:可选的函数,类型为Callable[[int, torch.Tensor], List[int]]
。用于在生成过程中限制每个位置上允许生成的tokens,通常用于受限生成任务。 -
logits_processor
:可选的LogitsProcessorList
对象,用户自定义的LogitsProcessor列表,可用于补充或覆盖默认的processor。 -
device
:字符串,可选参数,指定设备(如'cpu'
或'cuda'
)。如果未提供,默认为None
。 -
model_kwargs
:可选的字典,包含了传递给模型的其他关键字参数。 -
negative_prompt_ids
:可选的torch.Tensor
,用于一些生成策略(如Classifier-Free Guidance)中的负面提示IDs。 -
negative_prompt_attention_mask
:可选的torch.Tensor
,对应negative_prompt_ids
的注意力掩码。
1. 初始化LogitsProcessorList
# instantiate processors list
processors = LogitsProcessorList()
- 解释:创建一个空的
LogitsProcessorList
对象processors
,用于存储将要应用的所有LogitsProcessor
实例。
2. 处理Classifier-Free Guidance(CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:processors.append(UnbatchedClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale,self,unconditional_ids=negative_prompt_ids,unconditional_attention_mask=negative_prompt_attention_mask,use_cache=generation_config.use_cache,))
-
解释:
-
条件:如果
guidance_scale
不为None
且不等于1,则说明需要应用Classifier-Free Guidance(CFG)。guidance_scale
是CFG的缩放因子,通常大于1,用于调整生成的多样性与准确性。
-
操作:向
processors
中添加一个UnbatchedClassifierFreeGuidanceLogitsProcessor
实例。-
参数说明:
-
generation_config.guidance_scale
:CFG的缩放因子。 -
self
:模型实例,用于在LogitsProcessor中调用模型的其他方法。 -
unconditional_ids
:负面提示的IDs,即negative_prompt_ids
。 -
unconditional_attention_mask
:负面提示的注意力掩码,即negative_prompt_attention_mask
。 -
use_cache
:是否使用缓存,来自generation_config
。
-
-
-
3. 处理序列偏置(Sequence Bias)
if generation_config.sequence_bias is not None:processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
-
解释:
-
条件:如果
sequence_bias
不为None
,则需要应用序列偏置。sequence_bias
是一种机制,可对特定的token序列施加偏置,提高或降低它们在生成中的概率。
-
操作:向
processors
中添加一个SequenceBiasLogitsProcessor
实例,传入sequence_bias
参数。
-
4. 处理多样性惩罚(Diversity Penalty)
if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:processors.append(HammingDiversityLogitsProcessor(diversity_penalty=generation_config.diversity_penalty,num_beams=generation_config.num_beams,num_beam_groups=generation_config.num_beam_groups,))
-
解释:
-
条件:如果
diversity_penalty
不为None
且大于0,则需要应用多样性惩罚。- 多样性惩罚用于在束搜索中鼓励生成更多样化的序列。
-
操作:向
processors
中添加一个HammingDiversityLogitsProcessor
实例。-
参数说明:
-
diversity_penalty
:多样性惩罚系数。 -
num_beams
:束搜索的束宽,即同时考虑的序列数量。 -
num_beam_groups
:束搜索的组数,用于分组束搜索。
-
-
-
5. 处理编码器重复惩罚(Encoder Repetition Penalty)
if (generation_config.encoder_repetition_penalty is not Noneand generation_config.encoder_repetition_penalty != 1.0
):if len(encoder_input_ids.shape) == 2:processors.append(EncoderRepetitionPenaltyLogitsProcessor(penalty=generation_config.encoder_repetition_penalty,encoder_input_ids=encoder_input_ids,))else:warnings.warn("Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to ""`generate`, ignoring the argument.",UserWarning,)
-
解释:
-
条件:如果
encoder_repetition_penalty
不为None
且不等于1.0,则需要应用编码器重复惩罚。- 编码器重复惩罚用于减少模型在生成时重复输入内容的可能性。
-
检查:如果
encoder_input_ids
的形状为二维(即存在有效的编码器输入),则应用惩罚。 -
操作:向
processors
中添加一个EncoderRepetitionPenaltyLogitsProcessor
实例。-
参数说明:
-
penalty
:重复惩罚系数。 -
encoder_input_ids
:编码器的输入IDs。
-
-
-
否则:发出警告,提示需要提供
input_ids
以应用该惩罚,忽略该参数。
-
6. 处理重复惩罚(Repetition Penalty)
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
-
解释:
-
条件:如果
repetition_penalty
不为None
且不等于1.0,则需要应用重复惩罚。- 重复惩罚用于减少模型在生成时重复之前生成内容的可能性。
-
操作:向
processors
中添加一个RepetitionPenaltyLogitsProcessor
实例,传入penalty
参数。
-
7. 处理禁止重复的n-gram(No Repeat N-Gram)
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
-
解释:
-
条件:如果
no_repeat_ngram_size
不为None
且大于0,则需要禁止重复的n-gram。- 这用于防止模型在生成时重复生成相同的n-gram,提高生成的多样性。
-
操作:向
processors
中添加一个NoRepeatNGramLogitsProcessor
实例,传入no_repeat_ngram_size
参数。
-
8. 处理编码器禁止重复的n-gram(Encoder No Repeat N-Gram)
if (generation_config.encoder_no_repeat_ngram_size is not Noneand generation_config.encoder_no_repeat_ngram_size > 0
):if len(encoder_input_ids.shape) == 2:processors.append(EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size,encoder_input_ids,))else:warnings.warn("Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to ""`generate`, ignoring the argument.",UserWarning,)
-
解释:
-
条件:如果
encoder_no_repeat_ngram_size
不为None
且大于0,则需要在生成时避免重复输入中的n-gram。- 这用于防止模型在生成时重复输入序列中的n-gram。
-
检查:如果
encoder_input_ids
的形状为二维(存在有效的编码器输入),则应用该处理器。 -
操作:向
processors
中添加一个EncoderNoRepeatNGramLogitsProcessor
实例。-
参数说明:
-
encoder_no_repeat_ngram_size
:禁止重复的n-gram大小。 -
encoder_input_ids
:编码器的输入IDs。
-
-
-
否则:发出警告,提示需要提供
input_ids
以应用该处理器,忽略该参数。
-
9. 处理坏词(Bad Words)
if generation_config.bad_words_ids is not None:processors.append(NoBadWordsLogitsProcessor(generation_config.bad_words_ids,generation_config._eos_token_tensor,))
-
解释:
-
条件:如果
bad_words_ids
不为None
,则需要在生成过程中禁止某些词。bad_words_ids
是一个列表,包含需要禁止的词的token IDs。
-
操作:向
processors
中添加一个NoBadWordsLogitsProcessor
实例。-
参数说明:
-
bad_words_ids
:需要禁止的词的token IDs。 -
_eos_token_tensor
:结束标记的token张量,用于在必要时停止生成。
-
-
-
10. 处理最小长度(Minimum Length)
if (generation_config.min_length is not Noneand generation_config._eos_token_tensor is not Noneand generation_config.min_length > 0
):processors.append(MinLengthLogitsProcessor(generation_config.min_length,generation_config._eos_token_tensor,device=device,))
-
解释:
-
条件:如果
min_length
不为None
,_eos_token_tensor
不为None
,且min_length
大于0,则需要在生成达到最小长度之前禁止生成结束标记。 -
操作:向
processors
中添加一个MinLengthLogitsProcessor
实例。-
参数说明:
-
min_length
:最小生成长度。 -
_eos_token_tensor
:结束标记的token张量。 -
device
:设备信息。
-
-
-
11. 处理最小新tokens的长度(Minimum New Tokens Length)
if (generation_config.min_new_tokens is not Noneand generation_config._eos_token_tensor is not Noneand generation_config.min_new_tokens > 0
):processors.append(MinNewTokensLengthLogitsProcessor(input_ids_seq_length,generation_config.min_new_tokens,generation_config._eos_token_tensor,device=device,))
-
解释:
-
条件:如果
min_new_tokens
不为None
,_eos_token_tensor
不为None
,且min_new_tokens
大于0,则需要在生成新tokens达到最小数量之前禁止生成结束标记。 -
操作:向
processors
中添加一个MinNewTokensLengthLogitsProcessor
实例。-
参数说明:
-
input_ids_seq_length
:输入序列的长度。 -
min_new_tokens
:最小新生成的tokens数量。 -
_eos_token_tensor
:结束标记的token张量。 -
device
:设备信息。
-
-
-
12. 处理前缀限制(Prefix Allowed Tokens Function)
if prefix_allowed_tokens_fn is not None:processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn,generation_config.num_beams // generation_config.num_beam_groups,))
-
解释:
-
条件:如果
prefix_allowed_tokens_fn
不为None
,则需要在生成过程中限制每个位置上允许生成的tokens。- 这通常用于受限生成任务,例如自动补全或基于前缀的约束生成。
-
操作:向
processors
中添加一个PrefixConstrainedLogitsProcessor
实例。-
参数说明:
-
prefix_allowed_tokens_fn
:用于限制每个位置上允许生成的tokens的函数。 -
generation_config.num_beams // generation_config.num_beam_groups
:计算每个组中的束宽。
-
-
-
13. 处理强制起始token(Forced BOS Token)
if generation_config.forced_bos_token_id is not None:processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id,))
-
解释:
-
条件:如果
forced_bos_token_id
不为None
,则需要在生成的第一个位置强制生成指定的起始token。- 这用于确保生成的序列以特定的token开始,例如在某些任务中需要强制生成特定的起始标记。
-
操作:向
processors
中添加一个ForcedBOSTokenLogitsProcessor
实例,传入forced_bos_token_id
。
-
14. 处理强制结束token(Forced EOS Token)
if generation_config.forced_eos_token_id is not None:processors.append(ForcedEOSTokenLogitsProcessor(generation_config.max_length,generation_config.forced_eos_token_id,device=device,))
-
解释:
-
条件:如果
forced_eos_token_id
不为None
,则需要在生成达到最大长度时强制生成指定的结束token。- 这用于确保生成的序列以特定的token结束。
-
操作:向
processors
中添加一个ForcedEOSTokenLogitsProcessor
实例。-
参数说明:
-
generation_config.max_length
:最大生成长度。 -
forced_eos_token_id
:强制的结束token ID。 -
device
:设备信息。
-
-
-
15. 处理无效值移除(Remove Invalid Values)
if generation_config.remove_invalid_values is True:processors.append(InfNanRemoveLogitsProcessor())
-
解释:
-
条件:如果
remove_invalid_values
为True
,则需要在生成过程中移除inf
和nan
等无效值。- 这用于确保生成过程的稳定性,防止由于无效值导致的错误。
-
操作:向
processors
中添加一个InfNanRemoveLogitsProcessor
实例。
-
16. 处理指数衰减长度惩罚(Exponential Decay Length Penalty)
if generation_config.exponential_decay_length_penalty is not None:processors.append(ExponentialDecayLengthPenalty(generation_config.exponential_decay_length_penalty,generation_config._eos_token_tensor,input_ids_seq_length,))
-
解释:
-
条件:如果
exponential_decay_length_penalty
不为None
,则需要应用指数衰减的长度惩罚。- 这用于在生成过程中,对句子长度施加惩罚,鼓励模型生成特定长度的句子。
-
操作:向
processors
中添加一个ExponentialDecayLengthPenalty
实例。-
参数说明:
-
exponential_decay_length_penalty
:指数衰减长度惩罚的参数。 -
_eos_token_tensor
:结束标记的token张量。 -
input_ids_seq_length
:输入序列的长度。
-
-
-
17. 处理抑制特定tokens(Suppress Tokens)
if generation_config.suppress_tokens is not None:processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens,device=device,))
-
解释:
-
条件:如果
suppress_tokens
不为None
,则需要在生成过程中抑制特定的tokens,不让它们生成。suppress_tokens
是需要抑制的token IDs列表。
-
操作:向
processors
中添加一个SuppressTokensLogitsProcessor
实例。-
参数说明:
-
suppress_tokens
:需要抑制的token IDs。 -
device
:设备信息。
-
-
-
18. 处理在开头抑制特定tokens(Suppress Tokens at Begin)
if generation_config.begin_suppress_tokens is not None:begin_index = input_ids_seq_lengthbegin_index = (begin_indexif (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)else begin_index + 1)processors.append(SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens,begin_index,device=device,))
-
解释:
-
条件:如果
begin_suppress_tokens
不为None
,则需要在生成的开头位置抑制特定的tokens。- 这用于避免模型在一开始生成某些不期望的tokens。
-
计算起始索引:
-
begin_index
初始值为input_ids_seq_length
,表示当前生成的位置。 -
如果
input_ids_seq_length <= 1
且forced_bos_token_id
不为None
,则begin_index += 1
。
-
-
操作:向
processors
中添加一个SuppressTokensAtBeginLogitsProcessor
实例。-
参数说明:
-
begin_suppress_tokens
:需要抑制的token IDs。 -
begin_index
:开始抑制的位置索引。 -
device
:设备信息。
-
-
-
19. 处理强制解码器IDs(Forced Decoder IDs)
if generation_config.forced_decoder_ids is not None:# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PTraise ValueError("You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument ""in favour of `input_ids` or `decoder_input_ids` respectively.",)
-
解释:
-
条件:如果
forced_decoder_ids
不为None
,则抛出异常。 -
原因:当前不支持
forced_decoder_ids
,建议用户使用input_ids
或decoder_input_ids
来替代。 -
备注:注释中提到,当TensorFlow和FLAX版本与PyTorch版本对齐后,可以将此异常移动到
GenerationConfig.validate()
中。
-
20. 合并用户自定义的logits_processor
# TODO (joao): find a strategy to specify the order of the processors
processors = self._merge_criteria_processor_list(processors, logits_processor)
-
解释:
-
操作:调用
self._merge_criteria_processor_list()
方法,将之前构建的processors
列表与用户自定义的logits_processor
进行合并。 -
备注:注释中提到需要找到一种策略来指定处理器的顺序。
-
21. 处理采样策略下的LogitsWarper
# 只有在使用采样策略时,才应用之前被称为`LogitsWarpers`的处理器
if generation_config.do_sample:# 在beam方法中,我们需要至少保留一个非eos token,以探索可能具有更好得分的连续序列if generation_config.num_beams > 1:if isinstance(generation_config._eos_token_tensor, list):min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1elif isinstance(generation_config._eos_token_tensor, torch.Tensor):min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1else:min_tokens_to_keep = 2else:min_tokens_to_keep = 1# 以下思想主要来自PR:https://github.com/huggingface/transformers/pull/5420/files# 所有的sampler都在`generation_utils_samplers.py`中if generation_config.temperature is not None and generation_config.temperature != 1.0:processors.append(TemperatureLogitsWarper(generation_config.temperature))if generation_config.top_k is not None and generation_config.top_k != 0:processors.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))if generation_config.top_p is not None and generation_config.top_p < 1.0:processors.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))if generation_config.min_p is not None:# 在温度缩放之后应用(见:https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)processors.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))if generation_config.typical_p is not None and generation_config.typical_p < 1.0:processors.append(TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep))if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:processors.append(EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep))if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:processors.append(EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device))
-
解释:
-
条件:只有在
do_sample
为True
时,才应用这些处理器,因为它们与采样策略相关。 -
计算
min_tokens_to_keep
:-
在束搜索等方法中,需要保留至少一个非结束token,以确保能够探索可能更好的序列。
-
根据结束token的类型和数量,计算需要保留的最小token数量。
-
-
添加采样相关的
LogitsWarper
:-
根据
generation_config
中的配置,向processors
中添加相应的LogitsWarper
,包括:-
TemperatureLogitsWarper
:调整温度参数,控制生成的随机性。 -
TopKLogitsWarper
:仅保留概率最高的top_k
个tokens。 -
TopPLogitsWarper
:仅保留累计概率达到top_p
的tokens。 -
MinPLogitsWarper
:应用最小概率阈值。 -
TypicalLogitsWarper
:使用Typical采样策略。 -
EpsilonLogitsWarper
和EtaLogitsWarper
:应用epsilon或eta截断。
-
-
-
22. 处理水印(Watermarking)
# Watermarking应该在所有logits处理完成后再应用(参见#34630)
if generation_config.watermarking_config is not None:processors.append(generation_config.watermarking_config.construct_processor(self.config.vocab_size, device))
-
解释:
-
条件:如果
watermarking_config
不为None
,则需要在生成过程中应用水印策略。- 水印策略通常用于在生成的文本中嵌入隐式标记,以便于后续识别生成内容。
-
操作:向
processors
中添加一个由watermarking_config
构建的处理器。- 使用模型的词汇表大小
vocab_size
和设备信息device
。
- 使用模型的词汇表大小
-
备注:水印处理器应当在所有logits处理完成后再应用。
-
23. 处理Logits重归一化(Logit Renormalization)
# `LogitNormalization`应该始终是最后一个logit处理器(如果存在的话)
if generation_config.renormalize_logits is True:processors.append(LogitNormalization())
-
解释:
-
条件:如果
renormalize_logits
为True
,则需要在处理完所有logits后重新归一化。- 这用于确保logits在所有处理后仍然形成有效的概率分布。
-
操作:向
processors
中添加一个LogitNormalization
实例。- 备注:
LogitNormalization
应当始终是最后应用的logits处理器。
- 备注:
-
24. 返回处理器列表
return processors
- 解释:返回构建好的
LogitsProcessorList
对象processors
,供生成过程使用。
_get_stopping_criteria
def _get_stopping_criteria(self,generation_config: GenerationConfig,stopping_criteria: Optional[StoppingCriteriaList],tokenizer: Optional["PreTrainedTokenizerBase"] = None,**kwargs,
) -> StoppingCriteriaList:# 方法体...
参数说明
generation_config
:GenerationConfig
对象,包含生成过程中所需的各种配置参数,如max_length
、max_time
、stop_strings
等。stopping_criteria
:可选的StoppingCriteriaList
对象,用户可以传入自定义的停止条件列表,与默认的停止条件合并。tokenizer
:可选的PreTrainedTokenizerBase
对象,模型对应的 tokenizer,用于处理stop_strings
等需要 tokenizer 的功能。**kwargs
:其他可能的关键字参数,供扩展使用。
返回值
StoppingCriteriaList
:该方法返回一个StoppingCriteriaList
对象,包含生成过程中需要检查的停止条件。
该方法的主要作用是:
-
创建默认的停止条件列表:根据
generation_config
中的配置,生成对应的停止条件(如最大长度、最大时间、特殊字符串、结束标记等)。 -
合并用户提供的停止条件:如果用户在
stopping_criteria
中提供了自定义的停止条件,方法会将其与默认的停止条件合并。 -
返回完整的停止条件列表:生成过程会根据这个列表,在满足任何一个停止条件时结束生成。
步骤 1:初始化停止条件列表
criteria = StoppingCriteriaList()
- 说明:创建一个空的
StoppingCriteriaList
对象,用于存储后续添加的停止条件。
步骤 2:处理 max_length
停止条件
if generation_config.max_length is not None:max_position_embeddings = getattr(self.config, "max_position_embeddings", None)criteria.append(MaxLengthCriteria(max_length=generation_config.max_length,max_position_embeddings=max_position_embeddings,))
-
解释:
-
获取
max_length
:从generation_config
中获取max_length
,即生成的最大长度。 -
获取模型的最大位置嵌入长度:
-
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
-
从模型配置中获取
max_position_embeddings
,表示模型能处理的最大序列长度。
-
-
添加
MaxLengthCriteria
:-
创建一个
MaxLengthCriteria
对象,传入max_length
和max_position_embeddings
。 -
将其添加到
criteria
列表中。
-
-
-
作用:当生成的序列长度达到
max_length
或模型的最大位置嵌入长度时,停止生成。
步骤 3:处理 max_time
停止条件
if generation_config.max_time is not None:criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
-
解释:
-
获取
max_time
:从generation_config
中获取max_time
,即生成的最长时间(以秒为单位)。 -
添加
MaxTimeCriteria
:-
创建一个
MaxTimeCriteria
对象,传入max_time
。 -
将其添加到
criteria
列表中。
-
-
-
作用:当生成过程运行时间超过
max_time
秒时,停止生成。
步骤 4:处理 stop_strings
停止条件
if generation_config.stop_strings is not None:if tokenizer is None:raise ValueError("There are one or more stop strings, either in the arguments to `generate` or in the ""model's generation config, but we could not locate a tokenizer. When generating with ""stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`.")criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
-
解释:
-
获取
stop_strings
:从generation_config
中获取stop_strings
,这是一个字符串列表,包含用于停止生成的特殊字符串。 -
检查
tokenizer
:因为处理字符串需要使用 tokenizer,如果tokenizer
为None
,抛出ValueError
。 -
添加
StopStringCriteria
:-
创建一个
StopStringCriteria
对象,传入stop_strings
和tokenizer
。 -
将其添加到
criteria
列表中。
-
-
-
作用:当生成的文本中出现指定的字符串时,停止生成。
步骤 5:处理 eos_token_id
(结束标记)停止条件
if generation_config._eos_token_tensor is not None:criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
-
解释:
-
获取
_eos_token_tensor
:从generation_config
中获取_eos_token_tensor
,即结束标记的 token ID。 -
添加
EosTokenCriteria
:-
创建一个
EosTokenCriteria
对象,传入eos_token_id
。 -
将其添加到
criteria
列表中。
-
-
-
作用:当生成的序列中出现结束标记时,停止生成。
步骤 6:处理辅助模型的置信度停止条件(可选)
if (generation_config.is_assistantand generation_config.assistant_confidence_threshold is not Noneand generation_config.assistant_confidence_threshold > 0
):criteria.append(ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold))
-
解释:
-
条件检查:
-
generation_config.is_assistant
:判断是否使用了辅助模型。 -
generation_config.assistant_confidence_threshold
不为空且大于 0。
-
-
添加
ConfidenceCriteria
:-
创建一个
ConfidenceCriteria
对象,传入assistant_confidence_threshold
。 -
将其添加到
criteria
列表中。
-
-
-
作用:当辅助模型的置信度达到一定阈值时,停止生成。
步骤 7:合并用户提供的停止条件
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
-
解释:
-
合并默认和用户提供的停止条件:
-
调用
_merge_criteria_processor_list
方法,将默认的criteria
与用户提供的stopping_criteria
合并。 -
这个方法通常会去除重复的条件,或者根据某些规则进行合并。
-
-
-
作用:确保生成过程中考虑所有相关的停止条件。
步骤 8:返回最终的停止条件列表
return criteria
- 解释:返回包含所有停止条件的
StoppingCriteriaList
,供生成过程使用。
整体流程总结
-
目的:为生成过程准备一个完整的停止条件列表,当满足任何一个条件时,生成过程将停止。
-
处理逻辑:
-
初始化:创建一个空的停止条件列表。
-
添加默认停止条件:根据
generation_config
中的配置,添加对应的停止条件,包括最大长度、最大时间、结束标记、特殊字符串等。 -
辅助模型条件:如果使用了辅助模型,并且设置了置信度阈值,添加对应的停止条件。
-
合并用户停止条件:将用户提供的
stopping_criteria
合并到默认条件列表中。 -
返回:输出完整的停止条件列表。
-