diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index c6d20651283954000241f80a28d22c7821af2ff0..e74d8faa6f7c48d1ea91f47cbe47d2d3cf5bf704 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -89,7 +89,8 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): "Given dir {}.pdparams not exist.".format(checkpoints) assert os.path.exists(checkpoints + ".pdopt"), \ "Given dir {}.pdopt not exist.".format(checkpoints) - para_dict, opti_dict = paddle.load(checkpoints) + para_dict = paddle.load(checkpoints + '.pdparams') + opti_dict = paddle.load(checkpoints + '.pdopt') model.set_dict(para_dict) if optimizer is not None: optimizer.set_state_dict(opti_dict) @@ -133,8 +134,8 @@ def save_model(net, """ _mkdir_if_not_exist(model_path, logger) model_prefix = os.path.join(model_path, prefix) - paddle.save(net.state_dict(), model_prefix) - paddle.save(optimizer.state_dict(), model_prefix) + paddle.save(net.state_dict(), model_prefix + '.pdparams') + paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') # save metric and config with open(model_prefix + '.states', 'wb') as f: diff --git a/requirements.txt b/requirements.txt index 76305d0dc69968e771619b181d040e34e502d30e..132189634919760156b421570284faa2c6fc957f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ shapely imgaug pyclipper lmdb +opencv-python==4.2.0.32 tqdm numpy visualdl diff --git a/tools/eval.py b/tools/eval.py index 07181ee75d44e3f5f0676f55f89330b215ef5220..16cfe532aae49ce98bc9503ca73e009bf206caa7 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -23,12 +23,8 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) -import paddle -# paddle.manual_seed(2) - -from ppocr.utils.logging import get_logger from ppocr.data import build_dataloader -from ppocr.modeling import build_model +from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import init_model @@ -39,8 +35,7 @@ import tools.program as program def main(): global_config = config['Global'] # build dataloader - eval_loader, _ = build_dataloader(config['EVAL'], device, False, - global_config) + valid_dataloader = build_dataloader(config, 'Eval', device, logger) # build post process post_process_class = build_post_process(config['PostProcess'], @@ -63,16 +58,13 @@ def main(): eval_class = build_metric(config['Metric']) # start eval - metirc = program.eval(model, eval_loader, post_process_class, eval_class) + metirc = program.eval(model, valid_dataloader, post_process_class, + eval_class) logger.info('metric eval ***************') for k, v in metirc.items(): logger.info('{}:{}'.format(k, v)) if __name__ == '__main__': - device, config = program.preprocess() - paddle.disable_static(device) - - logger = get_logger() - print_dict(config, logger) + config, device, logger, vdl_writer = program.preprocess() main() diff --git a/tools/program.py b/tools/program.py index 41acb8665a345d1251fbf721e557212cd2771e04..8bae0fd5d16f4b17520c1162f6cd9bd54f032a73 100755 --- a/tools/program.py +++ b/tools/program.py @@ -231,7 +231,7 @@ def train(config, if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: cur_metirc = eval(model, valid_dataloader, post_process_class, - eval_class, logger, print_batch_step) + eval_class) cur_metirc_str = 'cur metirc, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metirc.items()])) logger.info(cur_metirc_str) @@ -293,8 +293,7 @@ def train(config, return -def eval(model, valid_dataloader, post_process_class, eval_class, logger, - print_batch_step): +def eval(model, valid_dataloader, post_process_class, eval_class): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -315,9 +314,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger, eval_class(post_result, batch) pbar.update(1) total_frame += len(images) - # if idx % print_batch_step == 0 and dist.get_rank() == 0: - # logger.info('tackling images for eval: {}/{}'.format( - # idx, len(valid_dataloader))) # Get final metirc,eg. acc or hmean metirc = eval_class.get_metric()