__init__.py 2.5 KB
Newer Older
B
Bin Lu 已提交
1
import copy
D
dongshuilong 已提交
2

B
Bin Lu 已提交
3 4
import paddle
import paddle.nn as nn
D
dongshuilong 已提交
5
from ppcls.utils import logger
B
Bin Lu 已提交
6

C
cuicheng01 已提交
7
from .celoss import CELoss, MixCELoss
C
cuicheng01 已提交
8
from .googlenetloss import GoogLeNetLoss
D
dongshuilong 已提交
9
from .centerloss import CenterLoss
B
Bin Lu 已提交
10
from .emlloss import EmlLoss
D
dongshuilong 已提交
11 12
from .msmloss import MSMLoss
from .npairsloss import NpairsLoss
B
Bin Lu 已提交
13
from .trihardloss import TriHardLoss
D
dongshuilong 已提交
14
from .triplet import TripletLoss, TripletLossV2
D
dongshuilong 已提交
15
from .supconloss import SupConLoss
F
Felix 已提交
16
from .pairwisecosface import PairwiseCosface
17 18
from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss
D
dongshuilong 已提交
19

20 21 22
from .distillationloss import DistillationCELoss
from .distillationloss import DistillationGTCELoss
from .distillationloss import DistillationDMLLoss
23 24
from .distillationloss import DistillationDistanceLoss
from .distillationloss import DistillationRKDLoss
wc晨曦's avatar
wc晨曦 已提交
25
from .distillationloss import DistillationKLDivLoss
C
cuicheng01 已提交
26
from .multilabelloss import MultiLabelLoss
wc晨曦's avatar
wc晨曦 已提交
27
from .afdloss import AFDLoss
D
dongshuilong 已提交
28

L
lubin 已提交
29 30 31
from .deephashloss import DSHSDLoss
from .deephashloss import LCDSHLoss
from .deephashloss import DCHLoss
S
stephon 已提交
32

B
Bin Lu 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

class CombinedLoss(nn.Layer):
    def __init__(self, config_list):
        super().__init__()
        self.loss_func = []
        self.loss_weight = []
        assert isinstance(config_list, list), (
            'operator config should be a list')
        for config in config_list:
            assert isinstance(config,
                              dict) and len(config) == 1, "yaml format error"
            name = list(config)[0]
            param = config[name]
            assert "weight" in param, "weight must be in param, but param just contains {}".format(
                param.keys())
            self.loss_weight.append(param.pop("weight"))
            self.loss_func.append(eval(name)(**param))
B
Bin Lu 已提交
50

B
Bin Lu 已提交
51 52
    def __call__(self, input, batch):
        loss_dict = {}
53 54 55
        # just for accelerate classification traing speed
        if len(self.loss_func) == 1:
            loss = self.loss_func[0](input, batch)
B
Bin Lu 已提交
56
            loss_dict.update(loss)
57 58 59 60 61 62 63 64
            loss_dict["loss"] = list(loss.values())[0]
        else:
            for idx, loss_func in enumerate(self.loss_func):
                loss = loss_func(input, batch)
                weight = self.loss_weight[idx]
                loss = {key: loss[key] * weight for key in loss}
                loss_dict.update(loss)
            loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
B
Bin Lu 已提交
65
        return loss_dict
B
Bin Lu 已提交
66

D
dongshuilong 已提交
67

B
Bin Lu 已提交
68
def build_loss(config):
D
dongshuilong 已提交
69
    module_class = CombinedLoss(copy.deepcopy(config))
L
littletomatodonkey 已提交
70
    logger.debug("build loss {} success.".format(module_class))
B
Bin Lu 已提交
71
    return module_class