NumPy 数组排序
文章目录
- 数组排序
- NumPy 中的快速排序:np.sort 和 np.argsort
- 按行或按列排序
- 部分排序:分区(Partitioning)
- 示例:k-近邻算法
数组排序
到目前为止,我们主要关注的是使用 NumPy 访问和操作数组数据的工具。本文将介绍与 NumPy 数组中数值排序相关的算法。
这些算法是计算机科学入门课程中的热门话题:如果你上过这类课程,可能曾经梦到过(或者,取决于你的性格,做过噩梦)插入排序、选择排序、归并排序、快速排序、冒泡排序,以及更多其他排序方法。它们的目标都是完成类似的任务:对列表或数组中的数值进行排序。
Python 有一些内置函数和方法可以对列表和其他可迭代对象进行排序。sorted
函数接受一个列表并返回其排序后的版本:
L = [3, 1, 4, 1, 5, 9, 2, 6]
sorted(L) # 返回一个排好序的副本
[1, 1, 2, 3, 4, 5, 6, 9]
与之相对,列表的 sort
方法会直接在原列表上进行排序:
L.sort() # 就地排序,返回 None
print(L)
[1, 1, 2, 3, 4, 5, 6, 9]
Python 的排序方法非常灵活,可以处理任何可迭代对象。例如,下面我们对一个字符串进行排序:
sorted('python')
['h', 'n', 'o', 'p', 't', 'y']
这些内置的排序方法非常方便,但如前所述,由于 Python 值的动态特性,它们的性能不如专门为统一数值数组设计的例程。 这正是 NumPy 排序例程发挥作用的地方。
NumPy 中的快速排序:np.sort 和 np.argsort
np.sort
函数类似于 Python 内置的 sorted
函数,可以高效地返回数组排序后的副本:
import numpy as npx = np.array([2, 1, 4, 3, 5])
np.sort(x)
array([1, 2, 3, 4, 5])
与 Python 列表的 sort
方法类似,你也可以使用数组的 sort
方法对数组进行原地排序:
x.sort()
print(x)
[1 2 3 4 5]
一个相关的函数是 argsort
,它返回的是排序元素的索引:
x = np.array([2, 1, 4, 3, 5])
i = np.argsort(x) # 返回排序后的索引
print(i)
[1 0 3 2 4]
第一个元素表示最小元素的索引,第二个值表示第二小元素的索引,依此类推。
这些索引随后可以通过花式索引(fancy indexing)来构造排序后的数组:
x[i]
array([1, 2, 3, 4, 5])
你将在本章后面看到 argsort
的一个应用。
按行或按列排序
NumPy 排序算法的一个有用特性是可以使用 axis
参数沿多维数组的特定行或列进行排序。例如:
# 使用指定的随机种子创建随机数生成器
rng = np.random.default_rng(seed=42)
# 生成一个4行6列,元素在0到9之间的随机整数数组
X = rng.integers(0, 10, (4, 6))
print(X) # 打印生成的数组
[[0 7 6 4 4 8][0 6 2 0 5 9][7 7 7 7 5 1][8 4 5 3 1 9]]
# 对 X 的每一列进行排序
np.sort(X, axis=0)
array([[0, 4, 2, 0, 1, 1],[0, 6, 5, 3, 4, 8],[7, 7, 6, 4, 5, 9],[8, 7, 7, 7, 5, 9]])
# 对 X 的每一行进行排序
np.sort(X, axis=1)
array([[0, 4, 4, 6, 7, 8],[0, 0, 2, 5, 6, 9],[1, 5, 7, 7, 7, 7],[1, 3, 4, 5, 8, 9]])
请记住,这会将每一行或每一列视为独立的数组,行或列之间的任何关系都会丢失!
部分排序:分区(Partitioning)
有时候我们并不关心对整个数组进行排序,而只是想找到数组中最小的 k 个值。NumPy 提供了 np.partition
函数来实现这一点。np.partition
接受一个数组和一个数字 K;结果是一个新数组,其中最小的 K 个值被移动到分区的左侧,其余的值在右侧:
x = np.array([7, 1, 7, 7, 1, 5, 7, 2, 3, 2, 6, 2, 3, 0])
X_partitioned = np.partition(x, 3) # 返回一个数组,其中前3个元素是最小的,其他元素未排序
print("分区后的数组:", X_partitioned)arr_to_partition = rng.integers(0, 100, 10)
print("待排序数组:", arr_to_partition)
arr_partitioned = np.partition(arr_to_partition, 5) # 将数组分为两部分,前5个元素是最小的
print("分区后的数组:", arr_partitioned)
分区后的数组: [0 1 1 2 2 2 3 3 5 6 7 7 7 7]
待排序数组: [ 6 97 44 89 67 77 75 19 36 46]
分区后的数组: [ 6 19 36 44 46 67 75 77 89 97]
注意,结果数组中的前三个值是数组中最小的三个值,其余位置包含剩下的元素。
在这两个分区内部,元素的顺序是任意的。
与排序类似,我们也可以沿多维数组的任意轴进行分区:
np.partition(X, 2, axis=1)
array([[0, 4, 4, 6, 7, 8],[0, 0, 2, 5, 6, 9],[1, 5, 7, 7, 7, 7],[1, 3, 4, 5, 8, 9]])
结果是一个数组,每一行的前两个位置包含该行中最小的两个值,其余值填充在剩下的位置。
最后,正如有一个 np.argsort
函数用于计算排序索引一样,NumPy 还提供了 np.argpartition
函数用于计算分区的索引。
我们将在下一节中看到这两个函数的实际应用。
示例:k-近邻算法
让我们快速了解一下如何利用 argsort
函数在多个轴上查找一组点中每个点的最近邻。
我们将首先在二维平面上随机生成 10 个点。按照标准惯例,我们会将这些点排列成一个 10 × 2 10\times 2 10×2 的数组:
X = rng.random((10, 2))
为了直观展示这些点的分布,我们可以快速绘制一个散点图(见下图):
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('seaborn-v0_8-whitegrid')
plt.scatter(X[:, 0], X[:, 1], s=100);
现在我们将计算每对点之间的距离。
回忆一下,两点之间的平方距离等于各维度差值的平方和;
利用 NumPy 提供的高效广播(数组的广播计算)和聚合(聚合:最小值、最大值及一切)操作,我们可以用一行代码计算出所有点对之间的平方距离矩阵:
dist_sq = np.sum((X[:, np.newaxis] - X[np.newaxis, :]) ** 2, axis=-1)
这行操作包含了很多内容,如果你还不熟悉 NumPy 的广播规则,可能会觉得有些困惑。当你遇到类似这样的代码时,将其拆解为各个组成步骤会很有帮助:
# 对每对点,计算它们坐标的差值
differences = X[:, np.newaxis] - X[np.newaxis, :]
differences.shape
(10, 10, 2)
# 对坐标差值进行平方
sq_differences = differences ** 2
sq_differences.shape
(10, 10, 2)
# 对坐标差值求和以获得平方距离
dist_sq = sq_differences.sum(-1)
dist_sq.shape
(10, 10)
我们的逻辑可以通过一个快速检查来验证:该矩阵的对角线(即每个点与自身之间的距离)应该全为零。
dist_sq.diagonal()
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
有了成对的平方距离后,我们现在可以使用 np.argsort
沿每一行进行排序。最左侧的列将给出最近邻的索引:
# 使用 np.argsort 沿每一行对 dist_sq 进行排序,得到每个点最近邻的索引
nearest = np.argsort(dist_sq, axis=1)
print(nearest)
[[0 4 5 3 6 8 9 1 2 7][1 7 6 9 8 3 2 5 4 0][2 9 6 7 5 1 3 4 8 0][3 5 6 8 4 1 9 0 7 2][4 0 5 3 6 9 8 1 2 7][5 3 6 4 9 0 2 1 7 8][6 9 5 3 1 7 2 8 4 0][7 1 6 9 2 3 8 5 4 0][8 3 1 6 7 5 9 4 0 2][9 6 2 7 5 1 3 8 4 0]]
请注意,第一列依次给出了数字 0 到 9:这是因为每个点距离自己最近,这正是我们所期望的。
在这里使用完整排序实际上做的工作比我们需要的要多。如果我们只对最近的 k k k 个邻居感兴趣,只需对每一行进行分区,使得最小的 k + 1 k+1 k+1 个平方距离排在前面,剩下的较大距离填充数组的其余位置即可。我们可以使用 np.argpartition
函数来实现这一点:
K = 2 # 最近邻的数量
# 对每一行进行分区,使得每行中最小的K+1个距离排在前面(包括自身距离为0的情况)
nearest_partition = np.argpartition(dist_sq, K + 1, axis=1)
为了可视化这个邻居网络,我们可以快速绘制这些点,并用线条表示每个点与其两个最近邻之间的连接(见下图):
plt.scatter(X[:, 0], X[:, 1], s=100)# 从每个点画线到它的两个最近邻
K = 2for i in range(X.shape[0]):for j in nearest_partition[i, :K+1]:# 从 X[i] 到 X[j] 画一条线# 用 zip 实现:plt.plot(*zip(X[j], X[i]), color='black')
每个点在图中都连接到了它的两个最近邻。
乍一看,你可能会觉得奇怪,为什么有些点会有超过两条线连出:这是因为如果点A是点B的两个最近邻之一,并不一定意味着点B也是点A的两个最近邻之一。
虽然这种利用广播和按行排序的方法看起来没有直接写循环那么直观,但它实际上是用Python处理这类数据非常高效的方式。
你可能会想通过手动循环遍历数据并分别对每组邻居排序来实现同样的操作,但这样几乎肯定会比我们用的矢量化方法慢得多。这种方法的优点在于它对输入数据的规模是无关的:无论是100个、1000个还是100万个点,甚至是任意维度的数据,我们都可以用同样的代码来计算最近邻。
最后需要说明的是,在进行非常大规模的最近邻搜索时,有一些基于树结构和/或近似的算法,其复杂度可以达到 O [ N log N ] \mathcal{O}[N\log N] O[NlogN] 或更优,而不是暴力算法的 O [ N 2 ] \mathcal{O}[N^2] O[N2]。其中一个例子是KD-Tree,在Scikit-Learn中有实现。