当前位置: 首页 > news >正文

GenerationMixin:generate

generate

以下是对您提供的 generate 方法的详细解释。这个方法用于大型语言模型(LLM)中的文本生成,尤其是具有语言模型头的模型。该方法包含了多个复杂的逻辑,支持多种生成模式,如贪心搜索、采样、束搜索等。


方法定义

@torch.no_grad()
def generate(
    self,
    inputs: Optional[torch.Tensor] = None,
    generation_config: Optional[GenerationConfig] = None,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
    synced_gpus: Optional[bool] = None,
    assistant_model: Optional["PreTrainedModel"] = None,
    streamer: Optional["BaseStreamer"] = None,
    negative_prompt_ids: Optional[torch.Tensor] = None,
    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
    **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
    ...

参数说明

  • inputs:可选,输入张量,形状可能因模态而异。用于作为生成的提示或编码器的输入。
  • generation_config:可选,GenerationConfig 对象,包含生成时的参数配置。
  • logits_processor:可选,LogitsProcessorList 对象,自定义的 logits 处理器列表,用于在生成过程中调整 logits。
  • stopping_criteria:可选,StoppingCriteriaList 对象,自定义的停止标准列表,用于在满足条件时终止生成。
  • prefix_allowed_tokens_fn:可选,函数,用于在每一步生成时限制允许的 token。
  • synced_gpus:可选,布尔值,指示是否在多 GPU 环境下同步运行以避免死锁。
  • assistant_model:可选,用于加速生成的辅助模型,必须具有相同的 tokenizer。
  • streamer:可选,用于流式输出生成序列的对象。
  • negative_prompt_ids:可选,torch.LongTensor,形状为 (batch_size, sequence_length),用于一些处理器(如 CFG)的负提示。
  • negative_prompt_attention_mask:可选,torch.LongTensor,形状为 (batch_size, sequence_length),对应 negative_prompt_ids 的 attention mask。
  • kwargs:其他参数,可用于覆盖 generation_config 中的设置,或传递给模型的 forward 方法。

返回值

  • GenerateOutputtorch.LongTensor:根据参数设置,返回生成的序列或者包含生成过程详细信息的输出对象。

方法逻辑解析

总体流程

  1. 处理生成配置和参数验证:确保 generation_configkwargs 的正确性。
  2. 设置生成参数:根据传入的配置或默认值,设置生成所需的参数。
  3. 准备模型输入:处理输入张量,生成 input_ids,并管理模型需要的其他关键字参数。
  4. 确定生成模式:根据配置,选择合适的生成模式,例如贪心搜索、采样、束搜索等。
  5. 生成序列:调用相应的生成函数,生成目标序列。
  6. 返回结果:根据参数设置,返回生成的序列或包含更多信息的对象。

详细步骤

1. 处理生成配置和参数验证
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None)
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None)
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
  • _validate_model_class:检查模型的类型是否支持生成。_validate_model_class
  • 提取 tokenizer:用于停止条件等,非必要参数。
  • 准备 generation_configmodel_kwargs:处理传入的配置和额外参数。_prepare_generation_config
  • 验证模型关键字参数:确保传入的参数与模型的预期一致。_validate_model_kwargs
  • 验证辅助模型:如果使用了 assistant_model,确保其与主模型兼容。
2. 设置生成参数
if synced_gpus is None:
    synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
  • synced_gpus:在多 GPU 环境下,判断是否需要同步。
  • logits_processorstopping_criteria:初始化或使用默认的处理器和停止标准。
  • StoppingCriteria代码分析
  • LogitsProcessor代码分析
  • 注意力掩码相关变量:检查模型的 forward 方法是否接受 attention_mask,以及当前参数中是否提供。
3. 准备模型输入
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
    inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
  • _prepare_model_inputs:处理输入张量,获取模型输入名称,并更新 model_kwargs。_prepare_model_inputs
  • batch_sizedevice:获取批次大小和设备信息。
  • 准备特殊 token:如 bos_token_ideos_token_id 等。_prepare_special_tokens
4. 检查并处理注意力掩码
       # decoder-only models must use left-padding for batched generation.
        if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
            # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
            # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
            if (
                generation_config._pad_token_tensor is not None
                and batch_size > 1
                and len(inputs_tensor.shape) == 2
                and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
            ):
                logger.warning(
                    "A decoder-only architecture is being used, but right-padding was detected! For correct "
                    "generation results, please set `padding_side='left'` when initializing the tokenizer."
                )

这段代码是关于为仅解码器架构(decoder-only models)处理输入时的填充方式建议。它检查是否使用了右填充(right-padding),在这种情况下给出警告。

  1. 架构类型检查:

    • self.config.is_encoder_decoder 用于检查模型是否属于编码器-解码器架构。
    • 如果模型不是编码器-解码器架构(即它是仅解码器架构),并且不是在TorchDynamo编译模式下,代码继续进行。
  2. 输入条件检查:

    • 确保批处理大小大于1,即有多个序列在一起进行处理。
    • 检查输入张量 inputs_tensor 的形状满足条件:它是二维张量,一般形式为 [batch_size, sequence_length]
    • 检查序列中的最后一个标识符是否为 pad_token_id,指定的 generation_config._pad_token_tensor 不为 None
    • torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 用于确定至少有一个序列的最后一个令牌是填充令牌。
  3. 警告日志:

    • 如果满足以上条件,显示警告信息。
    • 提示在仅解码器模型中检测到右填充,它建议使用左填充(padding_side='left'),以获取正确的生成结果。
      # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
        # generating the first new token or not, and we only want to use the embeddings for the first new token)
        if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
            generation_config.use_cache = True

这段代码针对仅解码器架构(decoder-only models)进行了一种配置调整,特别是在使用 inputs_embeds 作为输入的时候。以下是代码的简单说明:

  1. 架构类型检查:

    • not self.config.is_encoder_decoder 用于判断模型是否是仅解码器架构。
    • 如果模型是仅解码器架构,代码继续进行。
  2. 输入类型检查:

    • model_input_name == "inputs_embeds" 检查输入类型是否为嵌入层(embeddings)。
    • inputs_embeds 通常表示已经经过词嵌入层的输入,这意味着模型接收的不是直接的令牌ID,而是对应的词向量。
  3. 缓存使用设置:

    • generation_config.use_cache = True 设置生成配置的 use_cache 属性为 True
    • 启用缓存对于跟踪生成的序列特别重要,尤其是在无法确定新生成的第一个令牌是否已生成时。
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
    model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
        inputs_tensor, generation_config, model_kwargs
    )
elif kwargs_has_attention_mask:
    if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
        raise ValueError("`attention_mask` passed to `generate` must be 2D.")
  • 如果需要 attention_mask:且未提供,则生成默认的 attention_mask
  • 验证提供的 attention_mask:确保其形状正确。
  • _prepare_attention_mask_for_generation
5. 为编码器-解码器模型准备输入
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
        inputs_tensor, model_kwargs, model_input_name, generation_config
    )
  • 对于编码器-解码器模型:如果未提供 encoder_outputs,则进行编码器的前向计算,并更新 model_kwargs。计算编码器部分
6. 准备用于自回归生成的 input_ids
if self.config.is_encoder_decoder:
    input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
        batch_size=batch_size,
        model_input_name=model_input_name,
        model_kwargs=model_kwargs,
        decoder_start_token_id=generation_config._decoder_start_token_tensor,
        device=inputs_tensor.device,
    )
else:
    input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
  • 编码器-解码器模型:准备解码器的 input_ids
  • _prepare_decoder_input_ids_for_generation
  • _decoder_start_token_tensor
  • 解码器模型:直接使用 inputs_tensor 作为 input_ids
7. 处理特殊的生成配置
if generation_config.token_healing:
    input_ids = self.heal_tokens(input_ids, tokenizer)
if streamer is not None:
    streamer.put(input_ids.cpu())
  • token_healing:如果启用了此项配置,修复输入中的 tokens。token_healing
  • streamer:如果提供了流式处理器,传递当前的 input_ids
8. 准备生成长度相关的参数
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
    generation_config=generation_config,
    has_default_max_length=has_default_max_length,
    has_default_min_length=has_default_min_length,
    model_input_name=model_input_name,
    inputs_tensor=inputs_tensor,
    input_ids_length=input_ids_length,
)
  • 计算输入序列长度
  • 确定是否使用默认的最大和最小长度
  • 准备生成长度配置,可能会根据输入长度进行调整。_prepare_generated_length
9. 准备缓存和其他模型参数
if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
    model_kwargs["logits_to_keep"] = 1
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

  • logits_to_keep:如果模型支持,仅保留需要的 logits,减少内存占用。_supports_logits_to_keep
  • 验证生成长度:确保生成长度的合法性。_validate_generated_length
        # 7. Prepare the cache.
        # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
        # - different models have a different cache name expected by the model (default = "past_key_values")
        # - `max_length`, prepared above, is used to determine the maximum cache length
        max_cache_length = generation_config.max_length - 1
        if (
            inputs_tensor.shape[1] != input_ids_length
            and model_input_name == "inputs_embeds"
            and not self.config.is_encoder_decoder
        ):
            max_cache_length += inputs_tensor.shape[1]
        self._prepare_cache_for_generation(
            generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
        )
  • 准备缓存:为生成过程中的缓存(如注意力缓存)分配空间。
    _prepare_cache_for_generation
10. 确定生成模式
generation_mode = generation_config.get_generation_mode(assistant_model)
  • 根据生成配置和辅助模型,确定生成模式,例如:
    • 辅助生成(Assisted Generation)
    • DoLa 生成(DOLA Generation)
    • 对比搜索(Contrastive Search)
    • 采样或贪心搜索
    • 束搜索(Beam Search)
    • 组束搜索(Group Beam Search)
    • 受限束搜索(Constrained Beam Search)
11. 准备 logits 处理器和停止标准
prepared_logits_processor = self._get_logits_processor(
    generation_config=generation_config,
    input_ids_seq_length=input_ids_length,
    encoder_input_ids=inputs_tensor,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    logits_processor=logits_processor,
    device=inputs_tensor.device,
    model_kwargs=model_kwargs,
    negative_prompt_ids=negative_prompt_ids,
    negative_prompt_attention_mask=negative_prompt_attention_mask,
)
prepared_stopping_criteria = self._get_stopping_criteria(
    generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
)
model_kwargs["use_cache"] = generation_config.use_cache
  • 获取 logits 处理器:整合默认和自定义的 logits 处理器,用于在生成过程中调整 logits。
  • 获取停止标准:整合默认和自定义的停止标准,用于在满足条件时终止生成。
  • 设置 use_cache:根据配置,决定是否在生成过程中使用缓存。
    _get_logits_processor
    _get_stopping_criteria
12. 根据生成模式调用相应的生成函数
  • 辅助生成
if generation_mode == GenerationMode.ASSISTED_GENERATION:
    # 验证条件
    # 获取候选生成器
    # 执行辅助生成
  • DoLa 生成
elif generation_mode == GenerationMode.DOLA_GENERATION:
    # 执行 DoLa 解码
  • 对比搜索
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
    # 执行对比搜索
  • 采样或贪心搜索
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
    # 扩展 input_ids
    # 执行采样或贪心搜索
  • 束搜索
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
    # 准备束搜索评分器
    # 扩展 input_ids
    # 执行束搜索

GenerationMixin:_sample方法(GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH)

  • 组束搜索
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
    # 准备组束搜索评分器
    # 扩展 input_ids
    # 执行组束搜索
  • 受限束搜索
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
    # 准备约束条件
    # 准备受限束搜索评分器
    # 扩展 input_ids
    # 执行受限束搜索
13. 处理生成结果
# 如果需要,将缓存转换为传统格式
if (
    generation_config.return_legacy_cache is True
    and not is_torchdynamo_compiling()
    and hasattr(result, "past_key_values")
    and getattr(result.past_key_values, "to_legacy_cache") is not None
):
    result.past_key_values = result.past_key_values.to_legacy_cache()
return result
  • 转换缓存格式:如果配置需要,将生成过程中使用的缓存转换为传统格式。
  • 返回结果:最终将生成的结果返回。

_validate_model_class

函数功能概述

这个函数名为_validate_model_class,用于验证当前的模型类是否支持生成(generation)操作。如果不支持生成,则会抛出一个异常,提示用户使用合适的模型类。

  1. 条件判断

    if not is_torchdynamo_compiling() and not self.can_generate():
    
    • 这里有一个 if 条件,用于检查两个条件是否同时满足:
      • not is_torchdynamo_compiling()
      • not self.can_generate()
    • is_torchdynamo_compiling()
      • 这是一个函数,检查当前是否处于 TorchDynamo 的编译环境中。
      • TorchDynamo 是 PyTorch 的一个 JIT 编译器框架,可以对模型进行编译优化。
      • 在编译过程中,某些检查可能需要被跳过,因此在编译时不进行此验证。
    • self.can_generate()
      • 这是模型类的一个方法,返回布尔值,表示模型是否支持生成操作。
      • 如果模型具有语言模型头(language model head),则通常支持生成,即 self.can_generate() 返回 True

    因此,只有在不处于编译环境模型不能生成的情况下,才会进入 if 语句内部,抛出异常。

  2. 定义支持生成的模型类名后缀列表

    terminations_with_generation_support = [
        "ForCausalLM",
        "ForConditionalGeneration",
        "ForSpeechSeq2Seq",
        "ForVision2Seq",
    ]
    
    • 这是一个列表,包含了通常支持生成操作的模型类名称的后缀。
    • 这些后缀包括:
      • "ForCausalLM":用于自回归语言模型,如 GPT-2、GPT-3 等。
      • "ForConditionalGeneration":用于条件生成模型,如 BART、T5 等。
      • "ForSpeechSeq2Seq":用于语音序列到序列模型。
      • "ForVision2Seq":用于视觉到序列的模型,如图像描述生成。
  3. 抛出异常

    raise TypeError(
        f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
        "it doesn't have a language model head. Classes that support generation often end in one of these "
        f"names: {terminations_with_generation_support}."
    )
    
    • 如果条件满足,说明当前模型不支持生成,则抛出一个 TypeError 异常。
    • 异常信息包括:
      • 当前模型类的名称:{self.__class__.__name__}
      • 说明模型不兼容 .generate() 方法,因为它没有语言模型头。
      • 提示支持生成的模型类通常以哪些后缀结尾:{terminations_with_generation_support}

_prepare_generation_config

函数功能概述

这个函数名为_prepare_generation_config,它的作用是准备生成所需的配置对象generation_config,并应用从kwargs(关键字参数)中传入的任何生成配置选项。这个函数还处理了与模型配置文件(config)的向后兼容性。


参数和返回值
  • 参数

    • self:类的实例。
    • generation_config:可选的GenerationConfig对象,表示生成配置。如果为None,则需要从其他地方获取或生成。
    • **kwargs:其他关键字参数,可能包含生成配置的参数,将用于更新generation_config
  • 返回值

    • generation_config:最终准备好的生成配置对象。
    • model_kwargs:用于模型的关键字参数字典。

处理逻辑详解
  1. 初始设置

    # 设置一个标志,指示是否使用模型的默认生成配置
    using_model_generation_config = False
    
  2. 处理generation_configNone的情况

    if generation_config is None:
        ...
    

    当用户没有提供generation_config时,需要从模型中获取。但在处理之前,先考虑到可能的向后兼容性问题。

    • 遗留(Legacy)行为的处理

      # 遗留支持:用户可能修改了模型的配置来控制生成。要触发这种遗留行为,需要满足以下条件:
      # 1) generation_config 是从模型配置创建的(`_from_model_config`字段为True)
      # 2) generation_config 自创建以来没有被修改过(哈希值相同)
      # 3) 模型配置中有非默认的生成参数
      # 4) 用户在模型配置中设置了新的生成参数
      # 注意:`torch.compile`无法编译`hash`函数,因此在编译时,这种遗留支持被禁用
      if (
          not is_torchdynamo_compiling()
          and self.generation_config._from_model_config  # 条件1
          and self.generation_config._original_object_hash == hash(self.generation_config)  # 条件2
          and len(self.config._get_non_default_generation_parameters()) > 0  # 条件3
      ):
          new_generation_config = GenerationConfig.from_model_config(self.config)
          if new_generation_config != self.generation_config:  # 条件4
              warnings.warn(
                  "You have modified the pretrained model configuration to control generation. This is a"
                  " deprecated strategy to control generation and will be removed in v5."
                  " Please use and modify the model generation configuration (see"
                  " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
                  UserWarning,
              )
              self.generation_config = new_generation_config
      
      • 解释条件

        1. self.generation_config._from_model_config
          • 检查generation_config是否是从模型配置创建的。
        2. self.generation_config._original_object_hash == hash(self.generation_config)
          • 检查generation_config自创建以来是否没有被修改过。
          • 由于torch.compile无法编译hash函数,因此在编译时无法进行此检查。
        3. len(self.config._get_non_default_generation_parameters()) > 0
          • 检查模型配置中是否有非默认的生成参数。
        4. new_generation_config != self.generation_config
          • 检查新的生成配置是否与当前的不同,表示用户在模型配置中设置了新的生成参数。
      • 操作

        • 如果以上条件都满足,表示用户通过修改模型配置来控制生成。这是一种已被弃用的做法。
        • 发出警告,提示此策略将在v5版本中移除,建议用户使用并修改生成配置对象generation_config
        • 更新self.generation_config为新的生成配置new_generation_config
    • 设置generation_config

      generation_config = self.generation_config
      using_model_generation_config = True
      
      • 如果没有提供generation_config,则使用模型的默认生成配置self.generation_config
      • 设置标志using_model_generation_config = True,表示正在使用模型的默认生成配置。
  3. 处理torch.compile相关的问题

    # `torch.compile`无法编译`copy.deepcopy`等函数,因此需要根据是否在编译中,决定如何处理
    if not is_torchdynamo_compiling():
        ...
    else:
        model_kwargs = kwargs
    
    • 非编译环境下的处理

      if not is_torchdynamo_compiling():
          generation_config = copy.deepcopy(generation_config)
          model_kwargs = generation_config.update(**kwargs)
          ...
      
      • 深拷贝generation_config

        • 使用copy.deepcopy创建generation_config的深拷贝,以避免修改原始对象。
        • 由于torch.compile无法编译copy.deepcopy,因此在编译环境下无法进行此操作。
      • 更新generation_config

        • 调用generation_config.update(**kwargs)方法,用传入的kwargs更新生成配置。
        • 这个方法返回未被generation_config使用的参数,即那些不属于生成配置的参数,存储在model_kwargs中。
      • 处理特殊的Token ID

        # 如果提供了`generation_config`,需要确保所有特殊的Token ID都有默认值
        if not using_model_generation_config:
            if generation_config.bos_token_id is None:
                generation_config.bos_token_id = self.generation_config.bos_token_id
            if generation_config.eos_token_id is None:
                generation_config.eos_token_id = self.generation_config.eos_token_id
            if generation_config.pad_token_id is None:
                generation_config.pad_token_id = self.generation_config.pad_token_id
            if generation_config.decoder_start_token_id is None:
                generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
        
        • 如果用户提供了自己的generation_config(即不使用模型的默认生成配置),需要确保特殊的Token ID(开始、结束、填充、解码器开始)有默认值。
        • 如果这些ID在用户提供的generation_config中为None,则使用模型默认的self.generation_config中的值。
    • 编译环境下的处理

      else:
          model_kwargs = kwargs
      
      • 在编译环境下,由于无法使用copy.deepcopyhash,直接将传入的kwargs赋值给model_kwargs
      • 不进行深拷贝和更新操作。
  4. 返回结果

    return generation_config, model_kwargs
    
    • 函数返回准备好的generation_configmodel_kwargs
    • model_kwargs包含了模型需要的其他参数。

_validate_model_kwargs

函数功能概述

这个函数名为_validate_model_kwargs,用于在生成(generation)过程中对传入的模型关键字参数model_kwargs进行验证。它的主要作用是:

  • 验证传入的model_kwargs是否被模型的生成方法使用,以确保参数的正确性,防止参数拼写错误或传入不支持的参数。
  • 给出明确的错误信息,帮助用户及时发现并修正问题。

1. 检查past_key_values是否为Cache实例,并验证模型是否支持
if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
    raise ValueError(
        f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
        "check the model documentation for supported cache formats."
    )
  • 目的:如果model_kwargs中包含past_key_values,并且它是一个Cache实例,而模型不支持Cache类型的缓存,则抛出错误。

  • 解释

    • model_kwargs.get("past_key_values", None):尝试获取past_key_values的值,如果不存在则为None
    • isinstance(..., Cache):检查past_key_values是否是Cache类的实例。
    • not self._supports_cache_class:检查当前模型是否不支持Cache类型的缓存。
  • 错误处理:如果以上条件满足,则抛出ValueError,提示用户当前模型不支持Cache实例作为past_key_values,并建议查看模型文档以了解支持的缓存格式。

2. 对于编码器-解码器模型,移除特定的参数
if self.config.is_encoder_decoder:
    for key in ["decoder_input_ids"]:
        model_kwargs.pop(key, None)
  • 目的:在生成过程中,不需要直接传入decoder_input_ids,因此从model_kwargs中移除该参数,防止后续产生错误。

  • 解释

    • self.config.is_encoder_decoder:检查模型是否为编码器-解码器结构。
    • model_kwargs.pop(key, None):从model_kwargs中移除指定的键,如果不存在则返回None
3. 初始化未使用的模型参数列表和获取模型方法参数
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
  • unused_model_args:用于收集未被模型方法使用的参数名。

  • model_args:获取prepare_inputs_for_generation方法的参数名集合,表示模型在生成过程中可能使用的参数。

  • 解释

    • inspect.signature(...):获取函数的完整参数签名。
    • .parameters:获取参数名组成的有序字典。
    • set(...):将参数名转换为集合,方便后续操作。
4. 扩展模型参数集合,包含forward方法的参数
if "kwargs" in model_args or "model_kwargs" in model_args:
    model_args |= set(inspect.signature(self.forward).parameters)
  • 目的:如果prepare_inputs_for_generation方法接受kwargs(可变关键字参数),则需要将self.forward方法的参数也包含进来,因为这些参数可能会通过kwargs传递。

  • 解释

    • 检查model_args中是否包含"kwargs""model_kwargs",表示接受可变关键字参数。
    • 使用set.union|=)操作,将self.forward方法的参数集合并入model_args
5. 对于编码器-解码器模型,包含编码器和解码器的参数
if self.config.is_encoder_decoder:
    base_model = getattr(self, self.base_model_prefix, None)
    # 允许编码器的参数
    encoder = getattr(self, "encoder", None)
    # 特殊情况处理(如Musicgen模型)
    if encoder is None and base_model is not None:
        encoder = getattr(base_model, "encoder", None)
    if encoder is not None:
        encoder_model_args = set(inspect.signature(encoder.forward).parameters)
        model_args |= encoder_model_args
    # 允许解码器的参数
    decoder = getattr(self, "decoder", None)
    if decoder is None and base_model is not None:
        decoder = getattr(base_model, "decoder", None)
    if decoder is not None:
        decoder_model_args = set(inspect.signature(decoder.forward).parameters)
        model_args |= {f"decoder_{x}" for x in decoder_model_args}
  • 目的:将编码器和解码器的forward方法的参数名添加到model_args中,以便在验证model_kwargs时考虑到这些参数。

  • 详细解释

    • base_model = getattr(self, self.base_model_prefix, None):获取基础模型实例,self.base_model_prefix通常是模型的前缀名。

    • 处理编码器

      • encoder = getattr(self, "encoder", None):尝试直接获取self.encoder属性。
      • 如果encoderNonebase_model存在,则尝试从base_model中获取encoder
      • 获取encoder.forward方法的参数名集合encoder_model_args
      • encoder_model_args并入model_args
    • 处理解码器

      • 类似地,获取decoder实例。
      • 获取decoder.forward方法的参数名集合decoder_model_args
      • decoder_model_args的参数名加上前缀"decoder_",以匹配生成过程中参数的命名方式。
      • 将修改后的decoder_model_args并入model_args
  • 特殊情况说明

    • 对于某些模型(例如MusicgenForConditionalGeneration),编码器和解码器的命名和结构可能与标准情况不同,需要特别处理。
6. 检查model_kwargs中的未使用参数
for key, value in model_kwargs.items():
    if value is not None and key not in model_args:
        unused_model_args.append(key)
  • 目的:找出model_kwargs中那些未被模型方法使用的参数。

  • 解释

    • 遍历model_kwargs中的所有键值对。
    • 如果参数值不为None且参数名不在model_args集合中,说明模型并不会使用该参数。
    • 将这些未使用的参数名添加到unused_model_args列表中。
7. 抛出错误提示未使用的参数
if unused_model_args:
    raise ValueError(
        f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
        " generate arguments will also show up in this list)"
    )
  • 目的:如果存在未使用的参数,抛出ValueError,提醒用户这些参数未被模型使用,可能存在拼写错误或传入了不支持的参数。

  • 解释

    • 检查unused_model_args列表是否非空。
    • 抛出错误,错误信息中包含未使用的参数列表,并提示用户可能是由于参数名的拼写错误导致的。

_prepare_model_inputs

方法定义

def _prepare_model_inputs(
    self,
    inputs: Optional[torch.Tensor] = None,
    bos_token_id: Optional[torch.Tensor] = None,
    model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
    """
    This function extracts the model-specific `inputs` for generation.
    """
    # 方法体...

参数说明:

  • inputs:可选的 torch.Tensor,表示要作为模型输入的张量。可能是 input_idsinputs_embeds 等形式,具体取决于模型的要求。
  • bos_token_id:可选的 torch.Tensor,表示序列开始的 token ID(BOS = Begin Of Sequence)。在生成任务中,如果未提供输入,可能需要使用该 token 进行初始化。
  • model_kwargs:可选的字典,包含传递给模型的其他关键字参数。

返回值:

  • inputstorch.Tensor,准备好的模型输入张量。
  • input_namestr,模型输入的名称,可能是 input_idsinputs_embeds
  • model_kwargs:字典,更新后的模型关键字参数。

方法功能概述

该方法的主要目的是在生成过程中,准备和验证模型的输入,确保输入与模型的预期格式和要求一致。具体任务包括:

  1. 确定模型所需的主要输入名称(input_name),这可能取决于模型是编码器-解码器模型还是仅解码器模型。
  2. 处理传入的 inputsmodel_kwargs,以避免重复传递相同的输入参数。
  3. 在需要时,使用 bos_token_id 初始化 input_ids,以开始生成过程。
  4. 对于支持 inputs_embeds 的模型,正确处理 inputs_embeds 参数。

逐步详解

步骤 1:确定模型的主要输入名称
# 判断模型是否是编码器-解码器,并获取正确的输入名称
if (
    self.config.is_encoder_decoder
    and hasattr(self, "encoder")
    and self.encoder.main_input_name != self.main_input_name
):
    input_name = self.encoder.main_input_name
else:
    input_name = self.main_input_name

解释:

  • 目的:获取模型预期的主要输入参数名称,可能是 input_idsinputs_embeds 等。
  • 逻辑:
    • 如果模型是 编码器-解码器模型,并且编码器的 main_input_name 与模型的 main_input_name 不同,则使用编码器的 main_input_name
      • 这是因为编码器和解码器可能需要不同的输入名称。例如,一些模型的编码器可能使用 inputs_embeds,而解码器使用 input_ids
    • 否则,使用模型的 main_input_name 作为输入名称。

示例:

  • 如果 self.main_input_name'input_ids',而 self.encoder.main_input_name'inputs_embeds',则对于编码器,需要使用 'inputs_embeds' 作为输入名称。
步骤 2:从 model_kwargs 中提取非输入相关的参数
# 过滤掉模型输入的关键字参数,以及值为 None 的参数
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}

解释:

  • 目的:清理 model_kwargs,只保留非模型输入的参数和值非 None 的参数。
  • 逻辑:
    • 遍历 model_kwargs 中的所有键值对,保留以下情况的参数:
      • v 不为 None
      • 或者键 k 不是模型的主要输入名称 input_name
  • 原因:避免模型输入参数(如 input_ids)重复出现在 inputsmodel_kwargs 中,以防止冲突。

示例:

  • 如果 input_name'input_ids',并且 model_kwargs 包含 {'input_ids': None, 'attention_mask': tensor([...])},那么经过此步骤后,model_kwargs 变为 {'attention_mask': tensor([...])}
步骤 3:检查模型输入是否作为关键字参数传入
# 从 model_kwargs 中获取输入参数
inputs_kwarg = model_kwargs.pop(input_name, None)
# 检查是否同时传入了 inputs 和 inputs_kwarg
if inputs_kwarg is not None and inputs is not None:
    raise ValueError(
        f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "
        f"Make sure to either pass {inputs} or {input_name}=..."
    )
elif inputs_kwarg is not None:
    inputs = inputs_kwarg

解释:

  • 目的:确保模型输入参数只通过一个途径传入,避免冲突。
  • 逻辑:
    • model_kwargs 中取出键为 input_name 的参数,赋值给 inputs_kwarg
    • 如果同时传入了 inputsinputs_kwarg,抛出异常,提示用户只能通过一种方式传入模型输入。
    • 如果 inputsNone,但 inputs_kwarg 不为 None,则将 inputs_kwarg 赋值给 inputs

示例:

  • 用户通过位置参数传入了 inputs,同时在 model_kwargs 中传入了 input_ids,这将导致异常,因为模型无法确定使用哪个输入。
步骤 4:处理 inputs_embeds 的情况
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
    if not self.config.is_encoder_decoder:
        # 检查模型是否支持 inputs_embeds
        has_inputs_embeds_forwarding = "inputs_embeds" in set(
            inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
        )
        if not has_inputs_embeds_forwarding:
            raise ValueError(
                f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
                "doesn't have its forwarding implemented. See the GPT2 implementation for an example "
                "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
            )
        # 将 input_ids 初始化并加入 model_kwargs
        model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
            inputs, bos_token_id, model_kwargs=model_kwargs
        )
    else:
        if inputs is not None:
            raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
    # 更新 inputs 和 input_name
    inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

解释:

  • 目的:处理用户通过 inputs_embeds 提供输入的情况,确保模型支持这种输入方式,并正确处理。
  • 逻辑:
    • 当模型的输入名称为 'input_ids',且 model_kwargs 中存在 'inputs_embeds' 键时,进入此逻辑。
    • 对于非编码器-解码器模型:
      • 检查模型的 prepare_inputs_for_generation 方法是否接受 inputs_embeds 参数。
      • 如果不支持,则抛出异常,提示模型不支持通过 inputs_embeds 进行生成。
      • 如果支持,则需要初始化 input_ids,以便在生成过程中处理诸如 attention mask 等依赖 input_ids 的自动操作。
      • 将初始化的 input_ids 添加到 model_kwargs 中。
    • 对于编码器-解码器模型:
      • 如果同时传入了 inputsinputs_embeds,抛出异常,提示只能选择一种输入方式。
    • 最后,将 inputs 设置为 inputs_embeds,并更新 input_name'inputs_embeds'

示例:

  • 用户在生成时提供了 inputs_embeds,如果模型支持,则将其用于生成,否则抛出异常。
步骤 5:如果 inputs 仍为 None,尝试从 bos_token_id 创建 input_ids
# 如果 inputs 为 None,使用 bos_token_id 初始化 input_ids
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)

解释:

  • 目的:在用户未提供任何输入的情况下,使用开始标记 bos_token_id 初始化输入,以启动生成过程。
  • 逻辑:
    • 调用 _maybe_initialize_input_ids_for_generation 方法,如果 inputsNone,则尝试使用 bos_token_id 创建 input_ids
    • 如果 inputs 已存在,则保持不变。

示例:

  • 用户未提供 inputs,模型使用 bos_token_id [101](假设为开始标记的 ID)初始化 input_ids
步骤 6:返回处理后的结果
# 返回准备好的 inputs、input_name 和更新后的 model_kwargs
return inputs, input_name, model_kwargs

示例:

# 示例 1:用户只提供 inputs
inputs = torch.tensor([[101, 102, 103]])
inputs, input_name, model_kwargs = model._prepare_model_inputs(inputs=inputs)

# 示例 2:用户只提供 input_ids 作为关键字参数
model_kwargs = {'input_ids': torch.tensor([[101, 102, 103]])}
inputs, input_name, model_kwargs = model._prepare_model_inputs(model_kwargs=model_kwargs)

# 示例 3:用户提供 inputs_embeds,模型支持 inputs_embeds
model_kwargs = {'inputs_embeds': torch.tensor([...])}
inputs, input_name, model_kwargs = model._prepare_model_inputs(model_kwargs=model_kwargs)

# 示例 4:用户未提供任何输入,使用 bos_token_id 初始化
bos_token_id = torch.tensor([101])
inputs, input_name, model_kwargs = model._prepare_model_inputs(bos_token_id=bos_token_id)

_prepare_special_tokens

主要目的是准备生成所需的特殊标记(如开始标记 bos_token_id、结束标记 eos_token_id、填充标记 pad_token_id 和解码器起始标记 decoder_start_token_id),并将这些标记转换为张量。具体步骤如下:

  1. 定义辅助函数 _tensor_or_none

    • 将特殊标记转换为张量,如果标记为 None 则返回 None
  2. 将特殊标记转换为张量

    • 使用 _tensor_or_none 函数将 bos_token_ideos_token_idpad_token_iddecoder_start_token_id 转换为张量。
  3. 处理编码器-解码器模型

    • 如果模型是编码器-解码器类型,并且 decoder_start_token_id 未设置,则使用 bos_token_id 作为 decoder_start_token_id
  4. 处理 eos_token_tensor

    • 如果 eos_token_tensor 是 0 维张量,则将其扩展为 1 维张量。
  5. 设置 pad_token_tensor

    • 如果 pad_token_tensor 未设置且 eos_token_tensor 存在,则将 pad_token_tensor 设置为 eos_token_tensor 的第一个元素,并发出警告。
  6. 安全检查和警告

    • 检查编码器-解码器模型是否设置了 decoder_start_token_id
    • 检查 eos_token_tensor 是否与 pad_token_tensor 相同,并在未设置注意力掩码时发出警告。
    • 检查 eos_token_tensor 是否包含负数或浮点数,并发出警告。
  7. 更新生成配置

    • 将转换后的特殊标记张量存储在 generation_config 的新属性中,以启用端到端编译。

_prepare_attention_mask_for_generation

该方法用于在生成过程中为模型准备 attention_mask,以确保模型在生成序列时正确地关注输入序列的非填充部分(非 pad_token 部分)。


方法定义

def _prepare_attention_mask_for_generation(
    self,
    inputs_tensor: torch.Tensor,
    generation_config: GenerationConfig,
    model_kwargs: Dict[str, Any],
) -> torch.LongTensor:
    # 方法体...

参数说明:

  • inputs_tensor: torch.Tensor,输入张量,可能是模型的主要输入。
  • generation_config: GenerationConfig,生成配置对象,包含生成过程中所需的配置参数。
  • model_kwargs: Dict[str, Any],包含模型其他关键字参数的字典。

返回值:

  • attention_mask: torch.LongTensor,返回生成过程中使用的 attention_mask,用于指示模型需要关注的输入序列位置。

方法功能概述

该方法的主要目的是根据输入张量 inputs_tensor 和生成配置 generation_config,为生成过程准备合适的 attention_mask。具体而言,它会根据以下情况生成 attention_mask

  1. 如果输入张量包含 pad_token_id,则在 attention_mask 中标记出非填充的位置(即非 pad_token_id 的位置为 1,pad_token_id 的位置为 0)。
  2. 如果无法判断是否需要生成 attention_mask,则返回默认的 attention_mask,即全 1 的张量,表示所有位置都需要关注。

逐步详解

步骤 1:获取 pad_token_ideos_token_id
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
  • 说明:从 generation_config 中获取用于填充的 pad_token_id 和序列结束的 eos_token_id
步骤 2:检查 model_kwargs 中是否有 input_ids
# `input_ids` 可能存在于 model_kwargs 中,而不是主要输入(例如多模态模型)
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
    inputs_tensor = model_kwargs["input_ids"]
  • 解释

    • 在某些情况下,inputs_tensor 并不是模型的主要输入,如在多模态模型中,input_ids 可能存在于 model_kwargs 中。
    • 如果 model_kwargs 中有 input_ids,并且其形状的第二维度长度大于 0(即有数据),则用 model_kwargs["input_ids"] 替换 inputs_tensor
  • 目的:确保 inputs_tensor 是正确的 input_ids,以便后续的 attention_mask 计算。

步骤 3:创建默认的 attention_mask
# 如果无法推断 attention mask,则返回默认的 attention mask
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
  • 解释

    • 创建一个形状与 inputs_tensor 前两个维度相同的全 1 张量,作为默认的 attention_mask
    • 该默认 attention_mask 表示输入序列的所有位置都需要关注。
    • 切片的基本语法是 start: end:step,其中 start 是起始索引,end 是结束索引(不包括),step 是步长。写为 [:2] 是一种简写表示方式,没有显式指定起始和步长,默认从头开始且步长为1
步骤 4:如果 pad_token_idNone,返回默认的 attention_mask
if pad_token_id is None:
    return default_attention_mask
  • 说明

    • 如果 pad_token_id 未定义,则无法根据填充标记生成 attention_mask,因此直接返回默认的 attention_mask
步骤 5:检查 inputs_tensor 是否为有效的 input_ids
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
if not is_input_ids:
    return default_attention_mask
  • 解释

    • 检查 inputs_tensor 是否满足以下条件:

      • inputs_tensor 的形状为二维,即形状为 (batch_size, sequence_length)
      • inputs_tensor 的数据类型为整数类型 torch.inttorch.long
    • 如果不满足上述条件,则认为无法根据 inputs_tensor 推断出 attention_mask,因此返回默认的 attention_mask

步骤 6:检查 inputs_tensor 中是否包含 pad_token_id
is_pad_token_in_inputs = (pad_token_id is not None) and (
    isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
)
  • 解释

    • 使用 isin_mps_friendly 函数检查 inputs_tensor 中是否包含 pad_token_id

      • isin_mps_friendly(elements, test_elements):检查 elements 中的元素是否在 test_elements 中,返回布尔张量。
    • isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()

      • 检查 inputs_tensor 中是否有元素等于 pad_token_id,如果有,则返回 True
  • 结果

    • is_pad_token_in_inputs 为一个布尔值,表示 inputs_tensor 中是否包含 pad_token_id
步骤 7:检查 pad_token_id 是否不等于 eos_token_id
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
    isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
)
  • 解释

    • 需要确保 pad_token_ideos_token_id 不相等,以避免将 eos_token_id 误认为填充标记。

    • 逻辑:

      • 如果 eos_token_idNone,则认为 pad_token_id 不等于 eos_token_id

      • 如果 eos_token_id 不为 None,则检查 eos_token_id 是否等于 pad_token_id

        • isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()

          • 检查 eos_token_id 是否等于 pad_token_id
        • 使用按位取反运算符 ~,将结果取反。

  • 结果

    • is_pad_token_not_equal_to_eos_token_id 为一个布尔值,表示 pad_token_id 是否不等于 eos_token_id
步骤 8:确定是否可以推断 attention_mask
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
  • 解释

    • 通过将 is_pad_token_in_inputsis_pad_token_not_equal_to_eos_token_id 相乘(布尔值相乘相当于逻辑与),判断是否可以基于 pad_token_id 推断出 attention_mask
  • 结果

    • can_infer_attention_mask 为布尔值:

      • 如果 True,表示可以根据 pad_token_id 推断 attention_mask

      • 如果 False,则无法推断,需要返回默认的 attention_mask

步骤 9:基于 pad_token_id 生成 attention_mask
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
  • 解释

    • 使用 inputs_tensor.ne(pad_token_id),得到一个布尔张量,标记出 inputs_tensor 中不等于 pad_token_id 的位置。

    • .long():将布尔张量转换为长整型,即 True 转换为 1False 转换为 0

  • 结果

    • attention_mask_from_padding 为一个长整型张量,形状与 inputs_tensor 相同,非 pad_token_id 的位置为 1pad_token_id 的位置为 0
步骤 10:根据是否可以推断 attention_mask,选择最终的 attention_mask
attention_mask = (
    attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
  • 解释

    • 如果可以推断 attention_maskcan_infer_attention_maskTrue):

      • 使用 attention_mask_from_padding
    • 如果不能推断:

      • 使用 default_attention_mask
    • 计算方式:

      • attention_mask = attention_mask_from_padding * can_infer_attention_mask

        • can_infer_attention_maskTrue 时,attention_mask 等于 attention_mask_from_padding

        • can_infer_attention_maskFalse 时,乘积为 0

      • default_attention_mask * ~can_infer_attention_mask

        • ~can_infer_attention_maskcan_infer_attention_mask 取反。

        • can_infer_attention_maskFalse 时,~can_infer_attention_maskTrue

        • 因此,当不能推断时,attention_mask 等于 default_attention_mask

      • 最终,将两个结果相加,得到正确的 attention_mask

  • 结果

    • attention_mask 为最终的注意力掩码张量,表示模型在生成过程中需要关注的输入位置。
步骤 11:返回最终的 attention_mask
return attention_mask
  • 解释

    • 返回生成的 attention_mask 张量,用于在生成过程中指导模型关注正确的输入序列位置。

_prepare_encoder_decoder_kwargs_for_generation

以下是对您提供的 _prepare_encoder_decoder_kwargs_for_generation 方法的详细解释。该方法用于在生成过程中为 编码器-解码器模型 准备必要的关键字参数,以便在生成序列时正确调用编码器。


方法定义

def _prepare_encoder_decoder_kwargs_for_generation(
    self,
    inputs_tensor: torch.Tensor,
    model_kwargs,
    model_input_name: Optional[str],
    generation_config: GenerationConfig,
) -> Dict[str, Any]:
    # 方法体...

参数说明:

  • inputs_tensor: torch.Tensor,输入张量,通常是 input_ids,表示输入序列的标记(tokens)。
  • model_kwargs: Dict[str, Any],包含传递给模型的其他关键字参数。
  • model_input_name: Optional[str],模型输入的名称,默认为 None。如果为 None,则使用 self.main_input_name
  • generation_config: GenerationConfig,生成配置对象,包含生成过程中所需的配置参数。

返回值:

  • model_kwargs: 返回更新后的 model_kwargs,其中包括编码器的输出 encoder_outputs,以供生成器使用。

方法功能概述

该方法的主要目的是:

  1. 获取模型的 编码器 部分。
  2. model_kwargsgeneration_config 中提取 编码器所需的参数,并准备传递给编码器的关键字参数 encoder_kwargs
  3. 调用 编码器的 forward 方法,获取编码器的输出,并将其添加到 model_kwargs 中,以供 解码器 在生成过程中使用。

逐步详解

步骤 1:获取编码器
# 1. get encoder
encoder = self.get_encoder()
  • 说明:

    • 调用模型的 get_encoder() 方法,获取编码器对象。
    • 该编码器将用于处理输入的 inputs_tensor,生成编码器的输出。
步骤 1.1:兼容性处理
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(self, "hf_device_map"):
    if hasattr(encoder, "_hf_hook"):
        encoder._hf_hook.io_same_device = True
    else:
        add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
  • 解释:

    • 目的:确保在使用 Accelerate 库进行大型模型推理时,编码器的输出与输入位于 同一设备 上(如 GPU),避免跨设备的数据传输开销。
  • 逻辑:

    • 检查模型是否具有 hf_device_map 属性,如果存在,表示模型使用了 Accelerate 库进行设备映射。
    • 检查编码器是否具有 _hf_hook 属性:
      • 如果有,设置其 io_same_device 属性为 True,表示编码器的输入和输出在同一设备上。
      • 如果没有,使用 add_hook_to_module 函数,将 AlignDevicesHook(io_same_device=True) 添加到编码器模块上。
  • 相关函数:

    • add_hook_to_module(module, hook): 将钩子函数添加到指定的模块上,控制模块的输入输出行为。
步骤 2:准备编码器的参数
# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {
    argument: value
    for argument, value in model_kwargs.items()
    if not any(argument.startswith(p) for p in irrelevant_prefix)
}
  • 目的:

    • model_kwargs 中提取与编码器相关的参数,过滤掉与解码器或交叉注意力相关的参数。
  • 逻辑:

    • 定义一个列表 irrelevant_prefix,包含了不相关的参数前缀,如 "decoder_""cross_attn""use_cache"
    • 使用字典推导式,从 model_kwargs 中过滤掉以这些前缀开头的参数。
    • 结果是 encoder_kwargs,其中包含了需要传递给编码器的参数。
  • 示例:

    • 如果 model_kwargs 包含:

      model_kwargs = {
          "input_ids": tensor(...),
          "attention_mask": tensor(...),
          "decoder_input_ids": tensor(...),
          "use_cache": True,
      }
      
    • 过滤后,encoder_kwargs 为:

      encoder_kwargs = {
          "input_ids": tensor(...),
          "attention_mask": tensor(...),
      }
      
步骤 2.1:检查编码器的签名
encoder_signature = set(inspect.signature(encoder.forward).parameters)
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
if not encoder_accepts_wildcard:
    encoder_kwargs = {
        argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
    }
  • 解释:

    • 目的:确保传递给编码器的参数在其 forward 方法的参数列表中,即编码器能够接受这些参数。
  • 逻辑:

    • 使用 inspect.signature(encoder.forward).parameters 获取编码器 forward 方法的参数名称集合 encoder_signature
    • 检查编码器是否接受通配参数 **kwargs**model_kwargs,如果接受,则无需进一步过滤参数。
    • 如果编码器不接受通配参数,则过滤 encoder_kwargs,仅保留在 encoder_signature 中的参数。
步骤 2.2:添加生成配置中的参数
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
  • 说明:

    • generation_config 中提取 output_attentionsoutput_hidden_states 配置,添加到 encoder_kwargs 中。
    • 这些配置决定了编码器在前向计算时,是否返回注意力权重和隐藏状态。
步骤 3:调用编码器并更新 model_kwargs
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)  # type: ignore
return model_kwargs
  • 解释:

    • 确保编码器返回 ModelOutput 对象:

      • 设置 encoder_kwargs["return_dict"] = True,使编码器的输出为 ModelOutput 格式,而不是元组。
    • 准备编码器的输入:

      • 确定模型的输入名称 model_input_name,如果传入了 model_input_name,则使用它,否则使用 self.main_input_name
      • inputs_tensor 添加到 encoder_kwargs,键为 model_input_name
    • 调用编码器的 forward 方法:

      model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)  # type: ignore
      
      • 调用编码器,传入 encoder_kwargs,得到编码器的输出 encoder_outputs
      • encoder_outputs 添加到 model_kwargs 中,键为 "encoder_outputs"
      • 使用类型注解 : ModelOutput 指定类型,这是为了让静态类型检查器知道 encoder_outputs 的类型。
    • 返回更新后的 model_kwargs

      • 包含了 encoder_outputs,供后续的解码器生成过程中使用。

_prepare_decoder_input_ids_for_generation

以下是对您提供的 _prepare_decoder_input_ids_for_generation 方法的详细解释。这个方法用于在生成过程中为 编码器-解码器模型 准备 decoder_input_ids,确保解码器在生成时能够正确地开始。


方法定义

def _prepare_decoder_input_ids_for_generation(
    self,
    batch_size: int,
    model_input_name: str,
    model_kwargs: Dict[str, torch.Tensor],
    decoder_start_token_id: torch.Tensor,
    device: torch.device = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
    """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
    # 方法体...

参数说明:

  • batch_size: int,表示批次大小,即一次输入中样本的数量。
  • model_input_name: str,模型输入的名称,通常为 "input_ids"
  • model_kwargs: Dict[str, torch.Tensor],包含传递给模型的其他关键字参数。
  • decoder_start_token_id: torch.Tensor,解码器开始标记的 token ID。
  • device: torch.device,可选,指定要在哪个设备上运行,如 GPU 或 CPU。如果未指定,则使用模型默认的设备。

返回值:

  • decoder_input_ids: torch.LongTensor,准备好的用于生成的解码器输入 IDs。
  • model_kwargs: 更新后的模型关键字参数字典。

方法功能概述

这个方法的主要作用是:

  1. 检查用户是否手动提供了 decoder_input_ids,如果没有,则根据情况初始化它。
  2. 确保 decoder_start_token_id 的形状正确,并适应批次大小。
  3. 确保 decoder_input_ids 以特殊的开始标记(如 BOS token)开头,如果没有,则自动添加。
  4. 处理特定模型的例外情况,例如 “Donut” 和 “Whisper” 模型可能有不同的处理方式。

通过这些步骤,该方法确保了在生成过程中,解码器能够正确地开始生成序列。


逐步详解

步骤 1:检查用户是否提供了 decoder_input_ids
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
    decoder_input_ids = model_kwargs.pop("decoder_input_ids")
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
    decoder_input_ids = model_kwargs.pop("input_ids")
else:
    decoder_input_ids = None

解释:

  • 目的: 确定是否需要初始化 decoder_input_ids

  • 逻辑:

    • 如果 model_kwargs 中存在 decoder_input_ids,则将其取出,并从 model_kwargs 中删除,避免重复。
    • 如果 model_kwargs 中存在 input_ids,且 model_input_name 不是 "input_ids",则将其视为 decoder_input_ids。这是为了方便某些模型的输入命名。
    • 如果以上都不满足,则将 decoder_input_ids 设为 None,表示需要初始化。

示例:

  • 用户在调用生成方法时,可能通过 model_kwargs 提供了 decoder_input_ids,方法将直接使用它。
  • 如果用户没有提供,方法将尝试从 input_ids 中获取(当编码器不使用 input_ids 作为主要输入时)。
  • 如果都没有提供,方法将在后续步骤中初始化 decoder_input_ids
步骤 2:调整 decoder_start_token_id 的形状
if device is None:
    device = self.device
if decoder_start_token_id.ndim == 1:
    if decoder_start_token_id.shape[0] != batch_size:
        raise ValueError(
            f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
        )
    decoder_start_token_id = decoder_start_token_id.view(-1, 1)
else:
    decoder_start_token_id = (
        torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
    )

解释:

  • 目的: 确保 decoder_start_token_id 的形状为 (batch_size, 1),即每个样本都有一个开始标记。

  • 逻辑:

    • 如果未指定 device,则使用模型默认的设备 self.device
    • 检查 decoder_start_token_id 的维度:
      • 如果是一维张量(向量),即 decoder_start_token_id.ndim == 1
        • 检查其长度是否等于 batch_size,否则抛出错误。
        • 重塑为形状 (batch_size, 1),即每个样本一个开始标记。
      • 如果不是一维张量,则认为是单个标量:
        • 创建一个形状为 (batch_size, 1) 的张量,元素全为 decoder_start_token_id 的值。

示例:

  • 如果 decoder_start_token_id 是标量 101(假设是开始标记的 ID):
    • 创建一个形状为 (batch_size, 1) 的张量,所有元素都是 101
  • 如果 decoder_start_token_id 是张量 [101, 102]batch_size2
    • 将其重塑为:
      [[101],
       [102]]
      
步骤 3:确保 decoder_input_ids 以开始标记开头
# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:
    decoder_input_ids = decoder_start_token_id

解释:

  • 情况 1: 如果 decoder_input_idsNone,即用户未提供任何解码器输入:
    • 直接使用 decoder_start_token_id 作为 decoder_input_ids,表示解码器将从开始标记开始生成。

# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token...
elif "donut" in self.__class__.__name__.lower() or (
    self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
):
    pass
elif self.config.model_type in ["whisper"]:
    pass

解释:

  • 情况 2: 针对特定模型的例外情况:
    • Donut 模型:
      • Donut 模型有特定的解码器开始标记,且不需要额外的 BOS(Begin of Sequence)标记。
      • 如果模型名包含 "donut",则不对 decoder_input_ids 进行处理。
    • Whisper 模型:
      • Whisper 模型也有自己的处理逻辑,不需要添加开始标记。

# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item():
    decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)
    if "decoder_attention_mask" in model_kwargs:
        decoder_attention_mask = model_kwargs["decoder_attention_mask"]
        decoder_attention_mask = torch.cat(
            (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
            dim=-1,
        )
        model_kwargs["decoder_attention_mask"] = decoder_attention_mask

解释:

  • 情况 3: 用户提供了 decoder_input_ids,但其首个 token 并非 decoder_start_token_id
    • 检查首个 token 是否等于 decoder_start_token_id
      • (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
        • 比较每个样本的 decoder_input_ids 首个 token 是否不等于对应的 decoder_start_token_id
        • 如果对于所有样本都不相等,则返回 True
    • 如果首个 token 不匹配,则在 decoder_input_ids 前添加 decoder_start_token_id
      • 使用 torch.cat 在维度 -1(序列长度维度)上拼接 decoder_start_token_iddecoder_input_ids
    • 调整 decoder_attention_mask(如果提供):
      • 如果 model_kwargs 中存在 decoder_attention_mask
        • 在其前面添加一个值为 1 的位置,表示新添加的开始标记需要被注意。
        • 更新 model_kwargs["decoder_attention_mask"]

示例:

  • 如果原始 decoder_input_ids 为:

    [[5, 6, 7],
     [8, 9, 10]]
    
  • decoder_start_token_id 为:

    [[101],
     [101]]
    
  • 拼接后得到:

    [[101, 5, 6, 7],
     [101, 8, 9, 10]]
    
  • 同时调整 decoder_attention_mask,在前面添加一个 1


步骤 4:返回处理后的结果
return decoder_input_ids, model_kwargs
  • 返回更新后的 decoder_input_idsmodel_kwargs

整体流程总结

  • 输入处理:

    • 检查用户是否提供了 decoder_input_ids,如果没有,则需要初始化。
    • 通过 model_kwargs 获取 decoder_input_idsinput_ids,如果适用。
  • 确保解码器开始标记的形状正确:

    • decoder_start_token_id 调整为形状 (batch_size, 1),确保每个样本都有对应的开始标记。
  • 确保 decoder_input_ids 以开始标记开头:

    • 如果 decoder_input_idsNone,直接使用 decoder_start_token_id
    • 对于特定模型(如 Donut 和 Whisper),保留用户提供的 decoder_input_ids,不做修改。
    • 如果用户提供的 decoder_input_ids 不以 decoder_start_token_id 开头,自动在其前添加。
    • 同时,调整 decoder_attention_mask,确保新添加的开始标记在注意力掩码中被考虑。
  • 返回处理后的 decoder_input_idsmodel_kwargs,供后续生成过程使用。


示例代码

示例 1:用户未提供 decoder_input_ids

# 假设 batch_size = 2
decoder_start_token_id = torch.tensor([101], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
    batch_size=2,
    model_input_name="input_ids",
    model_kwargs={},
    decoder_start_token_id=decoder_start_token_id,
    device=device,
)
# 结果:decoder_input_ids = [[101], [101]]

示例 2:用户提供了 decoder_input_ids,但未以开始标记开头

decoder_input_ids = torch.tensor([[5, 6, 7], [8, 9, 10]], device=device)
model_kwargs = {"decoder_input_ids": decoder_input_ids}
decoder_start_token_id = torch.tensor([101, 102], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
    batch_size=2,
    model_input_name="input_ids",
    model_kwargs=model_kwargs,
    decoder_start_token_id=decoder_start_token_id,
    device=device,
)
# 结果:
# decoder_input_ids = [[101, 5, 6, 7], [102, 8, 9, 10]]

示例 3:针对 Donut 模型

# 设模型名称包含 "Donut",用户提供了 decoder_input_ids
self.__class__.__name__ = "DonutModel"
decoder_input_ids = torch.tensor([[5, 6, 7], [8, 9, 10]], device=device)
model_kwargs = {"decoder_input_ids": decoder_input_ids}
decoder_start_token_id = torch.tensor([101], device=device)
decoder_input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
    batch_size=2,
    model_input_name="input_ids",
    model_kwargs=model_kwargs,
    decoder_start_token_id=decoder_start_token_id,
    device=device,
)
# 结果:
# decoder_input_ids 保持不变,不添加 decoder_start_token_id

注意事项

  • 模型特定的例外情况: 对于某些模型,如 Donut 和 Whisper,需要特殊处理,不能盲目添加开始标记。

  • 一致性检查: 确保 batch_sizedecoder_start_token_id 的长度一致,否则抛出异常。

  • 处理注意力掩码: 如果修改了 decoder_input_ids,且提供了 decoder_attention_mask,需要相应地调整掩码。

  • 设备一致性: 所有张量都应该在同一设备上(CPU 或 GPU),方法中确保了这一点。


总结

该方法的主要功能是为生成过程准备合适的 decoder_input_ids

  • 初始值设置: 如果用户未提供,则使用 decoder_start_token_id 初始化。
  • 形状调整: 确保 decoder_start_token_id 的形状与批次大小匹配。
  • 开头标记: 确保 decoder_input_ids 以特定的开始标记开头,以符合模型的要求。
  • 特例处理: 对于特定模型,按其特殊需求处理。

通过这些步骤,保证了编码器-解码器模型在生成过程中能够正确地开始解码器的序列生成。


希望以上解释能够帮助您理解 _prepare_decoder_input_ids_for_generation 方法的功能和每个步骤的具体作用。如果您还有其他问题,欢迎继续提问!

heal_tokens

这个方法用于在生成过程中,对输入的 input_ids 进行修复(healing),以增强模型的输出质量。具体来说,它会根据输入序列的最后一个 token,寻找可能的扩展,并替换原始的尾部 token。这在某些情况下可以纠正模型生成中的不一致或错误。


方法定义

def heal_tokens(
    self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
) -> torch.LongTensor:
    r"""
    Generates sequences of token ids for models with a language modeling head.

    Parameters:
        input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation.
        tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids.

    Return:
        `torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension.
    """
    # 方法体...

参数说明:

  • input_ids: torch.LongTensor,形状为 (batch_size, sequence_length),表示输入的 token 序列,用于生成的提示。
  • tokenizer: PreTrainedTokenizerBase,可选参数,模型对应的 tokenizer,用于解码和编码 token IDs。

返回值:

  • torch.LongTensor,形状与 input_ids 相同,其中每个序列的尾部 token 被替换为了适当的扩展 token。

方法功能概述

该方法的主要目的是:

  1. 修复输入序列的尾部 token:对于每个输入序列,检查其最后一个 token,寻找可能的扩展 token,并替换之。
  2. 改进模型生成的连贯性:通过纠正输入序列的尾部,使得生成的序列在语义和形式上更加连贯。
  3. 处理空序列和特殊情况:在方法中包含了对空序列和特殊情况的处理,确保方法的稳健性。

逐步详解

步骤 1:验证 tokenizer 是否提供

if tokenizer is None:
    raise ValueError(
        " When generating with token healing, you must pass the model's tokenizer to the `tokenizer` "
        "argument of `generate`."
    )
  • 解释:
    • 方法要求必须提供 tokenizer 参数,否则无法进行解码和编码操作。
    • 如果未提供 tokenizer,则抛出 ValueError

步骤 2:获取特殊 token IDs 和构建词汇前缀树

bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id
vocab_trie = ExtensionsTrie(tokenizer.get_vocab())
generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id)
  • 解释:
    • 获取特殊 token IDs:
      • bos_token_id: 序列开始的 token ID(BOS = Begin Of Sequence)。
      • pad_token_id: 填充 token 的 ID。
    • 构建词汇前缀树:
      • tokenizer.get_vocab(): 获取 tokenizer 的词汇表,返回一个字典 {token: token_id}
      • ExtensionsTrie: 自定义的前缀树类,用于快速查找以某个前缀开始的所有 token。
    • 创建生成配置:
      • GenerationConfig: 配置生成过程的参数。
      • max_new_tokens=1: 生成的最大新 token 数设置为 1,因为我们只需要替换尾部 token。
      • pad_token_id=pad_token_id: 设置填充 token ID。

步骤 3:处理提示文本

# assumption: leading/trailing whitespace is not meaningful, so the prompts are
# stripped before re-tokenizing to desensitize generation to whitespace artefacts
prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
  • 解释:
    • 假设:首尾的空白字符(whitespace)对生成过程没有实质影响,因此在重新编码前去除。
    • 步骤:
      • tokenizer.batch_decode: 将 input_ids 解码为字符串列表,跳过特殊 tokens。
      • 使用列表推导式,对每个字符串进行 strip 操作,去除首尾空白字符。
      • 结果存储在 prompts 列表中。

步骤 4:重新编码输入序列

input_ids = tokenizer(
    prompts,
    return_tensors="pt",
    padding=True,
).input_ids.to(input_ids.device)
  • 解释:
    • 重新编码:
      • 将处理后的 prompts 列表再次通过 tokenizer 编码为 input_ids
      • return_tensors="pt": 返回 PyTorch 张量。
      • padding=True: 对序列进行填充,以使它们的长度一致。
    • 调整设备:
      • 使用 .to(input_ids.device) 将新生成的 input_ids 移动到与原始 input_ids 相同的设备上(如 GPU 或 CPU)。

步骤 5:替换序列中的 bos_token_idpad_token_id

# replace bos with pad to not condition healing on it
input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
  • 解释:
    • 目的:
      • 在后续的修复过程中,不希望序列开始的特殊标记(BOS)影响结果,因此将其替换为填充标记。
    • 操作:
      • 使用 torch.where 函数,将 input_ids 中等于 bos_token_id 的位置替换为 pad_token_id

步骤 6:检查 input_ids 是否为空

if input_ids.numel() == 0:
    return input_ids
  • 解释:
    • 目的:
      • 如果 input_ids 为空(即没有元素),则无需进行后续处理,直接返回。
    • 检查:
      • input_ids.numel(): 返回张量中元素的总数。
      • 如果为 0,则返回 input_ids

步骤 7:获取每个序列的尾部 token ID 和对应的 token

tail_ids = input_ids[:, -1].tolist()
space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
# tail tokens are used for a prefix search, thus, whitespaces are replaced with
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
  • 解释:
    • 获取尾部 token ID:
      • input_ids[:, -1]: 获取每个序列的最后一个 token ID。
      • .tolist(): 转换为 Python 列表,方便后续处理。
    • 处理空格 token:
      • tokenizer.convert_tokens_to_ids(" "): 将空格字符 " " 转换为对应的 token ID。
      • tokenizer.convert_ids_to_tokens: 再将 token ID 转换为 token 字符串。
      • [0]: 获取结果中的第一个元素(列表中可能包含多个 token)。
      • 结果 space_tok 为空格对应的 token,例如在 BPE(Byte-Pair Encoding)中,空格可能对应特殊字符,例如 'Ġ'
    • 获取尾部 token 的字符串表示:
      • 使用生成器表达式 (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
        • 对每个 tail_id
          • tokenizer.decode(t): 解码为字符串。
          • .replace(" ", space_tok): 将字符串中的空格替换为 space_tok,以便于后续前缀搜索。

步骤 8:遍历每个序列,尝试修复尾部 token

for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):
    batch_ids = input_ids[batch_idx]
    if torch.all(batch_ids == pad_token_id).item():
        continue  # skip empty sequences (all pad ids)
  • 解释:
    • 遍历序列:
      • 使用 enumeratetail_idstail_toks 进行遍历。
      • batch_idx: 当前序列的索引。
      • tail_id: 当前序列的尾部 token ID。
      • tail_tok: 当前序列的尾部 token 字符串(经过特殊处理的)。
    • 获取当前序列的 input_ids
      • batch_ids = input_ids[batch_idx]: 获取当前序列的 input_ids
    • 检查序列是否为空:
      • torch.all(batch_ids == pad_token_id).item(): 检查序列中的所有位置是否都是 pad_token_id
      • 如果是,表示序列为空,跳过后续处理。

步骤 9:构建可能的替代 token 的偏置字典

# apply bias for alternatives (extensions) to the tail token
"""
seq_bias key has to be tuple with int so have to use
tokenizer function to convert str to int
"""
seq_bias = {
    (tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
}
if len(seq_bias) == 1:
    continue  # skip if there are no token alternatives to heal with
# slightly favor original token to limit aggressive healing e.g. 'http' -> 'https'
seq_bias[(tail_id,)] += 1.0
generation_config.update(sequence_bias=seq_bias)
  • 解释:
    • 寻找可能的扩展 tokens:
      • vocab_trie.extensions(prefix=tail_tok): 在词汇前缀树中,找到以 tail_tok 为前缀的所有 token。
    • 构建偏置字典 seq_bias
      • 键:以单个 token ID 组成的元组 (token_id,),因为后续需要的键是元组形式。
      • 值:偏置值 10.0,用于在生成过程中提升这些 tokens 的概率。
      • 需要使用 tokenizer.convert_tokens_to_ids(alt_tok) 将 token 字符串转换为 token ID。
    • 检查是否有可替代的 tokens:
      • 如果 seq_bias 的长度为 1,表示只有原始的 tail_id,没有其他可替代的 tokens,跳过处理。
    • 稍微提升原始 token 的概率:
      • seq_bias[(tail_id,)] += 1.0: 对原始的 tail_id,增加一个较小的偏置 1.0,以防止过度修复(例如,将 'http' 修复为 'https')。
    • 更新生成配置:
      • generation_config.update(sequence_bias=seq_bias): 将偏置字典添加到生成配置中,供生成过程使用。

步骤 10:准备生成输入,去除尾部 token

trimmed_ids = batch_ids[:-1]
"""
the latter code assumes trimmed_ids is not empty
so have to check its element count
"""
if trimmed_ids.numel() == 0:
    continue
# if the prompt is a single (non-pad) token, regenerate from bos
if len(batch_ids[batch_ids != pad_token_id]) == 1:
    trimmed_ids[-1] = bos_token_id
  • 解释:
    • 去除尾部 token:
      • trimmed_ids = batch_ids[:-1]: 获取当前序列,去除最后一个 token,即 tail_id
    • 检查 trimmed_ids 是否为空:
      • 如果 trimmed_ids.numel() == 0,表示序列长度为 0,无需处理,继续下一个序列。
    • 特殊情况处理:
      • 如果序列只有一个非填充 token(除去 pad_token_id 后长度为 1):
        • 需要将 trimmed_ids 的最后一个位置替换为 bos_token_id,以重新从开始标记生成。

步骤 11:生成新的 token 并替换尾部 token

input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config)
  • 解释:
    • 调用生成方法:
      • self.generate: 调用模型的 generate 方法,生成新序列。
      • trimmed_ids.unsqueeze(0): 为了适应批量维度,将 trimmed_ids 添加一个维度,形状从 (sequence_length - 1,) 变为 (1, sequence_length - 1)
      • generation_config=generation_config: 使用之前配置的生成参数,包括偏置字典 seq_bias
    • 更新 input_ids
      • 将生成的新序列替换到 input_ids 中对应的位置。
      • 注意,这里生成的序列长度为 sequence_length,因为 max_new_tokens=1,所以会在 trimmed_ids 的基础上生成一个新 token。

步骤 12:返回修复后的 input_ids

return input_ids
  • 解释:
    • 返回处理后的 input_ids,其中每个序列的尾部 token 已根据可能的扩展进行了修复。

整体流程总结

  1. 准备工作:

    • 验证 tokenizer 参数。
    • 获取特殊 token IDs。
    • 构建词汇前缀树 vocab_trie,用于快速查找可能的扩展 token。
    • 配置生成参数 generation_config
  2. 处理输入序列:

    • input_ids 解码为字符串列表 prompts,去除首尾空白。
    • 重新编码 promptsinput_ids,确保一致性。
    • bos_token_id 替换为 pad_token_id,避免对序列开始标记的影响。
  3. 遍历每个序列,尝试修复尾部 token:

    • 获取每个序列的尾部 token ID 和对应的 token 字符串。
    • 查找以尾部 token 为前缀的可能扩展 tokens。
    • 构建偏置字典 seq_bias,提升这些扩展 token 在生成过程中的概率。
    • 去除序列的尾部 token,准备生成新的尾部 token。
    • 调用 self.generate 方法,生成新的序列,并替换到 input_ids 中。
  4. 返回处理后的 input_ids


_prepare_generated_length

这个方法用于在生成过程中,根据用户提供的生成配置和模型输入,准备和调整生成的最大长度 (max_length) 和最小长度 (min_length),以避免类似属性之间的冲突。


方法定义
def _prepare_generated_length(
    self,
    generation_config,
    has_default_max_length,
    has_default_min_length,
    model_input_name,
    input_ids_length,
    inputs_tensor,
):
    """Prepared max and min length in generation configs to avoid clashes between similar attributes"""
    # 方法体...

参数说明:

  • generation_configGenerationConfig 对象,包含了生成过程中需要的各种配置参数,如 max_lengthmin_lengthmax_new_tokensmin_new_tokens 等。

  • has_default_max_lengthbool 类型,指示 max_length 是否使用了默认值。如果为 True,表示用户未显式设置 max_length

  • has_default_min_lengthbool 类型,指示 min_length 是否使用了默认值。

  • model_input_namestr 类型,模型输入的名称,通常为 "input_ids""inputs_embeds"

  • input_ids_lengthint 类型,输入序列 input_ids 的长度,即输入的 token 数量。

  • inputs_tensortorch.Tensor 对象,模型的输入张量。


方法功能概述

该方法的主要作用是:

  1. 调整 max_lengthmin_length:根据用户提供的 max_new_tokensmin_new_tokensmax_lengthmin_length,以及输入序列的长度,计算并设置最终的生成长度参数,确保生成过程按照预期进行。

  2. 避免冲突:如果用户同时设置了类似的属性(例如同时设置了 max_lengthmax_new_tokens),该方法会明确优先级,并在必要时发出警告,提示用户可能存在的冲突。

  3. 处理特殊情况:针对一些特殊的输入情况,例如使用了 inputs_embeds、模型是编码器-解码器模型等,做出相应的调整。


逐步详解
1. 处理 max_length
if generation_config.max_new_tokens is not None:
    if not has_default_max_length and generation_config.max_length is not None:
        logger.warning(
            f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
            f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
            "Please refer to the documentation for more information. "
            "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
        )
    generation_config.max_length = generation_config.max_new_tokens + input_ids_length

解释:

  • 情况 1:用户设置了 max_new_tokens

    • 逻辑:

      • 如果 generation_config.max_new_tokens 不为 None,表示用户希望通过 max_new_tokens 来指定要生成的新 token 数量。

      • 冲突处理:

        • 如果用户同时设置了 max_length(并且不是默认值),则发出警告,提示两者可能存在冲突,但以 max_new_tokens 为准。

        • not has_default_max_length:表示 max_length 被用户显式设置了。

      • 计算 max_length

        • max_length 设置为 input_ids_length + max_new_tokens

          • input_ids_length:输入序列的长度。

          • max_new_tokens:用户希望生成的新 token 数量。

      • 这样,max_length 就表示生成的序列(包括输入和生成的部分)总长度。

示例:

  • 如果 input_ids_length = 10max_new_tokens = 20,则 max_length = 10 + 20 = 30

# if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
# otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`
elif (
    model_input_name == "inputs_embeds"
    and input_ids_length != inputs_tensor.shape[1]
    and not self.config.is_encoder_decoder
):
    generation_config.max_length -= inputs_tensor.shape[1]

解释:

  • 情况 2:模型输入是 inputs_embeds,且存在输入长度不匹配

    • 逻辑:

      • 条件判断:

        • model_input_name == "inputs_embeds":表示模型的输入是嵌入表示。

        • input_ids_length != inputs_tensor.shape[1]:输入的 input_ids 长度与 inputs_tensor 的长度(序列维度大小)不一致。

        • not self.config.is_encoder_decoder:模型不是编码器-解码器模型。

      • 处理:

        • 需要调整 max_length,减去 inputs_tensor.shape[1],即输入的序列长度。
      • 原因:

        • 当用户提供了 inputs_embeds 而非 input_ids,且两者长度不一致,为了确保生成的总长度不超过用户预期,需要调整 max_length

elif has_default_max_length:  # by default let's always generate 20 new tokens
    if generation_config.max_length == GenerationConfig().max_length:
        generation_config.max_length = generation_config.max_length + input_ids_length
        max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
        if max_position_embeddings is not None:
            generation_config.max_length = min(generation_config.max_length, max_position_embeddings)

解释:

  • 情况 3:用户未设置 max_length,使用默认值

    • 逻辑:

      • has_default_max_lengthTrue,即用户未显式设置 max_length

      • 如果 generation_config.max_length 等于默认的 max_length,则执行以下操作:

        • 计算新的 max_length

          • max_length 设置为原来的 max_length 加上 input_ids_length

            • 这样,默认情况下,会在输入的基础上生成 20 个新 tokens(假设默认 max_length20)。
        • 考虑模型的最大位置嵌入长度:

          • 获取模型配置中的 max_position_embeddings,它表示模型能处理的最大序列长度。

          • 如果存在,就将 generation_config.max_length 限制在 max_position_embeddings 之内,防止生成长度超过模型的能力。


2. 处理 min_length
if generation_config.min_new_tokens is not None:
    if not has_default_min_length:
        logger.warning(
            f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(="
            f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. "
            "Please refer to the documentation for more information. "
            "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
        )
    generation_config.min_length = generation_config.min_new_tokens + input_ids_length

解释:

  • 情况 1:用户设置了 min_new_tokens

    • 逻辑:

      • 如果 generation_config.min_new_tokens 不为 None,表示用户希望指定要生成的最小新 token 数量。

      • 冲突处理:

        • 如果用户同时设置了 min_length(并且不是默认值),则发出警告,提示两者存在冲突,以 min_new_tokens 为准。
      • 计算 min_length

        • min_length 设置为 min_new_tokens + input_ids_length

elif (
    model_input_name == "inputs_embeds"
    and input_ids_length != inputs_tensor.shape[1]
    and not self.config.is_encoder_decoder
):
    generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)

解释:

  • 情况 2:模型输入是 inputs_embeds,且存在输入长度不匹配

    • 逻辑:

      • 与处理 max_length 时类似,但这里调整的是 min_length

      • generation_config.min_length 减去 inputs_tensor.shape[1],并取最大值 0(防止出现负值)。

      • 这样可以确保生成的最小长度不会超过用户预期。


3. 返回更新后的 generation_config
return generation_config
  • 返回调整过后的 generation_config,包含更新的 max_lengthmin_length,以供生成过程使用。

_validate_generated_length

def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
    """Performs validation related to the resulting generated length"""
    # 函数主体从这里开始

功能说明

  • 目的:该函数用于对生成的长度进行验证,确保 generation_config 中的参数设置合理,避免在生成过程中出现不可预期的行为或错误。

  • 主要任务

    • 检查 max_lengthmax_new_tokens 之间的关系,给出适当的警告或错误。
    • 检查输入序列长度 input_ids_lengthmax_length 之间的关系。
    • 检查 min_lengthmax_length 之间的关系,给出警告。
    • 确保 min_new_tokens 加上 input_ids_length 不超过 max_length,并给出警告。

参数说明

  • self:类的实例,典型的 Python 类方法的第一个参数。
  • generation_config:生成配置对象,包含了生成过程中使用的各项参数设置,例如 max_lengthmin_lengthmax_new_tokens 等。
  • input_ids_length:整数,输入序列 input_ids 的长度,即序列的长度。
  • has_default_max_length:布尔值,指示是否使用了默认的 max_length 设置(即用户没有在调用时显式指定 max_length)。

代码详细解释

1. 编译时不进行警告或异常抛出

# Can't throw warnings/exceptions during compilation
if is_torchdynamo_compiling():
    return

解释

  • 目的:在使用 TorchDynamo(PyTorch 编译器)进行编译时,不要抛出警告或异常。

  • 逻辑

    • is_torchdynamo_compiling():检查当前是否在使用 TorchDynamo 进行编译。
    • if is_torchdynamo_compiling(): return:如果正在编译,直接返回,不进行后续的验证。
  • 原因:在编译过程中,抛出异常或者发出警告可能会导致编译失败或行为异常,因此在编译时跳过验证。

2. 第一部分:与参数设置相关的 max_length 警告

# 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
    # 20 is the default max_length of the generation config
    warnings.warn(
        f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
        "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
        "generation.",
        UserWarning,
    )

解释

  • 目的:当用户未指定 max_length 且未设置 max_new_tokens 时,发出警告提示。

  • 逻辑

    • 条件判断

      • has_default_max_length:用户未显式指定 max_length,使用了默认值。
      • generation_config.max_new_tokens is Nonemax_new_tokens 未设置。
      • generation_config.max_length == 20max_length 等于 20,这是 GenerationConfig 的默认值。
    • 处理

      • 如果以上条件都满足,使用 warnings.warn() 发出警告,提示用户正在使用模型无关的默认 max_length 来控制生成长度,建议用户设置 max_new_tokens 来控制生成的最大长度。
  • 原因

    • 如果用户没有明确设置生成长度的参数,可能会导致生成的序列长度与预期不符。
    • 建议用户使用 max_new_tokens,因为它更直观地控制生成的新 tokens 数量。

3. 检查输入序列长度是否超过或等于 max_length

if input_ids_length >= generation_config.max_length:
    input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
    raise ValueError(
        f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
        f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
        " increasing `max_length` or, better yet, setting `max_new_tokens`."
    )

解释

  • 目的:如果输入序列的长度大于或等于 max_length,抛出 ValueError,因为这样会导致生成过程中的问题。

  • 逻辑

    • 条件判断

      • if input_ids_length >= generation_config.max_length:如果输入序列的长度 input_ids_length 大于或等于 max_length
    • 处理

      • 根据模型类型确定输入 IDs 的名称:

        input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
        
        • 如果是编码器-解码器模型,使用 'decoder_input_ids'
        • 否则,使用 'input_ids'
      • 抛出 ValueError,包含详细的错误信息:

        • 提示输入序列的长度与 max_length 的值,以及可能导致意外行为。
        • 建议用户增加 max_length 或者更好的,设置 max_new_tokens
  • 原因

    • 如果输入序列的长度已经达到或超过 max_length,模型将无法生成新的 tokens,可能导致生成过程立即结束或出现错误。
    • 为了避免这种情况,必须确保 max_length 大于输入序列的长度。

4. 第二部分:由于不可行的参数组合导致的 min_length 警告

# 2. Min length warnings due to unfeasible parameter combinations
min_length_error_suffix = (
    " Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
    "increase the maximum length."
)
if has_default_max_length:
    min_length_error_suffix += (
        f" Note that `max_length` is set to {generation_config.max_length}, its default value."
    )

解释

  • 目的:准备 min_length 相关的错误信息,供后续警告使用。

  • 逻辑

    • 定义错误信息的后缀

      • min_length_error_suffix:提示用户需要降低最小长度和/或增加最大长度。
    • 如果使用了默认的 max_length

      • 如果 has_default_max_lengthTrue,则在 min_length_error_suffix 后面追加一句,说明 max_length 被设置为其默认值。

5. 检查 min_length 是否大于 max_length

if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
    warnings.warn(
        f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
        f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
        UserWarning,
    )

解释

  • 目的:如果 min_length 大于 max_length,发出警告,因为这是不可行的约束。

  • 逻辑

    • 条件判断

      • generation_config.min_length is not Nonemin_length 已被设置。
      • generation_config.min_length > generation_config.max_lengthmin_length 大于 max_length
    • 处理

      • 使用 warnings.warn() 发出警告,提示 min_length 大于 max_length,并附加之前准备的 min_length_error_suffix
  • 原因

    • 如果最小生成长度大于最大生成长度,模型无法满足这样的约束,可能导致生成过程在达到最大长度时停止,而未达到最小长度。
    • 提醒用户调整参数,使其合理。

6. 检查 min_new_tokens 加上 input_ids_length 是否超过 max_length

if generation_config.min_new_tokens is not None:
    min_length = generation_config.min_new_tokens + input_ids_length
    if min_length > generation_config.max_length:
        warnings.warn(
            f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "
            f"added to the prompt length ({input_ids_length}), is larger than"
            f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
            UserWarning,
        )

解释

  • 目的:如果 min_new_tokens 加上输入序列长度超过 max_length,发出警告。

  • 逻辑

    • 条件判断

      • if generation_config.min_new_tokens is not Nonemin_new_tokens 已被设置。
    • 计算最小长度

      • min_length = generation_config.min_new_tokens + input_ids_length:计算最小生成长度,即 min_new_tokens 加上输入序列长度。
    • 检查是否超过 max_length

      • if min_length > generation_config.max_length:如果计算得到的 min_length 大于 max_length
    • 处理

      • 使用 warnings.warn() 发出警告,提示 min_new_tokens 加上输入长度超过了可能的最大长度,并附加 min_length_error_suffix
  • 原因

    • 如果 min_new_tokens 与输入序列长度之和超过 max_length,模型无法生成满足最小新 tokens 数量的序列。
    • 提醒用户调整参数,降低 min_new_tokens、增加 max_length,或减少输入序列的长度。

_supports_logits_to_keep

方法定义
def _supports_logits_to_keep(self) -> bool:
    """
    Return True if the current model supports the keyword argument `logits_to_keep` in forward()
    to save memory. Checking it in this way allows to avoid using a new model attribute.
    """
    return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())

方法功能概述
  • 目的:确定当前模型的 forward 方法是否支持 logits_to_keep 参数。
  • 返回值:布尔值
    • True:如果 forward 方法的参数中包含 logits_to_keep
    • False:如果 forward 方法的参数中不包含 logits_to_keep

逐步详解
  1. 使用 inspect 模块分析 forward 方法的签名

    inspect.signature(self.forward)
    
    • 作用:获取 self.forward 方法的签名信息,包括参数列表和参数默认值等。
    • inspect 模块:Python 内置模块,提供了检查和获取对象(如函数、类、模块等)信息的功能。
  2. 获取 forward 方法的参数字典

    inspect.signature(self.forward).parameters
    
    • 返回值:一个有序字典(OrderedDict),键为参数名称,值为参数对应的 Parameter 对象。
  3. 获取参数名称列表并转换为集合

    set(inspect.signature(self.forward).parameters.keys())
    
    • parameters.keys():返回参数名称的可迭代对象。
    • set(...):将参数名称转换为集合,方便后续进行快速查找(in 操作)。
  4. 检查是否包含 logits_to_keep 参数

    "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
    
    • 作用:判断字符串 "logits_to_keep" 是否在参数名称集合中。
    • 返回值:布尔值。
  5. 返回判断结果

    • 如果包含:返回 True
    • 如果不包含:返回 False

方法用途和背景
  • 节省内存

    在生成任务中,模型可能会生成大量的 logits(每个时间步长预测下一个 token 的概率分布,通常是一个包含了整个词汇表大小的张量)。如果能够限制保留的 logits 数量(例如只保留 top-k 个 logits),可以大大节省内存。

  • 动态检查模型功能

    不同的模型可能实现了不同的功能。通过这种动态检查的方法,可以在不修改模型代码的情况下,了解模型是否支持某个特定的参数或功能。这有助于编写通用的代码,适用于多种模型。

  • 避免使用额外的模型属性

    通过检查方法签名,而不是增加一个模型属性,可以减少模型类的复杂性和维护成本。


举例说明

假设我们有一个模型,其 forward 方法定义如下:

def forward(self, input_ids, attention_mask=None, logits_to_keep=None):
    # 模型的前向计算逻辑
    logits = self.compute_logits(input_ids, attention_mask)
    if logits_to_keep is not None:
        # 只保留指定数量的 logits
        logits = logits[:, -1, :logits_to_keep]
    return logits
  • 模型支持 logits_to_keep 参数:在这种情况下,_supports_logits_to_keep 方法会返回 True,因为 forward 方法的参数中包含 logits_to_keep

  • 使用示例

    if self._supports_logits_to_keep():
        outputs = self.forward(input_ids, attention_mask=attention_mask, logits_to_keep=10)
    else:
        outputs = self.forward(input_ids, attention_mask=attention_mask)
    
    • 解释:代码首先检查模型是否支持 logits_to_keep 参数,如果支持,则在调用 forward 方法时传入该参数,以只保留 top-10 的 logits,从而节省内存。

_prepare_cache_for_generation

def _prepare_cache_for_generation(
    self,
    generation_config: GenerationConfig,
    model_kwargs: Dict,
    assistant_model: "PreTrainedModel",
    batch_size: int,
    max_cache_length: int,
    device: torch.device,
) -> bool:
    """
    Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
    instantiated, writes it to `model_kwargs`, under the name expected by the model.
    """
    # 函数主体从这里开始

功能说明

  • 目的:准备生成过程中使用的缓存(cache),根据给定的 generation_config 和其他参数,初始化或调整缓存。如果缓存被实例化,它将被写入到 model_kwargs 中,使用模型期望的缓存名称。

  • 背景:在文本生成任务中,使用缓存可以加速生成过程,特别是在自回归模型中,缓存先前的计算结果可以避免重复计算。在不同的模型或配置下,缓存的实现方式可能不同,因此需要根据情况准备合适的缓存。

参数说明

  • self:当前类的实例,典型的 Python 类方法的第一个参数。

  • generation_configGenerationConfig 类型,表示生成配置,其中包含生成过程中的各种参数设置,如是否使用缓存、缓存的实现方式等。

  • model_kwargsDict 类型,包含传递给模型的关键字参数。在函数中,可能会对其进行修改,添加缓存相关的参数。

  • assistant_modelPreTrainedModel 类型,可选的辅助模型,用于加速生成或其他目的。

  • batch_size:整数,表示批次大小。

  • max_cache_length:整数,表示缓存的最大长度,即缓存可以存储的最大序列长度。

  • devicetorch.device 类型,表示在何种设备(CPU 或 GPU)上运行。


1. 确定缓存名称

cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
  • 解释

    • 这行代码根据当前模型类的名称,确定缓存在 model_kwargs 中的键名称。

    • 如果类名中不包含 "mamba",则缓存名称为 "past_key_values";否则,缓存名称为 "cache_params"

  • 原因

    • 不同的模型可能期望的缓存名称不同。模型需要从 model_kwargs 中获取缓存,如果名称不一致,可能导致缓存无法正确工作。

2. 确定是否需要跨注意力缓存(cross-attention cache)

requires_cross_attention_cache = (
    self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
  • 解释

    • requires_cross_attention_cache 是一个布尔值,表示是否需要准备跨注意力缓存。

    • 条件:

      • self.config.is_encoder_decoder:如果模型是编码器-解码器架构,则需要跨注意力缓存。

      • model_kwargs.get("encoder_outputs") is not None:如果在 model_kwargs 中提供了编码器的输出,则也需要跨注意力缓存。

  • 原因

    • 在编码器-解码器模型(如 BART、T5)中,解码器需要访问编码器的输出,因此需要跨注意力缓存。

3. 快速退出路径 1:用户已在 model_kwargs 中指定了缓存

# 快速退出路径 1:如果用户指定了缓存,我们只需要:
# a) 检查是否有冲突的 `generate` 参数
# b) 如果用户传递了旧的缓存格式,并且模型支持,将其转换为新的缓存格式
user_defined_cache = model_kwargs.get(cache_name)
if user_defined_cache is not None:
    if generation_config.cache_implementation is not None:
        raise ValueError(
            f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
            "Cache object) is unsupported. Please use only one of the two."
        )
    if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache():
        model_kwargs[cache_name] = (
            DynamicCache.from_legacy_cache(user_defined_cache)
            if not requires_cross_attention_cache
            else EncoderDecoderCache.from_legacy_cache(user_defined_cache)
        )
    return
  • 解释

    • 获取用户定义的缓存

      • user_defined_cache = model_kwargs.get(cache_name):从 model_kwargs 中获取用户可能提供的缓存。
    • 检查用户是否同时指定了 cache_implementation

      • 如果用户既在 model_kwargs 中提供了缓存,又在 generation_config 中指定了 cache_implementation,这是冲突的,会引发错误。
    • 处理旧的缓存格式

      • 如果 user_defined_cache 是一个元组(旧的缓存格式),并且模型支持默认的动态缓存(self._supports_default_dynamic_cache() 返回 True),则将旧的缓存转换为新的缓存格式。

      • 根据是否需要跨注意力缓存,使用不同的缓存类:

        • 如果不需要跨注意力缓存,使用 DynamicCache.from_legacy_cache(user_defined_cache)

        • 如果需要跨注意力缓存,使用 EncoderDecoderCache.from_legacy_cache(user_defined_cache)

    • 返回

      • 在处理完用户提供的缓存后,直接返回,不再进行后续的缓存准备。

4. 快速退出路径 2:用户指定不使用缓存

# 快速退出路径 2:如果用户指定不使用缓存。(冲突的参数已在 `generation_config.validate()` 中处理)
if generation_config.use_cache is False:
    return
  • 解释

    • 如果在 generation_config 配置中,用户设置了 use_cache=False,表示不使用缓存。

    • 直接返回,不需要准备缓存。

5. 快速退出路径 3:模型仅支持旧的缓存格式

# 快速退出路径 3:模型仅支持旧的缓存格式,无需准备
if not self._supports_default_dynamic_cache():
    if generation_config.cache_implementation is not None:
        warnings.warn(
            "This model does not support `Cache` instances, it only supports the legacy cache format (tuple "
            f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be "
            "ignored.",
            UserWarning,
        )
    return
  • 解释

    • 如果模型不支持默认的动态缓存(self._supports_default_dynamic_cache() 返回 False),则无法使用新的缓存实现。

    • 如果用户在 generation_config 中指定了 cache_implementation,则发出警告,指出模型仅支持旧的缓存格式,cache_implementation 将被忽略。

    • 直接返回,不需要进一步准备缓存。

6. 需要准备缓存,根据 generation_config.cache_implementation

# 否则,我们需要根据 `generation_config.cache_implementation` 准备缓存
# TODO(joao): 在辅助生成中支持静态缓存。辅助生成需要回滚缓存,目前只有动态缓存支持
if assistant_model is not None and generation_config.cache_implementation is not None:
    logger.warning_once(
        "An assistant model is provided, using a dynamic cache instead of a cache of type="
        f"'{generation_config.cache_implementation}'."
    )
    generation_config.cache_implementation = None
  • 解释

    • 如果上述快速退出条件都不满足,且需要准备缓存,则需要根据 generation_config.cache_implementation 的值来准备缓存。

    • 特殊情况:辅助模型和缓存实现的冲突

      • 如果提供了 assistant_model,并且指定了 cache_implementation,则发出警告,指出由于提供了辅助模型,将使用动态缓存,而不是指定类型的缓存。

      • generation_config.cache_implementation 设置为 None,以确保使用动态缓存。

  • 原因

    • 在辅助生成过程中,需要回滚缓存,目前只有动态缓存支持回滚。因此,即使用户指定了其他缓存实现,也需要使用动态缓存。

7. 根据缓存实现方式准备缓存

if generation_config.cache_implementation is not None:
    if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
        if generation_config.cache_implementation == "static" and not self._supports_static_cache:
            raise ValueError(
                "This model does not support `cache_implementation='static'`. Please check the following "
                "issue: https://github.com/huggingface/transformers/issues/28981"
            )
        model_kwargs[cache_name] = self._get_cache(
            cache_implementation=generation_config.cache_implementation,
            batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
            max_cache_len=max_cache_length,
            device=device,
            model_kwargs=model_kwargs,
        )
    elif generation_config.cache_implementation == "quantized":
        if not self._supports_quantized_cache:
            raise ValueError(
                "This model does not support the quantized cache. If you want your model to support quantized "
                "cache, please open an issue and tag @zucchini-nlp."
            )
        cache_config = (
            generation_config.cache_config
            if generation_config.cache_config is not None
            else QuantizedCacheConfig()
        )
        cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
        if cache_config.backend == "quanto" and not is_optimum_quanto_available():
            raise ImportError(
                "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. "
                "Please install it via  with `pip install optimum-quanto`"
            )
        elif cache_config.backend == "HQQ" and not is_hqq_available():
            raise ImportError(
                "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
                "Please install it via  with `pip install hqq`"
            )
        model_kwargs[cache_name] = cache_class(cache_config)
    elif generation_config.cache_implementation == "offloaded":
        model_kwargs[cache_name] = OffloadedCache()
  • 解释

    • 检查缓存实现方式是否需要特别的准备

      • NEED_SETUP_CACHE_CLASSES_MAPPING:一个映射,包含需要特殊设置的缓存类。

      • 如果 generation_config.cache_implementationNEED_SETUP_CACHE_CLASSES_MAPPING 中,则需要调用 _get_cache 方法来获取缓存实例。

    • 处理静态缓存

      • 如果 cache_implementation"static",并且模型不支持静态缓存(not self._supports_static_cache),则抛出 ValueError,提示模型不支持静态缓存。

      • 提供一个 GitHub Issue 链接,供用户了解更多信息。

    • 获取缓存实例

      • 调用 self._get_cache() 方法,传入缓存实现方式、批次大小、最大缓存长度、设备等参数,获取缓存实例。

      • 批次大小计算

        • max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size:计算缓存所需的实际批次大小。

          • 在束搜索或其他情况下,批次大小可能需要乘以 num_beamsnum_return_sequences
    • 处理量化缓存(quantized cache)

      • 如果 cache_implementation"quantized",需要特殊处理。

      • 检查模型是否支持量化缓存(self._supports_quantized_cache)。

      • 如果不支持,抛出 ValueError

      • 获取缓存配置 cache_config,如果用户未提供 generation_config.cache_config,则使用默认的 QuantizedCacheConfig()

      • 根据缓存配置的后端,获取对应的缓存类 cache_class

      • 检查所需的包是否已安装:

        • 对于 "quanto" 后端,检查 is_optimum_quanto_available()

        • 对于 "HQQ" 后端,检查 is_hqq_available()

      • 如果未安装,抛出 ImportError,提示用户安装相应的包。

      • 实例化缓存类,并将其存储在 model_kwargs[cache_name] 中。

    • 处理离线缓存(offloaded cache)

      • 如果 cache_implementation"offloaded",则实例化 OffloadedCache(),并存储在 model_kwargs[cache_name] 中。

8. 默认情况下,使用动态缓存

# 默认情况下,使用 DynamicCache() 实例。这将避免在旧格式之间来回转换,从而避免复制缓存,节省内存
else:
    model_kwargs[cache_name] = (
        DynamicCache()
        if not requires_cross_attention_cache
        else EncoderDecoderCache(DynamicCache(), DynamicCache())
    )
  • 解释

    • 如果 generation_config.cache_implementationNone,即用户未指定特定的缓存实现方式,并且上述条件都不满足,则默认使用动态缓存。

    • 根据是否需要跨注意力缓存,实例化不同的缓存类:

      • 如果不需要跨注意力缓存,使用 DynamicCache()

      • 如果需要跨注意力缓存,使用 EncoderDecoderCache(DynamicCache(), DynamicCache()),即分别为编码器和解码器缓存初始化动态缓存。

  • 原因

    • 动态缓存可以在生成过程中动态扩展,并支持回滚等特性。

    • 使用默认的动态缓存,可以避免在旧缓存格式与新缓存格式之间来回转换,减少内存使用和数据复制。


_get_logits_processor

def _get_logits_processor(
    self,
    generation_config: GenerationConfig,
    input_ids_seq_length: int,
    encoder_input_ids: torch.LongTensor,
    prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
    logits_processor: Optional[LogitsProcessorList],
    device: str = None,
    model_kwargs: Optional[Dict[str, Any]] = None,
    negative_prompt_ids: Optional[torch.Tensor] = None,
    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
) -> LogitsProcessorList:
    """
    This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
    instances used to modify the scores of the language model head.
    """
    # 函数主体从这里开始

功能说明

  • 目的:该函数用于构建一个LogitsProcessorList对象,其中包含了一系列LogitsProcessor实例,这些实例用于在生成过程中修改语言模型头(language model head)的logits(模型输出的分数),以实现各种生成控制策略。

  • 背景:在生成任务中,通过调整模型输出的logits,可以引入各种生成策略,如重复惩罚、温度调节、Top-K采样、禁用某些词汇等,从而控制生成文本的风格、内容和长度。

参数说明

  • self:类的实例,典型的Python类方法的第一个参数。

  • generation_configGenerationConfig对象,包含了生成过程中的各种配置参数,如重复惩罚系数、最小生成长度、温度等。

  • input_ids_seq_length:整数,表示输入序列的长度,即input_ids的长度。

  • encoder_input_idstorch.LongTensor,编码器的输入IDs。如果模型是编码器-解码器架构,这对应于编码器的输入。

  • prefix_allowed_tokens_fn:可选的函数,类型为Callable[[int, torch.Tensor], List[int]]。用于在生成过程中限制每个位置上允许生成的tokens,通常用于受限生成任务。

  • logits_processor:可选的LogitsProcessorList对象,用户自定义的LogitsProcessor列表,可用于补充或覆盖默认的processor。

  • device:字符串,可选参数,指定设备(如'cpu''cuda')。如果未提供,默认为None

  • model_kwargs:可选的字典,包含了传递给模型的其他关键字参数。

  • negative_prompt_ids:可选的torch.Tensor,用于一些生成策略(如Classifier-Free Guidance)中的负面提示IDs。

  • negative_prompt_attention_mask:可选的torch.Tensor,对应negative_prompt_ids的注意力掩码。


1. 初始化LogitsProcessorList

# instantiate processors list
processors = LogitsProcessorList()
  • 解释:创建一个空的LogitsProcessorList对象processors,用于存储将要应用的所有LogitsProcessor实例。

2. 处理Classifier-Free Guidance(CFG)

if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
    processors.append(
        UnbatchedClassifierFreeGuidanceLogitsProcessor(
            generation_config.guidance_scale,
            self,
            unconditional_ids=negative_prompt_ids,
            unconditional_attention_mask=negative_prompt_attention_mask,
            use_cache=generation_config.use_cache,
        )
    )
  • 解释

    • 条件:如果guidance_scale不为None且不等于1,则说明需要应用Classifier-Free Guidance(CFG)。

      • guidance_scale是CFG的缩放因子,通常大于1,用于调整生成的多样性与准确性。
    • 操作:向processors中添加一个UnbatchedClassifierFreeGuidanceLogitsProcessor实例。

      • 参数说明

        • generation_config.guidance_scale:CFG的缩放因子。

        • self:模型实例,用于在LogitsProcessor中调用模型的其他方法。

        • unconditional_ids:负面提示的IDs,即negative_prompt_ids

        • unconditional_attention_mask:负面提示的注意力掩码,即negative_prompt_attention_mask

        • use_cache:是否使用缓存,来自generation_config


3. 处理序列偏置(Sequence Bias)

if generation_config.sequence_bias is not None:
    processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
  • 解释

    • 条件:如果sequence_bias不为None,则需要应用序列偏置。

      • sequence_bias是一种机制,可对特定的token序列施加偏置,提高或降低它们在生成中的概率。
    • 操作:向processors中添加一个SequenceBiasLogitsProcessor实例,传入sequence_bias参数。


4. 处理多样性惩罚(Diversity Penalty)

if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
    processors.append(
        HammingDiversityLogitsProcessor(
            diversity_penalty=generation_config.diversity_penalty,
            num_beams=generation_config.num_beams,
            num_beam_groups=generation_config.num_beam_groups,
        )
    )
  • 解释

    • 条件:如果diversity_penalty不为None且大于0,则需要应用多样性惩罚。

      • 多样性惩罚用于在束搜索中鼓励生成更多样化的序列。
    • 操作:向processors中添加一个HammingDiversityLogitsProcessor实例。

      • 参数说明

        • diversity_penalty:多样性惩罚系数。

        • num_beams:束搜索的束宽,即同时考虑的序列数量。

        • num_beam_groups:束搜索的组数,用于分组束搜索。


5. 处理编码器重复惩罚(Encoder Repetition Penalty)

if (
    generation_config.encoder_repetition_penalty is not None
    and generation_config.encoder_repetition_penalty != 1.0
):
    if len(encoder_input_ids.shape) == 2:
        processors.append(
            EncoderRepetitionPenaltyLogitsProcessor(
                penalty=generation_config.encoder_repetition_penalty,
                encoder_input_ids=encoder_input_ids,
            )
        )
    else:
        warnings.warn(
            "Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to "
            "`generate`, ignoring the argument.",
            UserWarning,
        )
  • 解释

    • 条件:如果encoder_repetition_penalty不为None且不等于1.0,则需要应用编码器重复惩罚。

      • 编码器重复惩罚用于减少模型在生成时重复输入内容的可能性。
    • 检查:如果encoder_input_ids的形状为二维(即存在有效的编码器输入),则应用惩罚。

    • 操作:向processors中添加一个EncoderRepetitionPenaltyLogitsProcessor实例。

      • 参数说明

        • penalty:重复惩罚系数。

        • encoder_input_ids:编码器的输入IDs。

    • 否则:发出警告,提示需要提供input_ids以应用该惩罚,忽略该参数。


6. 处理重复惩罚(Repetition Penalty)

if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
    processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
  • 解释

    • 条件:如果repetition_penalty不为None且不等于1.0,则需要应用重复惩罚。

      • 重复惩罚用于减少模型在生成时重复之前生成内容的可能性。
    • 操作:向processors中添加一个RepetitionPenaltyLogitsProcessor实例,传入penalty参数。


7. 处理禁止重复的n-gram(No Repeat N-Gram)

if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
    processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
  • 解释

    • 条件:如果no_repeat_ngram_size不为None且大于0,则需要禁止重复的n-gram。

      • 这用于防止模型在生成时重复生成相同的n-gram,提高生成的多样性。
    • 操作:向processors中添加一个NoRepeatNGramLogitsProcessor实例,传入no_repeat_ngram_size参数。


8. 处理编码器禁止重复的n-gram(Encoder No Repeat N-Gram)

if (
    generation_config.encoder_no_repeat_ngram_size is not None
    and generation_config.encoder_no_repeat_ngram_size > 0
):
    if len(encoder_input_ids.shape) == 2:
        processors.append(
            EncoderNoRepeatNGramLogitsProcessor(
                generation_config.encoder_no_repeat_ngram_size,
                encoder_input_ids,
            )
        )
    else:
        warnings.warn(
            "Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to "
            "`generate`, ignoring the argument.",
            UserWarning,
        )
  • 解释

    • 条件:如果encoder_no_repeat_ngram_size不为None且大于0,则需要在生成时避免重复输入中的n-gram。

      • 这用于防止模型在生成时重复输入序列中的n-gram。
    • 检查:如果encoder_input_ids的形状为二维(存在有效的编码器输入),则应用该处理器。

    • 操作:向processors中添加一个EncoderNoRepeatNGramLogitsProcessor实例。

      • 参数说明

        • encoder_no_repeat_ngram_size:禁止重复的n-gram大小。

        • encoder_input_ids:编码器的输入IDs。

    • 否则:发出警告,提示需要提供input_ids以应用该处理器,忽略该参数。


9. 处理坏词(Bad Words)

if generation_config.bad_words_ids is not None:
    processors.append(
        NoBadWordsLogitsProcessor(
            generation_config.bad_words_ids,
            generation_config._eos_token_tensor,
        )
    )
  • 解释

    • 条件:如果bad_words_ids不为None,则需要在生成过程中禁止某些词。

      • bad_words_ids是一个列表,包含需要禁止的词的token IDs。
    • 操作:向processors中添加一个NoBadWordsLogitsProcessor实例。

      • 参数说明

        • bad_words_ids:需要禁止的词的token IDs。

        • _eos_token_tensor:结束标记的token张量,用于在必要时停止生成。


10. 处理最小长度(Minimum Length)

if (
    generation_config.min_length is not None
    and generation_config._eos_token_tensor is not None
    and generation_config.min_length > 0
):
    processors.append(
        MinLengthLogitsProcessor(
            generation_config.min_length,
            generation_config._eos_token_tensor,
            device=device,
        )
    )
  • 解释

    • 条件:如果min_length不为None_eos_token_tensor不为None,且min_length大于0,则需要在生成达到最小长度之前禁止生成结束标记。

    • 操作:向processors中添加一个MinLengthLogitsProcessor实例。

      • 参数说明

        • min_length:最小生成长度。

        • _eos_token_tensor:结束标记的token张量。

        • device:设备信息。


11. 处理最小新tokens的长度(Minimum New Tokens Length)

if (
    generation_config.min_new_tokens is not None
    and generation_config._eos_token_tensor is not None
    and generation_config.min_new_tokens > 0
):
    processors.append(
        MinNewTokensLengthLogitsProcessor(
            input_ids_seq_length,
            generation_config.min_new_tokens,
            generation_config._eos_token_tensor,
            device=device,
        )
    )
  • 解释

    • 条件:如果min_new_tokens不为None_eos_token_tensor不为None,且min_new_tokens大于0,则需要在生成新tokens达到最小数量之前禁止生成结束标记。

    • 操作:向processors中添加一个MinNewTokensLengthLogitsProcessor实例。

      • 参数说明

        • input_ids_seq_length:输入序列的长度。

        • min_new_tokens:最小新生成的tokens数量。

        • _eos_token_tensor:结束标记的token张量。

        • device:设备信息。


12. 处理前缀限制(Prefix Allowed Tokens Function)

if prefix_allowed_tokens_fn is not None:
    processors.append(
        PrefixConstrainedLogitsProcessor(
            prefix_allowed_tokens_fn,
            generation_config.num_beams // generation_config.num_beam_groups,
        )
    )
  • 解释

    • 条件:如果prefix_allowed_tokens_fn不为None,则需要在生成过程中限制每个位置上允许生成的tokens。

      • 这通常用于受限生成任务,例如自动补全或基于前缀的约束生成。
    • 操作:向processors中添加一个PrefixConstrainedLogitsProcessor实例。

      • 参数说明

        • prefix_allowed_tokens_fn:用于限制每个位置上允许生成的tokens的函数。

        • generation_config.num_beams // generation_config.num_beam_groups:计算每个组中的束宽。


13. 处理强制起始token(Forced BOS Token)

if generation_config.forced_bos_token_id is not None:
    processors.append(
        ForcedBOSTokenLogitsProcessor(
            generation_config.forced_bos_token_id,
        )
    )
  • 解释

    • 条件:如果forced_bos_token_id不为None,则需要在生成的第一个位置强制生成指定的起始token。

      • 这用于确保生成的序列以特定的token开始,例如在某些任务中需要强制生成特定的起始标记。
    • 操作:向processors中添加一个ForcedBOSTokenLogitsProcessor实例,传入forced_bos_token_id


14. 处理强制结束token(Forced EOS Token)

if generation_config.forced_eos_token_id is not None:
    processors.append(
        ForcedEOSTokenLogitsProcessor(
            generation_config.max_length,
            generation_config.forced_eos_token_id,
            device=device,
        )
    )
  • 解释

    • 条件:如果forced_eos_token_id不为None,则需要在生成达到最大长度时强制生成指定的结束token。

      • 这用于确保生成的序列以特定的token结束。
    • 操作:向processors中添加一个ForcedEOSTokenLogitsProcessor实例。

      • 参数说明

        • generation_config.max_length:最大生成长度。

        • forced_eos_token_id:强制的结束token ID。

        • device:设备信息。


15. 处理无效值移除(Remove Invalid Values)

if generation_config.remove_invalid_values is True:
    processors.append(InfNanRemoveLogitsProcessor())
  • 解释

    • 条件:如果remove_invalid_valuesTrue,则需要在生成过程中移除infnan等无效值。

      • 这用于确保生成过程的稳定性,防止由于无效值导致的错误。
    • 操作:向processors中添加一个InfNanRemoveLogitsProcessor实例。


16. 处理指数衰减长度惩罚(Exponential Decay Length Penalty)

if generation_config.exponential_decay_length_penalty is not None:
    processors.append(
        ExponentialDecayLengthPenalty(
            generation_config.exponential_decay_length_penalty,
            generation_config._eos_token_tensor,
            input_ids_seq_length,
        )
    )
  • 解释

    • 条件:如果exponential_decay_length_penalty不为None,则需要应用指数衰减的长度惩罚。

      • 这用于在生成过程中,对句子长度施加惩罚,鼓励模型生成特定长度的句子。
    • 操作:向processors中添加一个ExponentialDecayLengthPenalty实例。

      • 参数说明

        • exponential_decay_length_penalty:指数衰减长度惩罚的参数。

        • _eos_token_tensor:结束标记的token张量。

        • input_ids_seq_length:输入序列的长度。


17. 处理抑制特定tokens(Suppress Tokens)

if generation_config.suppress_tokens is not None:
    processors.append(
        SuppressTokensLogitsProcessor(
            generation_config.suppress_tokens,
            device=device,
        )
    )
  • 解释

    • 条件:如果suppress_tokens不为None,则需要在生成过程中抑制特定的tokens,不让它们生成。

      • suppress_tokens是需要抑制的token IDs列表。
    • 操作:向processors中添加一个SuppressTokensLogitsProcessor实例。

      • 参数说明

        • suppress_tokens:需要抑制的token IDs。

        • device:设备信息。


18. 处理在开头抑制特定tokens(Suppress Tokens at Begin)

if generation_config.begin_suppress_tokens is not None:
    begin_index = input_ids_seq_length
    begin_index = (
        begin_index
        if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
        else begin_index + 1
    )
    processors.append(
        SuppressTokensAtBeginLogitsProcessor(
            generation_config.begin_suppress_tokens,
            begin_index,
            device=device,
        )
    )
  • 解释

    • 条件:如果begin_suppress_tokens不为None,则需要在生成的开头位置抑制特定的tokens。

      • 这用于避免模型在一开始生成某些不期望的tokens。
    • 计算起始索引

      • begin_index初始值为input_ids_seq_length,表示当前生成的位置。

      • 如果input_ids_seq_length <= 1forced_bos_token_id不为None,则begin_index += 1

    • 操作:向processors中添加一个SuppressTokensAtBeginLogitsProcessor实例。

      • 参数说明

        • begin_suppress_tokens:需要抑制的token IDs。

        • begin_index:开始抑制的位置索引。

        • device:设备信息。


19. 处理强制解码器IDs(Forced Decoder IDs)

if generation_config.forced_decoder_ids is not None:
    # TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT
    raise ValueError(
        "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "
        "in favour of `input_ids` or `decoder_input_ids` respectively.",
    )
  • 解释

    • 条件:如果forced_decoder_ids不为None,则抛出异常。

    • 原因:当前不支持forced_decoder_ids,建议用户使用input_idsdecoder_input_ids来替代。

    • 备注:注释中提到,当TensorFlow和FLAX版本与PyTorch版本对齐后,可以将此异常移动到GenerationConfig.validate()中。


20. 合并用户自定义的logits_processor

# TODO (joao): find a strategy to specify the order of the processors
processors = self._merge_criteria_processor_list(processors, logits_processor)
  • 解释

    • 操作:调用self._merge_criteria_processor_list()方法,将之前构建的processors列表与用户自定义的logits_processor进行合并。

    • 备注:注释中提到需要找到一种策略来指定处理器的顺序。


21. 处理采样策略下的LogitsWarper

# 只有在使用采样策略时,才应用之前被称为`LogitsWarpers`的处理器
if generation_config.do_sample:
    # 在beam方法中,我们需要至少保留一个非eos token,以探索可能具有更好得分的连续序列
    if generation_config.num_beams > 1:
        if isinstance(generation_config._eos_token_tensor, list):
            min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
        elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
            min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
        else:
            min_tokens_to_keep = 2
    else:
        min_tokens_to_keep = 1
    # 以下思想主要来自PR:https://github.com/huggingface/transformers/pull/5420/files
    # 所有的sampler都在`generation_utils_samplers.py`中
    if generation_config.temperature is not None and generation_config.temperature != 1.0:
        processors.append(TemperatureLogitsWarper(generation_config.temperature))
    if generation_config.top_k is not None and generation_config.top_k != 0:
        processors.append(
            TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)
        )
    if generation_config.top_p is not None and generation_config.top_p < 1.0:
        processors.append(
            TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)
        )
    if generation_config.min_p is not None:
        # 在温度缩放之后应用(见:https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
        processors.append(
            MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)
        )
    if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
        processors.append(
            TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
        )
    if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
        processors.append(
            EpsilonLogitsWarper(
                epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep
            )
        )
    if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
        processors.append(
            EtaLogitsWarper(
                epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
            )
        )
  • 解释

    • 条件:只有在do_sampleTrue时,才应用这些处理器,因为它们与采样策略相关。

    • 计算min_tokens_to_keep

      • 在束搜索等方法中,需要保留至少一个非结束token,以确保能够探索可能更好的序列。

      • 根据结束token的类型和数量,计算需要保留的最小token数量。

    • 添加采样相关的LogitsWarper

      • 根据generation_config中的配置,向processors中添加相应的LogitsWarper,包括:

        • TemperatureLogitsWarper:调整温度参数,控制生成的随机性。

        • TopKLogitsWarper:仅保留概率最高的top_k个tokens。

        • TopPLogitsWarper:仅保留累计概率达到top_p的tokens。

        • MinPLogitsWarper:应用最小概率阈值。

        • TypicalLogitsWarper:使用Typical采样策略。

        • EpsilonLogitsWarperEtaLogitsWarper:应用epsilon或eta截断。


22. 处理水印(Watermarking)

# Watermarking应该在所有logits处理完成后再应用(参见#34630)
if generation_config.watermarking_config is not None:
    processors.append(
        generation_config.watermarking_config.construct_processor(self.config.vocab_size, device)
    )
  • 解释

    • 条件:如果watermarking_config不为None,则需要在生成过程中应用水印策略。

      • 水印策略通常用于在生成的文本中嵌入隐式标记,以便于后续识别生成内容。
    • 操作:向processors中添加一个由watermarking_config构建的处理器。

      • 使用模型的词汇表大小vocab_size和设备信息device
    • 备注:水印处理器应当在所有logits处理完成后再应用。


23. 处理Logits重归一化(Logit Renormalization)

# `LogitNormalization`应该始终是最后一个logit处理器(如果存在的话)
if generation_config.renormalize_logits is True:
    processors.append(LogitNormalization())
  • 解释

    • 条件:如果renormalize_logitsTrue,则需要在处理完所有logits后重新归一化。

      • 这用于确保logits在所有处理后仍然形成有效的概率分布。
    • 操作:向processors中添加一个LogitNormalization实例。

      • 备注LogitNormalization应当始终是最后应用的logits处理器。

24. 返回处理器列表

return processors
  • 解释:返回构建好的LogitsProcessorList对象processors,供生成过程使用。

_get_stopping_criteria

def _get_stopping_criteria(
    self,
    generation_config: GenerationConfig,
    stopping_criteria: Optional[StoppingCriteriaList],
    tokenizer: Optional["PreTrainedTokenizerBase"] = None,
    **kwargs,
) -> StoppingCriteriaList:
    # 方法体...

参数说明

  • generation_configGenerationConfig 对象,包含生成过程中所需的各种配置参数,如 max_lengthmax_timestop_strings 等。
  • stopping_criteria:可选的 StoppingCriteriaList 对象,用户可以传入自定义的停止条件列表,与默认的停止条件合并。
  • tokenizer:可选的 PreTrainedTokenizerBase 对象,模型对应的 tokenizer,用于处理 stop_strings 等需要 tokenizer 的功能。
  • **kwargs:其他可能的关键字参数,供扩展使用。

返回值

  • StoppingCriteriaList:该方法返回一个 StoppingCriteriaList 对象,包含生成过程中需要检查的停止条件。

该方法的主要作用是:

  1. 创建默认的停止条件列表:根据 generation_config 中的配置,生成对应的停止条件(如最大长度、最大时间、特殊字符串、结束标记等)。

  2. 合并用户提供的停止条件:如果用户在 stopping_criteria 中提供了自定义的停止条件,方法会将其与默认的停止条件合并。

  3. 返回完整的停止条件列表:生成过程会根据这个列表,在满足任何一个停止条件时结束生成。


步骤 1:初始化停止条件列表

criteria = StoppingCriteriaList()
  • 说明:创建一个空的 StoppingCriteriaList 对象,用于存储后续添加的停止条件。

步骤 2:处理 max_length 停止条件

if generation_config.max_length is not None:
    max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
    criteria.append(
        MaxLengthCriteria(
            max_length=generation_config.max_length,
            max_position_embeddings=max_position_embeddings,
        )
    )
  • 解释

    • 获取 max_length:从 generation_config 中获取 max_length,即生成的最大长度。

    • 获取模型的最大位置嵌入长度

      • max_position_embeddings = getattr(self.config, "max_position_embeddings", None)

      • 从模型配置中获取 max_position_embeddings,表示模型能处理的最大序列长度。

    • 添加 MaxLengthCriteria

      • 创建一个 MaxLengthCriteria 对象,传入 max_lengthmax_position_embeddings

      • 将其添加到 criteria 列表中。

  • 作用:当生成的序列长度达到 max_length 或模型的最大位置嵌入长度时,停止生成。

步骤 3:处理 max_time 停止条件

if generation_config.max_time is not None:
    criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
  • 解释

    • 获取 max_time:从 generation_config 中获取 max_time,即生成的最长时间(以秒为单位)。

    • 添加 MaxTimeCriteria

      • 创建一个 MaxTimeCriteria 对象,传入 max_time

      • 将其添加到 criteria 列表中。

  • 作用:当生成过程运行时间超过 max_time 秒时,停止生成。

步骤 4:处理 stop_strings 停止条件

if generation_config.stop_strings is not None:
    if tokenizer is None:
        raise ValueError(
            "There are one or more stop strings, either in the arguments to `generate` or in the "
            "model's generation config, but we could not locate a tokenizer. When generating with "
            "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
        )
    criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
  • 解释

    • 获取 stop_strings:从 generation_config 中获取 stop_strings,这是一个字符串列表,包含用于停止生成的特殊字符串。

    • 检查 tokenizer:因为处理字符串需要使用 tokenizer,如果 tokenizerNone,抛出 ValueError

    • 添加 StopStringCriteria

      • 创建一个 StopStringCriteria 对象,传入 stop_stringstokenizer

      • 将其添加到 criteria 列表中。

  • 作用:当生成的文本中出现指定的字符串时,停止生成。

步骤 5:处理 eos_token_id(结束标记)停止条件

if generation_config._eos_token_tensor is not None:
    criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
  • 解释

    • 获取 _eos_token_tensor:从 generation_config 中获取 _eos_token_tensor,即结束标记的 token ID。

    • 添加 EosTokenCriteria

      • 创建一个 EosTokenCriteria 对象,传入 eos_token_id

      • 将其添加到 criteria 列表中。

  • 作用:当生成的序列中出现结束标记时,停止生成。

步骤 6:处理辅助模型的置信度停止条件(可选)

if (
    generation_config.is_assistant
    and generation_config.assistant_confidence_threshold is not None
    and generation_config.assistant_confidence_threshold > 0
):
    criteria.append(
        ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold)
    )
  • 解释

    • 条件检查

      • generation_config.is_assistant:判断是否使用了辅助模型。

      • generation_config.assistant_confidence_threshold 不为空且大于 0。

    • 添加 ConfidenceCriteria

      • 创建一个 ConfidenceCriteria 对象,传入 assistant_confidence_threshold

      • 将其添加到 criteria 列表中。

  • 作用:当辅助模型的置信度达到一定阈值时,停止生成。

步骤 7:合并用户提供的停止条件

criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
  • 解释

    • 合并默认和用户提供的停止条件

      • 调用 _merge_criteria_processor_list 方法,将默认的 criteria 与用户提供的 stopping_criteria 合并。

      • 这个方法通常会去除重复的条件,或者根据某些规则进行合并。

  • 作用:确保生成过程中考虑所有相关的停止条件。

步骤 8:返回最终的停止条件列表

return criteria
  • 解释:返回包含所有停止条件的 StoppingCriteriaList,供生成过程使用。

整体流程总结

  • 目的:为生成过程准备一个完整的停止条件列表,当满足任何一个条件时,生成过程将停止。

  • 处理逻辑

    1. 初始化:创建一个空的停止条件列表。

    2. 添加默认停止条件:根据 generation_config 中的配置,添加对应的停止条件,包括最大长度、最大时间、结束标记、特殊字符串等。

    3. 辅助模型条件:如果使用了辅助模型,并且设置了置信度阈值,添加对应的停止条件。

    4. 合并用户停止条件:将用户提供的 stopping_criteria 合并到默认条件列表中。

    5. 返回:输出完整的停止条件列表。


http://www.dtcms.com/a/113187.html

相关文章:

  • MIPI与DVP接口摄像头:深度解析与应用指南
  • 素数的判断方法
  • Mysql explain中列的解析
  • SortedSet结构之用户积分实时榜单实战
  • WordPress图标设置插件,免费功能小巧
  • 武装自己的Kali
  • 轨道交通装备三维检测与轻量化设计
  • Cookie、Session、Token、JWT的区别和使用场景
  • 深度测评 | 聚铭下一代智慧安全运营中心如何破解电力行业安全运维难题?
  • C++ 判断字符是否为数字或字母:isalpha、isdigit 和 isalnum 函数详解
  • 【2-8】同步通信与异步通信
  • 数据库性能优化(sql优化)_子查询02_yxy
  • 二十种中药果实识别分类系统,Python/resnet18/pytorch
  • C++_类和对象(下)
  • 无状态版的DHCPv6是不是SLAAC? 笔记250405
  • 【LeetCode Solutions】LeetCode 146 ~ 150 题解
  • JVM深入原理(六)(二):双亲委派机制
  • 元宇宙概念下,UI 设计如何打造沉浸式体验?
  • 从零开始玩python--python版植物大战僵尸来袭
  • 数字化转型中的开源AI智能客服与S2B2C商城小程序的融合创新
  • ☕️ 关于本博客 ☀️
  • OSCP - Proving Grounds- SoSimple
  • VUE+SPRINGBOOT+语音技术实现智能语音歌曲管理系统
  • 交换机与路由器的区别
  • 故障矩阵像素照片效果ps标题文本特效滤镜样机 Glitched Arcade Text Logo Effect
  • 【Python】数组的条件逻辑统计运算元素排序
  • Java的Selenium的特殊元素操作与定位之window切换
  • 推荐系统的注意力进化:从 Self-Attention 到 Target-Attention
  • BT-Basic函数之首字母T
  • 《打破SQL与AI框架对接壁垒,解锁融合新路径》