1. torch中以index_select为例子
torch.index_select(input, dim, index, out=None) - 功能:在维度dim上,按index索引数据 - 返回值:依index索引数据拼接的张量 - index:要索引的张量 - dim:要索引的维度 - index:要索引数据的序号
代码语言:javascript复制x = torch.randn(3, 4)
print(x)
indices = torch.tensor([0, 2])
torch.index_select(x, 1, indices)
#把1改为0
y = torch.randn(3, 4)
print(y)
indices = torch.tensor([0, 2])
torch.index_select(y, 0, indices)
输出如下,可以看出,dim=1时按照列索引;dim=0时,按照行索引
代码语言:javascript复制tensor([[ 1.9626, 0.1007, -1.2005, 1.2650],
[ 0.3603, 0.6343, -0.6197, 0.5740],
[-0.0798, 0.9674, -0.7761, 0.5552]])
tensor([[ 1.9626, -1.2005],
[ 0.3603, -0.6197],
[-0.0798, -0.7761]])
tensor([[ 0.2274, -2.1934, -0.3129, 0.3869],
[ 0.3831, -0.7156, -1.0765, -2.1098],
[-0.8007, -0.0095, 0.8703, -0.8797]])
tensor([[ 0.2274, -2.1934, -0.3129, 0.3869],
[-0.8007, -0.0095, 0.8703, -0.8797]])
2.numpy 中 以mean为例
代码语言:javascript复制x = numpy.random.randint(1,10,(3,4))
print(x)
print(x.mean(0))
y = numpy.random.randint(1,10,(3,4))
print(y)
print(y.mean(1))
输出如下,axis = 0时,按照竖直方向从上往下计算均值,输出4个数;axis=1时,按照水平方向从左往右计算均值,输出三个数。
代码语言:javascript复制[[6 8 4 9]
[7 5 9 3]
[1 7 6 1]]
[4.66666667 6.66666667 6.33333333 4.33333333]
[[3 3 6 5]
[4 3 1 5]
[7 2 2 5]]
[4.25 3.25 4. ]