从旋转位置编码RoPE到YaRN的原理与实现
旋转位置编码
旋转位置编码公式以及与正余弦位置编码对比
正余弦位置编码是原始 Transformer 中使用的标准方法,而旋转位置编码(RoPE)是当前大模型的主流选择,能够通过旋转角度编码位置信息,从而支持序列长度的平滑外推,解决了正余弦位置编码在超长上下文中难以外推的缺点。
既能适应序列长度的不同,在序列长度不同时相对距离要一致
正余弦位置编码与旋转位置编码:
两者都利用了旋转矩阵能得到的平移性,即乘上一个系数等得到增加一个变换量后的结果,只是正余弦位置编码是加上了位置项,使得QKT运算后多出了一个加上的相对位置项,而旋转位置编码是乘上旋转矩阵,使得一个是多出了QKT运算后多出了一个乘上的相对位置系数
共同特征是随着嵌入维度增加,频率逐渐减小,周期变长,从而关注更远的距离。
- 序列维度增加 → 向前旋转相同角度
- 嵌入维度增加 → 旋转角度更小,周期更长
正弦余弦编码是通过加法引入绝对位置,并依靠模型自身从复杂的注意力计算中隐式地学习出相对位置关系。它利用了正弦余弦函数的和角公式。
旋转位置编码是通过乘法(旋转)引入绝对位置,并利用旋转矩阵的数学性质,显式地、天然地在注意力分数中仅体现出相对位置关系。它利用了旋转矩阵的复合性(RmTRn=Rn−mR_m^T R_n = R_{n-m}RmTRn=Rn−m)。
公式:
对于位置 ttt 与维度 2i2i2i、2i+12i+12i+1:
PE(t,2i)=sin(t100002i/d),PE(t,2i+1)=cos(t100002i/d)
PE_{(t,2i)} = \sin\left(\frac{t}{10000^{2i/d}}\right), \quad
PE_{(t,2i+1)} = \cos\left(\frac{t}{10000^{2i/d}}\right)
PE(t,2i)=sin(100002i/dt),PE(t,2i+1)=cos(100002i/dt)
平移性质:
[sin(Δt+t)cos(Δt+t)]=[cos(Δt)−sin(Δt)sin(Δt)cos(Δt)][sin(t)cos(t)]
\begin{bmatrix}
\sin(\Delta t + t) \\ \cos(\Delta t + t)
\end{bmatrix} =
\begin{bmatrix}
\cos(\Delta t) & -\sin(\Delta t) \\
\sin(\Delta t) & \cos(\Delta t)
\end{bmatrix}
\begin{bmatrix}
\sin(t) \\ \cos(t)
\end{bmatrix}
[sin(Δt+t)cos(Δt+t)]=[cos(Δt)sin(Δt)−sin(Δt)cos(Δt)][sin(t)cos(t)]
公式:
θi=10000−2i/d\theta_i = 10000^{-2i/d}θi=10000−2i/d
旋转位置编码(RoPE)
RΘ,mdx=(x0x1x2x3⋮xd−2xd−1)⊗(cosmθ0cosmθ0cosmθ1cosmθ1⋮cosmθd/2−1cosmθd/2−1)+(−x1x0−x3x2⋮−xd−1xd−2)⊗(sinmθ0sinmθ0sinmθ1sinmθ1⋮sinmθd/2−1sinmθd/2−1)R^d_{\Theta,m} \mathbf{x} = \begin{pmatrix} x_0 \\ x_1 \\ x_2 \\ x_3 \\ \vdots \\ x_{d-2} \\ x_{d-1} \end{pmatrix} \otimes \begin{pmatrix} \cos m\theta_0 \\ \cos m\theta_0 \\ \cos m\theta_1 \\ \cos m\theta_1 \\ \vdots \\ \cos m\theta_{d/2-1} \\ \cos m\theta_{d/2-1} \end{pmatrix} + \begin{pmatrix} -x_1 \\ x_0 \\ -x_3 \\ x_2 \\ \vdots \\ -x_{d-1} \\ x_{d-2} \end{pmatrix} \otimes \begin{pmatrix} \sin m\theta_0 \\ \sin m\theta_0 \\ \sin m\theta_1 \\ \sin m\theta_1 \\ \vdots \\ \sin m\theta_{d/2-1} \\ \sin m\theta_{d/2-1} \end{pmatrix}RΘ,mdx=x0x1x2x3⋮xd−2xd−1⊗cosmθ0cosmθ0cosmθ1cosmθ1⋮cosmθd/2−1cosmθd/2−1+−x1x0−x3x2⋮−xd−1xd−2⊗sinmθ0sinmθ0sinmθ1sinmθ1⋮sinmθd/2−1sinmθd/2−1
正余弦位置编码(Sinusoidal PE)
xembedded=(x0x1x2x3⋮xd−2xd−1)+(sin(mθ0)cos(mθ0)sin(mθ1)cos(mθ1)⋮sin(mθd/2−1)cos(mθd/2−1))\mathbf{x}_{embedded} = \begin{pmatrix} x_0 \\ x_1 \\ x_2 \\ x_3 \\ \vdots \\ x_{d-2} \\ x_{d-1} \end{pmatrix} + \begin{pmatrix} \sin(m\theta_0) \\ \cos(m\theta_0) \\ \sin(m\theta_1) \\ \cos(m\theta_1) \\ \vdots \\ \sin(m\theta_{d/2-1}) \\ \cos(m\theta_{d/2-1}) \end{pmatrix}xembedded=x0x1x2x3⋮xd−2xd−1+sin(mθ0)cos(mθ0)sin(mθ1)cos(mθ1)⋮sin(mθd/2−1)cos(mθd/2−1)
相对位置信息推导
RoPE 中的相对位置
设查询向量在位置 mmm,键向量在位置 nnn:
qm=RΘ,mdxm,kn=RΘ,ndxn\mathbf{q}_m = R^d_{\Theta,m} \mathbf{x}_m, \quad \mathbf{k}_n = R^d_{\Theta,n} \mathbf{x}_nqm=RΘ,mdxm,kn=RΘ,ndxn
注意力权重:
qmTkn=xmT(RΘ,md)TRΘ,ndxn\mathbf{q}_m^T \mathbf{k}_n = \mathbf{x}_m^T (R^d_{\Theta,m})^T R^d_{\Theta,n} \mathbf{x}_nqmTkn=xmT(RΘ,md)TRΘ,ndxn
由于旋转矩阵性质:(RΘ,m)TRΘ,n=RΘ,(n−m)(R_{\Theta,m})^T R_{\Theta,n} = R_{\Theta,(n-m)}(RΘ,m)TRΘ,n=RΘ,(n−m)
qmTkn=xmTRΘ,(n−m)dxn\mathbf{q}_m^T \mathbf{k}_n = \mathbf{x}_m^T R^d_{\Theta,(n-m)} \mathbf{x}_nqmTkn=xmTRΘ,(n−m)dxn
结果:直接得到相对位置 (n−m)(n-m)(n−m) 的旋转变换
Sinusoidal PE 中的相对位置
设嵌入后的向量:xm′=xm+PE(m)\mathbf{x}'_m = \mathbf{x}_m + \mathbf{PE}(m)xm′=xm+PE(m),xn′=xn+PE(n)\mathbf{x}'_n = \mathbf{x}_n + \mathbf{PE}(n)xn′=xn+PE(n)
注意力权重:
(xm′)Txn′=(xm+PE(m))T(xn+PE(n))(\mathbf{x}'_m)^T \mathbf{x}'_n = (\mathbf{x}_m + \mathbf{PE}(m))^T (\mathbf{x}_n + \mathbf{PE}(n))(xm′)Txn′=(xm+PE(m))T(xn+PE(n))
=xmTxn+xmTPE(n)+PE(m)Txn+PE(m)TPE(n)= \mathbf{x}_m^T \mathbf{x}_n + \mathbf{x}_m^T \mathbf{PE}(n) + \mathbf{PE}(m)^T \mathbf{x}_n + \mathbf{PE}(m)^T \mathbf{PE}(n)=xmTxn+xmTPE(n)+PE(m)Txn+PE(m)TPE(n)
其中关键项:
PE(m)TPE(n)=∑i=0d/2−1[sin(mθi)sin(nθi)+cos(mθi)cos(nθi)]\mathbf{PE}(m)^T \mathbf{PE}(n) = \sum_{i=0}^{d/2-1} [\sin(m\theta_i)\sin(n\theta_i) + \cos(m\theta_i)\cos(n\theta_i)]PE(m)TPE(n)=∑i=0d/2−1[sin(mθi)sin(nθi)+cos(mθi)cos(nθi)]
=∑i=0d/2−1cos((m−n)θi)= \sum_{i=0}^{d/2-1} \cos((m-n)\theta_i)=∑i=0d/2−1cos((m−n)θi)
结果:通过三角恒等式间接得到相对位置 (m−n)(m-n)(m−n) 信息
旋转位置编码实现流程说明与代码
输出:(batch_size, seq_len, head_dim)
的旋转位置编码,用于 Query/Key 向量旋转。
1. 计算频率倒数
inv_freq[k]=1base2k/dim,k=0,1,…,dim/2−1 \text{inv\_freq}[k] = \frac{1}{\text{base}^{2k/dim}}, \quad k=0,1,\dots,dim/2-1 inv_freq[k]=base2k/dim1,k=0,1,…,dim/2−1
2. 获取位置信息
- 默认连续
[0,1,2,...,seq_len-1]
- 或自定义
position_ids
(支持非连续、批量等)
3. 外积计算角度矩阵
angles=position_ids⊗inv_freq⇒(batch_size,seq_len,dim/2) \text{angles} = \text{position\_ids} \otimes \text{inv\_freq} \quad \Rightarrow \quad (batch\_size, seq\_len, dim/2) angles=position_ids⊗inv_freq⇒(batch_size,seq_len,dim/2)
4. 复制角度并计算 sin/cos
- 方式 A(连接)(通常实现方式):
- 将前半维和后半维分别组成两部分:
x=[x1,x2,x3,x4]⇒前半维=[x1,x2], 后半维=[x3,x4] x = [x_1, x_2, x_3, x_4] \quad \Rightarrow \quad \text{前半维}=[x_1,x_2],\ \text{后半维}=[x_3,x_4] x=[x1,x2,x3,x4]⇒前半维=[x1,x2], 后半维=[x3,x4]
- 旋转操作:
x′=[x1cosθ1−x3sinθ1, x2cosθ2−x4sinθ2, x3cosθ1+x1sinθ1, x4cosθ2+x2sinθ2] x' = [x_1 \cos\theta_1 - x_3 \sin\theta_1, \ x_2 \cos\theta_2 - x_4 \sin\theta_2, \ x_3 \cos\theta_1 + x_1 \sin\theta_1, \ x_4 \cos\theta_2 + x_2 \sin\theta_2] x′=[x1cosθ1−x3sinθ1, x2cosθ2−x4sinθ2, x3cosθ1+x1sinθ1, x4cosθ2+x2sinθ2]
- 对应代码中的
rotate_half
是把前后半维交换,并加上符号。
angles_full=cat([angles,angles],dim=-1)⇒(batch,seq_len,dim) \text{angles\_full} = \text{cat}([\text{angles}, \text{angles}], \text{dim=-1}) \quad \Rightarrow (batch, seq\_len, dim) angles_full=cat([angles,angles],dim=-1)⇒(batch,seq_len,dim)
- 方式 B(交错)(严格对应数学公式):
- 将每两个维度交错组成复数对:
x=[x1,x2,x3,x4]⇒[(x1,x2),(x3,x4)] x = [x_1, x_2, x_3, x_4] \quad \Rightarrow \quad [(x_1,x_2), (x_3,x_4)] x=[x1,x2,x3,x4]⇒[(x1,x2),(x3,x4)]
- 旋转操作:
x′=[x1cosθ1−x2sinθ1, x2cosθ1+x1sinθ1, x3cosθ2−x4sinθ2, x4cosθ2+x3sinθ2] x' = [x_1 \cos\theta_1 - x_2 \sin\theta_1, \ x_2 \cos\theta_1 + x_1 \sin\theta_1, \x_3 \cos\theta_2 - x_4 \sin\theta_2, \ x_4 \cos\theta_2 + x_3 \sin\theta_2] x′=[x1cosθ1−x2sinθ1, x2cosθ1+x1sinθ1, x3cosθ2−x4sinθ2, x4cosθ2+x3sinθ2]
angles_full=interleave([angles,angles]) \text{angles\_full} = \text{interleave}([\text{angles}, \text{angles}]) angles_full=interleave([angles,angles])
- 计算:
cos=cos(angles_full),sin=sin(angles_full) \cos = \cos(\text{angles\_full}), \quad \sin = \sin(\text{angles\_full}) cos=cos(angles_full),sin=sin(angles_full)
5. 应用旋转变换
- 对输入向量 xxx(Q 或 K):
x′=x⋅cos+rotate_half(x)⋅sin x' = x \cdot \cos + \text{rotate\_half}(x) \cdot \sin x′=x⋅cos+rotate_half(x)⋅sin
rotate_half
根据 Cat/Interleave 不同方式选择维度交换策略
关键点
- 频率分层:不同维度对使用不同旋转频率
- 相对位置编码:相同相对距离的 token 对具有相同相对角度差
import torch
import torch.nn as nndef demonstrate_rope_complete_process():"""演示RoPE的完整流程"""print("=== 旋转位置编码(RoPE) 完整流程 ===\n")# 参数设置seq_len = 4head_dim = 8batch_size = 1print(f"参数: seq_len={seq_len}, head_dim={head_dim}, batch_size={batch_size}")# ============ 步骤1:计算频率倒数 ============print("\n步骤1:计算频率倒数")base = 10000.0# 频率倒数:1 / (base^(2i/dim)),只需要dim/2个dim_pairs = torch.arange(0, head_dim, 2, dtype=torch.float32) # [0, 2, 4, 6]inv_freq = 1.0 / (base ** (dim_pairs / head_dim))print(f"dim_pairs: {dim_pairs}")print(f"inv_freq形状: {inv_freq.shape} = {inv_freq}")# ============ 步骤2:获取位置信息 ============print("\n步骤2:获取位置信息")# 方式1:使用seq_lenposition_ids = torch.arange(seq_len, dtype=torch.float32).unsqueeze(0) # [1, seq_len]# 方式2:直接传入position_ids(更灵活)# position_ids = torch.tensor([[0, 1, 5, 8]], dtype=torch.float32) # 非连续位置print(f"position_ids形状: {position_ids.shape} = {position_ids}")# ============ 步骤3:外积计算角度矩阵 ============print("\n步骤3:外积计算角度矩阵")# position_ids: [batch_size, seq_len] -> [batch_size, seq_len, 1]# inv_freq: [dim/2] -> [1, 1, dim/2]pos_expanded = position_ids.unsqueeze(-1) # [1, 4, 1]freq_expanded = inv_freq.unsqueeze(0).unsqueeze(0) # [1, 1, 4]# 外积:[1, 4, 1] * [1, 1, 4] = [1, 4, 4] (广播)angles = pos_expanded * freq_expanded # [batch_size, seq_len, dim/2]print(f"angles形状: {angles.shape}")print(f"angles[0]:\n{angles[0].round(decimals=3)}")# ============ 步骤4:复制并计算sin/cos ============print("\n步骤4:复制并计算sin/cos")# 两种复制方式:# 方式A:连接复制 (Concatenate) - 更常用angles_cat = torch.cat([angles, angles], dim=-1) # [1, 4, 8]cos_cat = angles_cat.cos()sin_cat = angles_cat.sin()print(f"连接方式 - cos形状: {cos_cat.shape}")print(f"cos_cat[0, 0]: {cos_cat[0, 0].round(decimals=3)}")# 方式B:交错复制 (Interleave)angles_interleave = torch.stack([angles, angles], dim=-1) # [1, 4, 4, 2]angles_interleave = angles_interleave.flatten(start_dim=-2) # [1, 4, 8]cos_interleave = angles_interleave.cos()sin_interleave = angles_interleave.sin()print(f"交错方式 - cos形状: {cos_interleave.shape}")print(f"cos_interleave[0, 0]: {cos_interleave[0, 0].round(decimals=3)}")# ============ 步骤5:应用旋转变换 ============print("\n步骤5:应用旋转变换")# 模拟查询向量Qq = torch.randn(batch_size, seq_len, head_dim) # [1, 4, 8]print(f"原始查询向量Q形状: {q.shape}")print(f"Q[0, 0]: {q[0, 0].round(decimals=3)}")# 方式A对应的旋转函数def rotate_half_cat(x):"""连接方式的旋转:前半部分取负号移到后面"""x1 = x[..., :x.shape[-1] // 2] # 前半部分x2 = x[..., x.shape[-1] // 2:] # 后半部分return torch.cat([-x2, x1], dim=-1) # [-x2, x1]# 方式B对应的旋转函数def rotate_half_interleave(x):"""交错方式的旋转:奇偶位置交换并变号"""x = x.reshape(*x.shape[:-1], -1, 2) # [..., dim/2, 2]x_rotated = torch.stack([-x[..., 1], x[..., 0]], dim=-1) # 交换并变号return x_rotated.flatten(start_dim=-2)# 应用旋转变换print("\n连接方式旋转结果:")q_rotated_cat = rotate_half_cat(q)q_embed_cat = q * cos_cat + q_rotated_cat * sin_catprint(f"旋转后Q形状: {q_embed_cat.shape}")print(f"Q_embed_cat[0, 0]: {q_embed_cat[0, 0].round(decimals=3)}")print("\n交错方式旋转结果:")q_rotated_interleave = rotate_half_interleave(q)q_embed_interleave = q * cos_interleave + q_rotated_interleave * sin_interleaveprint(f"旋转后Q形状: {q_embed_interleave.shape}")print(f"Q_embed_interleave[0, 0]: {q_embed_interleave[0, 0].round(decimals=3)}")if __name__ == "__main__":demonstrate_rope_complete_process()
RoPE在Transformers中的使用流程简化版
def rotate_half(x):"""旋转输入的一半维度"""x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):"""应用旋转位置编码"""# 添加head维度cos = cos.unsqueeze(unsqueeze_dim)sin = sin.unsqueeze(unsqueeze_dim)q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embeddef rope_init_fn(config):dim = config.hidden_size//config.num_attention_headsinv_freq = 1.0 / (config.base ** (torch.arange(0, dim, 2) / dim))return inv_freqclass RotaryEmbedding(nn.Module):def __init__(self, config: ModelConfig, device=None):super().__init__()self.inv_freq = rope_init_fn(config)@torch.no_grad()def forward(self, position_ids):"""前向传播:计算cos和sin"""# 扩展inv_freq和position_ids以便批量计算# position_ids:(batch_size,seq_len), inv_freq:(dim//2,)# position_ids_expanded:(batch_size,1,seq_len),inv_freq_expanded:(batch_size,dim//2,1)inv_freq_expanded = self.inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1)position_ids_expanded = position_ids[:, None, :]# 计算频率 * 位置 = 角度# freqs:(batch_size,seq_len,dim//2)freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)# 复制频率以匹配完整的头维度emb = torch.cat((freqs, freqs), dim=-1) # (batch_size,seq_len)# 计算cos和sin,并应用注意力缩放cos = emb.cos()sin = emb.sin()return cos, sinclass Attention(nn.Module):def __init__(self, config: ModelConfig):super().__init__()passdef forward(self, x):b, s, h = x.shape # (batch_size, seq_len, hidden_size) # 得到(batch_size, num_heads, seq_len, head_dim)query = self.q_proj(x).view(b,s,self.num_heads,self.head_dim).transpose(1,2)key = self.k_proj(x).view(b,s,self.num_heads,self.head_dim).transpose(1,2)value = self.v_proj(x).view(b,s,self.num_heads,self.head_dim).transpose(1,2)# 添加位置编码position_ids = torch.arange(s).unsqueeze(0) cos, sin = self.rotary_emb(position_ids)query, key = apply_rotary_pos_emb(query, key, cos, sin, unsqueeze_dim=1) ...
YaRN - 长上下文外推技术
针对 RoPE 在中高维度外推时可能不稳定的缺点,YaRN 对中高维度进行缩放,同时保留低维度的原始高频,实现了更长上下文的稳定建模。对于长上下文文本,rope仍能处理其中的短距离关系(学习过),但长距离未学习过。Position Interpolation方法是将位置索引拉伸到长上下文范围,而Yarn通过对频率插值来适应长上下文。
对于YaRN,核心为计算出两种频率(外推频率、插值频率),每个维度根据依据权重混合这两种频率,使用渐变的权重让高频主要使用原始外推频率,低频主要使用插值频率。
极简版实现(修改上面的rope_init_fn)
# config.factor为缩放因子代表外推倍数,用于减少低频旋转角度
def rope_init_fn_yarn(config):dim = config.hidden_size//config.num_attention_heads # 头维度# 1. 计算基础频率freqs = config.base ** (torch.arange(0, dim, 2,) / dim)inv_freq_extrapolation = 1.0 / pos_freqsinv_freq_interpolation = 1.0 / (config.factor * pos_freqs) # 减少低频旋转角度# 2. 简单渐变混合(维度上线性渐变)extrapolation_factor = 1 - torch.linspace(0, 1, dim // 2) # (1 -> 0)# 3. 混合最终频率inv_freq = inv_freq_interpolation * (1 - extrapolation_factor) + inv_freq_extrapolation * extrapolation_factorreturn inv_freq
transformers中的实现核心流程:
- 生成一个随factor进行log增长的放大系数,即当增大上下文区间的同时适当放大注意力分数,会使用到参数mscale(控制缩放幅度)、mscale_all_dim(进行维度归一)
- beta_fast、beta_slow为设定的圈数界限,经验值为32和1。根据圈数界限计算出对应的维度区间
- 依据上下界给每个维度生成0~1之间的分段系数,区间左侧完全使用外推频率(即原始频率),区间右侧完全使用插值频率(即分母成了factor后放慢的频率),区间中线性混合
def _compute_yarn_parameters(config)base = config.rope_thetapartial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)dim = int(head_dim * partial_rotary_factor)factor = config.rope_scaling["factor"]attention_factor = config.rope_scaling.get("attention_factor")mscale = config.rope_scaling.get("mscale")mscale_all_dim = config.rope_scaling.get("mscale_all_dim")original_max_position_embeddings = (config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings)def get_mscale(scale, mscale=1):if scale <= 1:return 1.0return 0.1 * mscale * math.log(scale) + 1.0# Sets the attention factor as suggested in the paperif attention_factor is None:if mscale and mscale_all_dim:attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))else:attention_factor = get_mscale(factor)# Optional config options# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)beta_fast = config.rope_scaling.get("beta_fast") or 32beta_slow = config.rope_scaling.get("beta_slow") or 1# Compute the inverse frequenciesdef find_correction_dim(num_rotations, dim, base, max_position_embeddings):"""Inverse dimension formula to find the dimension based on the number of rotations"""return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):"""Find dimension range bounds based on rotations"""low = find_correction_dim(low_rot, dim, base, max_position_embeddings)high = find_correction_dim(high_rot, dim, base, max_position_embeddings)if truncate:low = math.floor(low)high = math.ceil(high)return max(low, 0), min(high, dim - 1)def linear_ramp_factor(min, max, dim):if min == max:max += 0.001 # Prevent singularitylinear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)ramp_func = torch.clamp(linear_func, 0, 1)return ramp_func# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs# to expand the possible context length. In other words, interpolation = apply scaling factor.pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)inv_freq_extrapolation = 1.0 / pos_freqsinv_freq_interpolation = 1.0 / (factor * pos_freqs)truncate = config.rope_scaling.get("truncate", True)low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)# Get n-dimensional rotational scaling corrected for extrapolationinv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)inv_freq = (inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)+ inv_freq_extrapolation * inv_freq_extrapolation_factor)return inv_freq, attention_factor