当前位置: 首页 > news >正文

机器学习常见问题之numpy维度问题


解密 NumPy 维度之谜:从 (100,)(100, 1) 的深度剖析

前言

在学习机器学习的道路上,我们常常将精力集中在理解复杂的算法模型上,比如逻辑回归的代价函数、梯度下降的推导过程等。然而,在实际的代码实现中,真正让我们头疼的,往往是一些看似微不足道的基础问题。今天,我就遇到了这样一个由 NumPy 数组维度引发的“血案”,它让我深刻理解了 shape=(100,)shape=(100, 1) 之间那微妙而关键的区别。

核心困惑:shape=(100,)shape=(100, 1) 的本质区别

初看之下,两者似乎都代表着“100个元素”。但对于 NumPy 来说,它们在数学和结构上的意义天差地别。

  • shape=(100,):这是一个一维数组 (1D Array)。它没有行和列的概念,更像是一个扁平的“数据列表”。它既不是行向量,也不是列向量。
  • shape=(100, 1):这是一个二维数组 (2D Array)。它有明确的 100 行和 1 列,是一个标准的列向量

关键点:线性代数中的向量运算(如矩阵乘法)通常是基于二维的行/列向量进行的。一个一维数组在参与这些运算时,可能会因为维度不明确而导致问题。

“意外”的广播机制 (Broadcasting)

如果我们将一个 shape=(100,) 的一维数组和一个 shape=(100, 1) 的列向量直接相加,会发生什么?

import numpy as nparr_1d = np.array([1, 2, 3])      # shape=(3,)
arr_2d = np.array([[10], [20], [30]]) # shape=(3, 1)result = arr_1d + arr_2d
print(result)

输出:

[[11 12 13][21 22 23][31 32 33]]

我们期望得到一个 (3, 1) 的列向量,结果却得到了一个 (3, 3) 的矩阵!这就是 NumPy 的广播机制在起作用。它会将 arr_1d “拉伸”成一个 (3, 3) 的矩阵,然后与同样被“拉伸”的 arr_2d 进行元素级相加。这在很多时候非常有用,但如果不是我们期望的行为,就会导致难以察觉的 bug。

解决方案:自由转换维度

为了避免广播带来的意外,我们需要确保参与运算的数组维度是一致的。

1. 从一维到二维:reshapenp.newaxis

要把 shape=(100,) 转换成 shape=(100, 1),你有两种优雅的方式:

  • 使用 .reshape():这是最直观的方法。-1 是一个强大的占位符,表示“让 NumPy 根据总元素数量和另一个维度自动计算”。

    arr_1d_col = arr_1d.reshape(-1, 1) # 转换成列向量
    # arr_1d_col.shape -> (3, 1)
    
  • 使用 np.newaxis (更推荐):这是一种更简洁、更具 NumPy 风格的做法,它可以在指定位置增加一个新维度。

    arr_1d_col = arr_1d[:, np.newaxis] # 在列的位置增加一个维度
    # arr_1d_col.shape -> (3, 1)
    

2. 从二维到一维:.flatten().squeeze()

反之,如果你想把列向量“压平”成一维数组,可以使用:

arr_2d_flat = arr_2d.flatten() # 返回一个拷贝
arr_2d_squeezed = arr_2d.squeeze() # 移除所有维度为1的轴
问题的根源:索引如何影响维度

在我的代码中,问题的根源在于对二维数组 X 进行索引的方式。这揭示了 NumPy 另一个至关重要的特性:

黄金法则:整数索引降维,切片索引保维。

  • X[:, i] (整数索引):这表示“取所有行的第 i 列”。由于 i 是一个整数,NumPy 会将结果的维度降低,返回一个一维数组 shape=(100,)
  • X[:, i:i+1] (切片索引):这表示“取所有行,以及从第 i 列到第 i+1 列的切片”。即使这个切片只包含一列,NumPy 依然会将其作为一个二维数组返回,shape=(100, 1),完美地保留了列向量的结构。
X = np.arange(12).reshape(4, 3)# 整数索引 -> 降维
col_1d = X[:, 1] 
print(col_1d.shape) # 输出: (4,)# 切片索引 -> 保维
col_2d = X[:, 1:2]
print(col_2d.shape) # 输出: (4, 1)
总结与最佳实践

这次踩坑经历让我收获颇丰。总结一下,为了写出更健壮、可预测的机器学习代码:

  1. 时刻对维度保持敏感:在进行数组运算前,脑中要清楚每个变量的 shape
  2. 显式处理向量:在线性代数运算中,尽量将向量显式地表示为二维数组(行向量 (1, n) 或列向量 (n, 1)),而不是使用一维数组。
  3. 善用索引技巧:当你需要从矩阵中提取列并保持其二维结构时,请使用切片索引 X[:, i:i+1] 或列表索引 X[:, [i]]
  4. 明确运算类型:在 np.array 上,使用 @ 进行矩阵乘法,使用 * 进行元素级乘法,避免混淆。

掌握了这些看似微小的细节,才能在实现复杂算法时游刃有余,让我们的机器学习之旅更加顺畅。

http://www.dtcms.com/a/617979.html

相关文章:

  • Redis 原理与实验
  • 网站开发职责与要求软件开发专业好吗
  • 【Linux驱动开发】Linux 设备驱动中的总线机制
  • 电压基准芯片详解:从原理到选型,附 TLV431 应用解析
  • 住房和城乡建设部网站监理工程师网站发送邮件功能
  • 开发第一个python程序
  • obet(Oracle Block Editor Tool)第二版发布
  • 【gas优化】2.11 Calldata 替换 Memory
  • 深度学习周报(11.10~11.16)
  • 阿里云建站论坛网站区块链网站建设方案
  • 李宏毅NLP-14-NLP任务
  • 惠普LaserJet Pro MFP M126a如何打印自检页
  • 南京大学cpp复习——面向对象第一部分(构造函数,拷贝构造函数,析构函数,移动构造函数,友元)
  • Stream 流核心速查表
  • 网站建设设计服务公司优化大师绿色版
  • STM32通信协议学习--I2C通信(了解)
  • 【技术选型】Go后台框架选型
  • 电子商务网站建设策划书范文西安优秀的集团门户网站建设费用
  • AI人工智能-语言模型-第六周(小白)
  • 找工作经验分享
  • 提供网站建设哪家好佛山外贸seo
  • 建站官网怎么做网页?
  • Qt编写28181推流分发服务/统计访问数量/无人观看超时关闭/等待重新点播/复用点播
  • CodexField Marketplace:重建内容与智能资产的链上市场结构
  • 微网站素材外贸家具网站
  • 后端服务发现工具,Consul与Eureka Consul vs Eureka:后端服务发现工具全面对比
  • 《动手学深度学习》6.5~6.6
  • 初识RabbitMQ
  • 历史数据分析——中国铝业
  • 网站建设设计公司类网站织梦模板 带手机端展厅设计图片