Pytorch RuntimeError 解决办法

2023-03-01 10:47:46 浏览数 (1)

问题描述

在Pytorch训练自定义数据集中发生如下错误:

RuntimeError: result type Float can't be cast to the desired output type Long

RuntimeError:结果类型 Float 无法转换为所需的输出类型 Long

代码语言:javascript复制
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights]))

问题解决

BCEWithLogitsLoss 要求它的目标是一个float 张量,而不是long。所以应该通过dtype=torch.float32指定张量的类型。

将上述代码修改如下:

代码语言:javascript复制
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights], dtype=torch.float32))

参考文章:Pytorch 抛出错误 RuntimeError: result type Float can’t be cast to the desired output type Long答案 - 爱码网 (likecs.com)

0 人点赞