未验证 提交 b863528d 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #1133 from WenmuZhou/dygraph_rc

fix bug and update save_load to rc version
...@@ -89,7 +89,8 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): ...@@ -89,7 +89,8 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
"Given dir {}.pdparams not exist.".format(checkpoints) "Given dir {}.pdparams not exist.".format(checkpoints)
assert os.path.exists(checkpoints + ".pdopt"), \ assert os.path.exists(checkpoints + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(checkpoints) "Given dir {}.pdopt not exist.".format(checkpoints)
para_dict, opti_dict = paddle.load(checkpoints) para_dict = paddle.load(checkpoints + '.pdparams')
opti_dict = paddle.load(checkpoints + '.pdopt')
model.set_dict(para_dict) model.set_dict(para_dict)
if optimizer is not None: if optimizer is not None:
optimizer.set_state_dict(opti_dict) optimizer.set_state_dict(opti_dict)
...@@ -133,8 +134,8 @@ def save_model(net, ...@@ -133,8 +134,8 @@ def save_model(net,
""" """
_mkdir_if_not_exist(model_path, logger) _mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
paddle.save(net.state_dict(), model_prefix) paddle.save(net.state_dict(), model_prefix + '.pdparams')
paddle.save(optimizer.state_dict(), model_prefix) paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
# save metric and config # save metric and config
with open(model_prefix + '.states', 'wb') as f: with open(model_prefix + '.states', 'wb') as f:
......
...@@ -23,12 +23,8 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) ...@@ -23,12 +23,8 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
import paddle
# paddle.manual_seed(2)
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.modeling import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model
...@@ -39,8 +35,7 @@ import tools.program as program ...@@ -39,8 +35,7 @@ import tools.program as program
def main(): def main():
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
eval_loader, _ = build_dataloader(config['EVAL'], device, False, valid_dataloader = build_dataloader(config, 'Eval', device, logger)
global_config)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
...@@ -63,16 +58,13 @@ def main(): ...@@ -63,16 +58,13 @@ def main():
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# start eval # start eval
metirc = program.eval(model, eval_loader, post_process_class, eval_class) metirc = program.eval(model, valid_dataloader, post_process_class,
eval_class)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metirc.items(): for k, v in metirc.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
if __name__ == '__main__': if __name__ == '__main__':
device, config = program.preprocess() config, device, logger, vdl_writer = program.preprocess()
paddle.disable_static(device)
logger = get_logger()
print_dict(config, logger)
main() main()
...@@ -231,7 +231,7 @@ def train(config, ...@@ -231,7 +231,7 @@ def train(config,
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
cur_metirc = eval(model, valid_dataloader, post_process_class, cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class, logger, print_batch_step) eval_class)
cur_metirc_str = 'cur metirc, {}'.format(', '.join( cur_metirc_str = 'cur metirc, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metirc.items()])) ['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
logger.info(cur_metirc_str) logger.info(cur_metirc_str)
...@@ -293,8 +293,7 @@ def train(config, ...@@ -293,8 +293,7 @@ def train(config,
return return
def eval(model, valid_dataloader, post_process_class, eval_class, logger, def eval(model, valid_dataloader, post_process_class, eval_class):
print_batch_step):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -315,9 +314,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger, ...@@ -315,9 +314,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
eval_class(post_result, batch) eval_class(post_result, batch)
pbar.update(1) pbar.update(1)
total_frame += len(images) total_frame += len(images)
# if idx % print_batch_step == 0 and dist.get_rank() == 0:
# logger.info('tackling images for eval: {}/{}'.format(
# idx, len(valid_dataloader)))
# Get final metirc,eg. acc or hmean # Get final metirc,eg. acc or hmean
metirc = eval_class.get_metric() metirc = eval_class.get_metric()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册