未验证 提交 b863528d 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #1133 from WenmuZhou/dygraph_rc

fix bug and update save_load to rc version
......@@ -89,7 +89,8 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
"Given dir {}.pdparams not exist.".format(checkpoints)
assert os.path.exists(checkpoints + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(checkpoints)
para_dict, opti_dict = paddle.load(checkpoints)
para_dict = paddle.load(checkpoints + '.pdparams')
opti_dict = paddle.load(checkpoints + '.pdopt')
model.set_dict(para_dict)
if optimizer is not None:
optimizer.set_state_dict(opti_dict)
......@@ -133,8 +134,8 @@ def save_model(net,
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
paddle.save(net.state_dict(), model_prefix)
paddle.save(optimizer.state_dict(), model_prefix)
paddle.save(net.state_dict(), model_prefix + '.pdparams')
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
# save metric and config
with open(model_prefix + '.states', 'wb') as f:
......
......@@ -23,12 +23,8 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
import paddle
# paddle.manual_seed(2)
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
from ppocr.modeling import build_model
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
......@@ -39,8 +35,7 @@ import tools.program as program
def main():
global_config = config['Global']
# build dataloader
eval_loader, _ = build_dataloader(config['EVAL'], device, False,
global_config)
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
# build post process
post_process_class = build_post_process(config['PostProcess'],
......@@ -63,16 +58,13 @@ def main():
eval_class = build_metric(config['Metric'])
# start eval
metirc = program.eval(model, eval_loader, post_process_class, eval_class)
metirc = program.eval(model, valid_dataloader, post_process_class,
eval_class)
logger.info('metric eval ***************')
for k, v in metirc.items():
logger.info('{}:{}'.format(k, v))
if __name__ == '__main__':
device, config = program.preprocess()
paddle.disable_static(device)
logger = get_logger()
print_dict(config, logger)
config, device, logger, vdl_writer = program.preprocess()
main()
......@@ -231,7 +231,7 @@ def train(config,
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class, logger, print_batch_step)
eval_class)
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
logger.info(cur_metirc_str)
......@@ -293,8 +293,7 @@ def train(config,
return
def eval(model, valid_dataloader, post_process_class, eval_class, logger,
print_batch_step):
def eval(model, valid_dataloader, post_process_class, eval_class):
model.eval()
with paddle.no_grad():
total_frame = 0.0
......@@ -315,9 +314,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
eval_class(post_result, batch)
pbar.update(1)
total_frame += len(images)
# if idx % print_batch_step == 0 and dist.get_rank() == 0:
# logger.info('tackling images for eval: {}/{}'.format(
# idx, len(valid_dataloader)))
# Get final metirc,eg. acc or hmean
metirc = eval_class.get_metric()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册