【arXiv 2025】新颖方法:基于快速傅里叶变换的高效自注意力,即插即用!
一、整体介绍
The FFT Strikes Again: An Efficient Alternative to Self-Attention
FFT再次出击:一种高效的自注意力替代方案
图1:FFTNet整体流程,包括局部窗口处理(STFT或小波变换,可选)和全局FFT,随后在频率/变换域进行等距融合(或门控)。
朋友们,今天为大家介绍一个非常有潜力,未来可能会在自然语言处理、计算机视觉、图像处理等领域发挥重大作用的方法。
中心思想:该方法来源arXiv[1],是2025年3月16日最新公开论文,提出了一种名为FFTNet的自适应频谱滤波框架,该框架利用快速傅里叶变换(FFT)在O(nlogn)时间内实现全局标记混合,有效解决了传统自注意力机制在处理长序列时的二次复杂度问题,把自注意力机制(Self-Attention)的时间复杂度从O(n²)降到O(nlogn)。
实现动机:传统的注意力机制在计算成对标记交互时,随着序列长度n的增加,成本呈二次方增长,这使得处理长序列变得昂贵。相比之下,离散傅里叶变换(DFT)在O(nlogn)时间内自然编码全局交互,因为它将标记序列分解为正交频率分量。
核心原理:根据帕塞瓦尔定理,在傅里叶变换下,对于输入序列X及其傅里叶变换F=FFT(X),信号的总能量保持不变,除了一个常数缩放因子。这一能量保持保证了自适应滤波和非线性操作不会意外扭曲输入信号的固有信息。(学过数字信号处理课程的朋友应该更容易理解,总结起来就是一句话:把信号转换到频域进行处理,不会丢失信号信息,但是可以减少计算量)
证明公式和复杂度计算的公式较为枯燥,本文省略。
下面以代码为例,展示原理及用法。
二、代码与原理解读
1. 基于快速傅里叶变换的基础网络块——FFTNetBlock
import torch
import torch.nn as nn
import torch.nn.functional as F
class ModReLU(nn.Module):def __init__(self, features):super().__init__()self.b = nn.Parameter(torch.Tensor(features))self.b.data.uniform_(-0.1, 0.1)def forward(self, x):return torch.abs(x) * F.relu(torch.cos(torch.angle(x) + self.b))
class FFTNetBlock(nn.Module):def __init__(self, dim):super().__init__()self.dim = dimself.filter = nn.Linear(dim, dim)self.modrelu = ModReLU(dim)def forward(self, x):# x: [batch_size, seq_len, dim]x_fft = torch.fft.fft(x, dim=1) # FFT along the sequence dimensionx_filtered = self.filter(x_fft.real) + 1j * self.filter(x_fft.imag)x_filtered = self.modrelu(x_filtered)x_out = torch.fft.ifft(x_filtered, dim=1).realreturn x_out
if __name__ == '__main__':# 参数设置batch_size = 1 # 批量大小seq_len = 224 * 224 # 序列长度(Transformer 中的 token 数量)dim = 32 # 维度# 创建随机输入张量,形状为 (batch_size, seq_len, embed_dim)x = torch.randn(batch_size, seq_len, dim)# 初始化 FFTNetBlock 模块model = FFTNetBlock(dim = dim)print(model)print("微信公众号: AI缝合术!")output = model(x)print(x.shape)print(output.shape)
运行结果:
该代码实现了一个基于 FFT(快速傅里叶变换)的神经网络块,称为 FFTNetBlock,并在 forward 过程中对输入信号进行频域处理。
实现流程:
①使用 FFT 进行频域转换:输入 x 通过 FFT 转换到频域,在频域进行操作。
②使用可学习的滤波器:通过 nn.Linear 进行频域的线性变换,相当于卷积核在频域对信号进行加权处理。
③使用 ModReLU 进行非线性处理:由于 FFT 产生的结果是复数,传统的 ReLU 不能直接作用,因此使用 ModReLU 进行非线性变换。ModReLU为修正的 ReLU 激活函数,作用类似于ReLU在实数域上的作用,但应用于复数域,通过修改相位角(angle)并结合 ReLU 进行修正。
④最终通过 iFFT 还原回时序空间:经过处理的频域信息通过逆 FFT(ifft)变换回时序域,得到最终输出。
2. 基于快速傅里叶变换的ViT网络——FFTNetViT
import torch
import torch.nn as nn
import torch.nn.functional as F
def drop_path(x, drop_prob: float = 0., training: bool = False):"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_prob# Generate binary tensor mask; shape: (batch_size, 1, 1, ..., 1)shape = (x.shape[0],) + (1,) * (x.ndim - 1)random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_() # binarizeoutput = x.div(keep_prob) * random_tensorreturn output
class DropPath(nn.Module):"""DropPath module that performs stochastic depth."""def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path(x, self.drop_prob, self.training)class MultiHeadSpectralAttention(nn.Module):def __init__(self, embed_dim, seq_len, num_heads=4, dropout=0.1, adaptive=True):"""频谱注意力模块,在保持 O(n log n) 计算复杂度的同时,引入额外的非线性和自适应能力。参数:- embed_dim: 总的嵌入维度。- seq_len: 序列长度(例如 Transformer 中 token 的数量,包括类 token)。- num_heads: 注意力头的数量。- dropout: 逆傅里叶变换(iFFT)后的 dropout 率。- adaptive: 是否启用自适应 MLP 以生成乘法和加法的自适应调制参数。"""super().__init__()if embed_dim % num_heads != 0:raise ValueError("embed_dim 必须能被 num_heads 整除")self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.seq_len = seq_lenself.adaptive = adaptive# 频域的 FFT 频率桶数量: (seq_len//2 + 1)self.freq_bins = seq_len // 2 + 1# 基础乘法滤波器: 每个注意力头和频率桶一个self.base_filter = nn.Parameter(torch.ones(num_heads, self.freq_bins, 1))# 基础加性偏置: 作为频率幅度的学习偏移self.base_bias = nn.Parameter(torch.full((num_heads, self.freq_bins, 1), -0.1))if adaptive:# 自适应 MLP: 每个头部和频率桶生成 2 个值(缩放因子和偏置)self.adaptive_mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim),nn.GELU(),nn.Linear(embed_dim, num_heads * self.freq_bins * 2))self.dropout = nn.Dropout(dropout)# 预归一化层,提高傅里叶变换的稳定性self.pre_norm = nn.LayerNorm(embed_dim)def complex_activation(self, z):"""对复数张量应用非线性激活函数。该函数计算 z 的幅度,将其传递到 GELU 进行非线性变换,并按比例缩放 z,以保持相位不变。参数:z: 形状为 (B, num_heads, freq_bins, head_dim) 的复数张量返回:经过非线性变换的复数张量,形状相同。"""mag = torch.abs(z)# 对幅度进行非线性变换,GELU 提供平滑的非线性mag_act = F.gelu(mag)# 计算缩放因子,防止除零错误scale = mag_act / (mag + 1e-6)return z * scaledef forward(self, x):"""增强型频谱注意力模块的前向传播。参数:x: 输入张量,形状为 (B, seq_len, embed_dim)返回:经过频谱调制和残差连接的张量,形状仍为 (B, seq_len, embed_dim)"""B, N, D = x.shape# 预归一化,提高频域变换的稳定性x_norm = self.pre_norm(x)# 重新排列张量以分离不同的注意力头,形状变为 (B, num_heads, seq_len, head_dim)x_heads = x_norm.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# 沿着序列维度计算 FFT,结果为复数张量,形状为 (B, num_heads, freq_bins, head_dim)F_fft = torch.fft.rfft(x_heads, dim=2, norm='ortho')# 计算自适应调制参数(如果启用)if self.adaptive:# 全局上下文:对 token 维度求均值,形状为 (B, embed_dim)context = x_norm.mean(dim=1)# 经过 MLP 计算自适应参数,输出形状为 (B, num_heads*freq_bins*2)adapt_params = self.adaptive_mlp(context)adapt_params = adapt_params.view(B, self.num_heads, self.freq_bins, 2)# 划分为乘法缩放因子和加法偏置adaptive_scale = adapt_params[..., 0:1] # 形状: (B, num_heads, freq_bins, 1)adaptive_bias = adapt_params[..., 1:2] # 形状: (B, num_heads, freq_bins, 1)else:# 如果不使用自适应机制,则缩放因子和偏置设为 0adaptive_scale = torch.zeros(B, self.num_heads, self.freq_bins, 1, device=x.device)adaptive_bias = torch.zeros(B, self.num_heads, self.freq_bins, 1, device=x.device)# 结合基础滤波器和自适应调制参数# effective_filter: 影响频谱响应的缩放因子effective_filter = self.base_filter * (1 + adaptive_scale)# effective_bias: 影响频谱响应的偏置effective_bias = self.base_bias + adaptive_bias# 在频域进行自适应调制# 先进行乘法缩放,再添加偏置(在 head_dim 维度上广播)F_fft_mod = F_fft * effective_filter + effective_bias# 在频域应用非线性激活F_fft_nl = self.complex_activation(F_fft_mod)# 逆傅里叶变换(iFFT)还原到时序空间# 需要指定 n=self.seq_len 以确保输出长度匹配输入x_filtered = torch.fft.irfft(F_fft_nl, dim=2, n=self.seq_len, norm='ortho')# 重新排列张量,将注意力头合并回嵌入维度x_filtered = x_filtered.permute(0, 2, 1, 3).reshape(B, N, D)# 残差连接并应用 Dropoutreturn x + self.dropout(x_filtered)
class TransformerEncoderBlock(nn.Module):def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.1, attention_module=None, drop_path=0.0):"""一个通用的 Transformer 编码器块,集成了 drop path 随机深度 。- embed_dim: 嵌入维度。- mlp_ratio: MLP 的扩展因子。- dropout: dropout 比率。- attention_module: 处理自注意力的模块。- drop_path: 随机深度的 drop path 比率。"""super().__init__()if attention_module is None:raise ValueError("必须提供一个注意力模块! 此处应调用 MultiHeadSpectralAttention")self.attention = attention_moduleself.mlp = nn.Sequential(nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),nn.GELU(),nn.Dropout(dropout),nn.Linear(int(embed_dim * mlp_ratio), embed_dim),nn.Dropout(dropout))self.norm = nn.LayerNorm(embed_dim)# 用于随机深度的 drop path 层self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()def forward(self, x):# 在残差连接中应用带有 drop path 的注意力。x = x + self.drop_path(self.attention(x))# 在残差连接中应用 MLP(经过层归一化)并加入 drop path。x = x + self.drop_path(self.mlp(self.norm(x)))return xif __name__ == '__main__':# 参数设置batch_size = 1 # 批大小seq_len = 224 * 224 # 序列长度embed_dim = 32 # 嵌入维度num_heads = 4 # 注意力头数# 创建随机输入张量 (batch_size, seq_len, embed_dim)x = torch.randn(batch_size, seq_len, embed_dim)# 初始化 MultiHeadSpectralAttentionattention_module = MultiHeadSpectralAttention(embed_dim=embed_dim, seq_len=seq_len, num_heads=num_heads)# 初始化 TransformerEncoderBlocktransformer_block = TransformerEncoderBlock(embed_dim=embed_dim, attention_module=attention_module)print(transformer_block)print("微信公众号: AI缝合术!")# 前向传播测试output = transformer_block(x)# 打印输出形状print("输入形状:", x.shape)print("输出形状:", output.shape)
运行结果:
乍一看代码比较多,其实原理非常简单,该代码实现了一个标准的Transformer编码器结构,除去两个固定操作的随机深度DropPath,剩下仅有两个类组成,MultiHeadSpectralAttention实现了基于快速傅里叶变换的高效多头自注意力,TransformerEncoderBlock是一个通用的Transformer编码器模块。
上图是ViT的经典结构图,我们只看右侧编码器部分,上述代码实现的就是右侧的编码器,只是将多头注意力转换到频域来进行计算,非常容易理解。
采用上面方法构建的FFTNetViT在LRA和ImageNet两个数据集上的广泛评估确认,FFTNet不仅实现了有竞争力的准确性,而且与固定傅里叶方法和标准自注意力相比,显著提高了计算效率。
以上两个模块均可即插即用,应用在自然语言处理、图像处理、计算机视觉等各类任务上,是非常好的创新!
https://github.com/AIFengheshu/Plug-play-modules
2025年全网最全即插即用模块,免费分享!包含人工智能全领域(机器学习、深度学习等),适用于图像分类、目标检测、实例分割、语义分割、全景分割、姿态识别、医学图像分割、视频目标分割、图像抠图、图像编辑、单目标跟踪、多目标跟踪、行人重识别、RGBT、图像去噪、去雨、去雾、去阴影、去模糊、超分辨率、去反光、去摩尔纹、图像恢复、图像修复、高光谱图像恢复、图像融合、图像上色、高动态范围成像、视频与图像压缩、3D点云、3D目标检测、3D语义分割、3D姿态识别等各类计算机视觉和图像处理任务,以及自然语言处理、大语言模型、多模态等其他各类人工智能相关任务。持续更新中.....