当前位置: 首页 > news >正文

【小白笔记】PyTorch 和 Python 基础的这些问题

1. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 这句话是固定的吗?

  • 人话解释: 这句话是 PyTorch 代码中用于实现设备无关性(Device Agnostic)的标准写法,它几乎是固定的模板。

  • 功能拆解:

    1. torch.cuda.is_available():询问 PyTorch:“机器上有没有英伟达的 GPU?”
    2. "cuda" if ... else "cpu":这是一个 Python 的三元表达式
      • 如果 GPU 可用,选择 "cuda"
      • 如果 GPU 不可用,选择 "cpu"
    3. torch.device(...):根据上一步的结果,创建一个 PyTorch 设备对象
  • 固定性: 这行代码是 PyTorch 社区高度推荐的,因为它能让你的代码在有 GPU 的机器上自动加速,在没有 GPU 的机器上自动回退到 CPU,无需修改代码。因此,在 PyTorch 项目中,这是应该背诵和使用的固定写法


2. y_train = torch.from_numpy(y_train).long() 为什么不用 int

  • 人话解释: long()int() 在 PyTorch 中都代表整数类型,但它们分别对应不同的位宽(bit width)。

    • long() (即 torch.int64): 这是一个 64 位的整数类型。它是 PyTorch 中处理索引标签、以及大型计数默认和推荐类型。
    • int() (即 torch.int32): 这是一个 32 位的整数类型。
  • 主要原因 (惯例和兼容性):

    1. 标签类型要求: 在 PyTorch 的许多内置函数中(例如损失函数 nn.CrossEntropyLoss),要求输入的目标标签张量必须是 torch.long (即 64 位整数) 类型。
    2. 安全范围: 64 位整数可以表示更大的数字,尽管像鸢尾花这种简单的分类任务用 32 位足够,但使用 long() 更安全、更符合 PyTorch 的习惯。
  • 记忆点: 在 PyTorch 中,特征 (X) 用 float()标签/索引 (Y) 用 long()


3. .to(self.device): 数据上设备。这个是什么用法?to 这个前面没有定义这个功能啊?

  • 人话解释: .to() 是 PyTorch 张量(Tensor)对象自带的“移动”能力。它不是一个需要您在类中定义的方法,而是 PyTorch 库已经给所有张量写好的内置方法

  • 功能: .to(目标) 方法用于将一个张量移动到指定的设备(如 CPU 或 CUDA/GPU),或转换为指定的数据类型。

  • 用法:

    my_tensor = torch.tensor([1, 2, 3])
    # 移动到 GPU
    my_tensor_on_gpu = my_tensor.to('cuda') # 移动到我们在 __init__ 中设置好的 self.device 上
    self.X_train = X_train.to(self.device)
    
  • 记忆点: 张量.to(device) 是 PyTorch 中数据上/下设备(GPU/CPU)的标准动作。


4. predictions.append(pred_label) 这个 append 是啥?经常见这个用法,为什么不用 add

  • 人话解释: append 是 Python 列表(List)对象的一个内置方法,意思是“在列表的末尾添加一个新元素”。

  • 为什么不用 add

    • 在 Python 中,加法运算 + 具有不同的语义,例如:
      • 数值: 1 + 2 得到 3
      • 集合 (Set): 没有 add 方法,使用 set.add(element)
      • 列表: [1] + [2] 会得到 [1, 2](这是连接两个列表)。
    • 为了明确“向列表中添加一个元素”这个操作,Python 的设计者选择了 append 这个词。add 通常用于集合 (set) 或用于表示数值相加。
  • 记忆点:

    • List(列表)的末尾添加元素用:.append()
    • Set(集合)中添加元素用:.add()

5. unsqueeze(0): 增加维度进行广播。是啥意思?

  • 人话解释: unsqueeze(0) 的意思是**“在第 0 个位置(最前面)增加一个维度,把这个向量变成一个矩阵”**。

  • 目的: PyTorch/NumPy 中的广播机制要求参与运算的张量维度能匹配。

  • 举例:

    • 原始样本 x_new 是一个特征向量,比如 [3.0, 3.0]。它的维度是 (2,)
    • 训练集 X_train 是一个特征矩阵,比如 6 个样本,维度是 (6, 2)

    如果直接相减 X_train - x_new,PyTorch 不知道怎么对齐。

    • x_new.unsqueeze(0) 后: [3.0, 3.0] 变成了 [[3.0, 3.0]]。维度从 (2,) 变成了 (1, 2)
    • 现在: 一个 (6, 2) 的矩阵和一个 (1, 2) 的行向量就可以使用广播机制进行减法了。
  • 记忆点: unsqueeze 是在不改变数据的前提下,增加一个维度(通常是为了满足广播或函数输入的要求)。


6. unsqueeze(0): 广播关键步骤。是啥意思?

  • 人话解释: 这里的“关键步骤”指的是,unsqueeze(0)激活 PyTorch 广播机制的关键。

  • 原因: 正如上一点所说,如果不增加这个维度,PyTorch 不会知道如何将 x_new 的值与 X_train 中的所有行(样本)进行匹配。一旦维度变成 (1, 2),PyTorch 就理解了:“哦,需要把这个 (1, 2) 的向量复制 6 次,然后进行逐元素相减。”

  • 记忆点: unsqueeze(0) 是我们手动调整维度,以便让 PyTorch 的自动广播机制能够工作的前置条件


7. 广播机制是啥?广播机制,计算新样本与所有训练样本的特征差。?

  • 人话解释: 广播机制 (Broadcasting) 是 PyTorch 和 NumPy 中一种聪明地处理不同形状数组之间运算的机制。它的核心思想是:在不实际复制数据的情况下,让维度较小的数组“伸展”到和维度较大的数组一样大,然后进行运算。

  • 在 KNN 中的应用:

    • 目标:计算 X_train(所有样本)与 x_new(新样本)之间的差。
    • 没有广播: 你必须写一个循环,对 X_train 中的每一行都减去 x_new,或者手动创建一个和 X_train 一样大的新样本矩阵。这效率很低。
    • 使用广播:
      • 原始:X_train (6, 2)x_new_expanded (1, 2)
      • 广播过程:PyTorch 发现第 0 维不匹配 (6 和 1),但可以扩展。它将 x_new_expanded 逻辑上复制 6 次,变成一个 (6, 2) 的张量。
      • 最终效果:differences = X_train - x_new_expanded 等价于:
        KaTeX parse error: Expected 'EOF', got '&' at position 49: …\text{new}, 1} &̲ x\_{1,2} - x\_…
        $$
      • 效率高: 实际上 PyTorch 并没有真的创建和存储那 6 个复制品,它只在计算时执行了正确的逻辑,大大节省了内存和时间。
  • 记忆点: 广播机制就是让你可以用一个小尺寸的张量(如单个样本)直接对一个大尺寸的张量(如整个数据集)进行运算(如加减乘除)。它是实现向量化计算的关键。

http://www.dtcms.com/a/494875.html

相关文章:

  • linux学习笔记(35)C语言连接mysql
  • 消息推送策略:如何在营销与用户体验间找到最佳平衡点
  • go资深之路笔记(九)kafka浅析
  • Java String 性能优化与内存管理:现代开发实战指南
  • 【软考备考】 NoSQL数据库有哪些,键值型、文档型、列族型、图数据库的特点与适用场景
  • 论《素数的几种筛法》
  • html静态页面怎么放在网站上原平的旅游网站怎么做的
  • 网页设计与网站建设作业公众号小程序制作步骤
  • 律师怎么做网站简单大气网站模板
  • 偏振相机在半导体制造的领域的应用
  • Uniapp微信小程序开发:EF Core 中级联删除
  • Java从入门到精通 - 集合框架(二)
  • 3proxy保姆级教程:WIN连接远端HTTPS代理
  • 大厂AI各走“开源”路
  • 室内装修效果图网站有哪些百度网盟推广是什么
  • grootN1 grootN1.5 gr00t安装方法以及使用(学习)
  • Typora(跨平台MarkDown编辑器) v1.12.2 中文绿色版
  • Unity开发抖音小游戏的震动
  • 团队作业——概要设计和数据库设计
  • 在Spring Boot开发中,HEAD、OPTIONS和 TRACE这些HTTP方法各有其特定的应用场景和实现方式
  • Flink DataStream「全分区窗口处理」mapPartition / sortPartition / aggregate / reduce
  • 网站备案号码查询大连网页设计哪家好
  • Next.js 入门指南
  • arcgis api for javascript 修改地图图层要素默认的高亮效果
  • 【论文速递】2025年第28周(Jul-06-12)(Robotics/Embodied AI/LLM)
  • 宁波市鄞州区建设局网站怎么做网站静态布局
  • 一文掌握 CodeX CLI 安装以及使用!
  • Android实战进阶 - 用户闲置超时自动退出登录功能详解
  • 2二、u-boot移植
  • 淄博网站建设哪家好常德网站建设技术