切片和索引是pytorch中经常使用的操作
为后续讲解方便,这里先介绍CNN的基本图片的概念,一般将图片设定为[batch_size, channel, height, width]的四维矩阵。
这里先随机建立一个矩阵
代码语言:javascript复制import torch
a = torch.rand(4, 3, 28, 28)
print(a.size())
输出size为:
代码语言:javascript复制torch.Size([4, 3, 28, 28])
再对第一维进行索引
代码语言:javascript复制# 对第一维进行索引
print(a[0].size())
代码语言:javascript复制torch.Size([3, 28, 28])
这里的输出可以认为是第一个图片的三个维度通道的28*28的像素点。
代码语言:javascript复制print(a[0, 0].size())
代码语言:javascript复制torch.Size([28, 28])
这里的输出可以认为是第一个图片的第一个维度通道的28*28的像素点。
当具体到某一个像素点时
代码语言:javascript复制print(a[0, 0, 2, 3])
代码语言:javascript复制tensor(0.4736)
这里的输出代表第一个图片的第一个维度通道的[2,3]的像素点张量为(0.4736)。
若想取连续的索引,
需要用到:
代码语言:javascript复制# 取连续索引
print(a.shape)
print(a[:2].shape)
代码语言:javascript复制torch.Size([2, 3, 28, 28])
# 这里的:相当于→(箭头),表明batch从第一个到第二个,不写默认写全部
同理
代码语言:javascript复制print(a[:2, 1:, :, :].shape)
# 1写在:前面,表明从1个通道开始到末尾,,不包括1
代码语言:javascript复制torch.Size([2, 2, 28, 28])
另外
当索引出现-1时,要提到一个知识点
代码语言:javascript复制print(a[:2, -1:, :, :].shape)
# 默认索引的顺序为[0, 1, 2],当倒着写时变为[-3, -2, -1]。由于这里取-1,因此为最后一位。
此时输出
代码语言:javascript复制torch.Size([2, 1, 28, 28])
当想隔点取样输出时
代码语言:javascript复制print(a[:, :, 0:28:2, 0:28:2].shape)
# 输出全部batch和channel,对每个高和宽间隔2个点采样
代码语言:javascript复制torch.Size([4, 3, 14, 14])
也可简化为
代码语言:javascript复制print(a[:, :, ::2, ::2].shape)
同样输出为
代码语言:javascript复制torch.Size([4, 3, 14, 14])
这里需要注意 当写为[0:28:]则等同于[0:28:1]因此可以认为[start:end:steps]