提问者:小点点

手动计算pytorch中的交叉熵损失


我试图在Pytorch中手动计算编码器-解码器模型的cross_entropy损失。

我使用这里发布的代码来计算它:PyTorch中的交叉熵

我更新了代码以丢弃填充令牌(-100)。最终代码是这样的:

class compute_crossentropyloss_manual:
    """
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    """
    def __init__(self, ignore_index=-100) -> None:
        self.ignore_index=ignore_index
    
    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        # print(n_class)
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())
            if class_index == self.ignore_index:  # <------ I added this if-statement
                continue
            loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
        loss = - loss/n_batch
        return loss

为了验证它是否正常工作,我在文本生成任务上对其进行了测试,并使用pytorch. nn实现和使用此代码计算了损失。

损失价值并不相同:

使用nn. CrossEntropyLoss

使用上面链接中的代码:

我错过什么了吗?

我试图获取nn. CrossEntropyLoss的源代码,但我不能。在这个链接nn/functional.py第2955行,您将看到该函数指向另一个名为torch的cross_entropy损失。_C_nncross_entropy_loss;我在repo中找不到这个函数。

编辑:

我注意到只有当我在黄金中有-100令牌时才会出现差异。

演示示例:

y = torch.randint(1, 50, (100, 50), dtype=float)
x = torch.randint(1, 50, (100,))

x[40:] = -100
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 25.55788695847976
> 10.223154783391905

当我们没有-100时:

x[40:] = 30 # any positive number
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 24.684453267596453
> 24.684453267596453

共2个答案

匿名用户

我通过更新代码解决了这个问题。我在-100标记(上面的if语句)之前丢弃了,但是我忘记了减小hidden_state大小(在上面的代码中称为n_batch)。这样做之后,损失数字与nn. CrossEntropyLoss值相同。最终代码:

class CrossEntropyLossManual:
    """
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    """
    def __init__(self, ignore_index=-100) -> None:
        self.ignore_index=ignore_index
    
    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        # print(n_class)
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())
            if class_index == self.ignore_index:
                n_batch -= 1
                continue
            loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
        loss = - loss/n_batch
        return loss

匿名用户

我也需要这个——谢谢你的手动交叉熵损失代码。它与pytorch结果完美匹配(与我的数据)。我对你上面的修复有一个小小的修复。最后,你需要除以未忽略行的最终计数(那些没有标签-100的行)。所以你需要一个计数器:

class compute_crossentropyloss_manual:
    """
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    """
    def __init__(self, ignore_index=-100) -> None:
        self.ignore_index=ignore_index
    
    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        # print(n_class)
        cnt = 0             # <----- I added this
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())
            if class_index == self.ignore_index:
                continue
            loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
            cnt += 1        # <----- I added this
        loss = - loss/cnt   # <---- I changed this from nbatch to 'cnt'
        return loss