技术背景
在前一篇文章中,我们提到了关于Numpy中的各种取index的方法,可以用于取出数组里面的元素,也可以用于做切片,甚至可以用来做排序。但是遇到对于高维矩阵的某一个维度取多个值的时候,单纯的使用下标已经无法完成相关的操作了。如果找不到相应的接口,对于性能要求不高的场景可以使用一个for循环进行替代,但是对于性能要求比较高的场景下,我们还是尽可能的使用Numpy本身自带的接口,比如本文将要提到的take_along_axis操作。
使用案例
我们考虑这样的一个场景,给定一个维度为(4,11,3)的矩阵a作为数据,和一个维度为(4,2)的矩阵b作为下标,意味着从a中第二条轴的11个元素中每次取两个元素,也就是希望得到一个维度为(4,2,3)的结果:
代码语言:javascript复制In [11]: a = np.arange(132).reshape((4,11,3))
In [12]: a
Out[12]:
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[ 12, 13, 14],
[ 15, 16, 17],
[ 18, 19, 20],
[ 21, 22, 23],
[ 24, 25, 26],
[ 27, 28, 29],
[ 30, 31, 32]],
[[ 33, 34, 35],
[ 36, 37, 38],
[ 39, 40, 41],
[ 42, 43, 44],
[ 45, 46, 47],
[ 48, 49, 50],
[ 51, 52, 53],
[ 54, 55, 56],
[ 57, 58, 59],
[ 60, 61, 62],
[ 63, 64, 65]],
[[ 66, 67, 68],
[ 69, 70, 71],
[ 72, 73, 74],
[ 75, 76, 77],
[ 78, 79, 80],
[ 81, 82, 83],
[ 84, 85, 86],
[ 87, 88, 89],
[ 90, 91, 92],
[ 93, 94, 95],
[ 96, 97, 98]],
[[ 99, 100, 101],
[102, 103, 104],
[105, 106, 107],
[108, 109, 110],
[111, 112, 113],
[114, 115, 116],
[117, 118, 119],
[120, 121, 122],
[123, 124, 125],
[126, 127, 128],
[129, 130, 131]]])
In [13]: b = np.array([[0,1],[1,2],[2,3],[3,4]])
In [14]: b
Out[14]:
array([[0, 1],
[1, 2],
[2, 3],
[3, 4]])
为了方便展示我们就定义了这样两个比较简单的矩阵a和b,那么在这个结果中,我们理想的结果应该是:
代码语言:javascript复制[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 36, 37, 38],
[ 39, 40, 41]],
[[ 72, 73, 74],
[ 75, 76, 77]],
[[108, 109, 110],
[111, 112, 113]]]
这样的一个矩阵。关于这个结果的来源,可以对b这个定义进行展开解释,b的值为:
代码语言:javascript复制[[0, 1],
[1, 2],
[2, 3],
[3, 4]]
它所表示的是在a0下取第0个元素和第1个元素,在a1下取第1个元素和第2个元素,以此类推。然而如果我们直接把定义好的b放到a的索引中或者直接使用numpy.take的方法的话,得到的结果是这样的:
代码语言:javascript复制In [16]: a[:,b]
Out[16]:
array([[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 3, 4, 5],
[ 6, 7, 8]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[ 9, 10, 11],
[ 12, 13, 14]]],
[[[ 33, 34, 35],
[ 36, 37, 38]],
[[ 36, 37, 38],
[ 39, 40, 41]],
[[ 39, 40, 41],
[ 42, 43, 44]],
[[ 42, 43, 44],
[ 45, 46, 47]]],
[[[ 66, 67, 68],
[ 69, 70, 71]],
[[ 69, 70, 71],
[ 72, 73, 74]],
[[ 72, 73, 74],
[ 75, 76, 77]],
[[ 75, 76, 77],
[ 78, 79, 80]]],
[[[ 99, 100, 101],
[102, 103, 104]],
[[102, 103, 104],
[105, 106, 107]],
[[105, 106, 107],
[108, 109, 110]],
[[108, 109, 110],
[111, 112, 113]]]])
显然这不是我们想要的结果。需要额外申明的是,这个执行操作中,最后一个维度的冒号加与不加是一样的效果,跟numpy.take本质上也是同样的操作,因此就需要使用到numpy中的另外一个接口:take_along_axis
,如下是其官方的API文档:
还有相关的使用案例:
需要注意的是,输入的indices必须要跟原始的数据矩阵保持同样的维度,因此在我们自己的案例中,对b进行了扩维,最终的代码如下所示:
代码语言:javascript复制In [23]: np.take_along_axis(a,b[:,:,None],axis=1)
Out[23]:
array([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 36, 37, 38],
[ 39, 40, 41]],
[[ 72, 73, 74],
[ 75, 76, 77]],
[[108, 109, 110],
[111, 112, 113]]])
最后得到的就是我们想要的结果了,并且是直接使用下标无法实现的操作(当然,也可能是我还没研究出来这样的操作)。这里axis设置为1,就表示a的第0个维度和b的第0个维度是一致的取法,也可以理解成全取的意思。
总结概要
Numpy是在Python中用于各种矩阵运算非常强大的工具之一,而快速的通过下标取出所需位置的元素也是numpy所支持的强大功能之一。常规的元素取法都可以通过numpy的下标或者是numpy.take函数来实现,比如array0,:可用于取第一条轴的所有元素,array:,0可以用于取第二条轴的所有第二个元素,放在一个2维的矩阵里面就分别是取第一行的所有元素和取第一列的所有元素。但是本文更加关注于更高维的矩阵,当我们想从多个维度中取多个元素时,是不太容易直接用下标去取的,比如同时取a0,a0,a1,a1的话,那么就只能使用numpy所支持的另外一个函数numpy.take_along_axis来实现。
版权声明
本文首发链接为:https://www.cnblogs.com/dechinphy/p/take_along_axis.html
作者ID:DechinPhy
更多原著文章请参考:https://www.cnblogs.com/dechinphy/
打赏专用链接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
腾讯云专栏同步:https://cloud.tencent.com/developer/column/91958
参考链接
- https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html#numpy.take_along_axis