当我们处理不平衡的训练数据(负样本较多,正样本较少)时,通常会使用pos_weight
参数。
pos_weight
的期望是当 positive sample
得到错误标签时,模型会比 negative sample
得到更高的损失。
当我使用 binary_cross_entropy_with_logits
函数时,我发现:
import torch
import torch.nn.functional as F
pos_weight = torch.FloatTensor([5])
preds_pos_wrong = torch.FloatTensor([0.5,1.5])
label_pos = torch.FloatTensor([1,0])
loss_pos_wrong = F.binary_cross_entropy_with_logits(preds_pos_wrong,label_pos,pos_weight=pos_weight)
'''
loss_pos_wrong = tensor(2.0359)
'''
preds_neg_wrong = torch.FloatTensor([1.5,0.5])
label_neg = torch.FloatTensor([0,1])
loss_neg_wrong = F.binary_cross_entropy_with_logits(preds_neg_wrong,label_neg,pos_weight=pos_weight)
'''
loss_neg_wrong = tensor(2.0359)
'''
错误的正样本和负样本得到的损失是一样的,那么pos_weight
在不平衡数据损失计算中是如何工作的?