当前位置: 首页 > wzjs >正文

做网站用什么编程语言好百度官方网站下载安装

做网站用什么编程语言好,百度官方网站下载安装,建立网站的基本过程,道路建设网站专题generate 以下是对您提供的 generate 方法的详细解释。这个方法用于大型语言模型(LLM)中的文本生成,尤其是具有语言模型头的模型。该方法包含了多个复杂的逻辑,支持多种生成模式,如贪心搜索、采样、束搜索等。 方法定…

generate

以下是对您提供的 generate 方法的详细解释。这个方法用于大型语言模型(LLM)中的文本生成,尤其是具有语言模型头的模型。该方法包含了多个复杂的逻辑,支持多种生成模式,如贪心搜索、采样、束搜索等。


方法定义

@torch.no_grad()
def generate(self,inputs: Optional[torch.Tensor] = None,generation_config: Optional[GenerationConfig] = None,logits_processor: Optional[LogitsProcessorList] = None,stopping_criteria: Optional[StoppingCriteriaList] = None,prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,synced_gpus: Optional[bool] = None,assistant_model: Optional["PreTrainedModel"] = None,streamer: Optional["BaseStreamer"] = None,negative_prompt_ids: Optional[torch.Tensor] = None,negative_prompt_attention_mask: Optional[torch.Tensor] = None,**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:...

参数说明

  • inputs:可选,输入张量,形状可能因模态而异。用于作为生成的提示或编码器的输入。
  • generation_config:可选,GenerationConfig 对象,包含生成时的参数配置。
  • logits_processor:可选,LogitsProcessorList 对象,自定义的 logits 处理器列表,用于在生成过程中调整 logits。
  • stopping_criteria:可选,StoppingCriteriaList 对象,自定义的停止标准列表,用于在满足条件时终止生成。
  • prefix_allowed_tokens_fn:可选,函数,用于在每一步生成时限制允许的 token。
  • synced_gpus:可选,布尔值,指示是否在多 GPU 环境下同步运行以避免死锁。
  • assistant_model:可选,用于加速生成的辅助模型,必须具有相同的 tokenizer。
  • streamer:可选,用于流式输出生成序列的对象。
  • negative_prompt_ids:可选,torch.LongTensor,形状为 (batch_size, sequence_length),用于一些处理器(如 CFG)的负提示。
  • negative_prompt_attention_mask:可选,torch.LongTensor,形状为 (batch_size, sequence_length),对应 negative_prompt_ids 的 attention mask。
  • kwargs:其他参数,可用于覆盖 generation_config 中的设置,或传递给模型的 forward 方法。

返回值

  • GenerateOutputtorch.LongTensor:根据参数设置,返回生成的序列或者包含生成过程详细信息的输出对象。

方法逻辑解析

总体流程

  1. 处理生成配置和参数验证:确保 generation_configkwargs 的正确性。
  2. 设置生成参数:根据传入的配置或默认值,设置生成所需的参数。
  3. 准备模型输入:处理输入张量,生成 input_ids,并管理模型需要的其他关键字参数。
  4. 确定生成模式:根据配置,选择合适的生成模式,例如贪心搜索、采样、束搜索等。
  5. 生成序列:调用相应的生成函数,生成目标序列。
  6. 返回结果:根据参数设置,返回生成的序列或包含更多信息的对象。

详细步骤

1. 处理生成配置和参数验证
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None)
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None)
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
  • _validate_model_class:检查模型的类型是否支持生成。_validate_model_class
  • 提取 tokenizer:用于停止条件等,非必要参数。
  • 准备 generation_configmodel_kwargs:处理传入的配置和额外参数。_prepare_generation_config
  • 验证模型关键字参数:确保传入的参数与模型的预期一致。_validate_model_kwargs
  • 验证辅助模型:如果使用了 assistant_model,确保其与主模型兼容。
2. 设置生成参数
if synced_gpus is None:synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
  • synced_gpus:在多 GPU 环境下,判断是否需要同步。
  • logits_processorstopping_criteria:初始化或使用默认的处理器和停止标准。
  • StoppingCriteria代码分析
  • LogitsProcessor代码分析
  • 注意力掩码相关变量:检查模型的 forward 方法是否接受 attention_mask,以及当前参数中是否提供。
3. 准备模型输入
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
  • _prepare_model_inputs:处理输入张量,获取模型输入名称,并更新 model_kwargs。_prepare_model_inputs
  • batch_sizedevice:获取批次大小和设备信息。
  • 准备特殊 token:如 bos_token_ideos_token_id 等。_prepare_special_tokens
4. 检查并处理注意力掩码
       # decoder-only models must use left-padding for batched generation.if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.if (generation_config._pad_token_tensor is not Noneand batch_size > 1and len(inputs_tensor.shape) == 2and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0):logger.warning("A decoder-only architecture is being used, but right-padding was detected! For correct ""generation results, please set `padding_side='left'` when initializing the tokenizer.")

这段代码是关于为仅解码器架构(decoder-only models)处理输入时的填充方式建议。它检查是否使用了右填充(right-padding),在这种情况下给出警告。

  1. 架构类型检查:

    • self.config.is_encoder_decoder 用于检查模型是否属于编码器-解码器架构。
    • 如果模型不是编码器-解码器架构(即它是仅解码器架构),并且不是在TorchDynamo编译模式下,代码继续进行。
  2. 输入条件检查:

    • 确保批处理大小大于1,即有多个序列在一起进行处理。
    • 检查输入张量 inputs_tensor 的形状满足条件:它是二维张量,一般形式为 [batch_size, sequence_length]
    • 检查序列中的最后一个标识符是否为 pad_token_id,指定的 generation_config._pad_token_tensor 不为 None
    • torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 用于确定至少有一个序列的最后一个令牌是填充令牌。
  3. 警告日志:

    • 如果满足以上条件,显示警告信息。
    • 提示在仅解码器模型中检测到右填充,它建议使用左填充(padding_side='left'),以获取正确的生成结果。
      # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are# generating the first new token or not, and we only want to use the embeddings for the first new token)if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":generation_config.use_cache = True

这段代码针对仅解码器架构(decoder-only models)进行了一种配置调整,特别是在使用 inputs_embeds 作为输入的时候。以下是代码的简单说明:

  1. 架构类型检查:

    • not self.config.is_encoder_decoder 用于判断模型是否是仅解码器架构。
    • 如果模型是仅解码器架构,代码继续进行。
  2. 输入类型检查:

    • model_input_name == "inputs_embeds" 检查输入类型是否为嵌入层(embeddings)。
    • inputs_embeds 通常表示已经经过词嵌入层的输入,这意味着模型接收的不是直接的令牌ID,而是对应的词向量。
  3. 缓存使用设置:

    • generation_config.use_cache = True 设置生成配置的 use_cache 属性为 True
    • 启用缓存对于跟踪生成的序列特别重要,尤其是在无法确定新生成的第一个令牌是否已生成时。
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(inputs_tensor, generation_config, model_kwargs)
elif kwargs_has_attention_mask:if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:raise ValueError("`attention_mask` passed to `generate` must be 2D.")
  • 如果需要 attention_mask:且未提供,则生成默认的 attention_mask
  • 验证提供的 attention_mask:确保其形状正确。
  • _prepare_attention_mask_for_generation
5. 为编码器-解码器模型准备输入
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs_tensor, model_kwargs, model_input_name, generation_config)
  • 对于编码器-解码器模型:如果未提供 encoder_outputs,则进行编码器的前向计算,并更新 model_kwargs。计算编码器部分
6. 准备用于自回归生成的 input_ids
if self.config.is_encoder_decoder:input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(batch_size=batch_size,model_input_name=model_input_name,model_kwargs=model_kwargs,decoder_start_token_id=generation_config._decoder_start_token_tensor,device=inputs_tensor.device,)
else:input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
  • 编码器-解码器模型:准备解码器的 input_ids
  • _prepare_decoder_input_ids_for_generation
  • _decoder_start_token_tensor
  • 解码器模型:直接使用 inputs_tensor 作为 input_ids
7. 处理特殊的生成配置
if generation_config.token_healing:input_ids = self.heal_tokens(input_ids, tokenizer)
if streamer is not None:streamer.put(input_ids.cpu())
  • token_healing:如果启用了此项配置,修复输入中的 tokens。token_healing
  • streamer:如果提供了流式处理器,传递当前的 input_ids
8. 准备生成长度相关的参数
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(generation_config=generation_config,has_default_max_length=has_default_max_length,has_default_min_length=has_default_min_length,model_input_name=model_input_name,inputs_tensor=inputs_tensor,input_ids_length=input_ids_length,
)
  • 计算输入序列长度
  • 确定是否使用默认的最大和最小长度
  • 准备生成长度配置,可能会根据输入长度进行调整。_prepare_generated_length
9. 准备缓存和其他模型参数
if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:model_kwargs["logits_to_keep"] = 1
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
  • logits_to_keep:如果模型支持,仅保留需要的 logits,减少内存占用。_supports_logits_to_keep
  • 验证生成长度:确保生成长度的合法性。_validate_generated_length
        # 7. Prepare the cache.# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.# - different models have a different cache name expected by the model (default = "past_key_values")# - `max_length`, prepared above, is used to determine the maximum cache lengthmax_cache_length = generation_config.max_length - 1if (inputs_tensor.shape[1] != input_ids_lengthand model_input_name == "inputs_embeds"and not self.config.is_encoder_decoder):max_cache_length += inputs_tensor.shape[1]self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device)
  • 准备缓存:为生成过程中的缓存(如注意力缓存)分配空间。
    _prepare_cache_for_generation
10. 确定生成模式
generation_mode = generation_config.get_generation_mode(assistant_model)
  • 根据生成配置和辅助模型,确定生成模式,例如:
    • 辅助生成(Assisted Generation)
    • DoLa 生成(DOLA Generation)
    • 对比搜索(Contrastive Search)
    • 采样或贪心搜索
    • 束搜索(Beam Search)
    • 组束搜索(Group Beam Search)
    • 受限束搜索(Constrained Beam Search)
11. 准备 logits 处理器和停止标准
prepared_logits_processor = self._get_logits_processor(generation_config=generation_config,input_ids_seq_length=input_ids_length,encoder_input_ids=inputs_tensor,prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,logits_processor=logits_processor,device=inputs_tensor.device,model_kwargs=model_kwargs,negative_prompt_ids=negative_prompt_ids,negative_prompt_attention_mask=negative_prompt_attention_mask,
)
prepared_stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
)
model_kwargs["use_cache"] = generation_config.use_cache
  • 获取 logits 处理器:整合默认和自定义的 logits 处理器,用于在生成过程中调整 logits。
  • 获取停止标准:整合默认和自定义的停止标准,用于在满足条件时终止生成。
  • 设置 use_cache:根据配置,决定是否在生成过程中使用缓存。
    _get_logits_processor
    _get_stopping_criteria
12. 根据生成模式调用相应的生成函数
  • 辅助生成
if generation_mode == GenerationMode.ASSISTED_GENERATION:# 验证条件# 获取候选生成器# 执行辅助生成
  • DoLa 生成
elif generation_mode == GenerationMode.DOLA_GENERATION:# 执行 DoLa 解码
  • 对比搜索
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:# 执行对比搜索
  • 采样或贪心搜索
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):# 扩展 input_ids# 执行采样或贪心搜索
  • 束搜索
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):# 准备束搜索评分器# 扩展 input_ids# 执行束搜索

GenerationMixin:_sample方法(GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH)

  • 组束搜索
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:# 准备组束搜索评分器# 扩展 input_ids# 执行组束搜索
  • 受限束搜索
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:# 准备约束条件# 准备受限束搜索评分器# 扩展 input_ids# 执行受限束搜索
13. 处理生成结果
# 如果需要,将缓存转换为传统格式
if (generation_config.return_legacy_cache is Trueand not is_torchdynamo_compiling()and hasattr(result, "past_key_values")and getattr(result.past_key_values, "to_legacy_cache") is not None
):result.past_key_values = result.past_key_values.to_legacy_cache()
return result
  • 转换缓存格式:如果配置需要,将生成过程中使用的缓存转换为传统格式。
  • 返回结果:最终将生成的结果返回。

_validate_model_class

函数功能概述

这个函数名为_validate_model_class,用于验证当前的模型类是否支持生成(generation)操作。如果不支持生成,则会抛出一个异常,提示用户使用合适的模型类。

  1. 条件判断

    if not is_torchdynamo_compiling() and not self.can_generate():
    
    • 这里有一个 if 条件,用于检查两个条件是否同时满足:
      • not is_torchdynamo_compiling()
      • not self.can_generate()
    • is_torchdynamo_compiling()
      • 这是一个函数,检查当前是否处于 TorchDynamo 的编译环境中。
      • TorchDynamo 是 PyTorch 的一个 JIT 编译器框架,可以对模型进行编译优化。
      • 在编译过程中,某些检查可能需要被跳过,因此在编译时不进行此验证。
    • self.can_generate()
      • 这是模型类的一个方法,返回布尔值,表示模型是否支持生成操作。
      • 如果模型具有语言模型头(language model head),则通常支持生成,即 self.can_generate() 返回 True

    因此,只有在不处于编译环境模型不能生成的情况下,才会进入 if 语句内部,抛出异常。

  2. 定义支持生成的模型类名后缀列表

    terminations_with_generation_support = ["ForCausalLM","ForConditionalGeneration","ForSpeechSeq2Seq","ForVision2Seq",
    ]
    
    • 这是一个列表,包含了通常支持生成操作的模型类名称的后缀。
    • 这些后缀包括:
      • "ForCausalLM":用于自回归语言模型,如 GPT-2、GPT-3 等。
      • "ForConditionalGeneration":用于条件生成模型,如 BART、T5 等。
      • "ForSpeechSeq2Seq":用于语音序列到序列模型。
      • "ForVision2Seq":用于视觉到序列的模型,如图像描述生成。
  3. 抛出异常

    raise TypeError(f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as ""it doesn't have a language model head. Classes that support generation often end in one of these "f"names: {terminations_with_generation_support}."
    )
    
    • 如果条件满足,说明当前模型不支持生成,则抛出一个 TypeError 异常。
    • 异常信息包括:
      • 当前模型类的名称:{self.__class__.__name__}
      • 说明模型不兼容 .generate() 方法,因为它没有语言模型头。
      • 提示支持生成的模型类通常以哪些后缀结尾:{terminations_with_generation_support}

_prepare_generation_config

函数功能概述

这个函数名为_prepare_generation_config,它的作用是准备生成所需的配置对象generation_config,并应用从kwargs(关键字参数)中传入的任何生成配置选项。这个函数还处理了与模型配置文件(config)的向后兼容性。


参数和返回值
  • 参数

    • self:类的实例。
    • generation_config:可选的GenerationConfig对象,表示生成配置。如果为None,则需要从其他地方获取或生成。
    • **kwargs:其他关键字参数,可能包含生成配置的参数,将用于更新generation_config
  • 返回值

    • generation_config:最终准备好的生成配置对象。
    • model_kwargs:用于模型的关键字参数字典。

处理逻辑详解
  1. 初始设置

    # 设置一个标志,指示是否使用模型的默认生成配置
    using_model_generation_config = False
    
  2. 处理generation_configNone的情况

    if generation_config is None:...
    

    当用户没有提供generation_config时,需要从模型中获取。但在处理之前,先考虑到可能的向后兼容性问题。

    • 遗留(Legacy)行为的处理

      # 遗留支持:用户可能修改了模型的配置来控制生成。要触发这种遗留行为,需要满足以下条件:
      # 1) generation_config 是从模型配置创建的(`_from_model_config`字段为True)
      # 2) generation_config 自创建以来没有被修改过(哈希值相同)
      # 3) 模型配置中有非默认的生成参数
      # 4) 用户在模型配置中设置了新的生成参数
      # 注意:`torch.compile`无法编译`hash`函数,因此在编译时,这种遗留支持被禁用
      if (not is_torchdynamo_compiling()and self.generation_config._from_model_config  # 条件1and self.generation_config._original_object_hash == hash(self.generation_config)  # 条件2and len(self.config._get_non_default_generation_parameters()) > 0  # 条件3
      ):new_generation_config = GenerationConfig.from_model_config(self.config)if new_generation_config != self.generation_config:  # 条件4warnings.warn("You have modified the pretrained model configuration to control generation. This is a"" deprecated strategy to control generation and will be removed in v5."" Please use and modify the model generation configuration (see"" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",UserWarning,)self.generation_config = new_generation_config
      
      • 解释条件

        1. self.generation_config._from_model_config
          • 检查generation_config是否是从模型配置创建的。
        2. self.generation_config._original_object_hash == hash(self.generation_config)
          • 检查generation_config自创建以来是否没有被修改过。
          • 由于torch.compile无法编译hash函数,因此在编译时无法进行此检查。
        3. len(self.config._get_non_default_generation_parameters()) > 0
          • 检查模型配置中是否有非默认的生成参数。
        4. new_generation_config != self.generation_config
          • 检查新的生成配置是否与当前的不同,表示用户在模型配置中设置了新的生成参数。
      • 操作

        • 如果以上条件都满足,表示用户通过修改模型配置来控制生成。这是一种已被弃用的做法。
        • 发出警告,提示此策略将在v5版本中移除,建议用户使用并修改生成配置对象generation_config
        • 更新self.generation_config为新的生成配置new_generation_config
    • 设置generation_config

      generation_config = self.generation_config
      using_model_generation_config = True
      
      • 如果没有提供generation_config,则使用模型的默认生成配置self.generation_config
      • 设置标志using_model_generation_config = True,表示正在使用模型的默认生成配置。
  3. 处理torch.compile相关的问题

    # `torch.compile`无法编译`copy.deepcopy`等函数,因此需要根据是否在编译中,决定如何处理
    if not is_torchdynamo_compiling():...
    else:model_kwargs = kwargs
    
    • 非编译环境下的处理

      if not is_torchdynamo_compiling():generation_config = copy.deepcopy(generation_config)model_kwargs = generation_config.update(**kwargs)...
      
      • 深拷贝generation_config

        • 使用copy.deepcopy创建generation_config的深拷贝,以避免修改原始对象。
        • 由于torch.compile无法编译copy.deepcopy,因此在编译环境下无法进行此操作。
      • 更新generation_config

        • 调用generation_config.update(**kwargs)方法,用传入的kwargs更新生成配置。
        • 这个方法返回未被generation_config使用的参数,即那些不属于生成配置的参数,存储在model_kwargs中。
      • 处理特殊的Token ID

        # 如果提供了`generation_config`,需要确保所有特殊的Token ID都有默认值
        if not using_model_generation_config:if generation_config.bos_token_id is None:generation_config.bos_token_id = self.generation_config.bos_token_idif generation_config.eos_token_id is None:generation_config.eos_token_id = self.generation_config.eos_token_idif generation_config.pad_token_id is None:generation_config.pad_token_id = self.generation_config.pad_token_idif generation_config.decoder_start_token_id is None:generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
        
        • 如果用户提供了自己的generation_config(即不使用模型的默认生成配置),需要确保特殊的Token ID(开始、结束、填充、解码器开始)有默认值。
        • 如果这些ID在用户提供的generation_config中为None,则使用模型默认的self.generation_config中的值。
    • 编译环境下的处理

      else:model_kwargs = kwargs
      
      • 在编译环境下,由于无法使用copy.deepcopyhash,直接将传入的kwargs赋值给model_kwargs
      • 不进行深拷贝和更新操作。
  4. 返回结果

    return generation_config, model_kwargs
    
    • 函数返回准备好的generation_configmodel_kwargs
    • model_kwargs包含了模型需要的其他参数。

_validate_model_kwargs

函数功能概述

这个函数名为_validate_model_kwargs,用于在生成(generation)过程中对传入的模型关键字参数model_kwargs进行验证。它的主要作用是:

  • 验证传入的model_kwargs是否被模型的生成方法使用,以确保参数的正确性,防止参数拼写错误或传入不支持的参数。
  • 给出明确的错误信息,帮助用户及时发现并修正问题。

1. 检查past_key_values是否为Cache实例,并验证模型是否支持
if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:raise ValueError(f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please ""check the model documentation for supported cache formats.")
  • 目的:如果model_kwargs中包含past_key_values,并且它是一个Cache实例,而模型不支持Cache类型的缓存,则抛出错误。

  • 解释

    • model_kwargs.get("past_key_values", None):尝试获取past_key_values的值,如果不存在则为None
    • isinstance(..., Cache):检查past_key_values是否是Cache类的实例。
    • not self._supports_cache_class:检查当前模型是否不支持Cache类型的缓存。
  • 错误处理:如果以上条件满足,则抛出ValueError,提示用户当前模型不支持Cache实例作为past_key_values,并建议查看模型文档以了解支持的缓存格式。

2. 对于编码器-解码器模型,移除特定的参数
if self.config.is_encoder_decoder:for key in ["decoder_input_ids"]:model_kwargs.pop(key, None)
  • 目的:在生成过程中,不需要直接传入decoder_input_ids,因此从model_kwargs中移除该参数,防止后续产生错误。

  • 解释

    • self.config.is_encoder_decoder:检查模型是否为编码器-解码器结构。
    • model_kwargs.pop(key, None):从model_kwargs中移除指定的键,如果不存在则返回None
3. 初始化未使用的模型参数列表和获取模型方法参数
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
  • unused_model_args:用于收集未被模型方法使用的参数名。

  • model_args:获取prepare_inputs_for_generation方法的参数名集合,表示模型在生成过程中可能使用的参数。

  • 解释

    • inspect.signature(...):获取函数的完整参数签名。
    • .parameters:获取参数名组成的有序字典。
    • set(...):将参数名转换为集合,方便后续操作。
4. 扩展模型参数集合,包含forward方法的参数
if "kwargs" in model_args or "model_kwargs" in model_args:model_args |= set(inspect.signature(self.forward).parameters)
  • 目的:如果prepare_inputs_for_generation方法接受kwargs(可变关键字参数),则需要将self.forward方法的参数也包含进来,因为这些参数可能会通过kwargs传递。

  • 解释

    • 检查model_args中是否包含"kwargs""model_kwargs",表示接受可变关键字参数。
    • 使用set.union|=)操作,将self.forward方法的参数集合并入model_args
5. 对于编码器-解码器模型,包含编码器和解码器的参数
if self.config.is_encoder_decoder:base_model = getattr(self, self.base_model_prefix, None)# 允许编码器的参数encoder = getattr(self, "encoder", None)# 特殊情况处理(如Musicgen模型)if encoder is None and base_model is not None:encoder = getattr(base_model, "encoder", None)if encoder is not None:encoder_model_args = set(inspect.signature(encoder.forward).parameters)model_args |= encoder_model_args# 允许解码器的参数decoder = getattr(self, "decoder", None)if decoder is None and base_model is not None:decoder = getattr(base_model, "decoder", None)if decoder is not None:decoder_model_args = set(inspect.signature(decoder.forward).parameters)model_args |= {f"decoder_{x}" for x in decoder_model_args}
  • 目的:将编码器和解码器的forward方法的参数名添加到model_args中,以便在验证model_kwargs时考虑到这些参数。

  • 详细解释

    • base_model = getattr(self, self.base_model_prefix, None):获取基础模型实例,self.base_model_prefix通常是模型的前缀名。

    • 处理编码器

      • encoder = getattr(self, "encoder", None):尝试直接获取self.encoder属性。
      • 如果encoderNonebase_model存在,则尝试从base_model中获取encoder
      • 获取encoder.forward方法的参数名集合encoder_model_args
      • encoder_model_args并入model_args
    • 处理解码器

      • 类似地,获取decoder实例。
      • 获取decoder.forward方法的参数名集合decoder_model_args
      • decoder_model_args的参数名加上前缀"decoder_",以匹配生成过程中参数的命名方式。
      • 将修改后的decoder_model_args并入model_args
  • 特殊情况说明

    • 对于某些模型(例如MusicgenForConditionalGeneration),编码器和解码器的命名和结构可能与标准情况不同,需要特别处理。
6. 检查model_kwargs中的未使用参数
for key, value in model_kwargs.items():if value is not None and key not in model_args:unused_model_args.append(key)
  • 目的:找出model_kwargs中那些未被模型方法使用的参数。

  • 解释

    • 遍历model_kwargs中的所有键值对。
    • 如果参数值不为None且参数名不在model_args集合中,说明模型并不会使用该参数。
    • 将这些未使用的参数名添加到unused_model_args列表中。
7. 抛出错误提示未使用的参数
if unused_model_args:raise ValueError(f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"" generate arguments will also show up in this list)")
  • 目的:如果存在未使用的参数,抛出ValueError,提醒用户这些参数未被模型使用,可能存在拼写错误或传入了不支持的参数。

  • 解释

    • 检查unused_model_args列表是否非空。
    • 抛出错误,错误信息中包含未使用的参数列表,并提示用户可能是由于参数名的拼写错误导致的。

_prepare_model_inputs

方法定义

def _prepare_model_inputs(self,inputs: Optional[torch.Tensor] = None,bos_token_id: Optional[torch.Tensor] = None,model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:"""This function extracts the model-specific `inputs` for generation."""# 方法体...

参数说明:

  • inputs:可选的 torch.Tensor,表示要作为模型输入的张量。可能是 input_idsinputs_embeds 等形式,具体取决于模型的要求。
  • bos_token_id:可选的 torch.Tensor,表示序列开始的 token ID(BOS = Begin Of Sequence)。在生成任务中,如果未提供输入,可能需要使用该 token 进行初始化。
  • model_kwargs:可选的字典,包含传递给模型的其他关键字参数。

返回值:

  • inputstorch.Tensor,准备好的模型输入张量。
  • input_namestr,模型输入的名称,可能是 input_idsinputs_embeds
  • model_kwargs:字典,更新后的模型关键字参数。

方法功能概述

该方法的主要目的是在生成过程中,准备和验证模型的输入,确保输入与模型的预期格式和要求一致。具体任务包括:

  1. 确定模型所需的主要输入名称(input_name),这可能取决于模型是编码器-解码器模型还是仅解码器模型。
  2. 处理传入的 inputsmodel_kwargs,以避免重复传递相同的输入参数。
  3. 在需要时,使用 bos_token_id 初始化 input_ids,以开始生成过程。
  4. 对于支持 inputs_embeds 的模型,正确处理 inputs_embeds 参数。

逐步详解

步骤 1:确定模型的主要输入名称
# 判断模型是否是编码器-解码器,并获取正确的输入名称
if (self.config.is_encoder_decoderand hasattr(self, "encoder")and self.encoder.main_input_name != self.main_input_name
):input_name = self.encoder.main_input_name
else:input_name = self.main_input_name

解释:

  • 目的:获取模型预期的主要输入参数名称,可能是 input_idsinputs_embeds 等。
  • 逻辑:
    • 如果模型是 编码器-解码器模型,并且编码器的 main_input_name 与模型的 main_input_name 不同,则使用编码器的 main_input_name
      • 这是因为编码器和解码器可能需要不同的输入名称。例如,一些模型的编码器可能使用 inputs_embeds,而解码器使用 input_ids
    • 否则,使用模型的 main_input_name 作为输入名称。

示例:

  • 如果 self.main_input_name'input_ids',而 self.encoder.main_input_name'inputs_embeds',则对于编码器,需要使用 'inputs_embeds' 作为输入名称。
步骤 2:从 model_kwargs 中提取非输入相关的参数
# 过滤掉模型输入的关键字参数,以及值为 None 的参数
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}

解释:

  • 目的:清理 model_kwargs,只保留非模型输入的参数和值非 None 的参数。
  • 逻辑:
    • 遍历 model_kwargs 中的所有键值对,保留以下情况的参数:
      • v 不为 None
      • 或者键 k 不是模型的主要输入名称 input_name
  • 原因:避免模型输入参数(如 input_ids)重复出现在 inputsmodel_kwargs 中,以防止冲突。

示例:

  • 如果 input_name'input_ids',并且 model_kwargs 包含 {'input_ids': None, 'attention_mask': tensor([...])},那么经过此步骤后,model_kwargs 变为 {'attention_mask': tensor([...])}
步骤 3:检查模型输入是否作为关键字参数传入
# 从 model_kwargs 中获取输入参数
inputs_kwarg = model_kwargs.pop(input_name, None)
# 检查是否同时传入了 inputs 和 inputs_kwarg
if inputs_kwarg is not None and inputs is not None:raise ValueError(f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "f"Make sure to either pass {inputs} or {input_name}=...")
elif inputs_kwarg is not None:inputs = inputs_kwarg

解释:

  • 目的:确保模型输入参数只通过一个途径传入,避免冲突。
  • 逻辑:
    • model_kwargs 中取出键为 input_name 的参数,赋值给 inputs_kwarg
    • 如果同时传入了 inputsinputs_kwarg,抛出异常,提示用户只能通过一种方式传入模型输入。
    • 如果 inputsNone,但 inputs_kwarg 不为 None,则将 inputs_kwarg 赋值给 inputs

示例:

  • 用户通过位置参数传入了 inputs,同时在 model_kwargs 中传入了 input_ids,这将导致异常,因为模型无法确定使用哪个输入。
步骤 4:处理 inputs_embeds 的情况
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:if not self.config.is_encoder_decoder:# 检查模型是否支持 inputs_embedshas_inputs_embeds_forwarding = "inputs_embeds" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())if not has_inputs_embeds_forwarding:raise ValueError(f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} ""doesn't have its forwarding implemented. See the GPT2 implementation for an example ""(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!")# 将 input_ids 初始化并加入 model_kwargsmodel_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs=model_kwargs)else:if inputs is not None:raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")# 更新 inputs 和 input_nameinputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

解释:

  • 目的:处理用户通过 inputs_embeds 提供输入的情况,确保模型支持这种输入方式,并正确处理。
  • 逻辑:
    • 当模型的输入名称为 'input_ids',且 model_kwargs 中存在 'inputs_embeds' 键时,进入此逻辑。
    • 对于非编码器-解码器模型:
      • 检查模型的 prepare_inputs_for_generation 方法是否接受 inputs_embeds 参数。
      • 如果不支持,则抛出异常,提示模型不支持通过 inputs_embeds 进行生成。
      • 如果支持,则需要初始化 input_ids,以便在生成过程中处理诸如 attention mask 等依赖 input_ids 的自动操作。
      • 将初始化的 input_ids 添加到 model_kwargs 中。
    • 对于编码器-解码器模型:
      • 如果同时传入了 inputsinputs_embeds,抛出异常,提示只能选择一种输入方式。
    • 最后,将 inputs 设置为 inputs_embeds,并更新 input_name'inputs_embeds'

示例:

  • 用户在生成时提供了 inputs_embeds,如果模型支持,则将其用于生成,否则抛出异常。
步骤 5:如果 inputs 仍为 None,尝试从 bos_token_id 创建 input_ids
# 如果 inputs 为 None,使用 bos_token_id 初始化 input_ids
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)

解释:

  • 目的:在用户未提供任何输入的情况下,使用开始标记 bos_token_id 初始化输入,以启动生成过程。
  • 逻辑:
    • 调用 _maybe_initialize_input_ids_for_generation 方法,如果 inputsNone,则尝试使用 bos_token_id 创建 input_ids
    • 如果 inputs 已存在,则保持不变。

示例:

  • 用户未提供 inputs,模型使用 bos_token_id [101](假设为开始标记的 ID)初始化 input_ids
步骤 6:返回处理后的结果
# 返回准备好的 inputs、input_name 和更新后的 model_kwargs
return inputs, input_name, model_kwargs

示例:

# 示例 1:用户只提供 inputs
inputs = torch.tensor([[101, 102, 103]])
inputs, input_name, model_kwargs = model._prepare_model_inputs(inputs=inputs)# 示例 2:用户只提供 input_ids 作为关键字参数
model_kwargs = {'input_ids': torch.tensor([[101, 102, 103]])}
inputs, input_name, model_kwargs = model._prepare_model_inputs(model_kwargs=model_kwargs)# 示例 3:用户提供 inputs_embeds,模型支持 inputs_embeds
model_kwargs = {'inputs_embeds': torch.tensor([...])}
inputs, input_name, model_kwargs = model._prepare_model_inputs(model_kwargs=model_kwargs)# 示例 4:用户未提供任何输入,使用 bos_token_id 初始化
bos_token_id = torch.tensor([101])
inputs, input_name, model_kwargs = model._prepare_model_inputs(bos_token_id=bos_token_id)

_prepare_special_tokens

主要目的是准备生成所需的特殊标记(如开始标记 bos_token_id、结束标记 eos_token_id、填充标记 pad_token_id 和解码器起始标记 decoder_start_token_id),并将这些标记转换为张量。具体步骤如下:

  1. 定义辅助函数 _tensor_or_none

    • 将特殊标记转换为张量,如果标记为 None 则返回 None
  2. 将特殊标记转换为张量

    • 使用 _tensor_or_none 函数将 bos_token_ideos_token_idpad_token_iddecoder_start_token_id 转换为张量。
  3. 处理编码器-解码器模型

    • 如果模型是编码器-解码器类型,并且 decoder_start_token_id 未设置,则使用 bos_token_id 作为 decoder_start_token_id
  4. 处理 eos_token_tensor

    • 如果 eos_token_tensor 是 0 维张量,则将其扩展为 1 维张量。
  5. 设置 pad_token_tensor

    • 如果 pad_token_tensor 未设置且 eos_token_tensor 存在,则将 pad_token_tensor 设置为 eos_token_tensor 的第一个元素,并发出警告。
  6. 安全检查和警告

    • 检查编码器-解码器模型是否设置了 decoder_start_token_id
    • 检查 eos_token_tensor 是否与 pad_token_tensor 相同,并在未设置注意力掩码时发出警告。
    • 检查 eos_token_tensor 是否包含负数或浮点数,并发出警告。
  7. 更新生成配置

    • 将转换后的特殊标记张量存储在 generation_config 的新属性中,以启用端到端编译。

_prepare_attention_mask_for_generation

该方法用于在生成过程中为模型准备 attention_mask,以确保模型在生成序列时正确地关注输入序列的非填充部分(非 pad_token 部分)。


方法定义

def _prepare_attention_mask_for_generation(self,inputs_tensor: torch.Tensor,generation_config: GenerationConfig,model_kwargs: Dict[str, Any],
) -> torch.LongTensor:# 方法体...

参数说明:

  • inputs_tensor: torch.Tensor,输入张量,可能是模型的主要输入。
  • generation_config: GenerationConfig,生成配置对象,包含生成过程中所需的配置参数。
  • model_kwargs: Dict[str, Any],包含模型其他关键字参数的字典。

返回值:

  • attention_mask: torch.LongTensor,返回生成过程中使用的 attention_mask,用于指示模型需要关注的输入序列位置。

方法功能概述

该方法的主要目的是根据输入张量 inputs_tensor 和生成配置 generation_config,为生成过程准备合适的 attention_mask。具体而言,它会根据以下情况生成 attention_mask

  1. 如果输入张量包含 pad_token_id,则在 attention_mask 中标记出非填充的位置(即非 pad_token_id 的位置为 1,pad_token_id 的位置为 0)。
  2. 如果无法判断是否需要生成 attention_mask,则返回默认的 attention_mask,即全 1 的张量,表示所有位置都需要关注。

逐步详解

步骤 1:获取 pad_token_ideos_token_id
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
  • 说明:从 generation_config 中获取用于填充的 pad_token_id 和序列结束的 eos_token_id
步骤 2:检查 model_kwargs 中是否有 input_ids
# `input_ids` 可能存在于 model_kwargs 中,而不是主要输入(例如多模态模型)
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:inputs_tensor = model_kwargs["input_ids"]
  • 解释

    • 在某些情况下,inputs_tensor 并不是模型的主要输入,如在多模态模型中,input_ids 可能存在于 model_kwargs 中。
    • 如果 model_kwargs 中有 input_ids,并且其形状的第二维度长度大于 0(即有数据),则用 model_kwargs["input_ids"] 替换 inputs_tensor
  • 目的:确保 inputs_tensor 是正确的 input_ids,以便后续的 attention_mask 计算。

步骤 3:创建默认的 attention_mask
# 如果无法推断 attention mask,则返回默认的 attention mask
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
  • 解释

    • 创建一个形状与 inputs_tensor 前两个维度相同的全 1 张量,作为默认的 attention_mask
    • 该默认 attention_mask 表示输入序列的所有位置都需要关注。
    • 切片的基本语法是 start: end:step,其中 start 是起始索引,end 是结束索引(不包括),step 是步长。写为 [:2] 是一种简写表示方式,没有显式指定起始和步长,默认从头开始且步长为1
步骤 4:如果 pad_token_idNone,返回默认的 attention_mask
if pad_token_id is None:return default_attention_mask
  • 说明

    • 如果 pad_token_id 未定义,则无法根据填充标记生成 attention_mask,因此直接返回默认的 attention_mask
步骤 5:检查 inputs_tensor 是否为有效的 input_ids
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
if not is_input_ids:return default_attention_mask
  • 解释

    • 检查 inputs_tensor 是否满足以下条件:

      • inputs_tensor 的形状为二维,即形状为 (batch_size, sequence_length)
      • inputs_tensor 的数据类型为整数类型 torch.inttorch.long
    • 如果不满足上述条件,则认为无法根据 inputs_tensor 推断出 attention_mask,因此返回默认的 attention_mask

步骤 6:检查 inputs_tensor 中是否包含 pad_token_id
is_pad_token_in_inputs = (pad_token_id is not None) and (isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
)
  • 解释

    • 使用 isin_mps_friendly 函数检查 inputs_tensor 中是否包含 pad_token_id

      • isin_mps_friendly(elements, test_elements):检查 elements 中的元素是否在 test_elements 中,返回布尔张量。
    • isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()

      • 检查 inputs_tensor 中是否有元素等于 pad_token_id,如果有,则返回 True
  • 结果

    • is_pad_token_in_inputs 为一个布尔值,表示 inputs_tensor 中是否包含 pad_token_id
步骤 7:检查 pad_token_id 是否不等于 eos_token_id
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
)
  • 解释

    • 需要确保 pad_token_ideos_token_id 不相等,以避免将 eos_token_id 误认为填充标记。

    • 逻辑:

      • 如果 eos_token_idNone,则认为 pad_token_id 不等于 eos_token_id

      • 如果 eos_token_id 不为 None,则检查 eos_token_id 是否等于 pad_token_id

        • isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()

          • 检查 eos_token_id 是否等于 pad_token_id
        • 使用按位取反运算符 ~,将结果取反。

  • 结果

    • is_pad_token_not_equal_to_eos_token_id 为一个布尔值,表示 pad_token_id 是否不等于 eos_token_id
步骤 8:确定是否可以推断 attention_mask
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
  • 解释

    • 通过将 is_pad_token_in_inputsis_pad_token_not_equal_to_eos_token_id 相乘(布尔值相乘相当于逻辑与),判断是否可以基于 pad_token_id 推断出 attention_mask
  • 结果

    • can_infer_attention_mask 为布尔值:

      • 如果 True,表示可以根据 pad_token_id 推断 attention_mask

      • 如果 False,则无法推断,需要返回默认的 attention_mask

步骤 9:基于 pad_token_id 生成 attention_mask
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
  • 解释

    • 使用 inputs_tensor.ne(pad_token_id),得到一个布尔张量,标记出 inputs_tensor 中不等于 pad_token_id 的位置。

    • .long():将布尔张量转换为长整型,即 True 转换为 1False 转换为 0

  • 结果

    • attention_mask_from_padding 为一个长整型张量,形状与 inputs_tensor 相同,非 pad_token_id 的位置为 1pad_token_id 的位置为 0
步骤 10:根据是否可以推断 attention_mask,选择最终的 attention_mask
attention_mask = (attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
  • 解释

    • 如果可以推断 attention_maskcan_infer_attention_maskTrue):

      • 使用 attention_mask_from_padding
    • 如果不能推断:

      • 使用 default_attention_mask
    • 计算方式:

      • attention_mask = attention_mask_from_padding * can_infer_attention_mask

        • can_infer_attention_maskTrue 时,attention_mask 等于 attention_mask_from_padding

        • can_infer_attention_maskFalse 时,乘积为 0

      • default_attention_mask * ~can_infer_attention_mask

        • ~can_infer_attention_maskcan_infer_attention_mask 取反。

        • can_infer_attention_maskFalse 时,~can_infer_attention_maskTrue

        • 因此,当不能推断时,attention_mask 等于 default_attention_mask

      • 最终,将两个结果相加,得到正确的 attention_mask

  • 结果

    • attention_mask 为最终的注意力掩码张量,表示模型在生成过程中需要关注的输入位置。
步骤 11:返回最终的 attention_mask
return attention_mask
  • 解释

    • 返回生成的 attention_mask 张量,用于在生成过程中指导模型关注正确的输入序列位置。

_prepare_encoder_decoder_kwargs_for_generation

以下是对您提供的 _prepare_encoder_decoder_kwargs_for_generation 方法的详细解释。该方法用于在生成过程中为 编码器-解码器模型 准备必要的关键字参数,以便在生成序列时正确调用编码器。


方法定义

def _prepare_encoder_decoder_kwargs_for_generation(self,inputs_tensor: torch.Tensor,model_kwargs,model_input_name: Optional[str],generation_config: GenerationConfig,
) -> Dict[str, Any]:# 方法体...

参数说明:

  • inputs_tensor: torch.Tensor,输入张量,通常是 input_ids,表示输入序列的标记(tokens)。
  • model_kwargs: Dict[str, Any],包含传递给模型的其他关键字参数。
  • model_input_name: Optional[str],模型输入的名称,默认为 None。如果为 None,则使用 self.main_input_name
  • generation_config: GenerationConfig,生成配置对象,包含生成过程中所需的配置参数。

返回值:

  • model_kwargs: 返回更新后的 model_kwargs,其中包括编码器的输出 encoder_outputs,以供生成器使用。

方法功能概述

该方法的主要目的是:

  1. 获取模型的 编码器 部分。
  2. model_kwargsgeneration_config 中提取 编码器所需的参数,并准备传递给编码器的关键字参数 encoder_kwargs
  3. 调用 编码器的 forward 方法,获取编码器的输出,并将其添加到 model_kwargs 中,以供 解码器 在生成过程中使用。

逐步详解

步骤 1:获取编码器
# 1. get encoder
encoder = self.get_encoder()
  • 说明:

    • 调用模型的 get_encoder() 方法,获取编码器对象。
    • 该编码器将用于处理输入的 inputs_tensor,生成编码器的输出。
步骤 1.1:兼容性处理
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(self, "hf_device_map"):if hasattr(encoder, "_hf_hook"):encoder._hf_hook.io_same_device = Trueelse:add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
  • 解释:

    • 目的:确保在使用 Accelerate 库进行大型模型推理时,编码器的输出与输入位于 同一设备 上(如 GPU),避免跨设备的数据传输开销。
  • 逻辑:

    • 检查模型是否具有 hf_device_map 属性,如果存在,表示模型使用了 Accelerate 库进行设备映射。
    • 检查编码器是否具有 _hf_hook 属性:
      • 如果有,设置其 io_same_device 属性为 True,表示编码器的输入和输出在同一设备上。
      • 如果没有,使用 add_hook_to_module 函数,将 AlignDevicesHook(io_same_device=True) 添加到编码器模块上。
  • 相关函数:

    • add_hook_to_module(module, hook): 将钩子函数添加到指定的模块上,控制模块的输入输出行为。
步骤 2:准备编码器的参数
# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {argument: valuefor argument, value in model_kwargs.items()if not any(argument.startswith(p) for p in irrelevant_prefix)
}
  • 目的:

    • model_kwargs 中提取与编码器相关的参数,过滤掉与解码器或交叉注意力相关的参数。
  • 逻辑:

    • 定义一个列表 irrelevant_prefix,包含了不相关的参数前缀,如 "decoder_""cross_attn""use_cache"
    • 使用字典推导式,从 model_kwargs 中过滤掉以这些前缀开头的参数。
    • 结果是 encoder_kwargs,其中包含了需要传递给编码器的参数。
  • 示例:

    • 如果 model_kwargs 包含:

      model_kwargs = {"input_ids": tensor(...),"attention_mask": tensor(...),"decoder_input_ids": tensor(...),"use_cache": True,
      }
      
    • 过滤后,encoder_kwargs 为:

      encoder_kwargs = {"input_ids": tensor(...),"attention_mask": tensor(...),
      }
      
步骤 2.1:检查编码器的签名
encoder_signature = set(inspect.signature(encoder.forward).parameters)
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
if not encoder_accepts_wildcard:encoder_kwargs = {argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature}
  • 解释:

    • 目的:确保传递给编码器的参数在其 forward 方法的参数列表中,即编码器能够接受这些参数。
  • 逻辑:

    • 使用 inspect.signature(encoder.forward).parameters 获取编码器 forward 方法的参数名称集合 encoder_signature
    • 检查编码器是否接受通配参数 **kwargs**model_kwargs,如果接受,则无需进一步过滤参数。
    • 如果编码器不接受通配参数,则过滤 encoder_kwargs,仅保留在 encoder_signature 中的参数。
步骤 2.2:添加生成配置中的参数
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
  • 说明:

    • generation_config 中提取 output_attentionsoutput_hidden_states 配置,添加到 encoder_kwargs 中。
    • 这些配置决定了编码器在前向计算时,是否返回注意力权重和隐藏状态。
步骤 3:调用编码器并更新 model_kwargs
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)  # type: ignore
return model_kwargs
  • 解释:

    • 确保编码器返回 ModelOutput 对象:

      • 设置 encoder_kwargs["return_dict"] = True,使编码器的输出为 ModelOutput 格式,而不是元组。
    • 准备编码器的输入:

      • 确定模型的输入名称 model_input_name,如果传入了 model_input_name,则使用它,否则使用 self.main_input_name
      • inputs_tensor 添加到 encoder_kwargs,键为 model_input_name
    • 调用编码器的 forward 方法:

      model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)  # type: ignore
      
      • 调用编码器,传入 encoder_kwargs,得到编码器的输出 encoder_outputs
      • encoder_outputs 添加到 model_kwargs 中,键为 "encoder_outputs"
      • 使用类型注解 : ModelOutput 指定类型,这是为了让静态类型检查器知道 encoder_outputs 的类型。
    • 返回更新后的 model_kwargs

      • 包含了 encoder_outputs,供后续的解码器生成过程中使用。

_prepare_decoder_input_ids_for_generation

以下是对您提供的 _prepare_decoder_input_ids_for_generation 方法的详细解释。这个方法用于在生成过程中为 编码器-解码器模型 准备 decoder_input_ids,确保解码器在生成时能够正确地开始。


方法定义

def _prepare_decoder_input_ids_for_generation(self,batch_size: int,model_input_name: str,model_kwargs: Dict[str, torch.Tensor],decoder_start_token_id: torch.Tensor,device: torch.device = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""# 方法体...

参数说明:

  • batch_size: int,表示批次大小,即一次输入中样本的数量。
  • model_input_name: str,模型输入的名称,通常为 "input_ids"
  • model_kwargs: Dict[str, torch.Tensor],包含传递给模型的其他关键字参数。
  • decoder_start_token_id: torch.Tensor,解码器开始标记的 token ID。
  • device: torch.device,可选,指定要在哪个设备上运行,如 GPU 或 CPU。如果未指定,则使用模型默认的设备。

返回值:

  • decoder_input_ids: torch.LongTensor,准备好的用于生成的解码器输入 IDs。
  • model_kwargs: 更新后的模型关键字参数字典。

方法功能概述

这个方法的主要作用是:

  1. 检查用户是否手动提供了 decoder_input_ids,如果没有,则根据情况初始化它。
  2. 确保 decoder_start_token_id 的形状正确,并适应批次大小。
  3. 确保 decoder_input_ids 以特殊的开始标记(如 BOS token)开头,如果没有,则自动添加。
  4. 处理特定模型的例外情况,例如 “Donut” 和 “Whisper” 模型可能有不同的处理方式。

通过这些步骤,该方法确保了在生成过程中,解码器能够正确地开始生成序列。


逐步详解

步骤 1:检查用户是否提供了 decoder_input_ids
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:decoder_input_ids = model_kwargs.pop("decoder_input_ids")
elif "input_ids" in model_kwargs and model_input_name != "input_ids":decoder_input_ids = model_kwargs.pop("input_ids")
else:decoder_input_ids = None

解释:

  • 目的: 确定是否需要初始化 decoder_input_ids

  • 逻辑:

    • 如果 model_kwargs 中存在 decoder_input_ids,则将其取出,并从 model_kwargs 中删除,避免重复。
    • 如果 model_kwargs 中存在 input_ids,且 model_input_name 不是 "input_ids",则将其视为 decoder_input_ids。这是为了方便某些模型的输入命名。
    • 如果以上都不满足,则将 decoder_input_ids 设为 None,表示需要初始化。

示例:

  • 用户在调用生成方法时,可能通过 model_kwargs 提供了 decoder_input_ids,方法将直接使用它。
  • 如果用户没有提供,方法将尝试从 input_ids 中获取(当编码器不使用 input_ids 作为主要输入时)。
  • 如果都没有提供,方法将在后续步骤中初始化 decoder_input_ids
步骤 2:调整 decoder_start_token_id 的形状
if device is None:device = self.device
if decoder_start_token_id.ndim == 1:if decoder_start_token_id.shape[0] != batch_size:raise ValueError(f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}")decoder_start_token_id = decoder_start_token_id.view(-1, 1)
else:decoder_start_token_id = (torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id)

解释:

  • 目的: 确保 decoder_start_token_id 的形状为 (batch_size, 1),即每个样本都有一个开始标记。

  • 逻辑:

    • 如果未指定 device,则使用模型默认的设备 self.device
    • 检查 decoder_start_token_id 的维度:
      • 如果是一维张量(向量),即 decoder_start_token_id.ndim == 1
        • 检查其长度是否等于 batch_size,否则抛出错误。
        • 重塑为形状 (batch_size, 1),即每个样本一个开始标记。
      • 如果不是一维张量,则认为是单个标量:
        • 创建一个形状为 (batch_size, 1) 的张量,元素全为 decoder_start_token_id 的值。

示例:

  • 如果 decoder_start_token_id 是标量 101(假设是开始标记的 ID):
    • 创建一个形状为 (batch_size, 1) 的张量,所有元素都是 101
  • 如果 decoder_start_token_id 是张量 [101, 102]batch_size2
    • 将其重塑为:
      [[101],[102]]
      
步骤 3:确保 decoder_input_ids 以开始标记开头
# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:decoder_input_ids = decoder_start_token_id

解释:

  • 情况 1: 如果 decoder_input_idsNone,即用户未提供任何解码器输入:
    • 直接使用 decoder_start_token_id 作为 decoder_input_ids,表示解码器将从开始标记开始生成。

# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token...
elif "donut" in self.__class__.__name__.lower() or (self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
):pass
elif self.config.model_type in ["whisper"]:pass

解释:

  • 情况 2: 针对特定模型的例外情况:
    • Donut 模型:
      • Donut 模型有特定的解码器开始标记,且不需要额外的 BOS(Begin of Sequence)标记。
      • 如果模型名包含 "donut",则不对 decoder_input_ids 进行处理。
    • Whisper 模型:
      • Whisper 模型也有自己的处理逻辑,不需要添加开始标记。

# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item():decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)if "decoder_attention_mask" in model_kwargs:decoder_attention_mask = model_kwargs["decoder_attention_mask"]decoder_attention_mask = torch.cat((torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),dim=-1,)model_kwargs["decoder_attention_mask"] = decoder_attention_mask

解释:

  • 情况 3: 用户提供了 decoder_input_ids,但其首个 token 并非 decoder_start_token_id
    • 检查首个 token 是否等于 decoder_start_token_id
      • (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
        • 比较每个样本的 decoder_input_ids 首个 token 是否不等于对应的 decoder_start_token_id
        • 如果对于所有样本都不相等,则返回 True
    • 如果首个 token 不匹配,则在 decoder_input_ids 前添加 decoder_start_token_id
      • 使用 torch.cat 在维度 -1(序列长度维度)上拼接 decoder_start_token_iddecoder_input_ids
    • 调整 decoder_attention_mask(如果提供):
      • 如果 model_kwargs 中存在 decoder_attention_mask
        • 在其前面添加一个值为 1 的位置,表示新添加的开始标记需要被注意。
        • 更新 model_kwargs["decoder_attention_mask"]

示例:

  • 如果原始 decoder_input_ids 为:

    [[5, 6, 7],[8, 9, 10]]
    
  • decoder_start_token_id 为:

    [[101],[101]]
    
  • 拼接后得到:

    [[101, 5, 6, 7],[101, 8, 9, 10]]
    
  • 同时调整 decoder_attention_mask,在前面添加一个 1


步骤 4:返回处理后的结果
return decoder_input_ids, model_kwargs
  • 返回更新后的 decoder_input_idsmodel_kwargs

整体流程总结

  • 输入处理:

    • 检查用户是否提供了 decoder_input_ids,如果没有,则需要初始化。
    • 通过 model_kwargs 获取 decoder_input_idsinput_ids,如果适用。
  • 确保解码器开始标记的形状正确:

    • decoder_start_token_id 调整为形状 (batch_size, 1),确保每个样本都有对应的开始标记。
  • 确保 decoder_input_ids 以开始标记开头:

    • 如果 decoder_input_idsNone,直接使用 decoder_start_token_id
    • 对于特定模型(如 Donut 和 Whisper),保留用户提供的 decoder_input_ids,不做修改。
    • 如果用户提供的 decoder_input_ids 不以 decoder_start_token_id 开头,自动在其前添加。
    • 同时,调整 decoder_attention_mask,确保新添加的开始标记在注意力掩码中被考虑。
  • 返回处理后的 decoder_input_idsmodel_kwargs,供后续生成过程使用。


示例代码

示例 1:用户未提供 decoder_input_ids

# 假设 batch_size = 2
decoder_start_token_id = torch.tensor([101], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(batch_size=2,model_input_name="input_ids",model_kwargs={},decoder_start_token_id=decoder_start_token_id,device=device,
)
# 结果:decoder_input_ids = [[101], [101]]

示例 2:用户提供了 decoder_input_ids,但未以开始标记开头

decoder_input_ids = torch.tensor([[5, 6, 7], [8, 9, 10]], device=device)
model_kwargs = {"decoder_input_ids": decoder_input_ids}
decoder_start_token_id = torch.tensor([101, 102], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(batch_size=2,model_input_name="input_ids",model_kwargs=model_kwargs,decoder_start_token_id=decoder_start_token_id,device=device,
)
# 结果:
# decoder_input_ids = [[101, 5, 6, 7], [102, 8, 9, 10]]

示例 3:针对 Donut 模型

# 设模型名称包含 "Donut",用户提供了 decoder_input_ids
self.__class__.__name__ = "DonutModel"
decoder_input_ids = torch.tensor([[5, 6, 7], [8, 9, 10]], device=device)
model_kwargs = {"decoder_input_ids": decoder_input_ids}
decoder_start_token_id = torch.tensor([101], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(batch_size=2,model_input_name="input_ids",model_kwargs=model_kwargs,decoder_start_token_id=decoder_start_token_id,device=device,
)
# 结果:
# decoder_input_ids 保持不变,不添加 decoder_start_token_id

注意事项

  • 模型特定的例外情况: 对于某些模型,如 Donut 和 Whisper,需要特殊处理,不能盲目添加开始标记。

  • 一致性检查: 确保 batch_sizedecoder_start_token_id 的长度一致,否则抛出异常。

  • 处理注意力掩码: 如果修改了 decoder_input_ids,且提供了 decoder_attention_mask,需要相应地调整掩码。

  • 设备一致性: 所有张量都应该在同一设备上(CPU 或 GPU),方法中确保了这一点。


总结

该方法的主要功能是为生成过程准备合适的 decoder_input_ids

  • 初始值设置: 如果用户未提供,则使用 decoder_start_token_id 初始化。
  • 形状调整: 确保 decoder_start_token_id 的形状与批次大小匹配。
  • 开头标记: 确保 decoder_input_ids 以特定的开始标记开头,以符合模型的要求。
  • 特例处理: 对于特定模型,按其特殊需求处理。

通过这些步骤,保证了编码器-解码器模型在生成过程中能够正确地开始解码器的序列生成。


希望以上解释能够帮助您理解 _prepare_decoder_input_ids_for_generation 方法的功能和每个步骤的具体作用。如果您还有其他问题,欢迎继续提问!

heal_tokens

这个方法用于在生成过程中,对输入的 input_ids 进行修复(healing),以增强模型的输出质量。具体来说,它会根据输入序列的最后一个 token,寻找可能的扩展,并替换原始的尾部 token。这在某些情况下可以纠正模型生成中的不一致或错误。


方法定义

def heal_tokens(self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
) -> torch.LongTensor:r"""Generates sequences of token ids for models with a language modeling head.Parameters:input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation.tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids.Return:`torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension."""# 方法体...

参数说明:

  • input_ids: torch.LongTensor,形状为 (batch_size, sequence_length),表示输入的 token 序列,用于生成的提示。
  • tokenizer: PreTrainedTokenizerBase,可选参数,模型对应的 tokenizer,用于解码和编码 token IDs。

返回值:

  • torch.LongTensor,形状与 input_ids 相同,其中每个序列的尾部 token 被替换为了适当的扩展 token。

方法功能概述

该方法的主要目的是:

  1. 修复输入序列的尾部 token:对于每个输入序列,检查其最后一个 token,寻找可能的扩展 token,并替换之。
  2. 改进模型生成的连贯性:通过纠正输入序列的尾部,使得生成的序列在语义和形式上更加连贯。
  3. 处理空序列和特殊情况:在方法中包含了对空序列和特殊情况的处理,确保方法的稳健性。

逐步详解

步骤 1:验证 tokenizer 是否提供

if tokenizer is None:raise ValueError(" When generating with token healing, you must pass the model's tokenizer to the `tokenizer` ""argument of `generate`.")
  • 解释:
    • 方法要求必须提供 tokenizer 参数,否则无法进行解码和编码操作。
    • 如果未提供 tokenizer,则抛出 ValueError

步骤 2:获取特殊 token IDs 和构建词汇前缀树

bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id
vocab_trie = ExtensionsTrie(tokenizer.get_vocab())
generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id)
  • 解释:
    • 获取特殊 token IDs:
      • bos_token_id: 序列开始的 token ID(BOS = Begin Of Sequence)。
      • pad_token_id: 填充 token 的 ID。
    • 构建词汇前缀树:
      • tokenizer.get_vocab(): 获取 tokenizer 的词汇表,返回一个字典 {token: token_id}
      • ExtensionsTrie: 自定义的前缀树类,用于快速查找以某个前缀开始的所有 token。
    • 创建生成配置:
      • GenerationConfig: 配置生成过程的参数。
      • max_new_tokens=1: 生成的最大新 token 数设置为 1,因为我们只需要替换尾部 token。
      • pad_token_id=pad_token_id: 设置填充 token ID。

步骤 3:处理提示文本

# assumption: leading/trailing whitespace is not meaningful, so the prompts are
# stripped before re-tokenizing to desensitize generation to whitespace artefacts
prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
  • 解释:
    • 假设:首尾的空白字符(whitespace)对生成过程没有实质影响,因此在重新编码前去除。
    • 步骤:
      • tokenizer.batch_decode: 将 input_ids 解码为字符串列表,跳过特殊 tokens。
      • 使用列表推导式,对每个字符串进行 strip 操作,去除首尾空白字符。
      • 结果存储在 prompts 列表中。

步骤 4:重新编码输入序列

input_ids = tokenizer(prompts,return_tensors="pt",padding=True,
).input_ids.to(input_ids.device)
  • 解释:
    • 重新编码:
      • 将处理后的 prompts 列表再次通过 tokenizer 编码为 input_ids
      • return_tensors="pt": 返回 PyTorch 张量。
      • padding=True: 对序列进行填充,以使它们的长度一致。
    • 调整设备:
      • 使用 .to(input_ids.device) 将新生成的 input_ids 移动到与原始 input_ids 相同的设备上(如 GPU 或 CPU)。

步骤 5:替换序列中的 bos_token_idpad_token_id

# replace bos with pad to not condition healing on it
input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
  • 解释:
    • 目的:
      • 在后续的修复过程中,不希望序列开始的特殊标记(BOS)影响结果,因此将其替换为填充标记。
    • 操作:
      • 使用 torch.where 函数,将 input_ids 中等于 bos_token_id 的位置替换为 pad_token_id

步骤 6:检查 input_ids 是否为空

if input_ids.numel() == 0:return input_ids
  • 解释:
    • 目的:
      • 如果 input_ids 为空(即没有元素),则无需进行后续处理,直接返回。
    • 检查:
      • input_ids.numel(): 返回张量中元素的总数。
      • 如果为 0,则返回 input_ids

步骤 7:获取每个序列的尾部 token ID 和对应的 token

tail_ids = input_ids[:, -1].tolist()
space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
# tail tokens are used for a prefix search, thus, whitespaces are replaced with
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
  • 解释:
    • 获取尾部 token ID:
      • input_ids[:, -1]: 获取每个序列的最后一个 token ID。
      • .tolist(): 转换为 Python 列表,方便后续处理。
    • 处理空格 token:
      • tokenizer.convert_tokens_to_ids(" "): 将空格字符 " " 转换为对应的 token ID。
      • tokenizer.convert_ids_to_tokens: 再将 token ID 转换为 token 字符串。
      • [0]: 获取结果中的第一个元素(列表中可能包含多个 token)。
      • 结果 space_tok 为空格对应的 token,例如在 BPE(Byte-Pair Encoding)中,空格可能对应特殊字符,例如 'Ġ'
    • 获取尾部 token 的字符串表示:
      • 使用生成器表达式 (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
        • 对每个 tail_id
          • tokenizer.decode(t): 解码为字符串。
          • .replace(" ", space_tok): 将字符串中的空格替换为 space_tok,以便于后续前缀搜索。

步骤 8:遍历每个序列,尝试修复尾部 token

for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):batch_ids = input_ids[batch_idx]if torch.all(batch_ids == pad_token_id).item():continue  # skip empty sequences (all pad ids)
  • 解释:
    • 遍历序列:
      • 使用 enumeratetail_idstail_toks 进行遍历。
      • batch_idx: 当前序列的索引。
      • tail_id: 当前序列的尾部 token ID。
      • tail_tok: 当前序列的尾部 token 字符串(经过特殊处理的)。
    • 获取当前序列的 input_ids
      • batch_ids = input_ids[batch_idx]: 获取当前序列的 input_ids
    • 检查序列是否为空:
      • torch.all(batch_ids == pad_token_id).item(): 检查序列中的所有位置是否都是 pad_token_id
      • 如果是,表示序列为空,跳过后续处理。

步骤 9:构建可能的替代 token 的偏置字典

# apply bias for alternatives (extensions) to the tail token
"""
seq_bias key has to be tuple with int so have to use
tokenizer function to convert str to int
"""
seq_bias = {(tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
}
if len(seq_bias) == 1:continue  # skip if there are no token alternatives to heal with
# slightly favor original token to limit aggressive healing e.g. 'http' -> 'https'
seq_bias[(tail_id,)] += 1.0
generation_config.update(sequence_bias=seq_bias)
  • 解释:
    • 寻找可能的扩展 tokens:
      • vocab_trie.extensions(prefix=tail_tok): 在词汇前缀树中,找到以 tail_tok 为前缀的所有 token。
    • 构建偏置字典 seq_bias
      • 键:以单个 token ID 组成的元组 (token_id,),因为后续需要的键是元组形式。
      • 值:偏置值 10.0,用于在生成过程中提升这些 tokens 的概率。
      • 需要使用 tokenizer.convert_tokens_to_ids(alt_tok) 将 token 字符串转换为 token ID。
    • 检查是否有可替代的 tokens:
      • 如果 seq_bias 的长度为 1,表示只有原始的 tail_id,没有其他可替代的 tokens,跳过处理。
    • 稍微提升原始 token 的概率:
      • seq_bias[(tail_id,)] += 1.0: 对原始的 tail_id,增加一个较小的偏置 1.0,以防止过度修复(例如,将 'http' 修复为 'https')。
    • 更新生成配置:
      • generation_config.update(sequence_bias=seq_bias): 将偏置字典添加到生成配置中,供生成过程使用。

步骤 10:准备生成输入,去除尾部 token

trimmed_ids = batch_ids[:-1]
"""
the latter code assumes trimmed_ids is not empty
so have to check its element count
"""
if trimmed_ids.numel() == 0:continue
# if the prompt is a single (non-pad) token, regenerate from bos
if len(batch_ids[batch_ids != pad_token_id]) == 1:trimmed_ids[-1] = bos_token_id
  • 解释:
    • 去除尾部 token:
      • trimmed_ids = batch_ids[:-1]: 获取当前序列,去除最后一个 token,即 tail_id
    • 检查 trimmed_ids 是否为空:
      • 如果 trimmed_ids.numel() == 0,表示序列长度为 0,无需处理,继续下一个序列。
    • 特殊情况处理:
      • 如果序列只有一个非填充 token(除去 pad_token_id 后长度为 1):
        • 需要将 trimmed_ids 的最后一个位置替换为 bos_token_id,以重新从开始标记生成。

步骤 11:生成新的 token 并替换尾部 token

input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config)
  • 解释:
    • 调用生成方法:
      • self.generate: 调用模型的 generate 方法,生成新序列。
      • trimmed_ids.unsqueeze(0): 为了适应批量维度,将 trimmed_ids 添加一个维度,形状从 (sequence_length - 1,) 变为 (1, sequence_length - 1)
      • generation_config=generation_config: 使用之前配置的生成参数,包括偏置字典 seq_bias
    • 更新 input_ids
      • 将生成的新序列替换到 input_ids 中对应的位置。
      • 注意,这里生成的序列长度为 sequence_length,因为 max_new_tokens=1,所以会在 trimmed_ids 的基础上生成一个新 token。

步骤 12:返回修复后的 input_ids

return input_ids
  • 解释:
    • 返回处理后的 input_ids,其中每个序列的尾部 token 已根据可能的扩展进行了修复。

整体流程总结

  1. 准备工作:

    • 验证 tokenizer 参数。
    • 获取特殊 token IDs。
    • 构建词汇前缀树 vocab_trie,用于快速查找可能的扩展 token。
    • 配置生成参数 generation_config
  2. 处理输入序列:

    • input_ids 解码为字符串列表 prompts,去除首尾空白。
    • 重新编码 promptsinput_ids,确保一致性。
    • bos_token_id 替换为 pad_token_id,避免对序列开始标记的影响。
  3. 遍历每个序列,尝试修复尾部 token:

    • 获取每个序列的尾部 token ID 和对应的 token 字符串。
    • 查找以尾部 token 为前缀的可能扩展 tokens。
    • 构建偏置字典 seq_bias,提升这些扩展 token 在生成过程中的概率。
    • 去除序列的尾部 token,准备生成新的尾部 token。
    • 调用 self.generate 方法,生成新的序列,并替换到 input_ids 中。
  4. 返回处理后的 input_ids


_prepare_generated_length

这个方法用于在生成过程中,根据用户提供的生成配置和模型输入,准备和调整生成的最大长度 (max_length) 和最小长度 (min_length),以避免类似属性之间的冲突。


方法定义
def _prepare_generated_length(self,generation_config,has_default_max_length,has_default_min_length,model_input_name,input_ids_length,inputs_tensor,
):"""Prepared max and min length in generation configs to avoid clashes between similar attributes"""# 方法体...

参数说明:

  • generation_configGenerationConfig 对象,包含了生成过程中需要的各种配置参数,如 max_lengthmin_lengthmax_new_tokensmin_new_tokens 等。

  • has_default_max_lengthbool 类型,指示 max_length 是否使用了默认值。如果为 True,表示用户未显式设置 max_length

  • has_default_min_lengthbool 类型,指示 min_length 是否使用了默认值。

  • model_input_namestr 类型,模型输入的名称,通常为 "input_ids""inputs_embeds"

  • input_ids_lengthint 类型,输入序列 input_ids 的长度,即输入的 token 数量。

  • inputs_tensortorch.Tensor 对象,模型的输入张量。


方法功能概述

该方法的主要作用是:

  1. 调整 max_lengthmin_length:根据用户提供的 max_new_tokensmin_new_tokensmax_lengthmin_length,以及输入序列的长度,计算并设置最终的生成长度参数,确保生成过程按照预期进行。

  2. 避免冲突:如果用户同时设置了类似的属性(例如同时设置了 max_lengthmax_new_tokens),该方法会明确优先级,并在必要时发出警告,提示用户可能存在的冲突。

  3. 处理特殊情况:针对一些特殊的输入情况,例如使用了 inputs_embeds、模型是编码器-解码器模型等,做出相应的调整。


逐步详解
1. 处理 max_length
if generation_config.max_new_tokens is not None:if not has_default_max_length and generation_config.max_length is not None:logger.warning(f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. ""Please refer to the documentation for more information. ""(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)")generation_config.max_length = generation_config.max_new_tokens + input_ids_length

解释:

  • 情况 1:用户设置了 max_new_tokens

    • 逻辑:

      • 如果 generation_config.max_new_tokens 不为 None,表示用户希望通过 max_new_tokens 来指定要生成的新 token 数量。

      • 冲突处理:

        • 如果用户同时设置了 max_length(并且不是默认值),则发出警告,提示两者可能存在冲突,但以 max_new_tokens 为准。

        • not has_default_max_length:表示 max_length 被用户显式设置了。

      • 计算 max_length

        • max_length 设置为 input_ids_length + max_new_tokens

          • input_ids_length:输入序列的长度。

          • max_new_tokens:用户希望生成的新 token 数量。

      • 这样,max_length 就表示生成的序列(包括输入和生成的部分)总长度。

示例:

  • 如果 input_ids_length = 10max_new_tokens = 20,则 max_length = 10 + 20 = 30

# if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
# otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`
elif (model_input_name == "inputs_embeds"and input_ids_length != inputs_tensor.shape[1]and not self.config.is_encoder_decoder
):generation_config.max_length -= inputs_tensor.shape[1]

解释:

  • 情况 2:模型输入是 inputs_embeds,且存在输入长度不匹配

    • 逻辑:

      • 条件判断:

        • model_input_name == "inputs_embeds":表示模型的输入是嵌入表示。

        • input_ids_length != inputs_tensor.shape[1]:输入的 input_ids 长度与 inputs_tensor 的长度(序列维度大小)不一致。

        • not self.config.is_encoder_decoder:模型不是编码器-解码器模型。

      • 处理:

        • 需要调整 max_length,减去 inputs_tensor.shape[1],即输入的序列长度。
      • 原因:

        • 当用户提供了 inputs_embeds 而非 input_ids,且两者长度不一致,为了确保生成的总长度不超过用户预期,需要调整 max_length

elif has_default_max_length:  # by default let's always generate 20 new tokensif generation_config.max_length == GenerationConfig().max_length:generation_config.max_length = generation_config.max_length + input_ids_lengthmax_position_embeddings = getattr(self.config, "max_position_embeddings", None)if max_position_embeddings is not None:generation_config.max_length = min(generation_config.max_length, max_position_embeddings)

解释:

  • 情况 3:用户未设置 max_length,使用默认值

    • 逻辑:

      • has_default_max_lengthTrue,即用户未显式设置 max_length

      • 如果 generation_config.max_length 等于默认的 max_length,则执行以下操作:

        • 计算新的 max_length

          • max_length 设置为原来的 max_length 加上 input_ids_length

            • 这样,默认情况下,会在输入的基础上生成 20 个新 tokens(假设默认 max_length20)。
        • 考虑模型的最大位置嵌入长度:

          • 获取模型配置中的 max_position_embeddings,它表示模型能处理的最大序列长度。

          • 如果存在,就将 generation_config.max_length 限制在 max_position_embeddings 之内,防止生成长度超过模型的能力。


2. 处理 min_length
if generation_config.min_new_tokens is not None:if not has_default_min_length:logger.warning(f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(="f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. ""Please refer to the documentation for more information. ""(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)")generation_config.min_length = generation_config.min_new_tokens + input_ids_length

解释:

  • 情况 1:用户设置了 min_new_tokens

    • 逻辑:

      • 如果 generation_config.min_new_tokens 不为 None,表示用户希望指定要生成的最小新 token 数量。

      • 冲突处理:

        • 如果用户同时设置了 min_length(并且不是默认值),则发出警告,提示两者存在冲突,以 min_new_tokens 为准。
      • 计算 min_length

        • min_length 设置为 min_new_tokens + input_ids_length

elif (model_input_name == "inputs_embeds"and input_ids_length != inputs_tensor.shape[1]and not self.config.is_encoder_decoder
):generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)

解释:

  • 情况 2:模型输入是 inputs_embeds,且存在输入长度不匹配

    • 逻辑:

      • 与处理 max_length 时类似,但这里调整的是 min_length

      • generation_config.min_length 减去 inputs_tensor.shape[1],并取最大值 0(防止出现负值)。

      • 这样可以确保生成的最小长度不会超过用户预期。


3. 返回更新后的 generation_config
return generation_config
  • 返回调整过后的 generation_config,包含更新的 max_lengthmin_length,以供生成过程使用。

_validate_generated_length

def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):"""Performs validation related to the resulting generated length"""# 函数主体从这里开始

功能说明

  • 目的:该函数用于对生成的长度进行验证,确保 generation_config 中的参数设置合理,避免在生成过程中出现不可预期的行为或错误。

  • 主要任务

    • 检查 max_lengthmax_new_tokens 之间的关系,给出适当的警告或错误。
    • 检查输入序列长度 input_ids_lengthmax_length 之间的关系。
    • 检查 min_lengthmax_length 之间的关系,给出警告。
    • 确保 min_new_tokens 加上 input_ids_length 不超过 max_length,并给出警告。

参数说明

  • self:类的实例,典型的 Python 类方法的第一个参数。
  • generation_config:生成配置对象,包含了生成过程中使用的各项参数设置,例如 max_lengthmin_lengthmax_new_tokens 等。
  • input_ids_length:整数,输入序列 input_ids 的长度,即序列的长度。
  • has_default_max_length:布尔值,指示是否使用了默认的 max_length 设置(即用户没有在调用时显式指定 max_length)。

代码详细解释

1. 编译时不进行警告或异常抛出

# Can't throw warnings/exceptions during compilation
if is_torchdynamo_compiling():return

解释

  • 目的:在使用 TorchDynamo(PyTorch 编译器)进行编译时,不要抛出警告或异常。

  • 逻辑

    • is_torchdynamo_compiling():检查当前是否在使用 TorchDynamo 进行编译。
    • if is_torchdynamo_compiling(): return:如果正在编译,直接返回,不进行后续的验证。
  • 原因:在编译过程中,抛出异常或者发出警告可能会导致编译失败或行为异常,因此在编译时跳过验证。

2. 第一部分:与参数设置相关的 max_length 警告

# 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:# 20 is the default max_length of the generation configwarnings.warn(f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the ""generation length. We recommend setting `max_new_tokens` to control the maximum length of the ""generation.",UserWarning,)

解释

  • 目的:当用户未指定 max_length 且未设置 max_new_tokens 时,发出警告提示。

  • 逻辑

    • 条件判断

      • has_default_max_length:用户未显式指定 max_length,使用了默认值。
      • generation_config.max_new_tokens is Nonemax_new_tokens 未设置。
      • generation_config.max_length == 20max_length 等于 20,这是 GenerationConfig 的默认值。
    • 处理

      • 如果以上条件都满足,使用 warnings.warn() 发出警告,提示用户正在使用模型无关的默认 max_length 来控制生成长度,建议用户设置 max_new_tokens 来控制生成的最大长度。
  • 原因

    • 如果用户没有明确设置生成长度的参数,可能会导致生成的序列长度与预期不符。
    • 建议用户使用 max_new_tokens,因为它更直观地控制生成的新 tokens 数量。

3. 检查输入序列长度是否超过或等于 max_length

if input_ids_length >= generation_config.max_length:input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"raise ValueError(f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"" increasing `max_length` or, better yet, setting `max_new_tokens`.")

解释

  • 目的:如果输入序列的长度大于或等于 max_length,抛出 ValueError,因为这样会导致生成过程中的问题。

  • 逻辑

    • 条件判断

      • if input_ids_length >= generation_config.max_length:如果输入序列的长度 input_ids_length 大于或等于 max_length
    • 处理

      • 根据模型类型确定输入 IDs 的名称:

        input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
        
        • 如果是编码器-解码器模型,使用 'decoder_input_ids'
        • 否则,使用 'input_ids'
      • 抛出 ValueError,包含详细的错误信息:

        • 提示输入序列的长度与 max_length 的值,以及可能导致意外行为。
        • 建议用户增加 max_length 或者更好的,设置 max_new_tokens
  • 原因

    • 如果输入序列的长度已经达到或超过 max_length,模型将无法生成新的 tokens,可能导致生成过程立即结束或出现错误。
    • 为了避免这种情况,必须确保 max_length 大于输入序列的长度。

4. 第二部分:由于不可行的参数组合导致的 min_length 警告

# 2. Min length warnings due to unfeasible parameter combinations
min_length_error_suffix = (" Generation will stop at the defined maximum length. You should decrease the minimum length and/or ""increase the maximum length."
)
if has_default_max_length:min_length_error_suffix += (f" Note that `max_length` is set to {generation_config.max_length}, its default value.")

解释

  • 目的:准备 min_length 相关的错误信息,供后续警告使用。

  • 逻辑

    • 定义错误信息的后缀

      • min_length_error_suffix:提示用户需要降低最小长度和/或增加最大长度。
    • 如果使用了默认的 max_length

      • 如果 has_default_max_lengthTrue,则在 min_length_error_suffix 后面追加一句,说明 max_length 被设置为其默认值。

5. 检查 min_length 是否大于 max_length

if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:warnings.warn(f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,UserWarning,)

解释

  • 目的:如果 min_length 大于 max_length,发出警告,因为这是不可行的约束。

  • 逻辑

    • 条件判断

      • generation_config.min_length is not Nonemin_length 已被设置。
      • generation_config.min_length > generation_config.max_lengthmin_length 大于 max_length
    • 处理

      • 使用 warnings.warn() 发出警告,提示 min_length 大于 max_length,并附加之前准备的 min_length_error_suffix
  • 原因

    • 如果最小生成长度大于最大生成长度,模型无法满足这样的约束,可能导致生成过程在达到最大长度时停止,而未达到最小长度。
    • 提醒用户调整参数,使其合理。

6. 检查 min_new_tokens 加上 input_ids_length 是否超过 max_length

if generation_config.min_new_tokens is not None:min_length = generation_config.min_new_tokens + input_ids_lengthif min_length > generation_config.max_length:warnings.warn(f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "f"added to the prompt length ({input_ids_length}), is larger than"f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,UserWarning,)

解释

  • 目的:如果 min_new_tokens 加上输入序列长度超过 max_length,发出警告。

  • 逻辑

    • 条件判断

      • if generation_config.min_new_tokens is not Nonemin_new_tokens 已被设置。
    • 计算最小长度

      • min_length = generation_config.min_new_tokens + input_ids_length:计算最小生成长度,即 min_new_tokens 加上输入序列长度。
    • 检查是否超过 max_length

      • if min_length > generation_config.max_length:如果计算得到的 min_length 大于 max_length
    • 处理

      • 使用 warnings.warn() 发出警告,提示 min_new_tokens 加上输入长度超过了可能的最大长度,并附加 min_length_error_suffix
  • 原因

    • 如果 min_new_tokens 与输入序列长度之和超过 max_length,模型无法生成满足最小新 tokens 数量的序列。
    • 提醒用户调整参数,降低 min_new_tokens、增加 max_length,或减少输入序列的长度。

_supports_logits_to_keep

方法定义
def _supports_logits_to_keep(self) -> bool:"""Return True if the current model supports the keyword argument `logits_to_keep` in forward()to save memory. Checking it in this way allows to avoid using a new model attribute."""return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())

方法功能概述
  • 目的:确定当前模型的 forward 方法是否支持 logits_to_keep 参数。
  • 返回值:布尔值
    • True:如果 forward 方法的参数中包含 logits_to_keep
    • False:如果 forward 方法的参数中不包含 logits_to_keep

逐步详解
  1. 使用 inspect 模块分析 forward 方法的签名

    inspect.signature(self.forward)
    
    • 作用:获取 self.forward 方法的签名信息,包括参数列表和参数默认值等。
    • inspect 模块:Python 内置模块,提供了检查和获取对象(如函数、类、模块等)信息的功能。
  2. 获取 forward 方法的参数字典

    inspect.signature(self.forward).parameters
    
    • 返回值:一个有序字典(OrderedDict),键为参数名称,值为参数对应的 Parameter 对象。
  3. 获取参数名称列表并转换为集合

    set(inspect.signature(self.forward).parameters.keys())
    
    • parameters.keys():返回参数名称的可迭代对象。
    • set(...):将参数名称转换为集合,方便后续进行快速查找(in 操作)。
  4. 检查是否包含 logits_to_keep 参数

    "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
    
    • 作用:判断字符串 "logits_to_keep" 是否在参数名称集合中。
    • 返回值:布尔值。
  5. 返回判断结果

    • 如果包含:返回 True
    • 如果不包含:返回 False

方法用途和背景
  • 节省内存

    在生成任务中,模型可能会生成大量的 logits(每个时间步长预测下一个 token 的概率分布,通常是一个包含了整个词汇表大小的张量)。如果能够限制保留的 logits 数量(例如只保留 top-k 个 logits),可以大大节省内存。

  • 动态检查模型功能

    不同的模型可能实现了不同的功能。通过这种动态检查的方法,可以在不修改模型代码的情况下,了解模型是否支持某个特定的参数或功能。这有助于编写通用的代码,适用于多种模型。

  • 避免使用额外的模型属性

    通过检查方法签名,而不是增加一个模型属性,可以减少模型类的复杂性和维护成本。


举例说明

假设我们有一个模型,其 forward 方法定义如下:

def forward(self, input_ids, attention_mask=None, logits_to_keep=None):# 模型的前向计算逻辑logits = self.compute_logits(input_ids, attention_mask)if logits_to_keep is not None:# 只保留指定数量的 logitslogits = logits[:, -1, :logits_to_keep]return logits
  • 模型支持 logits_to_keep 参数:在这种情况下,_supports_logits_to_keep 方法会返回 True,因为 forward 方法的参数中包含 logits_to_keep

  • 使用示例

    if self._supports_logits_to_keep():outputs = self.forward(input_ids, attention_mask=attention_mask, logits_to_keep=10)
    else:outputs = self.forward(input_ids, attention_mask=attention_mask)
    
    • 解释:代码首先检查模型是否支持 logits_to_keep 参数,如果支持,则在调用 forward 方法时传入该参数,以只保留 top-10 的 logits,从而节省内存。

_prepare_cache_for_generation

def _prepare_cache_for_generation(self,generation_config: GenerationConfig,model_kwargs: Dict,assistant_model: "PreTrainedModel",batch_size: int,max_cache_length: int,device: torch.device,
) -> bool:"""Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache isinstantiated, writes it to `model_kwargs`, under the name expected by the model."""# 函数主体从这里开始

功能说明

  • 目的:准备生成过程中使用的缓存(cache),根据给定的 generation_config 和其他参数,初始化或调整缓存。如果缓存被实例化,它将被写入到 model_kwargs 中,使用模型期望的缓存名称。

  • 背景:在文本生成任务中,使用缓存可以加速生成过程,特别是在自回归模型中,缓存先前的计算结果可以避免重复计算。在不同的模型或配置下,缓存的实现方式可能不同,因此需要根据情况准备合适的缓存。

参数说明

  • self:当前类的实例,典型的 Python 类方法的第一个参数。

  • generation_configGenerationConfig 类型,表示生成配置,其中包含生成过程中的各种参数设置,如是否使用缓存、缓存的实现方式等。

  • model_kwargsDict 类型,包含传递给模型的关键字参数。在函数中,可能会对其进行修改,添加缓存相关的参数。

  • assistant_modelPreTrainedModel 类型,可选的辅助模型,用于加速生成或其他目的。

  • batch_size:整数,表示批次大小。

  • max_cache_length:整数,表示缓存的最大长度,即缓存可以存储的最大序列长度。

  • devicetorch.device 类型,表示在何种设备(CPU 或 GPU)上运行。


1. 确定缓存名称

cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
  • 解释

    • 这行代码根据当前模型类的名称,确定缓存在 model_kwargs 中的键名称。

    • 如果类名中不包含 "mamba",则缓存名称为 "past_key_values";否则,缓存名称为 "cache_params"

  • 原因

    • 不同的模型可能期望的缓存名称不同。模型需要从 model_kwargs 中获取缓存,如果名称不一致,可能导致缓存无法正确工作。

2. 确定是否需要跨注意力缓存(cross-attention cache)

requires_cross_attention_cache = (self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
  • 解释

    • requires_cross_attention_cache 是一个布尔值,表示是否需要准备跨注意力缓存。

    • 条件:

      • self.config.is_encoder_decoder:如果模型是编码器-解码器架构,则需要跨注意力缓存。

      • model_kwargs.get("encoder_outputs") is not None:如果在 model_kwargs 中提供了编码器的输出,则也需要跨注意力缓存。

  • 原因

    • 在编码器-解码器模型(如 BART、T5)中,解码器需要访问编码器的输出,因此需要跨注意力缓存。

3. 快速退出路径 1:用户已在 model_kwargs 中指定了缓存

# 快速退出路径 1:如果用户指定了缓存,我们只需要:
# a) 检查是否有冲突的 `generate` 参数
# b) 如果用户传递了旧的缓存格式,并且模型支持,将其转换为新的缓存格式
user_defined_cache = model_kwargs.get(cache_name)
if user_defined_cache is not None:if generation_config.cache_implementation is not None:raise ValueError(f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a ""Cache object) is unsupported. Please use only one of the two.")if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache():model_kwargs[cache_name] = (DynamicCache.from_legacy_cache(user_defined_cache)if not requires_cross_attention_cacheelse EncoderDecoderCache.from_legacy_cache(user_defined_cache))return
  • 解释

    • 获取用户定义的缓存

      • user_defined_cache = model_kwargs.get(cache_name):从 model_kwargs 中获取用户可能提供的缓存。
    • 检查用户是否同时指定了 cache_implementation

      • 如果用户既在 model_kwargs 中提供了缓存,又在 generation_config 中指定了 cache_implementation,这是冲突的,会引发错误。
    • 处理旧的缓存格式

      • 如果 user_defined_cache 是一个元组(旧的缓存格式),并且模型支持默认的动态缓存(self._supports_default_dynamic_cache() 返回 True),则将旧的缓存转换为新的缓存格式。

      • 根据是否需要跨注意力缓存,使用不同的缓存类:

        • 如果不需要跨注意力缓存,使用 DynamicCache.from_legacy_cache(user_defined_cache)

        • 如果需要跨注意力缓存,使用 EncoderDecoderCache.from_legacy_cache(user_defined_cache)

    • 返回

      • 在处理完用户提供的缓存后,直接返回,不再进行后续的缓存准备。

4. 快速退出路径 2:用户指定不使用缓存

# 快速退出路径 2:如果用户指定不使用缓存。(冲突的参数已在 `generation_config.validate()` 中处理)
if generation_config.use_cache is False:return
  • 解释

    • 如果在 generation_config 配置中,用户设置了 use_cache=False,表示不使用缓存。

    • 直接返回,不需要准备缓存。

5. 快速退出路径 3:模型仅支持旧的缓存格式

# 快速退出路径 3:模型仅支持旧的缓存格式,无需准备
if not self._supports_default_dynamic_cache():if generation_config.cache_implementation is not None:warnings.warn("This model does not support `Cache` instances, it only supports the legacy cache format (tuple "f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be ""ignored.",UserWarning,)return
  • 解释

    • 如果模型不支持默认的动态缓存(self._supports_default_dynamic_cache() 返回 False),则无法使用新的缓存实现。

    • 如果用户在 generation_config 中指定了 cache_implementation,则发出警告,指出模型仅支持旧的缓存格式,cache_implementation 将被忽略。

    • 直接返回,不需要进一步准备缓存。

6. 需要准备缓存,根据 generation_config.cache_implementation

# 否则,我们需要根据 `generation_config.cache_implementation` 准备缓存
# TODO(joao): 在辅助生成中支持静态缓存。辅助生成需要回滚缓存,目前只有动态缓存支持
if assistant_model is not None and generation_config.cache_implementation is not None:logger.warning_once("An assistant model is provided, using a dynamic cache instead of a cache of type="f"'{generation_config.cache_implementation}'.")generation_config.cache_implementation = None
  • 解释

    • 如果上述快速退出条件都不满足,且需要准备缓存,则需要根据 generation_config.cache_implementation 的值来准备缓存。

    • 特殊情况:辅助模型和缓存实现的冲突

      • 如果提供了 assistant_model,并且指定了 cache_implementation,则发出警告,指出由于提供了辅助模型,将使用动态缓存,而不是指定类型的缓存。

      • generation_config.cache_implementation 设置为 None,以确保使用动态缓存。

  • 原因

    • 在辅助生成过程中,需要回滚缓存,目前只有动态缓存支持回滚。因此,即使用户指定了其他缓存实现,也需要使用动态缓存。

7. 根据缓存实现方式准备缓存

if generation_config.cache_implementation is not None:if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:if generation_config.cache_implementation == "static" and not self._supports_static_cache:raise ValueError("This model does not support `cache_implementation='static'`. Please check the following ""issue: https://github.com/huggingface/transformers/issues/28981")model_kwargs[cache_name] = self._get_cache(cache_implementation=generation_config.cache_implementation,batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,max_cache_len=max_cache_length,device=device,model_kwargs=model_kwargs,)elif generation_config.cache_implementation == "quantized":if not self._supports_quantized_cache:raise ValueError("This model does not support the quantized cache. If you want your model to support quantized ""cache, please open an issue and tag @zucchini-nlp.")cache_config = (generation_config.cache_configif generation_config.cache_config is not Noneelse QuantizedCacheConfig())cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]if cache_config.backend == "quanto" and not is_optimum_quanto_available():raise ImportError("You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. ""Please install it via  with `pip install optimum-quanto`")elif cache_config.backend == "HQQ" and not is_hqq_available():raise ImportError("You need to install `HQQ` in order to use KV cache quantization with HQQ backend. ""Please install it via  with `pip install hqq`")model_kwargs[cache_name] = cache_class(cache_config)elif generation_config.cache_implementation == "offloaded":model_kwargs[cache_name] = OffloadedCache()
  • 解释

    • 检查缓存实现方式是否需要特别的准备

      • NEED_SETUP_CACHE_CLASSES_MAPPING:一个映射,包含需要特殊设置的缓存类。

      • 如果 generation_config.cache_implementationNEED_SETUP_CACHE_CLASSES_MAPPING 中,则需要调用 _get_cache 方法来获取缓存实例。

    • 处理静态缓存

      • 如果 cache_implementation"static",并且模型不支持静态缓存(not self._supports_static_cache),则抛出 ValueError,提示模型不支持静态缓存。

      • 提供一个 GitHub Issue 链接,供用户了解更多信息。

    • 获取缓存实例

      • 调用 self._get_cache() 方法,传入缓存实现方式、批次大小、最大缓存长度、设备等参数,获取缓存实例。

      • 批次大小计算

        • max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size:计算缓存所需的实际批次大小。

          • 在束搜索或其他情况下,批次大小可能需要乘以 num_beamsnum_return_sequences
    • 处理量化缓存(quantized cache)

      • 如果 cache_implementation"quantized",需要特殊处理。

      • 检查模型是否支持量化缓存(self._supports_quantized_cache)。

      • 如果不支持,抛出 ValueError

      • 获取缓存配置 cache_config,如果用户未提供 generation_config.cache_config,则使用默认的 QuantizedCacheConfig()

      • 根据缓存配置的后端,获取对应的缓存类 cache_class

      • 检查所需的包是否已安装:

        • 对于 "quanto" 后端,检查 is_optimum_quanto_available()

        • 对于 "HQQ" 后端,检查 is_hqq_available()

      • 如果未安装,抛出 ImportError,提示用户安装相应的包。

      • 实例化缓存类,并将其存储在 model_kwargs[cache_name] 中。

    • 处理离线缓存(offloaded cache)

      • 如果 cache_implementation"offloaded",则实例化 OffloadedCache(),并存储在 model_kwargs[cache_name] 中。

8. 默认情况下,使用动态缓存

# 默认情况下,使用 DynamicCache() 实例。这将避免在旧格式之间来回转换,从而避免复制缓存,节省内存
else:model_kwargs[cache_name] = (DynamicCache()if not requires_cross_attention_cacheelse EncoderDecoderCache(DynamicCache(), DynamicCache()))
  • 解释

    • 如果 generation_config.cache_implementationNone,即用户未指定特定的缓存实现方式,并且上述条件都不满足,则默认使用动态缓存。

    • 根据是否需要跨注意力缓存,实例化不同的缓存类:

      • 如果不需要跨注意力缓存,使用 DynamicCache()

      • 如果需要跨注意力缓存,使用 EncoderDecoderCache(DynamicCache(), DynamicCache()),即分别为编码器和解码器缓存初始化动态缓存。

  • 原因

    • 动态缓存可以在生成过程中动态扩展,并支持回滚等特性。

    • 使用默认的动态缓存,可以避免在旧缓存格式与新缓存格式之间来回转换,减少内存使用和数据复制。


_get_logits_processor

def _get_logits_processor(self,generation_config: GenerationConfig,input_ids_seq_length: int,encoder_input_ids: torch.LongTensor,prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],logits_processor: Optional[LogitsProcessorList],device: str = None,model_kwargs: Optional[Dict[str, Any]] = None,negative_prompt_ids: Optional[torch.Tensor] = None,negative_prompt_attention_mask: Optional[torch.Tensor] = None,
) -> LogitsProcessorList:"""This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]instances used to modify the scores of the language model head."""# 函数主体从这里开始

功能说明

  • 目的:该函数用于构建一个LogitsProcessorList对象,其中包含了一系列LogitsProcessor实例,这些实例用于在生成过程中修改语言模型头(language model head)的logits(模型输出的分数),以实现各种生成控制策略。

  • 背景:在生成任务中,通过调整模型输出的logits,可以引入各种生成策略,如重复惩罚、温度调节、Top-K采样、禁用某些词汇等,从而控制生成文本的风格、内容和长度。

参数说明

  • self:类的实例,典型的Python类方法的第一个参数。

  • generation_configGenerationConfig对象,包含了生成过程中的各种配置参数,如重复惩罚系数、最小生成长度、温度等。

  • input_ids_seq_length:整数,表示输入序列的长度,即input_ids的长度。

  • encoder_input_idstorch.LongTensor,编码器的输入IDs。如果模型是编码器-解码器架构,这对应于编码器的输入。

  • prefix_allowed_tokens_fn:可选的函数,类型为Callable[[int, torch.Tensor], List[int]]。用于在生成过程中限制每个位置上允许生成的tokens,通常用于受限生成任务。

  • logits_processor:可选的LogitsProcessorList对象,用户自定义的LogitsProcessor列表,可用于补充或覆盖默认的processor。

  • device:字符串,可选参数,指定设备(如'cpu''cuda')。如果未提供,默认为None

  • model_kwargs:可选的字典,包含了传递给模型的其他关键字参数。

  • negative_prompt_ids:可选的torch.Tensor,用于一些生成策略(如Classifier-Free Guidance)中的负面提示IDs。

  • negative_prompt_attention_mask:可选的torch.Tensor,对应negative_prompt_ids的注意力掩码。


1. 初始化LogitsProcessorList

# instantiate processors list
processors = LogitsProcessorList()
  • 解释:创建一个空的LogitsProcessorList对象processors,用于存储将要应用的所有LogitsProcessor实例。

2. 处理Classifier-Free Guidance(CFG)

if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:processors.append(UnbatchedClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale,self,unconditional_ids=negative_prompt_ids,unconditional_attention_mask=negative_prompt_attention_mask,use_cache=generation_config.use_cache,))
  • 解释

    • 条件:如果guidance_scale不为None且不等于1,则说明需要应用Classifier-Free Guidance(CFG)。

      • guidance_scale是CFG的缩放因子,通常大于1,用于调整生成的多样性与准确性。
    • 操作:向processors中添加一个UnbatchedClassifierFreeGuidanceLogitsProcessor实例。

      • 参数说明

        • generation_config.guidance_scale:CFG的缩放因子。

        • self:模型实例,用于在LogitsProcessor中调用模型的其他方法。

        • unconditional_ids:负面提示的IDs,即negative_prompt_ids

        • unconditional_attention_mask:负面提示的注意力掩码,即negative_prompt_attention_mask

        • use_cache:是否使用缓存,来自generation_config


3. 处理序列偏置(Sequence Bias)

if generation_config.sequence_bias is not None:processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
  • 解释

    • 条件:如果sequence_bias不为None,则需要应用序列偏置。

      • sequence_bias是一种机制,可对特定的token序列施加偏置,提高或降低它们在生成中的概率。
    • 操作:向processors中添加一个SequenceBiasLogitsProcessor实例,传入sequence_bias参数。


4. 处理多样性惩罚(Diversity Penalty)

if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:processors.append(HammingDiversityLogitsProcessor(diversity_penalty=generation_config.diversity_penalty,num_beams=generation_config.num_beams,num_beam_groups=generation_config.num_beam_groups,))
  • 解释

    • 条件:如果diversity_penalty不为None且大于0,则需要应用多样性惩罚。

      • 多样性惩罚用于在束搜索中鼓励生成更多样化的序列。
    • 操作:向processors中添加一个HammingDiversityLogitsProcessor实例。

      • 参数说明

        • diversity_penalty:多样性惩罚系数。

        • num_beams:束搜索的束宽,即同时考虑的序列数量。

        • num_beam_groups:束搜索的组数,用于分组束搜索。


5. 处理编码器重复惩罚(Encoder Repetition Penalty)

if (generation_config.encoder_repetition_penalty is not Noneand generation_config.encoder_repetition_penalty != 1.0
):if len(encoder_input_ids.shape) == 2:processors.append(EncoderRepetitionPenaltyLogitsProcessor(penalty=generation_config.encoder_repetition_penalty,encoder_input_ids=encoder_input_ids,))else:warnings.warn("Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to ""`generate`, ignoring the argument.",UserWarning,)
  • 解释

    • 条件:如果encoder_repetition_penalty不为None且不等于1.0,则需要应用编码器重复惩罚。

      • 编码器重复惩罚用于减少模型在生成时重复输入内容的可能性。
    • 检查:如果encoder_input_ids的形状为二维(即存在有效的编码器输入),则应用惩罚。

    • 操作:向processors中添加一个EncoderRepetitionPenaltyLogitsProcessor实例。

      • 参数说明

        • penalty:重复惩罚系数。

        • encoder_input_ids:编码器的输入IDs。

    • 否则:发出警告,提示需要提供input_ids以应用该惩罚,忽略该参数。


6. 处理重复惩罚(Repetition Penalty)

if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
  • 解释

    • 条件:如果repetition_penalty不为None且不等于1.0,则需要应用重复惩罚。

      • 重复惩罚用于减少模型在生成时重复之前生成内容的可能性。
    • 操作:向processors中添加一个RepetitionPenaltyLogitsProcessor实例,传入penalty参数。


7. 处理禁止重复的n-gram(No Repeat N-Gram)

if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
  • 解释

    • 条件:如果no_repeat_ngram_size不为None且大于0,则需要禁止重复的n-gram。

      • 这用于防止模型在生成时重复生成相同的n-gram,提高生成的多样性。
    • 操作:向processors中添加一个NoRepeatNGramLogitsProcessor实例,传入no_repeat_ngram_size参数。


8. 处理编码器禁止重复的n-gram(Encoder No Repeat N-Gram)

if (generation_config.encoder_no_repeat_ngram_size is not Noneand generation_config.encoder_no_repeat_ngram_size > 0
):if len(encoder_input_ids.shape) == 2:processors.append(EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size,encoder_input_ids,))else:warnings.warn("Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to ""`generate`, ignoring the argument.",UserWarning,)
  • 解释

    • 条件:如果encoder_no_repeat_ngram_size不为None且大于0,则需要在生成时避免重复输入中的n-gram。

      • 这用于防止模型在生成时重复输入序列中的n-gram。
    • 检查:如果encoder_input_ids的形状为二维(存在有效的编码器输入),则应用该处理器。

    • 操作:向processors中添加一个EncoderNoRepeatNGramLogitsProcessor实例。

      • 参数说明

        • encoder_no_repeat_ngram_size:禁止重复的n-gram大小。

        • encoder_input_ids:编码器的输入IDs。

    • 否则:发出警告,提示需要提供input_ids以应用该处理器,忽略该参数。


9. 处理坏词(Bad Words)

if generation_config.bad_words_ids is not None:processors.append(NoBadWordsLogitsProcessor(generation_config.bad_words_ids,generation_config._eos_token_tensor,))
  • 解释

    • 条件:如果bad_words_ids不为None,则需要在生成过程中禁止某些词。

      • bad_words_ids是一个列表,包含需要禁止的词的token IDs。
    • 操作:向processors中添加一个NoBadWordsLogitsProcessor实例。

      • 参数说明

        • bad_words_ids:需要禁止的词的token IDs。

        • _eos_token_tensor:结束标记的token张量,用于在必要时停止生成。


10. 处理最小长度(Minimum Length)

if (generation_config.min_length is not Noneand generation_config._eos_token_tensor is not Noneand generation_config.min_length > 0
):processors.append(MinLengthLogitsProcessor(generation_config.min_length,generation_config._eos_token_tensor,device=device,))
  • 解释

    • 条件:如果min_length不为None_eos_token_tensor不为None,且min_length大于0,则需要在生成达到最小长度之前禁止生成结束标记。

    • 操作:向processors中添加一个MinLengthLogitsProcessor实例。

      • 参数说明

        • min_length:最小生成长度。

        • _eos_token_tensor:结束标记的token张量。

        • device:设备信息。


11. 处理最小新tokens的长度(Minimum New Tokens Length)

if (generation_config.min_new_tokens is not Noneand generation_config._eos_token_tensor is not Noneand generation_config.min_new_tokens > 0
):processors.append(MinNewTokensLengthLogitsProcessor(input_ids_seq_length,generation_config.min_new_tokens,generation_config._eos_token_tensor,device=device,))
  • 解释

    • 条件:如果min_new_tokens不为None_eos_token_tensor不为None,且min_new_tokens大于0,则需要在生成新tokens达到最小数量之前禁止生成结束标记。

    • 操作:向processors中添加一个MinNewTokensLengthLogitsProcessor实例。

      • 参数说明

        • input_ids_seq_length:输入序列的长度。

        • min_new_tokens:最小新生成的tokens数量。

        • _eos_token_tensor:结束标记的token张量。

        • device:设备信息。


12. 处理前缀限制(Prefix Allowed Tokens Function)

if prefix_allowed_tokens_fn is not None:processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn,generation_config.num_beams // generation_config.num_beam_groups,))
  • 解释

    • 条件:如果prefix_allowed_tokens_fn不为None,则需要在生成过程中限制每个位置上允许生成的tokens。

      • 这通常用于受限生成任务,例如自动补全或基于前缀的约束生成。
    • 操作:向processors中添加一个PrefixConstrainedLogitsProcessor实例。

      • 参数说明

        • prefix_allowed_tokens_fn:用于限制每个位置上允许生成的tokens的函数。

        • generation_config.num_beams // generation_config.num_beam_groups:计算每个组中的束宽。


13. 处理强制起始token(Forced BOS Token)

if generation_config.forced_bos_token_id is not None:processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id,))
  • 解释

    • 条件:如果forced_bos_token_id不为None,则需要在生成的第一个位置强制生成指定的起始token。

      • 这用于确保生成的序列以特定的token开始,例如在某些任务中需要强制生成特定的起始标记。
    • 操作:向processors中添加一个ForcedBOSTokenLogitsProcessor实例,传入forced_bos_token_id


14. 处理强制结束token(Forced EOS Token)

if generation_config.forced_eos_token_id is not None:processors.append(ForcedEOSTokenLogitsProcessor(generation_config.max_length,generation_config.forced_eos_token_id,device=device,))
  • 解释

    • 条件:如果forced_eos_token_id不为None,则需要在生成达到最大长度时强制生成指定的结束token。

      • 这用于确保生成的序列以特定的token结束。
    • 操作:向processors中添加一个ForcedEOSTokenLogitsProcessor实例。

      • 参数说明

        • generation_config.max_length:最大生成长度。

        • forced_eos_token_id:强制的结束token ID。

        • device:设备信息。


15. 处理无效值移除(Remove Invalid Values)

if generation_config.remove_invalid_values is True:processors.append(InfNanRemoveLogitsProcessor())
  • 解释

    • 条件:如果remove_invalid_valuesTrue,则需要在生成过程中移除infnan等无效值。

      • 这用于确保生成过程的稳定性,防止由于无效值导致的错误。
    • 操作:向processors中添加一个InfNanRemoveLogitsProcessor实例。


16. 处理指数衰减长度惩罚(Exponential Decay Length Penalty)

if generation_config.exponential_decay_length_penalty is not None:processors.append(ExponentialDecayLengthPenalty(generation_config.exponential_decay_length_penalty,generation_config._eos_token_tensor,input_ids_seq_length,))
  • 解释

    • 条件:如果exponential_decay_length_penalty不为None,则需要应用指数衰减的长度惩罚。

      • 这用于在生成过程中,对句子长度施加惩罚,鼓励模型生成特定长度的句子。
    • 操作:向processors中添加一个ExponentialDecayLengthPenalty实例。

      • 参数说明

        • exponential_decay_length_penalty:指数衰减长度惩罚的参数。

        • _eos_token_tensor:结束标记的token张量。

        • input_ids_seq_length:输入序列的长度。


17. 处理抑制特定tokens(Suppress Tokens)

if generation_config.suppress_tokens is not None:processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens,device=device,))
  • 解释

    • 条件:如果suppress_tokens不为None,则需要在生成过程中抑制特定的tokens,不让它们生成。

      • suppress_tokens是需要抑制的token IDs列表。
    • 操作:向processors中添加一个SuppressTokensLogitsProcessor实例。

      • 参数说明

        • suppress_tokens:需要抑制的token IDs。

        • device:设备信息。


18. 处理在开头抑制特定tokens(Suppress Tokens at Begin)

if generation_config.begin_suppress_tokens is not None:begin_index = input_ids_seq_lengthbegin_index = (begin_indexif (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)else begin_index + 1)processors.append(SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens,begin_index,device=device,))
  • 解释

    • 条件:如果begin_suppress_tokens不为None,则需要在生成的开头位置抑制特定的tokens。

      • 这用于避免模型在一开始生成某些不期望的tokens。
    • 计算起始索引

      • begin_index初始值为input_ids_seq_length,表示当前生成的位置。

      • 如果input_ids_seq_length <= 1forced_bos_token_id不为None,则begin_index += 1

    • 操作:向processors中添加一个SuppressTokensAtBeginLogitsProcessor实例。

      • 参数说明

        • begin_suppress_tokens:需要抑制的token IDs。

        • begin_index:开始抑制的位置索引。

        • device:设备信息。


19. 处理强制解码器IDs(Forced Decoder IDs)

if generation_config.forced_decoder_ids is not None:# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PTraise ValueError("You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument ""in favour of `input_ids` or `decoder_input_ids` respectively.",)
  • 解释

    • 条件:如果forced_decoder_ids不为None,则抛出异常。

    • 原因:当前不支持forced_decoder_ids,建议用户使用input_idsdecoder_input_ids来替代。

    • 备注:注释中提到,当TensorFlow和FLAX版本与PyTorch版本对齐后,可以将此异常移动到GenerationConfig.validate()中。


20. 合并用户自定义的logits_processor

# TODO (joao): find a strategy to specify the order of the processors
processors = self._merge_criteria_processor_list(processors, logits_processor)
  • 解释

    • 操作:调用self._merge_criteria_processor_list()方法,将之前构建的processors列表与用户自定义的logits_processor进行合并。

    • 备注:注释中提到需要找到一种策略来指定处理器的顺序。


21. 处理采样策略下的LogitsWarper

# 只有在使用采样策略时,才应用之前被称为`LogitsWarpers`的处理器
if generation_config.do_sample:# 在beam方法中,我们需要至少保留一个非eos token,以探索可能具有更好得分的连续序列if generation_config.num_beams > 1:if isinstance(generation_config._eos_token_tensor, list):min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1elif isinstance(generation_config._eos_token_tensor, torch.Tensor):min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1else:min_tokens_to_keep = 2else:min_tokens_to_keep = 1# 以下思想主要来自PR:https://github.com/huggingface/transformers/pull/5420/files# 所有的sampler都在`generation_utils_samplers.py`中if generation_config.temperature is not None and generation_config.temperature != 1.0:processors.append(TemperatureLogitsWarper(generation_config.temperature))if generation_config.top_k is not None and generation_config.top_k != 0:processors.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))if generation_config.top_p is not None and generation_config.top_p < 1.0:processors.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))if generation_config.min_p is not None:# 在温度缩放之后应用(见:https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)processors.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))if generation_config.typical_p is not None and generation_config.typical_p < 1.0:processors.append(TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep))if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:processors.append(EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep))if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:processors.append(EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device))
  • 解释

    • 条件:只有在do_sampleTrue时,才应用这些处理器,因为它们与采样策略相关。

    • 计算min_tokens_to_keep

      • 在束搜索等方法中,需要保留至少一个非结束token,以确保能够探索可能更好的序列。

      • 根据结束token的类型和数量,计算需要保留的最小token数量。

    • 添加采样相关的LogitsWarper

      • 根据generation_config中的配置,向processors中添加相应的LogitsWarper,包括:

        • TemperatureLogitsWarper:调整温度参数,控制生成的随机性。

        • TopKLogitsWarper:仅保留概率最高的top_k个tokens。

        • TopPLogitsWarper:仅保留累计概率达到top_p的tokens。

        • MinPLogitsWarper:应用最小概率阈值。

        • TypicalLogitsWarper:使用Typical采样策略。

        • EpsilonLogitsWarperEtaLogitsWarper:应用epsilon或eta截断。


22. 处理水印(Watermarking)

# Watermarking应该在所有logits处理完成后再应用(参见#34630)
if generation_config.watermarking_config is not None:processors.append(generation_config.watermarking_config.construct_processor(self.config.vocab_size, device))
  • 解释

    • 条件:如果watermarking_config不为None,则需要在生成过程中应用水印策略。

      • 水印策略通常用于在生成的文本中嵌入隐式标记,以便于后续识别生成内容。
    • 操作:向processors中添加一个由watermarking_config构建的处理器。

      • 使用模型的词汇表大小vocab_size和设备信息device
    • 备注:水印处理器应当在所有logits处理完成后再应用。


23. 处理Logits重归一化(Logit Renormalization)

# `LogitNormalization`应该始终是最后一个logit处理器(如果存在的话)
if generation_config.renormalize_logits is True:processors.append(LogitNormalization())
  • 解释

    • 条件:如果renormalize_logitsTrue,则需要在处理完所有logits后重新归一化。

      • 这用于确保logits在所有处理后仍然形成有效的概率分布。
    • 操作:向processors中添加一个LogitNormalization实例。

      • 备注LogitNormalization应当始终是最后应用的logits处理器。

24. 返回处理器列表

return processors
  • 解释:返回构建好的LogitsProcessorList对象processors,供生成过程使用。

_get_stopping_criteria

def _get_stopping_criteria(self,generation_config: GenerationConfig,stopping_criteria: Optional[StoppingCriteriaList],tokenizer: Optional["PreTrainedTokenizerBase"] = None,**kwargs,
) -> StoppingCriteriaList:# 方法体...

参数说明

  • generation_configGenerationConfig 对象,包含生成过程中所需的各种配置参数,如 max_lengthmax_timestop_strings 等。
  • stopping_criteria:可选的 StoppingCriteriaList 对象,用户可以传入自定义的停止条件列表,与默认的停止条件合并。
  • tokenizer:可选的 PreTrainedTokenizerBase 对象,模型对应的 tokenizer,用于处理 stop_strings 等需要 tokenizer 的功能。
  • **kwargs:其他可能的关键字参数,供扩展使用。

返回值

  • StoppingCriteriaList:该方法返回一个 StoppingCriteriaList 对象,包含生成过程中需要检查的停止条件。

该方法的主要作用是:

  1. 创建默认的停止条件列表:根据 generation_config 中的配置,生成对应的停止条件(如最大长度、最大时间、特殊字符串、结束标记等)。

  2. 合并用户提供的停止条件:如果用户在 stopping_criteria 中提供了自定义的停止条件,方法会将其与默认的停止条件合并。

  3. 返回完整的停止条件列表:生成过程会根据这个列表,在满足任何一个停止条件时结束生成。


步骤 1:初始化停止条件列表

criteria = StoppingCriteriaList()
  • 说明:创建一个空的 StoppingCriteriaList 对象,用于存储后续添加的停止条件。

步骤 2:处理 max_length 停止条件

if generation_config.max_length is not None:max_position_embeddings = getattr(self.config, "max_position_embeddings", None)criteria.append(MaxLengthCriteria(max_length=generation_config.max_length,max_position_embeddings=max_position_embeddings,))
  • 解释

    • 获取 max_length:从 generation_config 中获取 max_length,即生成的最大长度。

    • 获取模型的最大位置嵌入长度

      • max_position_embeddings = getattr(self.config, "max_position_embeddings", None)

      • 从模型配置中获取 max_position_embeddings,表示模型能处理的最大序列长度。

    • 添加 MaxLengthCriteria

      • 创建一个 MaxLengthCriteria 对象,传入 max_lengthmax_position_embeddings

      • 将其添加到 criteria 列表中。

  • 作用:当生成的序列长度达到 max_length 或模型的最大位置嵌入长度时,停止生成。

步骤 3:处理 max_time 停止条件

if generation_config.max_time is not None:criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
  • 解释

    • 获取 max_time:从 generation_config 中获取 max_time,即生成的最长时间(以秒为单位)。

    • 添加 MaxTimeCriteria

      • 创建一个 MaxTimeCriteria 对象,传入 max_time

      • 将其添加到 criteria 列表中。

  • 作用:当生成过程运行时间超过 max_time 秒时,停止生成。

步骤 4:处理 stop_strings 停止条件

if generation_config.stop_strings is not None:if tokenizer is None:raise ValueError("There are one or more stop strings, either in the arguments to `generate` or in the ""model's generation config, but we could not locate a tokenizer. When generating with ""stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`.")criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
  • 解释

    • 获取 stop_strings:从 generation_config 中获取 stop_strings,这是一个字符串列表,包含用于停止生成的特殊字符串。

    • 检查 tokenizer:因为处理字符串需要使用 tokenizer,如果 tokenizerNone,抛出 ValueError

    • 添加 StopStringCriteria

      • 创建一个 StopStringCriteria 对象,传入 stop_stringstokenizer

      • 将其添加到 criteria 列表中。

  • 作用:当生成的文本中出现指定的字符串时,停止生成。

步骤 5:处理 eos_token_id(结束标记)停止条件

if generation_config._eos_token_tensor is not None:criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
  • 解释

    • 获取 _eos_token_tensor:从 generation_config 中获取 _eos_token_tensor,即结束标记的 token ID。

    • 添加 EosTokenCriteria

      • 创建一个 EosTokenCriteria 对象,传入 eos_token_id

      • 将其添加到 criteria 列表中。

  • 作用:当生成的序列中出现结束标记时,停止生成。

步骤 6:处理辅助模型的置信度停止条件(可选)

if (generation_config.is_assistantand generation_config.assistant_confidence_threshold is not Noneand generation_config.assistant_confidence_threshold > 0
):criteria.append(ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold))
  • 解释

    • 条件检查

      • generation_config.is_assistant:判断是否使用了辅助模型。

      • generation_config.assistant_confidence_threshold 不为空且大于 0。

    • 添加 ConfidenceCriteria

      • 创建一个 ConfidenceCriteria 对象,传入 assistant_confidence_threshold

      • 将其添加到 criteria 列表中。

  • 作用:当辅助模型的置信度达到一定阈值时,停止生成。

步骤 7:合并用户提供的停止条件

criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
  • 解释

    • 合并默认和用户提供的停止条件

      • 调用 _merge_criteria_processor_list 方法,将默认的 criteria 与用户提供的 stopping_criteria 合并。

      • 这个方法通常会去除重复的条件,或者根据某些规则进行合并。

  • 作用:确保生成过程中考虑所有相关的停止条件。

步骤 8:返回最终的停止条件列表

return criteria
  • 解释:返回包含所有停止条件的 StoppingCriteriaList,供生成过程使用。

整体流程总结

  • 目的:为生成过程准备一个完整的停止条件列表,当满足任何一个条件时,生成过程将停止。

  • 处理逻辑

    1. 初始化:创建一个空的停止条件列表。

    2. 添加默认停止条件:根据 generation_config 中的配置,添加对应的停止条件,包括最大长度、最大时间、结束标记、特殊字符串等。

    3. 辅助模型条件:如果使用了辅助模型,并且设置了置信度阈值,添加对应的停止条件。

    4. 合并用户停止条件:将用户提供的 stopping_criteria 合并到默认条件列表中。

    5. 返回:输出完整的停止条件列表。


http://www.dtcms.com/wzjs/108753.html

相关文章:

  • 做采集网站难不网络营销平台有哪些?
  • 效果好的魔站建站系统网页设计html代码大全
  • 湖南郴州人很穷吗优化公司
  • 自做衣服网站网络广告营销方案策划内容
  • 重庆网站seo公司哪家好营销策略ppt
  • 组服务器做网站seo领导屋
  • 网页设计怎么样东莞seo建站公司
  • 网站建设交流完整的网页设计代码
  • 湖南网站制作团队注册百度账号
  • 网站asp模板网络营销毕业论文范文
  • 音乐网站开发教程站长工具seo综合查询广告
  • 安阳手机网站建设网店seo关键词
  • wordpress子主题安全seo优质友链购买
  • 设计9大软件安新seo优化排名网站
  • 网上商城运营推广方案纵横seo
  • 婚庆公司网站建设得多少钱网络精准推广
  • 衡水市网站建设河南网站建设报价
  • 广州大型网站建设抖音信息流广告怎么投放
  • 互联网网站建设营销推广计划怎么写
  • 重庆网站建客户管理系统
  • 网站开发语成都网络推广运营公司
  • 无锡专业网站建设百度指数支持数据下载吗
  • nas wordpress 外网访问成都网站seo推广
  • 百度网站入口特效词黄冈地区免费网站推广平台
  • 网站建设公司如何盈利项目推广平台排行榜
  • 基于jquery做的网站各种手艺培训班
  • 深圳骏域网站建设专家88百度推广账号注册流程
  • 爱做网站网址最新推广赚钱的app
  • 专注营销型网站建设广州seo排名优化公司
  • 沈阳网站建设哪家便宜crm网站