当前位置: 首页 > news >正文

计算Transformer的Flops

计算GCN的Flops

计算CNN的Flops

计算ResNet的Flops

计算MobileNet的Flops

计算Transformer的Flops

核心程序:

    model = torch.load(path)calculate_flops = 1if calculate_flops:total_params = sum(p.numel() for p in model.parameters())print("Total parameters:", total_params)trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)print("Trainable parameters:", trainable_params)# 确保模型构建完成input_shape = (1, x.shape[1])  # batch size = 1macs, params = get_model_complexity_info(model, input_shape, as_strings=True,print_per_layer_stat=True, verbose=True)print(f"FLOPs: {macs}")print(f"Parameters: {params}")

每一模块的Flops计算结果,其中1Flops=2MACs=1次乘法+1次加法

最终结果:

完整程序,仅供参考:

import argparse
import torch
import h5py
import numpy as np
from transformer.Models import Transformer, MLP
import os
import re
import torch.nn as nn
from torch.nn import functional as F
from tensorboardX import SummaryWriter
import time
from ptflops import get_model_complexity_infoN = 30
snr = 40
learning_rate = 0.003
M = 3000
epochs = 10000
dataset_n = 100def graph_normalize(signals):for i in range(signals.shape[0]):signals[i, :] = signals[i, :]/np.max(np.abs(signals[i, :]))def generate_dataset():base_dir = "D:\无线通信网络认知\论文1\大修意见\Reviewer2-7 多种深度学习方法对比实验\\test data 30 (mat)\\"mat_file = h5py.File(base_dir  + '30_nodes_dataset.mat', 'r')# 获取数据集signals = mat_file["Signals"][()]tp = mat_file["Tp"][()]tp_list = mat_file["Tp_list"][()]Signals = np.swapaxes(signals, 2, 0)Tp = np.swapaxes(tp, 2, 0)# tp_list = tp_list - 1# 关闭文件mat_file.close()#把每张图的数据和标签都装进去x_list = []y_list = []for n in range(0,dataset_n,1):# print("n: ",n)signals = Signals[n,:,:]graph_normalize(signals)L = 2 * signals.shape[1]  # 待分析的2个信号拼在一起tp = Tp[n,:,:]# x 的形状为 (100000, 3600)x = np.zeros((N * N, L))# y 的形状为 (100000, 1)y = np.zeros((N * N, 1))# 生成 x 和 yindex = 0#处理正样本for i in range(N):for j in range(N):if i!=j and tp[i, j]==1:combined_signal = np.concatenate((signals[i, :], signals[j, :]))x[index, :] = combined_signaly[index, 0] = 1index += 1#算负样本总共有多少n_pair_list = []for i in range(N):for j in range(N):if i!=j and tp[i, j] == 0:n_pair_list.append((i,j))np.random.seed(42)indices = np.arange(len(n_pair_list))np.random.shuffle(indices)n_pair_list = np.array(n_pair_list)n_pair_list = n_pair_list[indices]#根据正样本数取负样本n_pair = n_pair_list[:index,:]for k in range(n_pair.shape[0]):i = n_pair[k, 0]j = n_pair[k, 1]combined_signal = np.concatenate((signals[i, :], signals[j, :]))x[index, :] = combined_signaly[index, 0] = 0index += 1x = x[:index,:]y = y[:index, :]# x = np.expand_dims(x, axis=-1)# plt.plot(np.linspace(0, 1, 2520), np.hstack((x[1, :1260], x[2, 1260:])))# c = 70# plot_sig = np.vstack((x[c, :1260], x[c, 1260:]))# print("y: ",1-y[c,0])# plt.plot(np.linspace(0, 1, 1260), plot_sig.T)x_list.append(x)y_list.append(y)x = np.vstack(x_list)y = np.vstack(y_list)return x, ydef generate_time_dataset():base_dir = "D:\无线通信网络认知\论文1\大修意见\Reviewer2-7 多种深度学习方法对比实验\\test data 30 (mat)\\"mat_file = h5py.File(base_dir + '30_nodes_dataset.mat', 'r')signals = mat_file["Signals"][()]tp = mat_file["Tp"][()]# 获取数据集Signals = np.swapaxes(signals, 2, 0)Tp = np.swapaxes(tp, 2, 0)# tp_list = tp_list - 1# 关闭文件mat_file.close()x_list = []for n in range(0,dataset_n,1):# print("n: ",n)signals = Signals[n, :, :]L = 2 * signals.shape[1]  # 待分析的2个信号拼在一起tp = Tp[n, :, :]# x 的形状为 (100000, 3600)x = np.zeros((N * N, L))# y 的形状为 (100000, 1)y = np.zeros((N * N, 1))# 生成 x 和 yindex = 0for i in range(N):for j in range(N):combined_signal = np.concatenate((signals[i, :], signals[j, :]))x[index, :] = combined_signalindex += 1x_list.append(x)x = np.vstack(x_list)return xdef cal_performance(tra_pred,tra_true):return F.mse_loss(tra_pred,tra_true)def train(model, data, label, optimizer):best_acc = 0count = 0# model = torch.load(model_dir + 'model.pt')for epoch in range(epochs):# if epoch!=0:count = count + 1if count>50:breakoptimizer.zero_grad()  # 清零优化器梯度,梯度不清零会一直存在# score = score.to(device)correct_count = 0# pred = model(before_track_data.get(p).to(device), after_track_data.get(q).to(device))pre = model(data)loss = loss_function(pre, label)  # 计算一次损失# loss = loss_function(pre_1.float(), data_1.y.float())# loss = loss_function(pre_1, data_1.y)# loss反向传播就行,这里没有acc监视器loss.backward()# print(" ")# 用反向传播得到的梯度进行参数更新optimizer.step()# 计算准确率with torch.no_grad():# 输出是概率,转换成0/1预测值pred_class = (pre >= 0.5).float()correct = (pred_class == label).sum().item()total = label.size(0)accuracy = correct / totalif accuracy > best_acc:best_acc = accuracytorch.save(model, model_dir + 'epoch '+ str(round(epoch)) +' accuracy '+str(round(accuracy,3))+'.pt')count = 0print("epoch: ", epoch, "  loss: ", loss.item(), "accuracy: ", accuracy)def find_max_val_acc_file(folder_path):max_val_acc = -1max_file = None# 正则表达式匹配 'val_acc_' 后面的数字# pattern = re.compile(r'val_accuracy_([0-9]+\.[0-9]+)')pattern = re.compile(r'accuracy ([0-9]+\.[0-9]+)')for filename in os.listdir(folder_path):match = pattern.search(filename)if match:val_acc = float(match.group(1))if val_acc > max_val_acc:max_val_acc = val_accmax_file = filenamereturn max_file, max_val_accdef test(data, label):base_path = r'D:\english\WCNA\Transformer\fn 59 8\\'max_file, max_val_acc = find_max_val_acc_file(base_path)path = base_path + max_filemodel = torch.load(path)calculate_flops = 1if calculate_flops:# model.summary()# print("Total parameters:", model.count_params())total_params = sum(p.numel() for p in model.parameters())print("Total parameters:", total_params)trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)print("Trainable parameters:", trainable_params)# 确保模型构建完成input_shape = (1, x.shape[1])  # batch size = 1# input_shape = x.shape[1]# input_shape = (1, x_val.shape[1])  # 例如 (batch_size=1, sequence_length=3600)macs, params = get_model_complexity_info(model, input_shape, as_strings=True,print_per_layer_stat=True, verbose=True)print(f"FLOPs: {macs}")print(f"Parameters: {params}")# print(f"FLOPs: {flops / 10 ** 6:.2f} MFLOPs")pre = model(data)# loss = loss_function(pre, label)with torch.no_grad():# 输出是概率,转换成0/1预测值pred_class = (pre >= 0.5).float()correct = (pred_class == label).sum().item()total = label.size(0)accuracy = correct / totalprint("accuracy: ", accuracy)def calculate(data):base_path = r'D:\englist\WCNA\Transformer\fn 59 8\\'max_file, max_val_acc = find_max_val_acc_file(base_path)path = base_path + max_filemodel = torch.load(path)t1 =time.time()pre = model(data)t2 = time.time()delta_t = (t2 - t1)/100 #100个样本print("time: ", delta_t)if __name__ == '__main__':x, y = generate_dataset()indices = np.arange(y.shape[0])np.random.shuffle(indices)x = x[indices]y = y[indices]# x_down = x[:, ::2]a = round(0.8 * len(indices))# b = 12000x_train = x[:a, :]y_train = y[:a, :]x_val = x[a:, :]y_val = y[a:, :]x_train = torch.tensor(x_train).float()y_train = torch.tensor(y_train).float()x_val = torch.tensor(x_val).float()y_val = torch.tensor(y_val).float()device = "cuda:0"x_train = x_train.to(device)y_train = y_train.to(device)x_val = x_val.to(device)y_val = y_val.to(device)log_writer = SummaryWriter()# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'model_dir = "D:\englist\WCNA\Transformer\\fn 7 4\\"loss_function = nn.BCELoss()  # 损失函数parser = argparse.ArgumentParser()parser.add_argument('-epoch', type=int, default=epochs)parser.add_argument('-b', '--batch_size', type=int, default=40)parser.add_argument('-d_model', type=int, default=7)parser.add_argument('-d_inner_hid', type=int, default=236)parser.add_argument('-d_k', type=int, default=64)parser.add_argument('-d_v', type=int, default=64)parser.add_argument('-warmup', '--n_warmup_steps', type=int, default=2000)parser.add_argument('-lr_mul', type=float, default=2.0)parser.add_argument('-lr', type=float, default=0.001)parser.add_argument('-n_head', type=int, default=2)parser.add_argument('-n_layers', type=int, default=1)parser.add_argument('-dropout', type=float, default=0.1)parser.add_argument('-do_train', type=bool, default=False)parser.add_argument('-do_retrain', type=bool, default=False)parser.add_argument('-do_eval', type=bool, default=True)parser.add_argument('-use_mlp', type=bool, default=False)parser.add_argument('-calculate', type=bool, default=False)opt = parser.parse_args()opt.d_word_vec = opt.d_model# device = "gpu:0"# device="cpu"transformer = Transformer(2000,2000,d_k=opt.d_k,d_v=opt.d_v,d_model=opt.d_model,d_word_vec=opt.d_word_vec,d_inner=opt.d_inner_hid,n_layers=opt.n_layers,n_head=opt.n_head,dropout=opt.dropout,n_position = 250).to(device)mlp = MLP(10,10,25,50,use_extra_input=False).to(device)model_train = transformerif opt.use_mlp:model_train = mlpif opt.do_train == True:parameters = mlp.parameters() if opt.use_mlp else transformer.parameters()# optimizer = ScheduledOptim(#     optim.Adam(parameters, betas=(0.9, 0.98), eps=1e-09),#     opt.lr, opt.d_model, opt.n_warmup_steps, opt.use_mlp)lr = 0.001optimizer = torch.optim.Adam(model_train.parameters(), lr=lr)if opt.do_retrain == True: # only used for transformercheckpoint = torch.load("./checkpoint/ckpt.pth")transformer.load_state_dict(checkpoint['net'])optimizer.load_state_dict(checkpoint['optimizer'])train(model=model_train,data=x_train,label=y_train,optimizer=optimizer)if opt.do_eval == True:test(data=x_val,label=y_val)if opt.calculate == True:x = generate_time_dataset()x = torch.tensor(x).float()x = x.to("cuda:0")calculate(data=x)# model = torch.load(model_dir + 'model.pt')

相关文章:

  • 从 0 到 1 打造社区产品:短说社区助力开启社交新篇
  • Java编程中的设计模式:单例模式的深度剖析
  • 深度解析 Caffeine:高性能 Java 缓存库
  • LED-Merging: 无需训练的模型合并框架,兼顾LLM安全和性能!!
  • iOS App 上架步骤解析:适合资源有限团队的上架流程与注意事项
  • 【Verilog】Verilator的TestBench该用C++还是SystemC
  • OpenSSL 混合加密
  • 16.数据聚合
  • C++的前世今生-C++11
  • 进入python虚拟环境的方法
  • hive集群优化和治理常见的问题答案
  • 「ECG信号处理——(18)基于时空特征的心率变异性分析」2025年6月23日
  • 实时反欺诈:基于 Spring Boot 与 Flink 构建信用卡风控系统
  • 2025.06.23【甲基化】methylKit:甲基化测序数据分析安装与详细使用教程
  • 鸿蒙容器组件 Row 全解析:水平布局技术与多端适配指南
  • 《Effective Python》第十章 健壮性——善用 try/except/else/finally,写出更健壮的 Python 异常处理代码
  • 体制内写公文,用ai工具辅助写材料
  • Advent of Cyber 1 [2019] - [Day 13] | TryHackMe
  • Go 语言使用 excelize 库操作 Excel 的方法
  • FastAPI + PyMySQL 报错:“dict can not be used as parameter”的原因及解决方案
  • 静态网站源文件下载/互联广告精准营销
  • 贵阳汽车网站建设/今日nba比赛直播
  • 如何做微信小程序网站/网站注册查询
  • 做财经比较好的网站有哪些/网页搜索快捷键
  • 着陆页设计网站国内/申请百度收录网址
  • 网站怎么做数据库/外包网络推广公司