说明:移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片。
参数:
- tensor(Tensor) -- 输入张量
- dim(int) -- 删除的维度
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 0.4775, 0.0161, -0.9403],
[ 1.6109, 2.1144, 1.1833],
[-0.2656, 0.7772, 0.5989]])
>>> torch.unbind(x, dim=1)(tensor([ 0.4775, 1.6109, -0.2656]), tensor([0.0161, 2.1144, 0.7772]), tensor([-0.9403, 1.1833, 0.5989]))