导读
Numpy是Python中的一个基础的数据分析工具包,其提供了大量常用的数值计算功能,当然这些数值计算函数大多依赖于其核心的数据结构:ndarray,也就是N维数组。而关于这个ndarray,有一个重要特性是广播机制,也正是整个广播机制,使得Numpy中的数值计算功能更加丰富和强大。那么问题来了,你是否已经正确理解了这个广播机制呢?
本文选摘自numpy入门详细教程,近期有感而发,稍加修改后再次发文。
广播机制是Numpy中的一个重要特性,是指对ndarray执行某些数值计算时(这里是指矩阵间的数值计算,对应位置元素1对1执行标量运算,而非线性代数中的矩阵间运算),可以确保在数组间形状不完全相同时可以自动的通过广播机制扩散到相同形状,进而执行相应的计算功能。
当然,这里的广播机制是有条件的,而非对任意形状不同的数组都能完成自动广播,显然,理解这里的"条件"是理解广播机制的核心原理。
为了探究广播机制的限制条件,我们求助于numpy的官方文档,比如在numpy源码中打开doc文件夹,可以看到有一个numpy/doc/broadcasting.py的文件,里面其实全是注释性的文档,可以找到这样一段:
条件很简单,即从两个数组的最后维度开始比较,如果该维度满足维度相等或者其中一个大小为1,则可以实现广播。当然,维度相等时相当于无需广播,所以严格的说广播仅适用于某一维度从1广播到N;如果当前维度满足广播要求,则同时前移一个维度继续比较,直至首先完成其中一个矩阵的所有维度——另一矩阵如果还有剩余的话,其实也无所谓了,这很好理解。
为了直观理解这个广播条件,举个例子,下面的情况均满足广播条件:
而如下例子则无法完成广播:
当然,以上这几个例子其实都源自刚才的numpy/doc/broadcasting.py文件。另外,doc包下还包括很多说明文档,对于深刻理解numpy运行机制大有裨益。
再进一步探究:或许值得好奇,为什么必须要1对N才能广播,N的任意因数(比如N/2、N/3等)不是都可以"合理"广播到N吗?对此,个人也曾有此困惑,我的理解是这里的"合理"只停留于数学层面的合理,但若考虑数组背后的业务含义则往往不再合理:比如两个矩阵的同一维度取值分别为2和12,那如果将2广播到12,该怎样理解这其中的广播意义呢?比如说按照奇偶不同广播?那3广播到12呢?4广播到12呢?终究还是欠缺解释性。所以numpy限制必须是1广播到N或者二者相等,才可以广播。
实际上,不止是numpy,torch或者tf中的tensor其实也是存在类似的广播机制!