__init__.py 1.9 KB
Newer Older
B
Bin Lu 已提交
1
import copy
D
dongshuilong 已提交
2

B
Bin Lu 已提交
3 4
import paddle
import paddle.nn as nn
D
dongshuilong 已提交
5
from ppcls.utils import logger
B
Bin Lu 已提交
6

C
cuicheng01 已提交
7
from .celoss import CELoss, MixCELoss
C
cuicheng01 已提交
8
from .googlenetloss import GoogLeNetLoss
D
dongshuilong 已提交
9
from .centerloss import CenterLoss
B
Bin Lu 已提交
10
from .emlloss import EmlLoss
D
dongshuilong 已提交
11 12
from .msmloss import MSMLoss
from .npairsloss import NpairsLoss
B
Bin Lu 已提交
13
from .trihardloss import TriHardLoss
D
dongshuilong 已提交
14
from .triplet import TripletLoss, TripletLossV2
D
dongshuilong 已提交
15
from .supconloss import SupConLoss
F
Felix 已提交
16
from .pairwisecosface import PairwiseCosface
17 18
from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss
D
dongshuilong 已提交
19

20 21 22
from .distillationloss import DistillationCELoss
from .distillationloss import DistillationGTCELoss
from .distillationloss import DistillationDMLLoss
D
dongshuilong 已提交
23

B
Bin Lu 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

class CombinedLoss(nn.Layer):
    def __init__(self, config_list):
        super().__init__()
        self.loss_func = []
        self.loss_weight = []
        assert isinstance(config_list, list), (
            'operator config should be a list')
        for config in config_list:
            assert isinstance(config,
                              dict) and len(config) == 1, "yaml format error"
            name = list(config)[0]
            param = config[name]
            assert "weight" in param, "weight must be in param, but param just contains {}".format(
                param.keys())
            self.loss_weight.append(param.pop("weight"))
            self.loss_func.append(eval(name)(**param))
B
Bin Lu 已提交
41

B
Bin Lu 已提交
42 43 44 45 46 47 48 49 50
    def __call__(self, input, batch):
        loss_dict = {}
        for idx, loss_func in enumerate(self.loss_func):
            loss = loss_func(input, batch)
            weight = self.loss_weight[idx]
            loss = {key: loss[key] * weight for key in loss}
            loss_dict.update(loss)
        loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
        return loss_dict
B
Bin Lu 已提交
51

D
dongshuilong 已提交
52

B
Bin Lu 已提交
53
def build_loss(config):
D
dongshuilong 已提交
54
    module_class = CombinedLoss(copy.deepcopy(config))
L
littletomatodonkey 已提交
55
    logger.debug("build loss {} success.".format(module_class))
B
Bin Lu 已提交
56
    return module_class