loss.py 3.9 KB
Newer Older
1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WuHaobo 已提交
14

littletomatodonkey's avatar
littletomatodonkey 已提交
15 16
import paddle
import paddle.nn.functional as F
W
WuHaobo 已提交
17

littletomatodonkey's avatar
littletomatodonkey 已提交
18
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss']
W
WuHaobo 已提交
19 20 21 22 23 24 25 26 27 28


class Loss(object):
    """
    Loss
    """

    def __init__(self, class_dim=1000, epsilon=None):
        assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
        self._class_dim = class_dim
littletomatodonkey's avatar
littletomatodonkey 已提交
29
        if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
W
WuHaobo 已提交
30 31 32 33 34 35 36
            self._epsilon = epsilon
            self._label_smoothing = True
        else:
            self._epsilon = None
            self._label_smoothing = False

    def _labelsmoothing(self, target):
littletomatodonkey's avatar
littletomatodonkey 已提交
37
        if target.shape[-1] != self._class_dim:
littletomatodonkey's avatar
littletomatodonkey 已提交
38
            one_hot_target = F.one_hot(target, self._class_dim)
littletomatodonkey's avatar
littletomatodonkey 已提交
39 40
        else:
            one_hot_target = target
41
        soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)
littletomatodonkey's avatar
littletomatodonkey 已提交
42
        soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
W
WuHaobo 已提交
43 44
        return soft_target

H
huangxu96 已提交
45
    def _crossentropy(self, input, target):
W
WuHaobo 已提交
46 47
        if self._label_smoothing:
            target = self._labelsmoothing(target)
littletomatodonkey's avatar
littletomatodonkey 已提交
48
            input = -F.log_softmax(input, axis=-1)
49
            cost = paddle.sum(target * input, axis=-1)
littletomatodonkey's avatar
littletomatodonkey 已提交
50
        else:
H
huangxu96 已提交
51
            cost = F.cross_entropy(input=input, label=target) 
H
huangxu96 已提交
52
        avg_cost = paddle.mean(cost)
W
WuHaobo 已提交
53 54
        return avg_cost

littletomatodonkey's avatar
littletomatodonkey 已提交
55 56 57 58
    def _kldiv(self, input, target, name=None):
        eps = 1.0e-10
        cost = target * paddle.log(
            (target + eps) / (input + eps)) * self._class_dim
littletomatodonkey's avatar
littletomatodonkey 已提交
59 60 61
        return cost

    def _jsdiv(self, input, target):
littletomatodonkey's avatar
littletomatodonkey 已提交
62 63
        input = F.softmax(input)
        target = F.softmax(target)
littletomatodonkey's avatar
littletomatodonkey 已提交
64 65
        cost = self._kldiv(input, target) + self._kldiv(target, input)
        cost = cost / 2
littletomatodonkey's avatar
littletomatodonkey 已提交
66
        avg_cost = paddle.mean(cost)
littletomatodonkey's avatar
littletomatodonkey 已提交
67 68
        return avg_cost

W
WuHaobo 已提交
69 70 71 72 73 74 75 76 77 78 79 80
    def __call__(self, input, target):
        pass


class CELoss(Loss):
    """
    Cross entropy loss
    """

    def __init__(self, class_dim=1000, epsilon=None):
        super(CELoss, self).__init__(class_dim, epsilon)

H
huangxu96 已提交
81 82
    def __call__(self, input, target):
        cost = self._crossentropy(input, target)
W
WuHaobo 已提交
83 84 85 86 87 88 89 90 91 92 93
        return cost


class MixCELoss(Loss):
    """
    Cross entropy loss with mix(mixup, cutmix, fixmix)
    """

    def __init__(self, class_dim=1000, epsilon=None):
        super(MixCELoss, self).__init__(class_dim, epsilon)

H
huangxu96 已提交
94 95 96
    def __call__(self, input, target0, target1, lam):
        cost0 = self._crossentropy(input, target0)
        cost1 = self._crossentropy(input, target1)
H
huangxu96 已提交
97
        cost = lam * cost0 + (1.0 - lam) * cost1  
H
huangxu96 已提交
98
        avg_cost = paddle.mean(cost)
W
WuHaobo 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        return avg_cost


class GoogLeNetLoss(Loss):
    """
    Cross entropy loss used after googlenet
    """

    def __init__(self, class_dim=1000, epsilon=None):
        super(GoogLeNetLoss, self).__init__(class_dim, epsilon)

    def __call__(self, input0, input1, input2, target):
        cost0 = self._crossentropy(input0, target)
        cost1 = self._crossentropy(input1, target)
        cost2 = self._crossentropy(input2, target)
        cost = cost0 + 0.3 * cost1 + 0.3 * cost2
littletomatodonkey's avatar
littletomatodonkey 已提交
115
        avg_cost = paddle.mean(cost)
W
WuHaobo 已提交
116
        return avg_cost
littletomatodonkey's avatar
littletomatodonkey 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129


class JSDivLoss(Loss):
    """
    JSDiv loss
    """

    def __init__(self, class_dim=1000, epsilon=None):
        super(JSDivLoss, self).__init__(class_dim, epsilon)

    def __call__(self, input, target):
        cost = self._jsdiv(input, target)
        return cost