__init__.py 3.1 KB
Newer Older
B
Bin Lu 已提交
1
import copy
D
dongshuilong 已提交
2

B
Bin Lu 已提交
3 4
import paddle
import paddle.nn as nn
D
dongshuilong 已提交
5
from ppcls.utils import logger
B
Bin Lu 已提交
6

C
cuicheng01 已提交
7
from .celoss import CELoss, MixCELoss
C
cuicheng01 已提交
8
from .googlenetloss import GoogLeNetLoss
D
dongshuilong 已提交
9
from .centerloss import CenterLoss
H
add xbm  
HydrogenSulfate 已提交
10 11
from .contrasiveloss import ContrastiveLoss
from .contrasiveloss import ContrastiveLoss_XBM
B
Bin Lu 已提交
12
from .emlloss import EmlLoss
D
dongshuilong 已提交
13 14
from .msmloss import MSMLoss
from .npairsloss import NpairsLoss
B
Bin Lu 已提交
15
from .trihardloss import TriHardLoss
D
dongshuilong 已提交
16
from .triplet import TripletLoss, TripletLossV2
H
HydrogenSulfate 已提交
17
from .tripletangularmarginloss import TripletAngularMarginLoss
D
dongshuilong 已提交
18
from .supconloss import SupConLoss
F
Felix 已提交
19
from .pairwisecosface import PairwiseCosface
20 21
from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss
悟、's avatar
悟、 已提交
22
from .softtargetceloss import SoftTargetCrossEntropy
D
dongshuilong 已提交
23

24 25 26
from .distillationloss import DistillationCELoss
from .distillationloss import DistillationGTCELoss
from .distillationloss import DistillationDMLLoss
27 28
from .distillationloss import DistillationDistanceLoss
from .distillationloss import DistillationRKDLoss
wc晨曦's avatar
wc晨曦 已提交
29
from .distillationloss import DistillationKLDivLoss
wc晨曦's avatar
wc晨曦 已提交
30
from .distillationloss import DistillationDKDLoss
31
from .distillationloss import DistillationWSLLoss
U
add skd  
user3984 已提交
32
from .distillationloss import DistillationSKDLoss
33
from .distillationloss import DistillationMultiLabelLoss
littletomatodonkey's avatar
littletomatodonkey 已提交
34
from .distillationloss import DistillationDISTLoss
littletomatodonkey's avatar
littletomatodonkey 已提交
35
from .distillationloss import DistillationPairLoss
littletomatodonkey's avatar
littletomatodonkey 已提交
36

C
cuicheng01 已提交
37
from .multilabelloss import MultiLabelLoss
wc晨曦's avatar
wc晨曦 已提交
38
from .afdloss import AFDLoss
D
dongshuilong 已提交
39

L
lubin 已提交
40 41 42
from .deephashloss import DSHSDLoss
from .deephashloss import LCDSHLoss
from .deephashloss import DCHLoss
S
stephon 已提交
43

B
Bin Lu 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60

class CombinedLoss(nn.Layer):
    def __init__(self, config_list):
        super().__init__()
        self.loss_func = []
        self.loss_weight = []
        assert isinstance(config_list, list), (
            'operator config should be a list')
        for config in config_list:
            assert isinstance(config,
                              dict) and len(config) == 1, "yaml format error"
            name = list(config)[0]
            param = config[name]
            assert "weight" in param, "weight must be in param, but param just contains {}".format(
                param.keys())
            self.loss_weight.append(param.pop("weight"))
            self.loss_func.append(eval(name)(**param))
61
            self.loss_func = nn.LayerList(self.loss_func)
B
Bin Lu 已提交
62

B
Bin Lu 已提交
63 64
    def __call__(self, input, batch):
        loss_dict = {}
65 66 67
        # just for accelerate classification traing speed
        if len(self.loss_func) == 1:
            loss = self.loss_func[0](input, batch)
B
Bin Lu 已提交
68
            loss_dict.update(loss)
69 70 71 72 73 74 75 76
            loss_dict["loss"] = list(loss.values())[0]
        else:
            for idx, loss_func in enumerate(self.loss_func):
                loss = loss_func(input, batch)
                weight = self.loss_weight[idx]
                loss = {key: loss[key] * weight for key in loss}
                loss_dict.update(loss)
            loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
B
Bin Lu 已提交
77
        return loss_dict
B
Bin Lu 已提交
78

D
dongshuilong 已提交
79

B
Bin Lu 已提交
80
def build_loss(config):
D
dongshuilong 已提交
81 82
    if config is None:
        return None
D
dongshuilong 已提交
83
    module_class = CombinedLoss(copy.deepcopy(config))
L
littletomatodonkey 已提交
84
    logger.debug("build loss {} success.".format(module_class))
B
Bin Lu 已提交
85
    return module_class