我们正在尝试使用pytorch中的CNN实现多标签分类。我们有8个标签和大约260个图像,使用90/10分割用于训练/验证集。
这些类别高度不平衡,最频繁的类别出现在140多张图像中。另一方面,最少的类出现在不到5张图像中。
我们最初尝试了BCEWellLogitsLoss函数,这导致模型预测所有图像的相同标签。
然后,我们实施了一种焦点损失方法来处理类不平衡,如下所示:
import torch.nn as nn
import torch
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, outputs, targets):
bce_criterion = nn.BCEWithLogitsLoss()
bce_loss = bce_criterion(outputs, targets)
pt = torch.exp(-bce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
return focal_loss
这导致模型预测每个图像的空集(无标签),因为它无法为任何类获得大于0.5的置信度。
pytorch中是否有方法来帮助解决这种情况?
基本上有三种方法可以解决这个问题。
>
从更常见的类中丢弃数据
权重少数类损失值更重
对少数类进行过采样
选项1是通过选择您包含在数据集中的文件来实现的。
选项2使用pos_weight
参数实现,用于BCEWellLogitsLoss
选项3使用传递给您的Dataloader的自定义Sampler
来实现
对于深度学习,过采样通常效果最好。