大模型推理--temperature、top_k、top_p理解
LLM推理的最后一步是要从众多候选token中选择一个输出,一般可以选择softmax概率最大的token输出,这样相同的输入都会获得确定的输出。不过,在很多情况下,最优输出不见得是最好的输出,尤其在当下LLM还不完美的情况下。为此我们需要让LLM的输出在保证靠谱的前提下尽可能多样,temperature、top_k、top_p这三个变量就是出于此目的设计出来的。当然很多博客中已经介绍了这三个变量的作用,但是很多人可能对细节还不了解,正好最近看了一个Python实现,借此给大家详细介绍一下这三个变量的作用。
1. 源码实现
我参考的源码是Freeze Omni这个项目中的post_process,并进行了简化,源码如下:
def do_sampling(logits: torch.Tensor, temperature=1.0, top_k=20, top_p=0.8):
if temperature != 1.0:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
if top_k != 0:
top_k_probs, top_k_indices = torch.topk(probs, top_k)
probs = torch.zeros_like(probs).scatter_(0, top_k_indices, top_k_probs)
probs = probs / probs.sum()
if top_p > 0.0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p #前0后1的数组
sorted_indices_to_remove[0] = 0 #确保要保留一个
indices_to_remove = sorted_indices[sorted_indices_to_remove]
probs[indices_to_remove] = 0
probs = probs / probs.sum()
token_index = torch.multinomial(probs, 1)
return token_index
2. temperature作用
从源码中可以看到,涉及temperature的部分很简单,只有一句: logits = logits / temperature。虽然只有这简单的一句,但是内涵相当丰富。Temperature可以用来调节输入logits不同值之间的差值(比例关系不变),这个在后续求softmax的时候会有很大影响。可以设想一下,假设我们将temperature设的很大,经过这个除法之后logits的值就都非常小,在求softmax的时候各个位置的概率就会很均匀,在后续利用top_k和top_p进行采样时就会让采样具有多样性,结果也会更富有变化;如果temperature设的很小,则会导致logits更容易出现一个很大的值,在求softmax的时候就会导致某个token的概率很大,进而就会导致后续的采样结果更确定。
3. top_k的作用
top_k的作用就是从softmax的输出中选择前k大值,如果top_k等于0相当于softmax的所有输出均参与后面的采样。由于softmax的输出分布极为不均匀,往往只有一个或者几个较大的值,其他的值都接近于0。如果不设置top_k,一些概率很小的token也会参与采样,可能会导致结果过于发散(当然这种情况的概率也很低),所以一般都会设置top_k。
代码在求得top_k的概率和位置之后,将这k个值散布到和原始probs相同大小的零tensor中再归一化,这相当于把top_k之外的所有位置都置零。概率为0的位置在后续采样的时候就不会被选中。
4. top_p的作用
top_p相关的代码较长,作用解释起来稍微有点复杂。它是从累积概率的角度对softmax或者top_k之后的token位置进行进一步的筛选。代码首先会对probs进行降序排序,然后计算累积概率,找到累积概率首次超过top_p的位置,截断此后的所有概率。后续几步运算就是把累积概率超过top_p的所有位置置零,确保这些位置在采样时被排除。这样,top_p能在保持多样性的同时,避免极端小概率事件的影响,使结果更可控。通过合理设置top_k和top_p,能在精确性和多样性间找到平衡。例如,当top_p设为0.9时,意味着只保留累积概率达到90%的前几个token,其余的则被舍弃。这样既保证了采样结果的丰富性,又避免了低概率token的干扰,使得生成文本在可控范围内更具质量和连贯性。通过细致调整这两个参数,模型输出将更加符合预期,满足不同场景下的需求。
代码的最后利用multinomial来进行采样。在经过top_k和top_p的筛选之后,multinomial的输入是只有几个位置概率大于0,其他位置均为0的一个概率分布。
multinomial函数会根据调整后的概率分布进行随机采样,选择最可能的token,确保生成文本既符合预期又具备一定随机性,从而提升整体的自然性和可读性。
5. 性能优化
上述代码对理解这三个参数的含义比较好,但是在性能方面却存在不少问题,我们尝试对其进行优化。
5.1 优化一
原始代码在求得top_k之后会重新构造一个和原始logits一样大小的tensor,然后再进行排序。但实际上排完序之后的tensor前k个元素和top_k的结果一样,完全没必要构造新的tensor,我们可以直接利用top_k的结果求累积概率,这样我们就把sort的时间给省掉了。
利用top_k_probs对top_p进行筛选,找到首次超过top_p的位置进行截断归一化,相比之前对完整的probs进行筛选现在只需要在top_k个位置上进行筛选,速度提升了不少。按照该种优化方法实现的采样代码如下:
def do_sampling(logits: torch.Tensor, temperature=1.0, top_k=20, top_p=0.8):
if temperature != 1.0:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
if top_k != 0:
top_k_probs, top_k_indices = torch.topk(probs, top_k)
top_k_probs /= top_k_probs.sum()
if top_p > 0.0:
cumulative_probs = torch.cumsum(top_k_probs, dim=-1)
mask = cumulative_probs > top_p
mask[0] = 1 #确保要保留至少一个位置
Probs = top_k_probs[mask]
probs = probs / probs.sum()
token_index = torch.multinomial(probs, 1)
return top_k_indices[token_index]
上述代码在虽然优化了性能,但是也将top_k和top_p绑定在一起,大家酌情使用。
5.2 优化二
优化一相当于原始代码的等价变换,相同的输入得到相同的输出。还有一个略微改变输出结果的方法,主要是在softmax身上做文章。原始代码是对完整的logits求softmax,然后再求topk。可以改为先对logits求topk,再对筛选后的topk结果求softmax,这样top_k的复杂度没变,但是softmax的复杂度则大幅降低。这样计算可行的原因是利用了softmax的单调性。不过这样计算会导致softmax的输出与原始代码不一样,需要我们重新调整top_k和top_p的取值,以确保结果的可靠性。代码与优化一类似,在此不再给出代码。