__init__.py 2.4 KB
Newer Older
B
Bin Lu 已提交
1
import copy
D
dongshuilong 已提交
2

B
Bin Lu 已提交
3 4
import paddle
import paddle.nn as nn
D
dongshuilong 已提交
5
from ppcls.utils import logger
B
Bin Lu 已提交
6

C
cuicheng01 已提交
7
from .celoss import CELoss, MixCELoss
C
cuicheng01 已提交
8
from .googlenetloss import GoogLeNetLoss
D
dongshuilong 已提交
9
from .centerloss import CenterLoss
B
Bin Lu 已提交
10
from .emlloss import EmlLoss
D
dongshuilong 已提交
11 12
from .msmloss import MSMLoss
from .npairsloss import NpairsLoss
B
Bin Lu 已提交
13
from .trihardloss import TriHardLoss
D
dongshuilong 已提交
14
from .triplet import TripletLoss, TripletLossV2
D
dongshuilong 已提交
15
from .supconloss import SupConLoss
F
Felix 已提交
16
from .pairwisecosface import PairwiseCosface
17 18
from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss
D
dongshuilong 已提交
19

20 21 22
from .distillationloss import DistillationCELoss
from .distillationloss import DistillationGTCELoss
from .distillationloss import DistillationDMLLoss
23 24
from .distillationloss import DistillationDistanceLoss
from .distillationloss import DistillationRKDLoss
C
cuicheng01 已提交
25
from .multilabelloss import MultiLabelLoss
D
dongshuilong 已提交
26

S
stephon 已提交
27 28
from .deephashloss import DSHSDLoss, LCDSHLoss

B
Bin Lu 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45

class CombinedLoss(nn.Layer):
    def __init__(self, config_list):
        super().__init__()
        self.loss_func = []
        self.loss_weight = []
        assert isinstance(config_list, list), (
            'operator config should be a list')
        for config in config_list:
            assert isinstance(config,
                              dict) and len(config) == 1, "yaml format error"
            name = list(config)[0]
            param = config[name]
            assert "weight" in param, "weight must be in param, but param just contains {}".format(
                param.keys())
            self.loss_weight.append(param.pop("weight"))
            self.loss_func.append(eval(name)(**param))
B
Bin Lu 已提交
46

B
Bin Lu 已提交
47 48
    def __call__(self, input, batch):
        loss_dict = {}
49 50 51
        # just for accelerate classification traing speed
        if len(self.loss_func) == 1:
            loss = self.loss_func[0](input, batch)
B
Bin Lu 已提交
52
            loss_dict.update(loss)
53 54 55 56 57 58 59 60
            loss_dict["loss"] = list(loss.values())[0]
        else:
            for idx, loss_func in enumerate(self.loss_func):
                loss = loss_func(input, batch)
                weight = self.loss_weight[idx]
                loss = {key: loss[key] * weight for key in loss}
                loss_dict.update(loss)
            loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
B
Bin Lu 已提交
61
        return loss_dict
B
Bin Lu 已提交
62

D
dongshuilong 已提交
63

B
Bin Lu 已提交
64
def build_loss(config):
D
dongshuilong 已提交
65
    module_class = CombinedLoss(copy.deepcopy(config))
L
littletomatodonkey 已提交
66
    logger.debug("build loss {} success.".format(module_class))
B
Bin Lu 已提交
67
    return module_class