PyTorch入门笔记-复制数据repeat函数

2021-02-12 01:28:04 浏览数 (1)

repeat

前面提到过 input.expand(*sizes) 函数能够实现 input 输入张量中单维度(singleton dimension)上数据的复制操作。「对于非单维度上的复制操作,expand 函数就无能为力了,此时就需要使用 input.repeat(*sizes)。」

input.repeat(*sizes) 可以对 input 输入张量中的单维度和非单维度进行复制操作,并且会真正的复制数据保存到内存中。input.expand(*sizes)input.repeat(*sizes) 两个函数的区别如下表所示。

input.repeat(*sizes) 函数中的 *sizes 参数分别指定了各个维度上复制的倍数,对于不需要复制的维度需要指定为 1。(在expand函数中对于不需要(或非单维度)进行复制的维度,对应位置上可以写上原始维度的大小或者直接写 -1)

对单维度上的数据进行复制,repeat 函数和 expand 函数类似,和 expand 函数一样,repeat 函数也融合了插入批量维度并在新插入的批量维度上复制数据的操作。

代码语言:txt复制
import torch

# 创建偏置b
b = torch.tensor([1, 2, 3])
# 为张量b插入新的维度
B = torch.unsqueeze(b, 0)

print(B.size())
# torch.Size([1, 3])

print(B)
# tensor([[1, 2, 3]])

在批量维度上复制数据 1 份,实现如下:

代码语言:txt复制
# 1意味着不对对应维度进行复制
B = B.repeat([2, 1])
print(B)
# tensor([[1, 2, 3],
#         [1, 2, 3]])

由于 repeat 函数也融合了插入批量维度并在新插入的批量维度上复制数据的操作,所以对于上面的偏置 b,我们可以省略 torch.unsqueeze(b, dim = 0) 插入批量维度的操作,直接使用 repeat 函数。

代码语言:txt复制
import torch

# 创建偏置b
b = torch.tensor([1, 2, 3])
# 直接插入批量维度并复制2份
B = b.repeat([2, 1])

print(B.size())
# torch.Size([2, 3])

print(B)
# tensor([[1, 2, 3],
#         [1, 2, 3]])

「使用 repeat 函数对非单维度进行复制,简单来说就是对非单维度的所有元素整体进行复制。」 以下面形状为 (2, 2) 的 2D 张量为例。

  • Step1: 将 dim = 0 维度上的数据复制 1 份,dim = 1 维度上的数据保持不变。
  • Step2: Step1 得到的形状为 (4, 2) 的 2D 张量的 dim = 0 维度上的数据保持不变,dim = 1 维度上的数据复制 1 份。

上面操作使用 repeat 函数的具体实现如下。

代码语言:txt复制
import torch

a = torch.arange(4).reshape([2, 2])
print(a)
# tensor([[0, 1],
#         [2, 3]]) 


# dim=0维度的数据复制1份,dim=1维度的数据保持不变
step1_a = a.repeat([2, 1])
print(step1_a)
# tensor([[0, 1],
#         [2, 3],
#         [0, 1],
#         [2, 3]])


# 将dim=0维度的数据保持不变,dim=1维度的数据复制1份
step2_a = step1_a.repeat([1, 2])
print(step2_a)
# tensor([[0, 1, 0, 1],
#         [2, 3, 2, 3],
#         [0, 1, 0, 1],
#         [2, 3, 2, 3]])

0 人点赞