二元交叉熵计算中的pos_weight

当我们处理不平衡的训练数据(负样本较多,正样本较少)时,通常会使用pos_weight参数。 pos_weight 的期望是当 positive sample 得到错误标签时,模型会比 negative sample 得到更高的损失。 当我使用 binary_cross_entropy_with_logits 函数时,我发现:

import torch
import torch.nn.functional as F

pos_weight = torch.FloatTensor([5])

preds_pos_wrong =  torch.FloatTensor([0.5,1.5])
label_pos = torch.FloatTensor([1,0])
loss_pos_wrong = F.binary_cross_entropy_with_logits(preds_pos_wrong,label_pos,pos_weight=pos_weight)

'''
loss_pos_wrong = tensor(2.0359)
'''

preds_neg_wrong =  torch.FloatTensor([1.5,0.5])
label_neg = torch.FloatTensor([0,1])
loss_neg_wrong = F.binary_cross_entropy_with_logits(preds_neg_wrong,label_neg,pos_weight=pos_weight)

'''
loss_neg_wrong = tensor(2.0359)
'''

错误的正样本和负样本得到的损失是一样的,那么pos_weight在不平衡数据损失计算中是如何工作的?

sgmk3738 回答:二元交叉熵计算中的pos_weight

暂时没有好的解决方案,如果你有好的解决方案,请发邮件至:iooj@foxmail.com
本文链接:https://www.f2er.com/1723.html

大家都在问