NumPy的广播机制

2022-09-03 20:15:37 浏览数 (1)

目录

一、广播(Broadcasting)简介

二、广播(Broadcasting)的机制


一、广播(Broadcasting)简介

在线性代数中我们曾经学到过如下规则:

a1 = 1 ,a2 = 2,a1,a2是0维张量,即标量,

,b1,b2是1维张量,即向量,c1,c2是如下所示的2维张量,即矩阵:

a1与a2之间可以进行加减乘除,b1与b2可以进行逐元素的加减乘除以及点积运算,c1与c2之间可以进行逐元素的加减乘除以及矩阵相乘运算(矩阵相乘必须满足维度的对应关系),而a与b,或者b与c之间不能进行逐元素的加减乘除运算,原因是他们的维度不匹配。而在NumPy中,通过广播可以完成这项操作。

广播(Boardcasting)是NumPy中用于在不同大小的阵列(包括标量与向量,标量与二维数组,向量与二维数组,二维数组与高维数组等)之间进行逐元素运算(例如,逐元素 加法,减法,乘法,赋值等)的一组规则。尽管该技术是为NumPy开发的,但它在其他数值计算库中也得到了更广泛的应用,例如深度学习框架TensorFlow和Pytorch。

NumPy在广播的时候实际上并没有复制较小的数组; 相反,它使存储器和计算上有效地使用存储器中的现有结构,实际上实现了相同的结果。

注意:

代码语言:javascript复制
import numpy as np
A = np.zeros((3,4))
B = np.zeros((5,6))
print(np.dot(A, B))

报错如下: 在这里插入图片描述 并没有显示 broadcast的错误,说明dot,即点积(不是逐元素运算,对于两个向量,计算的是内积,对于两个数组,则尝试计算他们的矩阵乘积)并不能运用广播机制。

代码语言:javascript复制
import numpy as np
A = np.zeros((2,4))
B = np.zeros((3,4))
C = A*B

报错如下: 在这里插入图片描述 这种是逐元素相乘,会运用广播机制,只不过,此时当前两个元素的维度不能广播,所以报错。

二、广播(Broadcasting)的机制

  1. 让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐
  2. 输出数组的shape是输入数组shape的各个轴上的最大值
  3. 如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错
  4. 当输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值

简单来说,我总结为两条规则:

两个array的shape长度与shape的每个对应值都相等的时候,那么结果就是对应元素逐元素运算,运算的结果shape不变。shape长度不相等时,先把短的shape前面一直补1,直到与长的shape长度相等时,此时,两个array的shape对应位置上的值 :1、相等 或 2、其中一个为1,这样才能进行广播。

例子如下:

代码语言:javascript复制
Image (3d array):  256 x 256 x 3
Scale (1d array):              3
Result (3d array): 256 x 256 x 3

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5

A      (2d array):  5 x 4
B      (1d array):      1
Result (2d array):  5 x 4

A      (2d array):  15 x 3 x 5
B      (1d array):  15 x 1 x 5
Result (2d array):  15 x 3 x 5

再来看一些不能进行broadcast的例子:

A  (1d array): 3
B  (1d array): 4        # 最后一维(trailing dimension)不匹配

A  (2d array):      2 x 1
B  (3d array):  8 x 4 x 3(倒数第二维不匹配)

输出数组的维度是每一个维度的最大值,广播将值为1的维度进行“复制”、“拉伸”,如图所示

0 人点赞