下面对dim和keepdim在统计中所起到的作用进行介绍
代码语言:javascript复制import torch
a = torch.rand([4, 10])
print(a)
print(a.max(dim=1))
# 在dim=1,即[10]维度上找出最大值,返回的是dim=1、shape=4的数组
print('不加keepdim', a.argmax(dim=1))
# 输出最大值所在的位置
print('加keepdim', a.argmax(dim=1, keepdim=True))
# 输出最大值所在的位置
分别输出为
代码语言:javascript复制tensor([[0.0568, 0.3917, 0.5959, 0.9399, 0.1483, 0.0496, 0.5185, 0.0032, 0.7857,
0.5175],
[0.9679, 0.3518, 0.0875, 0.5074, 0.3734, 0.6795, 0.1170, 0.2051, 0.8028,
0.6152],
[0.5434, 0.8552, 0.4144, 0.5518, 0.0127, 0.2684, 0.3239, 0.1878, 0.4649,
0.8533],
[0.9886, 0.6360, 0.9998, 0.0884, 0.5477, 0.0661, 0.7822, 0.3943, 0.2967,
0.3295]])
torch.return_types.max(
values=tensor([0.9399, 0.9679, 0.8552, 0.9998]),
indices=tensor([3, 0, 1, 2]))
不加keepdim tensor([3, 0, 1, 2])
加keepdim tensor([[3],
[0],
[1],
[2]])
由输出结果可见,加了keepdim=True后,输出矩阵的shape为[4, 2]而不是[4, 1]。本身统计信息是带有改变dim功能的,添加该语句后,可以保持前后的din一致。当然也可以使用.unsqueeze函数添加列,但不如直接加keepdim=True简单。
再如topk函数应用也较多,top为‘位居前列的’、k代表‘具体数值’。topk可以比.max和.min返回更多的数据。其API为:.topk(self, k, dim, largest, sorted)
举例
代码语言:javascript复制a = torch.tensor([[0.9082],
[0.8063],
[0.5034],
[0.9467],
[0.8170],
[0.3109],
[0.7786],
[0.7719],
[0.7661],
[0.1433]])
print(a)
print(a.topk(3, dim=0, largest=True))
# 设置了k=3, 输出最大的3个元素的数值和其相应的位置
# 上式中,若想求得最小的几个,将largest=True,改为largest=False即可
输出
代码语言:javascript复制tensor([[0.9082],
[0.8063],
[0.5034],
[0.9467],
[0.8170],
[0.3109],
[0.7786],
[0.7719],
[0.7661],
[0.1433]])
torch.return_types.topk(
values=tensor([[0.9467],
[0.9082],
[0.8170]]),
indices=tensor([[3],
[0],
[4]]))
由输出结果看出最大的三个数值按照由大到小的顺序排列依次是0.9467、0.9082和0.8170,他们所在的位置分别是第3位、第0位和第4位。
与topk功能类似的kthvalue,即第k个的value。
.kthvalue(self, k, dim, keepdim)
默认选择最小的
代码语言:javascript复制print('最小数值的位置', a.kthvalue(8, dim=0))
# 第8个最小的即为第3个大的 (第9小->第2大,第10小->第一大)
输出
代码语言:javascript复制最小数值的位置 torch.return_types.kthvalue(
values=tensor([0.8170]),
indices=tensor([4]))
由结果看出第8个小的(第3个大的)数值为0.8170,位置为第4个元素。
另外当想进行元素间比较时,与数学一样要用到各类数学符号:>、>=、<、<=、!=、==。
而返回的结果为0和1。
当想比较两矩阵是否相同时,用torch.eq()函数,相同返回1,不同返回0。
代码语言:javascript复制a = torch.tensor([21, 21])
b = torch.tensor([21, 2])
c = torch.tensor([21, 21])
print('a与b比较', torch.eq(a, b))
print('a与c比较', torch.eq(a, c))
分别输出为
代码语言:javascript复制a与b比较 tensor([1, 0], dtype=torch.uint8)
a与c比较 tensor([1, 1], dtype=torch.uint8)
由上可知Torch.eq会挨个元素进行对比。
而torch.equal()会对比整体是否相同,相同返回True,不同返回False。
代码语言:javascript复制a = torch.tensor([21, 21])
b = torch.tensor([21, 2])
c = torch.tensor([21, 21])
print('a与b比较', torch.equal(a, b))
print('a与c比较', torch.equal(a, c))
分别输出
代码语言:javascript复制a与b比较 False
a与c比较 True