diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml index a1ff0d67d09d355c9a98bea2391c13e757c47180..8e8acd8b107384dd79ed28b8d3588ff0e76b3679 100644 --- a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml +++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml @@ -98,7 +98,7 @@ Loss: PostProcess: name: DistillationCTCLabelDecode - model_name: ["Student", "Teacher"] + model_name: ["Student"] key: head_out Metric: diff --git a/ppocr/metrics/distillation_metric.py b/ppocr/metrics/distillation_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..a7d3d095a7d384bf8cdc69b97f8109c359ac2b5b --- /dev/null +++ b/ppocr/metrics/distillation_metric.py @@ -0,0 +1,76 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import copy + +from .rec_metric import RecMetric +from .det_metric import DetMetric +from .e2e_metric import E2EMetric +from .cls_metric import ClsMetric + + +class DistillationMetric(object): + def __init__(self, + key=None, + base_metric_name="RecMetric", + main_indicator='acc', + **kwargs): + self.main_indicator = main_indicator + self.key = key + self.main_indicator = main_indicator + self.base_metric_name = base_metric_name + self.kwargs = kwargs + self.metrics = None + + def _init_metrcis(self, preds): + self.metrics = dict() + mod = importlib.import_module(__name__) + for key in preds: + self.metrics[key] = getattr(mod, self.base_metric_name)( + main_indicator=self.main_indicator, **self.kwargs) + self.metrics[key].reset() + + def __call__(self, preds, *args, **kwargs): + assert isinstance(preds, dict) + if self.metrics is None: + self._init_metrcis(preds) + output = dict() + for key in preds: + metric = self.metrics[key].__call__(preds[key], *args, **kwargs) + for sub_key in metric: + output["{}_{}".format(key, sub_key)] = metric[sub_key] + return output + + def get_metric(self): + """ + return metrics { + 'acc': 0, + 'norm_edit_dis': 0, + } + """ + output = dict() + for key in self.metrics: + metric = self.metrics[key].get_metric() + # main indicator + if key == self.key: + output.update(metric) + else: + for sub_key in metric: + output["{}_{}".format(key, sub_key)] = metric[sub_key] + return output + + def reset(self): + for key in self.metrics: + self.metrics[key].reset() diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 2563f5a8197ed39b1b5d44c7cfee32797e760758..6bd8f1429072e533e8c449c8c8a439ed51f521a3 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -46,8 +46,14 @@ def main(): # build model if hasattr(post_process_class, 'character'): - config['Architecture']["Head"]['out_channels'] = len( - getattr(post_process_class, 'character')) + char_num = len(getattr(post_process_class, 'character')) + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture'])