所以我有一些带注释的图像,我想用它来训练二值图像分类器,但是我在创建数据集和实际训练测试模型时遇到了问题。每个图像要么属于某个类,要么不属于某个类,所以我想使用PyTorch设置一个二分类数据集/模型。我有一些问题:
提前谢谢
报价删除
二进制分类与多标签分类略有不同:对于多标签,您的模型预测每个样本的“logits”向量,并使用softmax将logits转换为概率;在二进制情况下,模型预测每个样本的标量“logit”,并使用sigmoid函数将其转换为类概率。
在pytorch中,softmax和sigmoind被“折叠”到损失层(出于数值稳定性考虑),因此对于两种情况有不同的交叉熵损失层:nn. BCEWellLogitsLoss
对于二进制情况(使用sigmoid)和nn.CrossEntropyLoss
对于多标签情况(使用softmax)。
在您的示例中,您希望使用二进制版本(带有sigmoid):nn. BCEWellLogitsLoss
。
因此,您的标签应该是torch.float32
类型(与网络输出的float32
类型相同)而不是整数。每个样本应该有一个标签。因此,如果您的批处理大小为200,则目标应该具有形状(200,1)
。
我将把它留在这里作为一个练习,以表明训练一个具有两个输出和CEsoftmax的模型相当于二进制输出sigmoid;)