pytorch基础知识 切片与索引-上

2019-11-17 23:26:20 浏览数 (1)

切片和索引是pytorch中经常使用的操作

为后续讲解方便,这里先介绍CNN的基本图片的概念,一般将图片设定为[batch_size, channel, height, width]的四维矩阵。

这里先随机建立一个矩阵

代码语言:javascript复制
import torch
a = torch.rand(4, 3, 28, 28)
print(a.size())

输出size为:

代码语言:javascript复制
torch.Size([4, 3, 28, 28])

再对第一维进行索引

代码语言:javascript复制
# 对第一维进行索引
print(a[0].size())
代码语言:javascript复制
torch.Size([3, 28, 28])

这里的输出可以认为是第一个图片的三个维度通道的28*28的像素点。

代码语言:javascript复制
print(a[0, 0].size())
代码语言:javascript复制
torch.Size([28, 28])

这里的输出可以认为是第一个图片的第一个维度通道的28*28的像素点。

当具体到某一个像素点时

代码语言:javascript复制
print(a[0, 0, 2, 3])
代码语言:javascript复制
tensor(0.4736)

这里的输出代表第一个图片的第一个维度通道的[2,3]的像素点张量为(0.4736)。

若想取连续的索引,

需要用到:

代码语言:javascript复制
# 取连续索引
print(a.shape)
print(a[:2].shape)
代码语言:javascript复制
torch.Size([2, 3, 28, 28])
# 这里的:相当于→(箭头),表明batch从第一个到第二个,不写默认写全部

同理

代码语言:javascript复制
print(a[:2, 1:, :, :].shape)
# 1写在:前面,表明从1个通道开始到末尾,,不包括1
代码语言:javascript复制
torch.Size([2, 2, 28, 28])

另外

当索引出现-1时,要提到一个知识点

代码语言:javascript复制
print(a[:2, -1:, :, :].shape)
# 默认索引的顺序为[0, 1, 2],当倒着写时变为[-3, -2, -1]。由于这里取-1,因此为最后一位。

此时输出

代码语言:javascript复制
torch.Size([2, 1, 28, 28])

当想隔点取样输出时

代码语言:javascript复制
print(a[:, :, 0:28:2, 0:28:2].shape)
# 输出全部batch和channel,对每个高和宽间隔2个点采样
代码语言:javascript复制
torch.Size([4, 3, 14, 14])

也可简化为

代码语言:javascript复制
print(a[:, :, ::2, ::2].shape)

同样输出为

代码语言:javascript复制
torch.Size([4, 3, 14, 14])

这里需要注意 当写为[0:28:]则等同于[0:28:1]因此可以认为[start:end:steps]

0 人点赞