问题描述
在Pytorch训练自定义数据集中发生如下错误:
RuntimeError: result type Float can't be cast to the desired output type Long
RuntimeError:结果类型 Float 无法转换为所需的输出类型 Long
代码语言:javascript复制loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights]))
问题解决
BCEWithLogitsLoss
要求它的目标是一个float
张量,而不是long
。所以应该通过dtype=torch.float32
指定张量的类型。
将上述代码修改如下:
代码语言:javascript复制loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights], dtype=torch.float32))
参考文章:Pytorch 抛出错误 RuntimeError: result type Float can’t be cast to the desired output type Long答案 - 爱码网 (likecs.com)