diff --git a/ppocr/metrics/det_metric.py b/ppocr/metrics/det_metric.py index 811ee2fad5fd2af417541e36914bbe7446429cb3..e68cb3905b17ee7f40c199fdc527fe159c587fc2 100644 --- a/ppocr/metrics/det_metric.py +++ b/ppocr/metrics/det_metric.py @@ -55,9 +55,9 @@ class DetMetric(object): result = self.evaluator.evaluate_image(gt_info_list, det_info_list) self.results.append(result) - metircs = self.evaluator.combine_results(self.results) - self.reset() - return metircs + # metircs = self.evaluator.combine_results(self.results) + # self.reset() + # return metircs def get_metric(self): """ diff --git a/ppocr/metrics/distillation_metric.py b/ppocr/metrics/distillation_metric.py index a7d3d095a7d384bf8cdc69b97f8109c359ac2b5b..c440cebdd0f96493fc33000a0d304cbe5e3f0624 100644 --- a/ppocr/metrics/distillation_metric.py +++ b/ppocr/metrics/distillation_metric.py @@ -24,8 +24,8 @@ from .cls_metric import ClsMetric class DistillationMetric(object): def __init__(self, key=None, - base_metric_name="RecMetric", - main_indicator='acc', + base_metric_name=None, + main_indicator=None, **kwargs): self.main_indicator = main_indicator self.key = key @@ -42,16 +42,13 @@ class DistillationMetric(object): main_indicator=self.main_indicator, **self.kwargs) self.metrics[key].reset() - def __call__(self, preds, *args, **kwargs): + def __call__(self, preds, batch, **kwargs): assert isinstance(preds, dict) if self.metrics is None: self._init_metrcis(preds) output = dict() for key in preds: - metric = self.metrics[key].__call__(preds[key], *args, **kwargs) - for sub_key in metric: - output["{}_{}".format(key, sub_key)] = metric[sub_key] - return output + self.metrics[key].__call__(preds[key], batch, **kwargs) def get_metric(self): """ diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py index 1e95fe574433eaca6f322ff47c8547cc1a29a248..2b1d3aae3b7303a61b20db15df5ce4bd9bb7b235 100644 --- a/ppocr/modeling/architectures/distillation_model.py +++ b/ppocr/modeling/architectures/distillation_model.py @@ -46,7 +46,7 @@ class DistillationModel(nn.Layer): pretrained = model_config.pop("pretrained") model = BaseModel(model_config) if pretrained is not None: - load_pretrained_params(model, pretrained) + model = load_pretrained_params(model, pretrained) if freeze_params: for param in model.parameters(): param.trainable = False diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index e318c5254b92e09a086309761529eb787dbc8d96..d9c9869dfcd35cb9b491db826f3bff5f766723f4 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -189,29 +189,27 @@ class DBPostProcess(object): return boxes_batch -class DistillationDBPostProcess(DBPostProcess): - def __init__(self, - model_name=["student"], +class DistillationDBPostProcess(object): + def __init__(self, model_name=["student"], key=None, thresh=0.3, - box_thresh=0.7, + box_thresh=0.6, max_candidates=1000, - unclip_ratio=2.0, + unclip_ratio=1.5, use_dilation=False, score_mode="fast", **kwargs): - super().__init__() - if not isinstance(model_name, list): - model_name = [model_name] self.model_name = model_name self.key = key + self.post_process = DBPostProcess(thresh=thresh, + box_thresh=box_thresh, + max_candidates=max_candidates, + unclip_ratio=unclip_ratio, + use_dilation=use_dilation, + score_mode=score_mode) def __call__(self, predicts, shape_list): results = {} - for name in self.model_name: - pred = predicts[name] - if self.key is not None: - pred = pred[self.key] - results[name] = super().__call__(pred, shape_list=shape_list) - + for k in self.model_name: + results[k] = self.post_process(predicts[k], shape_list=shape_list) return results diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 4ee4b29f435246eb53b0e1864b75fc35f197af16..b3724c2dddf2f12fd6f4f4a5c46aa34595104582 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -136,7 +136,7 @@ def load_pretrained_params(model, path): ) model.set_state_dict(new_state_dict) print(f"load pretrain successful from {path}") - return True + return model def save_model(model, optimizer, diff --git a/tools/eval.py b/tools/eval.py index 022498bbef28bd1b00a33739bf935d10ae2f5bf2..c99c7d474d44905ed3af582a384d82986feb8d65 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -27,7 +27,7 @@ from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import init_model, load_pretrained_params from ppocr.utils.utility import print_dict import tools.program as program @@ -59,7 +59,8 @@ def main(): model_type = config['Architecture']['model_type'] else: model_type = None - best_model_dict = init_model(config, model) + + best_model_dict = init_model(config, model, model_type) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): @@ -70,7 +71,7 @@ def main(): # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, model_type, use_srn) + eval_class, model_type, use_srn) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/program.py b/tools/program.py index 595fe4cb96c0379b1a33504e0ebdd85e70086340..4c12bc09cc7ca480663208506a47269784c6f6d3 100755 --- a/tools/program.py +++ b/tools/program.py @@ -374,6 +374,7 @@ def eval(model, eval_class(preds, batch) else: post_result = post_process_class(preds, batch[1]) + # post_result = post_result_["Student"] eval_class(post_result, batch) pbar.update(1) total_frame += len(images) diff --git a/tools/train.py b/tools/train.py index 20f5a670d5c8e666678259e0042b3b790e528590..2091ff48b4b83c1e3955d0b9600c60815d4d99ec 100755 --- a/tools/train.py +++ b/tools/train.py @@ -97,8 +97,8 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) - + #pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) + pre_best_model_dict = {} logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format(