From 0343756e468521443d0ac2560e78a76742740784 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Thu, 3 Jun 2021 13:31:25 +0000 Subject: [PATCH] fix metric --- ...c_chinese_lite_train_distillation_v2.1.yml | 8 +++---- ppocr/losses/basic_loss.py | 21 ++++++------------- ppocr/losses/distillation_loss.py | 13 +++++++----- ppocr/metrics/__init__.py | 21 +++++++++++-------- ppocr/postprocess/rec_postprocess.py | 16 +++++++++----- 5 files changed, 41 insertions(+), 38 deletions(-) diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml index 38aeffcb..a1ff0d67 100644 --- a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml +++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml @@ -95,17 +95,17 @@ Loss: model_name_pairs: - ["Student", "Teacher"] key: backbone_out - - PostProcess: name: DistillationCTCLabelDecode - model_name: "Student" + model_name: ["Student", "Teacher"] key: head_out Metric: - name: RecMetric + name: DistillationMetric + base_metric_name: RecMetric main_indicator: acc + key: "Student" Train: dataset: diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index 022ae5c6..4f9a9133 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -22,9 +22,8 @@ from paddle.nn import SmoothL1Loss class CELoss(nn.Layer): - def __init__(self, name="loss_ce", epsilon=None): + def __init__(self, epsilon=None): super().__init__() - self.name = name if epsilon is not None and (epsilon <= 0 or epsilon >= 1): epsilon = None self.epsilon = epsilon @@ -52,9 +51,7 @@ class CELoss(nn.Layer): else: soft_label = False loss = F.cross_entropy(x, label=label, soft_label=soft_label) - - loss_dict[self.name] = paddle.mean(loss) - return loss_dict + return loss class DMLLoss(nn.Layer): @@ -62,11 +59,10 @@ class DMLLoss(nn.Layer): DMLLoss """ - def __init__(self, act=None, name="loss_dml"): + def __init__(self, act=None): super().__init__() if act is not None: assert act in ["softmax", "sigmoid"] - self.name = name if act == "softmax": self.act = nn.Softmax(axis=-1) elif act == "sigmoid": @@ -75,7 +71,6 @@ class DMLLoss(nn.Layer): self.act = None def forward(self, out1, out2): - loss_dict = {} if self.act is not None: out1 = self.act(out1) out2 = self.act(out2) @@ -85,18 +80,16 @@ class DMLLoss(nn.Layer): loss = (F.kl_div( log_out1, out2, reduction='batchmean') + F.kl_div( log_out2, log_out1, reduction='batchmean')) / 2.0 - loss_dict[self.name] = loss - return loss_dict + return loss class DistanceLoss(nn.Layer): """ DistanceLoss: mode: loss mode - name: loss key in the output dict """ - def __init__(self, mode="l2", name="loss_dist", **kargs): + def __init__(self, mode="l2", **kargs): super().__init__() assert mode in ["l1", "l2", "smooth_l1"] if mode == "l1": @@ -106,7 +99,5 @@ class DistanceLoss(nn.Layer): elif mode == "smooth_l1": self.loss_func = nn.SmoothL1Loss(**kargs) - self.name = "{}_{}".format(name, mode) - def forward(self, x, y): - return {self.name: self.loss_func(x, y)} + return self.loss_func(x, y) diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 539680d9..1e8aa0d8 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -26,10 +26,11 @@ class DistillationDMLLoss(DMLLoss): def __init__(self, model_name_pairs=[], act=None, key=None, name="loss_dml"): - super().__init__(act=act, name=name) + super().__init__(act=act) assert isinstance(model_name_pairs, list) self.key = key self.model_name_pairs = model_name_pairs + self.name = name def forward(self, predicts, batch): loss_dict = dict() @@ -42,8 +43,8 @@ class DistillationDMLLoss(DMLLoss): loss = super().forward(out1, out2) if isinstance(loss, dict): for key in loss: - loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[ - key] + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] else: loss_dict["{}_{}".format(self.name, idx)] = loss return loss_dict @@ -82,10 +83,11 @@ class DistillationDistanceLoss(DistanceLoss): key=None, name="loss_distance", **kargs): - super().__init__(mode=mode, name=name, **kargs) + super().__init__(mode=mode, **kargs) assert isinstance(model_name_pairs, list) self.key = key self.model_name_pairs = model_name_pairs + self.name = name + "_l2" def forward(self, predicts, batch): loss_dict = dict() @@ -101,5 +103,6 @@ class DistillationDistanceLoss(DistanceLoss): loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[ key] else: - loss_dict["{}_{}".format(self.name, idx)] = loss + loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], + idx)] = loss return loss_dict diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index f913010d..9e9060fa 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -19,20 +19,23 @@ from __future__ import unicode_literals import copy -__all__ = ['build_metric'] +__all__ = ["build_metric"] +from .det_metric import DetMetric +from .rec_metric import RecMetric +from .cls_metric import ClsMetric +from .e2e_metric import E2EMetric +from .distillation_metric import DistillationMetric -def build_metric(config): - from .det_metric import DetMetric - from .rec_metric import RecMetric - from .cls_metric import ClsMetric - from .e2e_metric import E2EMetric - support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric'] +def build_metric(config): + support_dict = [ + "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric" + ] config = copy.deepcopy(config) - module_name = config.pop('name') + module_name = config.pop("name") assert module_name in support_dict, Exception( - 'metric only support {}'.format(support_dict)) + "metric only support {}".format(support_dict)) module_class = eval(module_name)(**config) return module_class diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index e5729ea5..ae5470a5 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -135,19 +135,25 @@ class DistillationCTCLabelDecode(CTCLabelDecode): character_dict_path=None, character_type='ch', use_space_char=False, - model_name="student", + model_name=["student"], key=None, **kwargs): super(DistillationCTCLabelDecode, self).__init__( character_dict_path, character_type, use_space_char) + if not isinstance(model_name, list): + model_name = [model_name] self.model_name = model_name + self.key = key def __call__(self, preds, label=None, *args, **kwargs): - pred = preds[self.model_name] - if self.key is not None: - pred = pred[self.key] - return super().__call__(pred, label=label, *args, **kwargs) + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + output[name] = super().__call__(pred, label=label, *args, **kwargs) + return output class AttnLabelDecode(BaseRecLabelDecode): -- GitLab