提问者:小点点

Pytorch类别交叉熵损失函数行为


我对Pytorch的分类交叉熵损失的计算有疑问。我制作了这个简单的代码片段,因为我使用输出张量的argmax作为目标,我不明白为什么损失仍然很高。

import torch
import torch.nn as nn
ce_loss = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
targets = torch.argmax(output, dim=1)
loss = ce_loss(outputs, targets)
print(loss)

谢谢你帮我理解这件事向杰罗姆问好


共1个答案

匿名用户

因此,这是来自您的代码的示例数据,其中输出标签损失具有以下值

outputs =  tensor([[ 0.5968, -0.8249,  1.5018,  2.7888, -0.6125],
                   [-1.1534, -0.4921,  1.0688,  0.2241, -0.0257],
                   [ 0.3747,  0.8957,  0.0816,  0.0745,  0.2695]], requires_grad=True)requires_grad=True)

labels = tensor([3, 2, 1])
loss = tensor(0.7354, grad_fn=<NllLossBackward>)

所以让我们检查这些值,

如果您计算日志(输出)的softmax输出,使用类似于torch. softmax(输出,轴=1)的东西,您将获得

probs = tensor([[0.0771, 0.0186, 0.1907, 0.6906, 0.0230],
                [0.0520, 0.1008, 0.4801, 0.2063, 0.1607],
                [0.1972, 0.3321, 0.1471, 0.1461, 0.1775]], grad_fn=<SoftmaxBackward>)

这些就是你的预测概率。

现在交叉熵损失只不过是softmax负对数似然损失的组合。因此,您的损失可以简单地使用

loss = (torch.log(1/probs[0,3]) +  torch.log(1/probs[1,2]) + torch.log(1/probs[2,1])) / 3

,这是您的真实标签概率的负对数的平均值。上述等式的计算结果为0.7354,相当于从nn. CrossEntropyLoss模块返回的值。