当前位置: 首页 > news >正文

大模型推理--temperature、top_k、top_p理解

LLM推理的最后一步是要从众多候选token中选择一个输出,一般可以选择softmax概率最大的token输出,这样相同的输入都会获得确定的输出。不过,在很多情况下,最优输出不见得是最好的输出,尤其在当下LLM还不完美的情况下。为此我们需要让LLM的输出在保证靠谱的前提下尽可能多样,temperature、top_k、top_p这三个变量就是出于此目的设计出来的。当然很多博客中已经介绍了这三个变量的作用,但是很多人可能对细节还不了解,正好最近看了一个Python实现,借此给大家详细介绍一下这三个变量的作用。

1. 源码实现

我参考的源码是Freeze Omni这个项目中的post_process,并进行了简化,源码如下:

def do_sampling(logits: torch.Tensor, temperature=1.0, top_k=20, top_p=0.8):
    if temperature != 1.0:
        logits = logits / temperature

    probs = F.softmax(logits, dim=-1)

    if top_k != 0:
        top_k_probs, top_k_indices = torch.topk(probs, top_k)
        probs = torch.zeros_like(probs).scatter_(0, top_k_indices, top_k_probs)
        probs = probs / probs.sum()

    if top_p > 0.0:
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p #前0后1的数组
        sorted_indices_to_remove[0] = 0 #确保要保留一个

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        probs[indices_to_remove] = 0
        probs = probs / probs.sum()

    token_index = torch.multinomial(probs, 1)
return token_index

2. temperature作用

从源码中可以看到,涉及temperature的部分很简单,只有一句: logits = logits / temperature。虽然只有这简单的一句,但是内涵相当丰富。Temperature可以用来调节输入logits不同值之间的差值(比例关系不变),这个在后续求softmax的时候会有很大影响。可以设想一下,假设我们将temperature设的很大,经过这个除法之后logits的值就都非常小,在求softmax的时候各个位置的概率就会很均匀,在后续利用top_k和top_p进行采样时就会让采样具有多样性,结果也会更富有变化;如果temperature设的很小,则会导致logits更容易出现一个很大的值,在求softmax的时候就会导致某个token的概率很大,进而就会导致后续的采样结果更确定。

3. top_k的作用

top_k的作用就是从softmax的输出中选择前k大值,如果top_k等于0相当于softmax的所有输出均参与后面的采样。由于softmax的输出分布极为不均匀,往往只有一个或者几个较大的值,其他的值都接近于0。如果不设置top_k,一些概率很小的token也会参与采样,可能会导致结果过于发散(当然这种情况的概率也很低),所以一般都会设置top_k。
代码在求得top_k的概率和位置之后,将这k个值散布到和原始probs相同大小的零tensor中再归一化,这相当于把top_k之外的所有位置都置零。概率为0的位置在后续采样的时候就不会被选中。

4. top_p的作用

top_p相关的代码较长,作用解释起来稍微有点复杂。它是从累积概率的角度对softmax或者top_k之后的token位置进行进一步的筛选。代码首先会对probs进行降序排序,然后计算累积概率,找到累积概率首次超过top_p的位置,截断此后的所有概率。后续几步运算就是把累积概率超过top_p的所有位置置零,确保这些位置在采样时被排除。这样,top_p能在保持多样性的同时,避免极端小概率事件的影响,使结果更可控。通过合理设置top_k和top_p,能在精确性和多样性间找到平衡。例如,当top_p设为0.9时,意味着只保留累积概率达到90%的前几个token,其余的则被舍弃。这样既保证了采样结果的丰富性,又避免了低概率token的干扰,使得生成文本在可控范围内更具质量和连贯性。通过细致调整这两个参数,模型输出将更加符合预期,满足不同场景下的需求。
代码的最后利用multinomial来进行采样。在经过top_k和top_p的筛选之后,multinomial的输入是只有几个位置概率大于0,其他位置均为0的一个概率分布。
multinomial函数会根据调整后的概率分布进行随机采样,选择最可能的token,确保生成文本既符合预期又具备一定随机性,从而提升整体的自然性和可读性。

5. 性能优化

上述代码对理解这三个参数的含义比较好,但是在性能方面却存在不少问题,我们尝试对其进行优化。

5.1 优化一

原始代码在求得top_k之后会重新构造一个和原始logits一样大小的tensor,然后再进行排序。但实际上排完序之后的tensor前k个元素和top_k的结果一样,完全没必要构造新的tensor,我们可以直接利用top_k的结果求累积概率,这样我们就把sort的时间给省掉了。
利用top_k_probs对top_p进行筛选,找到首次超过top_p的位置进行截断归一化,相比之前对完整的probs进行筛选现在只需要在top_k个位置上进行筛选,速度提升了不少。按照该种优化方法实现的采样代码如下:

def do_sampling(logits: torch.Tensor, temperature=1.0, top_k=20, top_p=0.8):
    if temperature != 1.0:
        logits = logits / temperature

    probs = F.softmax(logits, dim=-1)

    if top_k != 0:
        top_k_probs, top_k_indices = torch.topk(probs, top_k)
		top_k_probs /= top_k_probs.sum()  

    if top_p > 0.0:
        cumulative_probs = torch.cumsum(top_k_probs, dim=-1)
        mask = cumulative_probs > top_p 
		mask[0] = 1 #确保要保留至少一个位置
		Probs = top_k_probs[mask]
        probs = probs / probs.sum()

    token_index = torch.multinomial(probs, 1)
	return top_k_indices[token_index]

上述代码在虽然优化了性能,但是也将top_k和top_p绑定在一起,大家酌情使用。

5.2 优化二

优化一相当于原始代码的等价变换,相同的输入得到相同的输出。还有一个略微改变输出结果的方法,主要是在softmax身上做文章。原始代码是对完整的logits求softmax,然后再求topk。可以改为先对logits求topk,再对筛选后的topk结果求softmax,这样top_k的复杂度没变,但是softmax的复杂度则大幅降低。这样计算可行的原因是利用了softmax的单调性。不过这样计算会导致softmax的输出与原始代码不一样,需要我们重新调整top_k和top_p的取值,以确保结果的可靠性。代码与优化一类似,在此不再给出代码。

相关文章:

  • 未授权访问
  • neo4j知识图谱常用命令
  • 在Java中使用JDK8创建SpringBoot项目时无法选择Java8
  • es6 fetch
  • Flutter快速搭建聊天
  • eNSP中华为S5700交换机基础配置命令
  • Android Compose 框架物理动画之弹簧动画(Spring、SpringSpec)深入剖析(二十七)
  • SEO长尾关键词精准布局策略
  • JAVA多线程
  • 物联网平台架构介绍
  • redis 学习笔记
  • 初教六双击编队特技动作解析
  • 【产品小白】需求分析的进阶
  • DeepSeek-V3 模型更新,加量不加价
  • 2025 polarctf春季个人挑战赛web方向wp
  • 向量数据库学习笔记(1) —— 基础概念
  • 1.基于TCP的简单套接字服务器实现
  • TiDB与Doris实操对比:深度剖析数据库选型要点
  • 期权合约到期了还能继续持有吗?
  • 至此,他19岁青春烙印上了苦涩的烧痕。
  • 赡养纠纷个案推动类案监督,检察机关保障特殊群体胜诉权
  • 我国城市规划“全面体检”套餐出台,城市体检将逐步与供地计划等挂钩
  • 广西等地旱情缓解,水利部针对甘肃启动干旱防御Ⅳ级响应
  • 澳大利亚首例“漂绿”诉讼开庭:能源巨头因“碳中和”承诺遭起诉
  • 专访|导演刘江:给谍战题材注入现实主义的魂
  • 线下哪些商家支持无理由退货?查询方法公布