提问者:小点点

如何使用pytorch处理多标签分类中的类不平衡


我们正在尝试使用pytorch中的CNN实现多标签分类。我们有8个标签和大约260个图像,使用90/10分割用于训练/验证集。

这些类别高度不平衡,最频繁的类别出现在140多张图像中。另一方面,最少的类出现在不到5张图像中。

我们最初尝试了BCEWellLogitsLoss函数,这导致模型预测所有图像的相同标签。

然后,我们实施了一种焦点损失方法来处理类不平衡,如下所示:

    import torch.nn as nn
    import torch

    class FocalLoss(nn.Module):
        def __init__(self, alpha=1, gamma=2):
            super(FocalLoss, self).__init__()
            self.alpha = alpha
            self.gamma = gamma

        def forward(self, outputs, targets):
            bce_criterion = nn.BCEWithLogitsLoss()
            bce_loss = bce_criterion(outputs, targets)
            pt = torch.exp(-bce_loss)
            focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
            return focal_loss 

这导致模型预测每个图像的空集(无标签),因为它无法为任何类获得大于0.5的置信度。

pytorch中是否有方法来帮助解决这种情况?


共1个答案

匿名用户

基本上有三种方法可以解决这个问题。

>

  • 从更常见的类中丢弃数据

    权重少数类损失值更重

    对少数类进行过采样

    选项1是通过选择您包含在数据集中的文件来实现的。

    选项2使用pos_weight参数实现,用于BCEWellLogitsLoss

    选项3使用传递给您的Dataloader的自定义Sampler来实现

    对于深度学习,过采样通常效果最好。