chunk
torch.chunk(input, chunks, dim = 0)
函数会将输入张量(input)沿着指定维度(dim)均匀的分割成特定数量的张量块(chunks),并返回元素为张量块的元组。torch.chunk 函数有三个参数:
- input(Tensor)- 待分割的输入张量
- chunks(int)- 均匀分割张量块的数量
- dim(int)- 进行分割的维度
以包含批量维度的图像张量为例,设张量
保存了 128 张,长和宽为 32 的三通道像素矩阵,则张量
的形状为
(PyTorch将通道维度放在前面,即
)。
现在我们想将张量
这 128 张图片均匀的分割成 16 块,每块包含 8 张图片。可以使用 torch.chunk 函数沿着第 0 个维度(批量维度,dim = 0)均匀的将张量
(input = A)分割成 16 块(chunks = 16)。
代码语言:javascript复制import torch
A = torch.randint(0, 255, (128, 3, 32, 32))
result = torch.chunk(input=A,
chunks=16,
dim=0)
print(type(result))
# <class 'tuple'>
print(len(result))
# 16
print(type(result[0]))
# <class 'torch.Tensor'>
print(result[0].size())
# torch.Size([8, 3, 32, 32])
将形状为
的张量
,沿着第 0 个维度(批量维度)均匀分割成 16 块(
),其中每一块都是形状为
的张量。
如果将将张量
这 128 张图片均匀的分割成 14 块(
),显然不能像分割成 16 块那样能够均匀的分割。在这种情况下,torch.chunk 函数会先按照每块 10 张图片进行分割,即每一块都是形状为
的张量,余下的作为最后一块。
代码语言:javascript复制import torch
A = torch.randint(0, 255, (128, 3, 32, 32))
result = torch.chunk(input=A,
chunks=14,
dim=0)
print(len(result))
# 13
print(result[0].size())
# torch.Size([10, 3, 32, 32])
print(result[-1].size())
# torch.Size([8, 3, 32, 32])
小结
可以沿着输入张量的任意维度均匀分割。使用 torch.chunk 函数沿着 dim 维度将张量均匀的分割成 chunks 块,若式子
结果为:
- 整数(整除),表示能够将其均匀的分割成 chunks 块,直接进行分割即可;
- 浮点数(不能够整除),先按每块
(
为向上取整)进行分割,余下的作为最后一块;
比如,将形状为
的张量
,现在沿着第 1 个维度均匀的分割成 2 块。B.size(1) = 3
、chunks = 2,即:
1.5 不是整数,则将其向上取整
,先将 3 按每块 2 个进行分割,余下的作为最后一块。
代码语言:javascript复制import torch
B = torch.arange(6).reshape(2, 3)
result = torch.chunk(input = B,
chunks = 2,
dim = 1)
print(B)
# tensor([[0, 1, 2],
# [3, 4, 5]])
print(result)
# tensor([[0, 1],
# [3, 4]]), tensor([[2],
# [5]]))