1.作用
对tensor中元素排序 2.用法
dim = -1,按照行排序,dim= 1按照列排序,descending=True,则递减排序,否则递增 3.例子
按照行排序
代码语言:javascript复制logits = torch.tensor([[[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]]])
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) # 对logits进行递减排序
print(sorted_logits)
print(sorted_indices)
输出:
tensor([[[ 0.4053, -0.3873, -0.5816, -1.0145, -1.0215],
[ 1.8823, 1.4164, 1.3443, 1.2035, 0.7265],
[ 1.7255, 1.2590, 0.1673, -0.4451, -2.0757],
[ 0.3849, 0.3041, 0.2021, 0.1383, -1.6311]]])
tensor([[[4, 1, 0, 3, 2],
[4, 1, 2, 3, 0],
[4, 2, 1, 0, 3],
[3, 1, 0, 2, 4]]])
按照列排序:
代码语言:javascript复制sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=1) # 对logits进行递减排序
print(sorted_logits)
print(sorted_indices)
输出:
tensor([[[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[ 0.2021, 0.3041, 1.2590, 0.3849, 1.7255],
[-0.4451, 0.1673, 0.1383, -1.0145, 0.4053],
[-0.5816, -0.3873, -1.0215, -2.0757, -1.6311]]])
tensor([[[1, 1, 1, 1, 1],
[3, 3, 2, 3, 2],
[2, 2, 3, 0, 0],
[0, 0, 0, 2, 3]]])