numpy中对axis的理解

2023-11-26 15:14:53 浏览数 (1)

axis在Python的numpy库中是一个基本概念,出现的非常多,特别是在函数调用、合并数据等操作的时候,本文对axis的作用和规律做一下梳理,加深对Python中的numpy库的axis理解。

axis的作用

在numpy中,有很多的函数都涉及到axis,很多函数根据axis的取值不同,得到的结果也完全不同。可以说,axis让numpy的多维数组变的更加灵活,但也让numpy变得越发难以理解。这里通过详细的例子来学习下,axis到底是什么,它在numpy中的作用到底如何。

为什么会有axis这个东西,原因很简单:numpy是针对矩阵或者多为数组进行运算的,而在多维数组中,对数据的操作有太多的可能,特别是数组有多个维度,对于不同维度的操作会有不同的结果,我们先来看一个例子。比如我们有一个二维数组:

代码语言:txt复制
import numpy as np
>>> data = np.array([
... [1,2,1],
... [0,3,1],
... [2,1,4],
... [1,3,1]])

这个数组代表了样本数据的特征,其中每一行代表一个样本的三个特征,每一列是不同样本的特征。

如果在分析样本的过程中需要对每个样本的三个特征求和,该如何处理?简单:

代码语言:txt复制
np.sum(data, axis=1)
array([4, 4, 7, 5])

那如果想求每种特征的最小值,该如何处理?也简单:

代码语言:txt复制
np.min(data, axis=0)
array([0, 1, 1])

又如果想得知所有样本所有特征的平均值呢?还是很简单:

代码语言:txt复制
np.average(data)
1.6666666666666667

由此可以看出,通过不同的axis,numpy会沿着不同的方向进行操作:

  • 如果不设置,那么对所有的元素操作
  • 如果axis=0,则沿着纵轴进行操作
  • 如果axis=1,则沿着横轴进行操作

但这只是简单的二位数组,如果是多维的呢?可以总结为一句话:设axis=i,则numpy沿着第i个下标变化的放下进行操作。这是非常重要的,理解了这个也就理解了axis的作用:表示数组的维度。那么在函数中引入axis也就是表示,对axis所在的维度的数据进行处理。

下面我们举一个四维的求sum的例子来验证一下:

代码语言:txt复制
data = np.random.randint(0, 5, [4,3,2,3])
>>> data
array([[[[4, 1, 0],
         [4, 3, 0]],
        [[1, 2, 4],
         [2, 2, 3]],
        [[4, 3, 3],
         [4, 2, 3]]],

       [[[4, 0, 1],
         [1, 1, 1]],
        [[0, 1, 0],
         [0, 4, 1]],
        [[1, 3, 0],
         [0, 3, 0]]],

       [[[3, 3, 4],
         [0, 1, 0]],
        [[1, 2, 3],
         [4, 0, 4]],
        [[1, 4, 1],
         [1, 3, 2]]],

       [[[0, 1, 1],
         [2, 4, 3]],
        [[4, 1, 4],
         [1, 4, 1]],
        [[0, 1, 0],
         [2, 4, 3]]]])

当axis=0时,numpy验证第0维的方向来求和,也就是第一个元素值=a0000 a1000 a2000 a3000=11,第二个元素=a0001 a1001 a2001 a3001=5,同理可得最后的结果如下:

代码语言:txt复制
data.sum(axis=0)  
array([[[11, 5, 6],  
[ 7, 9, 4]],  
  
[[ 6, 6, 11],  
[ 7, 10, 9]],  
  
[[ 6, 11, 4],  
[ 7, 12, 8]]])

当axis=3时,numpy验证第3维的方向来求和,也就是第一个元素值=a0000 a0001 a0002=5,第二个元素=a0010 a0011 a0012=7,同理可得最后的结果如下:

代码语言:txt复制
data.sum(axis=3)
array([[[ 5,  7],
        [ 7,  7],
        [10,  9]],

       [[ 5,  3],
        [ 1,  5],
        [ 4,  3]],

       [[10,  1],
        [ 6,  8],
        [ 6,  6]],

       [[ 2,  9],
        [ 9,  6],
        [ 1,  9]]])

axis相关函数举例

在numpy中,使用的axis的地方非常多,处理上文已经提到的average、max、min、sum,比较常见的还有sort和prod,下面分别举几个例子看一下:

sort

代码语言:txt复制
data = np.random.randint(0, 5, [3,2,3])
>>> data
array([[[4, 2, 0],
        [0, 0, 4]],

       [[2, 1, 1],
        [1, 0, 2]],

       [[3, 0, 4],
        [0, 1, 3]]])
>>> np.sort(data)  ## 默认对最大的axis进行排序,这里即是axis=2
array([[[0, 2, 4],
        [0, 0, 4]],

       [[1, 1, 2],
        [0, 1, 2]],

       [[0, 3, 4],
        [0, 1, 3]]])
>>> np.sort(data, axis=0)  # 沿着第0维进行排序,原先的a000->a100->a200转变为a100->a200->a000
array([[[2, 0, 0],
        [0, 0, 2]],

       [[3, 1, 1],
        [0, 0, 3]],

       [[4, 2, 4],
        [1, 1, 4]]])
>>> np.sort(data, axis=1)  # 沿着第1维进行排序
array([[[0, 0, 0],
        [4, 2, 4]],

       [[1, 0, 1],
        [2, 1, 2]],

       [[0, 0, 3],
        [3, 1, 4]]])
>>> np.sort(data, axis=2)  # 沿着第2维进行排序
array([[[0, 2, 4],
        [0, 0, 4]],

       [[1, 1, 2],
        [0, 1, 2]],

       [[0, 3, 4],
        [0, 1, 3]]])
>>> np.sort(data, axis=None)  # 对全部数据进行排序
array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4])

prod(即product,乘积)

代码语言:txt复制
np.prod([[1.,2.],[3.,4.]])
24.0

>>> np.prod([[1.,2.],[3.,4.]], axis=1)
array([  2.,  12.])

>>> np.prod([[1.,2.],[3.,4.]], axis=0)
array([ 3.,  8.])

我正在参与2023腾讯技术创作特训营第三期有奖征文,组队打卡瓜分大奖!

0 人点赞