一、前言
本文将介绍PyTorch中张量的索引和切片操作。
二、实验环境
本系列实验使用如下环境
代码语言:javascript复制conda create -n DL python==3.11
代码语言:javascript复制conda activate DL
代码语言:javascript复制conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
三、PyTorch数据结构
1、Tensor(张量)
Tensor(张量)是PyTorch中用于表示多维数据的主要数据结构,类似于多维数组,可以存储和操作数字数据。
1. 维度(Dimensions)
Tensor(张量)的维度(Dimensions)是指张量的轴数或阶数。在PyTorch中,可以使用size()方法获取张量的维度信息,使用dim()方法获取张量的轴数。
2. 数据类型(Data Types)
PyTorch中的张量可以具有不同的数据类型:
- torch.float32或torch.float:32位浮点数张量。
- torch.float64或torch.double:64位浮点数张量。
- torch.float16或torch.half:16位浮点数张量。
- torch.int8:8位整数张量。
- torch.int16或torch.short:16位整数张量。
- torch.int32或torch.int:32位整数张量。
- torch.int64或torch.long:64位整数张量。
- torch.bool:布尔张量,存储True或False。
【深度学习】Pytorch 系列教程(一):PyTorch数据结构:1、Tensor(张量)及其维度(Dimensions)、数据类型(Data Types)
3. GPU加速(GPU Acceleration)
【深度学习】Pytorch 系列教程(二):PyTorch数据结构:1、Tensor(张量): GPU加速(GPU Acceleration)
2、张量的数学运算
PyTorch提供了丰富的操作函数,用于对Tensor进行各种操作,如数学运算、统计计算、张量变形、索引和切片等。这些操作函数能够高效地利用GPU进行并行计算,加速模型训练过程。
1. 向量运算
【深度学习】Pytorch 系列教程(三):PyTorch数据结构:2、张量的数学运算(1):向量运算(加减乘除、数乘、内积、外积、范数、广播机制)
2. 矩阵运算
【深度学习】Pytorch 系列教程(四):PyTorch数据结构:2、张量的数学运算(2):矩阵运算及其数学原理(基础运算、转置、行列式、迹、伴随矩阵、逆、特征值和特征向量)
3. 向量范数、矩阵范数、与谱半径详解
【深度学习】Pytorch 系列教程(五):PyTorch数据结构:2、张量的数学运算(3):向量范数(0、1、2、p、无穷)、矩阵范数(弗罗贝尼乌斯、列和、行和、谱范数、核范数)与谱半径详解
4. 一维卷积运算
【深度学习】Pytorch 系列教程(六):PyTorch数据结构:2、张量的数学运算(4):一维卷积及其数学原理(步长stride、零填充pad;宽卷积、窄卷积、等宽卷积;卷积运算与互相关运算)
5. 二维卷积运算
【深度学习】Pytorch 系列教程(七):PyTorch数据结构:2、张量的数学运算(5):二维卷积及其数学原理
6. 高维张量
【深度学习】pytorch教程(八):PyTorch数据结构:2、张量的数学运算(6):高维张量:乘法、卷积(conv2d~ 四维张量;conv3d~五维张量)
3、张量的统计计算
【深度学习】Pytorch教程(九):PyTorch数据结构:3、张量的统计计算详解
4、张量操作
1. 张量变形
【深度学习】Pytorch教程(十):PyTorch数据结构:4、张量操作(1):张量变形
2. 索引
在PyTorch中,可以使用索引和切片操作来访问和修改张量的特定元素或子集。
a. 使用整数索引访问单个元素
代码语言:javascript复制import torch
x = torch.tensor([1, 2, 3, 4, 5])
element = x[0] # 访问第一个元素
print(element)
输出:
代码语言:javascript复制tensor(1)
b. 使用多个整数索引访问多个元素
代码语言:javascript复制import torch
x = torch.tensor([1, 2, 3, 4, 5])
elements = x[[0, 2, 4]] # 访问第1、3、5个元素
print(elements)
输出:
代码语言:javascript复制tensor([1, 3, 5])
c. 使用负数索引从张量的末尾开始计数
代码语言:javascript复制import torch
x = torch.tensor([1, 2, 3, 4, 5])
print(x[-1]) # 访问最后一个元素
输出:
代码语言:javascript复制tensor(5)
d. 使用布尔索引访问满足条件的元素
代码语言:javascript复制import torch
x = torch.tensor([1, 2, 3, 4, 5])
mask = x > 2 # 创建一个布尔掩码
elements = x[mask] # 访问大于2的元素
print(elements)
输出:
代码语言:javascript复制tensor([3, 4, 5])
e. torch.where()函数根据条件选择元素
代码语言:javascript复制import torch
x = torch.tensor([1, 2, 3, 4, 5])
indices = torch.where(x > 2) # 找到大于2的元素的索引
selected = x[indices] # 根据索引选择元素
print(selected)
输出:
代码语言:javascript复制tensor([3, 4, 5])
代码语言:javascript复制import torch
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([10, 20, 30, 40, 50])
condition = torch.tensor([True, False, True, False, True])
result = torch.where(condition, x, y) # 根据条件选择x或y中的元素
print(result)
输出:
代码语言:javascript复制tensor([ 1, 20, 3, 40, 5])
f. torch.take()函数按索引从张量中选择元素
代码语言:javascript复制import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
indices = torch.tensor([0, 4, 8])
selected = torch.take(x, indices)
print(selected)
输出:
代码语言:javascript复制tensor([1, 5, 9])
g. torch.nonzero()函数找到张量中非零元素的索引
代码语言:javascript复制import torch
x = torch.tensor([0, 1, 0, 2, 3, 0])
indices = torch.nonzero(x)
print(indices)
输出:
代码语言:javascript复制tensor([[1],
[3],
[4]])
3. 切片操作
a. 使用start:end切片操作访问子集
代码语言:javascript复制import torch
x = torch.tensor([1, 2, 3, 4, 5])
subset = x[1:4] # 获取索引1到3的子集
print(subset)
输出:
代码语言:javascript复制tensor([2, 3, 4])