From 49958dca6176d2938d19e2a9c1196d8eb73621e5 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Mon, 9 Nov 2020 13:27:31 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=82=E9=85=8Drc=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppocr/utils/save_load.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index c6d20651..e74d8faa 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -89,7 +89,8 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): "Given dir {}.pdparams not exist.".format(checkpoints) assert os.path.exists(checkpoints + ".pdopt"), \ "Given dir {}.pdopt not exist.".format(checkpoints) - para_dict, opti_dict = paddle.load(checkpoints) + para_dict = paddle.load(checkpoints + '.pdparams') + opti_dict = paddle.load(checkpoints + '.pdopt') model.set_dict(para_dict) if optimizer is not None: optimizer.set_state_dict(opti_dict) @@ -133,8 +134,8 @@ def save_model(net, """ _mkdir_if_not_exist(model_path, logger) model_prefix = os.path.join(model_path, prefix) - paddle.save(net.state_dict(), model_prefix) - paddle.save(optimizer.state_dict(), model_prefix) + paddle.save(net.state_dict(), model_prefix + '.pdparams') + paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') # save metric and config with open(model_prefix + '.states', 'wb') as f: -- GitLab