From 21fedfe5cddad486ece59fc5c7cb2665e4bc11ed Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Thu, 10 Jun 2021 09:02:50 +0000 Subject: [PATCH] fix loss --- ppcls/loss/distanceloss.py | 43 ++++++++++ ppcls/loss/distillationloss.py | 141 +++++++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 ppcls/loss/distanceloss.py create mode 100644 ppcls/loss/distillationloss.py diff --git a/ppcls/loss/distanceloss.py b/ppcls/loss/distanceloss.py new file mode 100644 index 00000000..aad4cc5c --- /dev/null +++ b/ppcls/loss/distanceloss.py @@ -0,0 +1,43 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.nn import L1Loss +from paddle.nn import MSELoss as L2Loss +from paddle.nn import SmoothL1Loss + + +class DistanceLoss(nn.Layer): + """ + DistanceLoss: + mode: loss mode + """ + + def __init__(self, mode="l2", **kargs): + super().__init__() + assert mode in ["l1", "l2", "smooth_l1"] + if mode == "l1": + self.loss_func = nn.L1Loss(**kargs) + elif mode == "l2": + self.loss_func = nn.MSELoss(**kargs) + elif mode == "smooth_l1": + self.loss_func = nn.SmoothL1Loss(**kargs) + self.mode = mode + + def forward(self, x, y): + loss = self.loss_func(x, y) + return {"loss_{}".format(self.mode): loss} \ No newline at end of file diff --git a/ppcls/loss/distillationloss.py b/ppcls/loss/distillationloss.py new file mode 100644 index 00000000..54dc601b --- /dev/null +++ b/ppcls/loss/distillationloss.py @@ -0,0 +1,141 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn + +from .celoss import CELoss +from .dmlloss import DMLLoss +from .distanceloss import DistanceLoss + + +class DistillationCELoss(CELoss): + """ + DistillationCELoss + """ + + def __init__(self, + model_name_pairs=[], + epsilon=None, + key=None, + name="loss_ce"): + super().__init__(epsilon=epsilon) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + for key in loss: + loss_dict["{}_{}_{}".format(key, pair[0], pair[1])] = loss[key] + return loss_dict + + +class DistillationGTCELoss(CELoss): + """ + DistillationGTCELoss + """ + + def __init__(self, + model_names=[], + epsilon=None, + key=None, + name="loss_gt_ce"): + super().__init__(epsilon=epsilon) + assert isinstance(model_names, list) + self.key = key + self.model_names = model_names + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, name in enumerate(self.model_names): + out = predicts[name] + if self.key is not None: + out = out[self.key] + loss = super().forward(out, batch) + for key in loss: + loss_dict["{}_{}".format(key, name)] = loss[key] + return loss_dict + + +class DistillationDMLLoss(DMLLoss): + """ + """ + + def __init__(self, + model_name_pairs=[], + act=None, + key=None, + name="loss_dml"): + super().__init__(act=act) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss + return loss_dict + + +class DistillationDistanceLoss(DistanceLoss): + """ + """ + + def __init__(self, + mode="l2", + model_name_pairs=[], + key=None, + name="loss_", + **kargs): + super().__init__(mode=mode, **kargs) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + "_l2" + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + for key in loss: + loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[key] + return loss_dict -- GitLab