diff --git a/tools/eval.py b/tools/eval.py index 0120baab0f34d5fadbbf4df20d92d6b62dd176a2..7d6fb94f387da47466200f1e819394b0ffd03dfd 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -27,7 +27,7 @@ from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model, load_pretrained_params +from ppocr.utils.save_load import init_model, load_dygraph_params from ppocr.utils.utility import print_dict import tools.program as program @@ -60,7 +60,7 @@ def main(): else: model_type = None - best_model_dict = init_model(config, model) + best_model_dict = load_dygraph_params(config, model, logger, None) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): @@ -71,7 +71,7 @@ def main(): # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, model_type, use_srn) + eval_class, model_type, use_srn) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v))