From b8a65d4333cead891b6320b6f627f7aeb4fb155c Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Thu, 8 Jul 2021 14:32:44 +0000 Subject: [PATCH] fix eval bug --- ppocr/metrics/det_metric.py | 6 ++--- ppocr/metrics/distillation_metric.py | 11 +++----- .../architectures/distillation_model.py | 2 +- ppocr/postprocess/db_postprocess.py | 26 +++++++++---------- ppocr/utils/save_load.py | 2 +- tools/eval.py | 7 ++--- tools/program.py | 1 + tools/train.py | 4 +-- 8 files changed, 28 insertions(+), 31 deletions(-) diff --git a/ppocr/metrics/det_metric.py b/ppocr/metrics/det_metric.py index 811ee2fa..e68cb390 100644 --- a/ppocr/metrics/det_metric.py +++ b/ppocr/metrics/det_metric.py @@ -55,9 +55,9 @@ class DetMetric(object): result = self.evaluator.evaluate_image(gt_info_list, det_info_list) self.results.append(result) - metircs = self.evaluator.combine_results(self.results) - self.reset() - return metircs + # metircs = self.evaluator.combine_results(self.results) + # self.reset() + # return metircs def get_metric(self): """ diff --git a/ppocr/metrics/distillation_metric.py b/ppocr/metrics/distillation_metric.py index a7d3d095..c440cebd 100644 --- a/ppocr/metrics/distillation_metric.py +++ b/ppocr/metrics/distillation_metric.py @@ -24,8 +24,8 @@ from .cls_metric import ClsMetric class DistillationMetric(object): def __init__(self, key=None, - base_metric_name="RecMetric", - main_indicator='acc', + base_metric_name=None, + main_indicator=None, **kwargs): self.main_indicator = main_indicator self.key = key @@ -42,16 +42,13 @@ class DistillationMetric(object): main_indicator=self.main_indicator, **self.kwargs) self.metrics[key].reset() - def __call__(self, preds, *args, **kwargs): + def __call__(self, preds, batch, **kwargs): assert isinstance(preds, dict) if self.metrics is None: self._init_metrcis(preds) output = dict() for key in preds: - metric = self.metrics[key].__call__(preds[key], *args, **kwargs) - for sub_key in metric: - output["{}_{}".format(key, sub_key)] = metric[sub_key] - return output + self.metrics[key].__call__(preds[key], batch, **kwargs) def get_metric(self): """ diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py index 1e95fe57..2b1d3aae 100644 --- a/ppocr/modeling/architectures/distillation_model.py +++ b/ppocr/modeling/architectures/distillation_model.py @@ -46,7 +46,7 @@ class DistillationModel(nn.Layer): pretrained = model_config.pop("pretrained") model = BaseModel(model_config) if pretrained is not None: - load_pretrained_params(model, pretrained) + model = load_pretrained_params(model, pretrained) if freeze_params: for param in model.parameters(): param.trainable = False diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index e318c525..d9c9869d 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -189,29 +189,27 @@ class DBPostProcess(object): return boxes_batch -class DistillationDBPostProcess(DBPostProcess): - def __init__(self, - model_name=["student"], +class DistillationDBPostProcess(object): + def __init__(self, model_name=["student"], key=None, thresh=0.3, - box_thresh=0.7, + box_thresh=0.6, max_candidates=1000, - unclip_ratio=2.0, + unclip_ratio=1.5, use_dilation=False, score_mode="fast", **kwargs): - super().__init__() - if not isinstance(model_name, list): - model_name = [model_name] self.model_name = model_name self.key = key + self.post_process = DBPostProcess(thresh=thresh, + box_thresh=box_thresh, + max_candidates=max_candidates, + unclip_ratio=unclip_ratio, + use_dilation=use_dilation, + score_mode=score_mode) def __call__(self, predicts, shape_list): results = {} - for name in self.model_name: - pred = predicts[name] - if self.key is not None: - pred = pred[self.key] - results[name] = super().__call__(pred, shape_list=shape_list) - + for k in self.model_name: + results[k] = self.post_process(predicts[k], shape_list=shape_list) return results diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 4ee4b29f..b3724c2d 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -136,7 +136,7 @@ def load_pretrained_params(model, path): ) model.set_state_dict(new_state_dict) print(f"load pretrain successful from {path}") - return True + return model def save_model(model, optimizer, diff --git a/tools/eval.py b/tools/eval.py index 022498bb..c99c7d47 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -27,7 +27,7 @@ from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import init_model, load_pretrained_params from ppocr.utils.utility import print_dict import tools.program as program @@ -59,7 +59,8 @@ def main(): model_type = config['Architecture']['model_type'] else: model_type = None - best_model_dict = init_model(config, model) + + best_model_dict = init_model(config, model, model_type) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): @@ -70,7 +71,7 @@ def main(): # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, model_type, use_srn) + eval_class, model_type, use_srn) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/program.py b/tools/program.py index 595fe4cb..4c12bc09 100755 --- a/tools/program.py +++ b/tools/program.py @@ -374,6 +374,7 @@ def eval(model, eval_class(preds, batch) else: post_result = post_process_class(preds, batch[1]) + # post_result = post_result_["Student"] eval_class(post_result, batch) pbar.update(1) total_frame += len(images) diff --git a/tools/train.py b/tools/train.py index 20f5a670..2091ff48 100755 --- a/tools/train.py +++ b/tools/train.py @@ -97,8 +97,8 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) - + #pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) + pre_best_model_dict = {} logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format( -- GitLab