celoss.py 3.8 KB
Newer Older
B
Bin Lu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn.functional as F

__all__ = ['CELoss', 'JSDivLoss', 'KLDivLoss']


class Loss(object):
    """
    Loss
    """
D
dongshuilong 已提交
25

B
Bin Lu 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38
    def __init__(self, class_dim=1000, epsilon=None):
        assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
        self._class_dim = class_dim
        if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
            self._epsilon = epsilon
            self._label_smoothing = True  #use label smoothing.(Actually, it is softmax label)
        else:
            self._epsilon = None
            self._label_smoothing = False

    #do label_smoothing
    def _labelsmoothing(self, target):
        if target.shape[-1] != self._class_dim:
D
dongshuilong 已提交
39 40 41
            one_hot_target = F.one_hot(
                target,
                self._class_dim)  #do ont hot(23,34,46)-> 3 * _class_dim
B
Bin Lu 已提交
42 43 44 45
        else:
            one_hot_target = target

        #do label_smooth
D
dongshuilong 已提交
46 47 48
        soft_target = F.label_smooth(
            one_hot_target,
            epsilon=self._epsilon)  #(1 - epsilon) * input + eposilon / K.
B
Bin Lu 已提交
49 50 51 52 53 54
        soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
        return soft_target

    def _crossentropy(self, input, target, use_pure_fp16=False):
        if self._label_smoothing:
            target = self._labelsmoothing(target)
D
dongshuilong 已提交
55
            input = -F.log_softmax(input, axis=-1)  #softmax and do log
B
Bin Lu 已提交
56 57
            cost = paddle.sum(target * input, axis=-1)  #sum  
        else:
D
dongshuilong 已提交
58
            cost = F.cross_entropy(input=input, label=target)
B
Bin Lu 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71

        if use_pure_fp16:
            avg_cost = paddle.sum(cost)
        else:
            avg_cost = paddle.mean(cost)
        return avg_cost

    def _kldiv(self, input, target, name=None):
        eps = 1.0e-10
        cost = target * paddle.log(
            (target + eps) / (input + eps)) * self._class_dim
        return cost

D
dongshuilong 已提交
72 73
    def _jsdiv(self, input,
               target):  #so the input and target is the fc output; no softmax
B
Bin Lu 已提交
74
        input = F.softmax(input)
D
dongshuilong 已提交
75
        target = F.softmax(target)
B
Bin Lu 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95

        #two distribution
        cost = self._kldiv(input, target) + self._kldiv(target, input)
        cost = cost / 2
        avg_cost = paddle.mean(cost)
        return avg_cost

    def __call__(self, input, target):
        pass


class CELoss(Loss):
    """
    Cross entropy loss
    """

    def __init__(self, class_dim=1000, epsilon=None):
        super(CELoss, self).__init__(class_dim, epsilon)

    def __call__(self, input, target, use_pure_fp16=False):
D
dongshuilong 已提交
96 97 98 99
        if type(input) is dict:
            logits = input["logits"]
        else:
            logits = input
B
Bin Lu 已提交
100 101 102
        cost = self._crossentropy(logits, target, use_pure_fp16)
        return {"CELoss": cost}

D
dongshuilong 已提交
103

B
Bin Lu 已提交
104 105 106 107
class JSDivLoss(Loss):
    """
    JSDiv loss
    """
D
dongshuilong 已提交
108

B
Bin Lu 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    def __init__(self, class_dim=1000, epsilon=None):
        super(JSDivLoss, self).__init__(class_dim, epsilon)

    def __call__(self, input, target):
        cost = self._jsdiv(input, target)
        return cost


class KLDivLoss(paddle.nn.Layer):
    def __init__(self):
        super(KLDivLoss, self).__init__()

    def __call__(self, p, q, is_logit=True):
        if is_logit:
            p = paddle.nn.functional.softmax(p)
            q = paddle.nn.functional.softmax(q)
B
Bin Lu 已提交
125
        return -(p * paddle.log(q + 1e-8)).sum(1).mean()