multilabelloss.py 2.1 KB
Newer Older
C
cuicheng01 已提交
1 2 3 4 5
import paddle
import paddle.nn as nn
import paddle.nn.functional as F


Z
zhiboniu 已提交
6 7 8 9 10 11 12 13 14 15 16
def ratio2weight(targets, ratio):
    pos_weights = targets * (1. - ratio)
    neg_weights = (1. - targets) * ratio
    weights = paddle.exp(neg_weights + pos_weights)

    # for RAP dataloader, targets element may be 2, with or without smooth, some element must great than 1
    weights = weights - weights * (targets > 1)

    return weights


C
cuicheng01 已提交
17 18 19 20 21
class MultiLabelLoss(nn.Layer):
    """
    Multi-label loss
    """

Z
zhiboniu 已提交
22
    def __init__(self, epsilon=None, size_sum=False, weight_ratio=False):
C
cuicheng01 已提交
23 24 25 26
        super().__init__()
        if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
            epsilon = None
        self.epsilon = epsilon
Z
zhiboniu 已提交
27
        self.weight_ratio = weight_ratio
Z
zhiboniu 已提交
28
        self.size_sum = size_sum
C
cuicheng01 已提交
29 30 31 32 33 34 35 36 37 38 39

    def _labelsmoothing(self, target, class_num):
        if target.ndim == 1 or target.shape[-1] != class_num:
            one_hot_target = F.one_hot(target, class_num)
        else:
            one_hot_target = target
        soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
        soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
        return soft_target

    def _binary_crossentropy(self, input, target, class_num):
Z
zhiboniu 已提交
40 41
        if self.weight_ratio:
            target, label_ratio = target
C
cuicheng01 已提交
42 43
        if self.epsilon is not None:
            target = self._labelsmoothing(target, class_num)
Z
zhiboniu 已提交
44 45
        cost = F.binary_cross_entropy_with_logits(
            logit=input, label=target, reduction='none')
Z
zhiboniu 已提交
46

Z
zhiboniu 已提交
47
        if self.weight_ratio:
Z
zhiboniu 已提交
48
            targets_mask = paddle.cast(target > 0.5, 'float32')
Z
zhiboniu 已提交
49
            weight = ratio2weight(targets_mask, paddle.to_tensor(label_ratio))
Z
zhiboniu 已提交
50 51
            weight = weight * (target > -1)
            cost = cost * weight
Z
zhiboniu 已提交
52 53 54

        if self.size_sum:
            cost = cost.sum(1).mean() if self.size_sum else cost.mean()
C
cuicheng01 已提交
55 56 57 58 59 60 61 62 63 64

        return cost

    def forward(self, x, target):
        if isinstance(x, dict):
            x = x["logits"]
        class_num = x.shape[-1]
        loss = self._binary_crossentropy(x, target, class_num)
        loss = loss.mean()
        return {"MultiLabelLoss": loss}