提问者:小点点

交叉熵损失Pytorch


我有一个关于在我的pytorch网络中交叉熵损失的最佳实现的问题。我正在构建一个预测卷图片3D分割的网络。我有后台类或一个前台类,但它也应该有可能预测两个或多个不同的前台类。

训练时输入到我的网络的形状是[batch_size,input_channels,宽度,高度,深度]。输入通道等于1,因为我只有灰度图像,深度是根据构成3D体积的2D图片的数量。

网络的输出是[batch_size,#类,宽度,高度,深度]。#类现在等于两个,这两个维度几乎是相反的,因为我现在只有背景类和一个前景类。

现在我想计算交叉熵损失。现在我的实现如下所示:

    def loss(output, gt):
    loss_fn = nn.CrossEntropyLoss()
    gt_temp = torch.tensor(np.zeros((output.shape[0], output.shape[2],output.shape[3], output.shape[4]))).long()
    gt_temp[:, :, :, :] = gt[:, 0, :, :, :]
    return loss_fn(output, gt_temp)

我必须对基础真理进行一些维度更改,因为nn. CrossEntropyLoss()期望基础真理的大小为(5,384,384,81),而输出的大小为(5,2,384,384,81)-

问题是这并没有真正起作用。它收敛到0.5-0.6的损失,它只是预测整个输出是背景。我该怎么办?

如果有人能帮忙,那就太棒了!!!我也不明白为什么输出和基础真理不应该是相同的形状?!

谢谢!


共1个答案

匿名用户

关于形状问题,交叉熵损失有两个pytorch损失函数:

>

  • 二进制交叉熵损失-期望每个目标和输出是形状[batch_size,num_classes,……]的张量,每个张量的值在[0,1]范围内。

    交叉熵损失-为简单起见,目标张量不是size[batch_size,…]。每个值在[0,num_classes-1]范围内。

    认为交叉熵损失函数中做的第一件事是使用目标张量计算二进制交叉熵损失的目标张量可能会有所帮助,然后损失的计算方式与此损失函数相同。唯一的区别是,根据您的目标数据的格式,一个或另一个公式可能稍微更方便,因为它需要更少的重塑。