我有一个多标签分类问题。我有11个类,大约4k个例子。每个例子可以有1到4-5个标签。目前,我正在用log_loss为每个类单独训练一个分类器。正如你所料,训练11个分类器需要相当长的时间,我想尝试另一种方法,只训练1个分类器。这个想法是,这个分类器的最后一层将有11个节点,并将输出一个实数的类,这些类将被转换为一个sigmoid。我想优化的损失是所有类log_loss的平均值。
不幸的是,我是pytorch的新手,即使通过阅读损失的源代码,我也无法弄清楚已经存在的损失是否正是我想要的,或者我是否应该创建一个新的损失,如果是这样,我真的不知道如何去做。
非常具体地说,我想为批处理的每个元素提供一个大小为11的向量(其中包含每个标签的实数(越接近无穷大,该类预测越接近1)和一个大小为11的向量(其中每个真实标签都包含1),并且能够计算所有11个标签的平均log_loss,并根据该损失优化我的分类器。
任何帮助将不胜感激:)
您正在寻找torch. nn.BCELoss
。这是示例代码:
import torch
batch_size = 2
num_classes = 11
loss_fn = torch.nn.BCELoss()
outputs_before_sigmoid = torch.randn(batch_size, num_classes)
sigmoid_outputs = torch.sigmoid(outputs_before_sigmoid)
target_classes = torch.randint(0, 2, (batch_size, num_classes)) # randints in [0, 2).
loss = loss_fn(sigmoid_outputs, target_classes)
# alternatively, use BCE with logits, on outputs before sigmoid.
loss_fn_2 = torch.nn.BCEWithLogitsLoss()
loss2 = loss_fn_2(outputs_before_sigmoid, target_classes)
assert loss == loss2