celoss.py 3.7 KB
Newer Older
B
Bin Lu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn.functional as F

__all__ = ['CELoss', 'JSDivLoss', 'KLDivLoss']


class Loss(object):
    """
    Loss
    """
    def __init__(self, class_dim=1000, epsilon=None):
        assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
        self._class_dim = class_dim
        if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
            self._epsilon = epsilon
            self._label_smoothing = True  #use label smoothing.(Actually, it is softmax label)
        else:
            self._epsilon = None
            self._label_smoothing = False

    #do label_smoothing
    def _labelsmoothing(self, target):
        if target.shape[-1] != self._class_dim:
            one_hot_target = F.one_hot(target, self._class_dim)  #do ont hot(23,34,46)-> 3 * _class_dim
        else:
            one_hot_target = target

        #do label_smooth
        soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)   #(1 - epsilon) * input + eposilon / K.
        soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
        return soft_target

    def _crossentropy(self, input, target, use_pure_fp16=False):
        if self._label_smoothing:
            target = self._labelsmoothing(target)
            input = -F.log_softmax(input, axis=-1)      #softmax and do log
            cost = paddle.sum(target * input, axis=-1)  #sum  
        else:
            cost = F.cross_entropy(input=input, label=target) 

        if use_pure_fp16:
            avg_cost = paddle.sum(cost)
        else:
            avg_cost = paddle.mean(cost)
        return avg_cost

    def _kldiv(self, input, target, name=None):
        eps = 1.0e-10
        cost = target * paddle.log(
            (target + eps) / (input + eps)) * self._class_dim
        return cost

    def _jsdiv(self, input, target):  #so the input and target is the fc output; no softmax
        input = F.softmax(input)
        target = F.softmax(target) 

        #two distribution
        cost = self._kldiv(input, target) + self._kldiv(target, input)
        cost = cost / 2
        avg_cost = paddle.mean(cost)
        return avg_cost

    def __call__(self, input, target):
        pass


class CELoss(Loss):
    """
    Cross entropy loss
    """

    def __init__(self, class_dim=1000, epsilon=None):
        super(CELoss, self).__init__(class_dim, epsilon)

    def __call__(self, input, target, use_pure_fp16=False):
        logits = input["logits"]
        cost = self._crossentropy(logits, target, use_pure_fp16)
        return {"CELoss": cost}

class JSDivLoss(Loss):
    """
    JSDiv loss
    """
    def __init__(self, class_dim=1000, epsilon=None):
        super(JSDivLoss, self).__init__(class_dim, epsilon)

    def __call__(self, input, target):
        cost = self._jsdiv(input, target)
        return cost


class KLDivLoss(paddle.nn.Layer):
    def __init__(self):
        super(KLDivLoss, self).__init__()

    def __call__(self, p, q, is_logit=True):
        if is_logit:
            p = paddle.nn.functional.softmax(p)
            q = paddle.nn.functional.softmax(q)
        return -(p * paddle.log(q + 1e-8)).sum(1).mean()