当前位置: 首页 > news >正文

python/pytorch杂聊

Dataset

  • 是否需要自己定义:如果你使用的数据集不是 PyTorch 提供的标准数据集(如 MNIST、CIFAR-10 等),那么你需要继承 torch.utils.data.Dataset 类并实现两个方法:__len__() 和 __getitem__()
  • __len__() 应该返回数据集的总大小。
  • __getitem__() 应该根据索引返回一个数据样本。

DataLoader

  • 是否需要自己定义DataLoader 不需要自己定义,它是 PyTorch 提供的一个类,用于包装 Dataset 并在数据集上提供迭代功能。它支持批量处理、打乱数据、多线程加载等。
  • 使用 DataLoader 时,你可以指定批处理大小(batch_size)、是否打乱数据(shuffle)、数据加载的线程数(num_workers)等。

model定义【继承nn.module父类】

forward:input--forward-->output

forward(self,x)中x表示输入,即x->卷积->relu->卷积-->relu-->输出

class HeightPredictor(nn.Module):
    def __init__(self):
        super(HeightPredictor, self).__init__()
        self.conv1 = nn.Conv2d(1,20,5)
        self.conv2 = nn.Conv2d(20,20,5)
       
    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))

Dict

building_info = {}【dict,key--value】

这是一个字典(dictionary)的创建语句。在Python中,字典是一种可变的、无序的、键值对(key-value pairs)的集合。每个键(key)都是唯一的,且必须是不可变的类型(如字符串、数字或元组),而值(value)可以是任何类型的数据。字典通过键来访问对应的值,提供了快速查找和插入的能力。

特殊:defaultdict:defaultdict是Python标准库collections模块中的一个类。defaultdict与普通字典类似,但它在创建时提供了一个默认工厂函数【比如defaultdict(list):当访问一个不存在的键时,defaultdict会自动为该键创建一个空列表作为默认值。】,当尝试访问一个不存在的键时,defaultdict会自动为该键创建一个默认值,而不会抛出KeyError

整理csv

df = pd.read_csv(file_path, encoding="utf-8")#读取csv

#根据某个属性分组

area_bins = [0, 100, 200, 300, 400, np.inf]

area_labels = [f"{left}-{right}" if right != np.inf else f">{left}"

              for left, right in zip(area_bins[:-1], area_bins[1:])]

df['area_bins'] = pd.cut(df['area'], bins=area_bins, labels=area_labels)

methods = ["a","b"]

attributes = ['material', 'height_bin']

for attr in attributes:

    results = []

    for method in methods:

            Num_col = f"{method}_Num"

            predict_col = f"{method}_predict"

            if Num_col not in df.columns or predict_col not in df.columns:

                print(f"跳过 {method},缺少必要列")

                continue

           

            valid_data = df[['true_Num', Num_col, predict_col, attr]].dropna()

            if valid_data.empty:

                print(f"{method} 在属性 {attr} 下无有效数据")

                continue

           

            # 计算完整指标

            grouped = valid_data.groupby(attr).apply(

                lambda x: pd.Series({

                    'Ori_RMSE': np.sqrt(mean_squared_error(x['true_Num'], x[Num_col])),

                    'Pred_RMSE': np.sqrt(mean_squared_error(x['true_Num'], x[predict_col])),

                    'Ori_MAE': mean_absolute_error(x['true_Num'], x[Num_col]),

                    'Pred_MAE': mean_absolute_error(x['true_Num'], x[predict_col]),

                    'Group_Size': len(x),

                    'Sample_Optimized': np.sum(

                        np.abs(x[Num_col] - x['true_Num']) >

                        np.abs(x[predict_col] - x['true_height'])

                    )

                })

            ).reset_index()

       

        grouped['method'] = method

        results.append(grouped)

   

    if not results:

        print(f"属性 {attr} 无数据,跳过")

        continue

   

    # 合并结果

    combined_df = pd.concat(results, ignore_index=True)

   

    # 生成透视表

    pivot_df = combined_df.pivot(

        index=attr,

        columns='method',

        values=['Ori_RMSE', 'Pred_RMSE', 'Ori_MAE', 'Pred_MAE']

    )

   

    # 扁平化列名并填充NaN

    pivot_df.columns = [f"{method}_{metric}" for metric, method in pivot_df.columns]

    pivot_df = pivot_df.fillna(0)

   

    # 保存到CSV

    csv_path = os.path.join(output_dir, f"{attr}.csv")

    pivot_df.reset_index().to_csv(csv_path, index=False)

实现了分别对每个方法依据不同属性评估的功能

http://www.dtcms.com/a/107481.html

相关文章:

  • Nature旗下 | npj Digital Medicine | 图像+转录组+临床变量三合一,多模态AI预测化疗反应,值得复现学习的完整框架
  • 大智慧前端面试题及参考答案
  • 爬虫【feapder框架】
  • 【LeetCode基础算法】二叉树所有类型
  • ESLint语法报错
  • Mysql基础笔记
  • 论文:Generalized Category Discovery with Clustering Assignment Consistency
  • 获取各类基本因子
  • day21和day22学习Pandas库
  • Ray Flow Insight:让分布式系统调试不再“黑盒“
  • 【模型部署】onnx模型-LOOP 节点实例
  • 2.3.3 使用@Profile注解进行多环境配置
  • 高通将进军英国芯片 IP 业务 Alphawave
  • Qt线程等待条件QWaitCondition
  • 深入理解DRAM刷新机制:异步刷新为何无需扣除刷新时间?
  • 风电行业预测性维护解决方案:给风机装上 “智能医生”,实现故障 “秒级预警”
  • HTMX构建无重载闪烁的交互式页面
  • Vue开发系列——npm镜像问题
  • Frida Hook Native:jobjectArray 参数解析
  • SQL Server 增删改查详解
  • 使用pytesseract和Cookie登录古诗文网~(python爬虫)
  • 从Hugging Face下载Qwen/Qwen2-Audio-7B-Instruct模型到本地运行,使用python实现一个音频转文字的助手
  • 树莓派超全系列教程文档--(21)用户配置
  • 芋道源码——Spring Cloud Bus RocketMQ 入门
  • 《全栈+双客户端Turnkey方案》架构设计图
  • 软件版本号递增应该遵循的规范
  • 分层防御:对称与非对称加密如何守护数字世界
  • 0402-对象和类(访问器 更改器 日期类)
  • 北方算网获邀在中关村论坛发言 解析人工智能+产业落地核心路径
  • 【数据库原理及安全实验】实验一 数据库安装与创建