提交 b658e10b 编写于 作者: T tink2123

fix srn for eval

上级 0d74b46c
...@@ -47,6 +47,7 @@ def main(): ...@@ -47,6 +47,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = len( config['Architecture']["Head"]['out_channels'] = len(
getattr(post_process_class, 'character')) getattr(post_process_class, 'character'))
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
best_model_dict = init_model(config, model, logger) best_model_dict = init_model(config, model, logger)
if len(best_model_dict): if len(best_model_dict):
...@@ -59,7 +60,7 @@ def main(): ...@@ -59,7 +60,7 @@ def main():
# start eval # start eval
metirc = program.eval(model, valid_dataloader, post_process_class, metirc = program.eval(model, valid_dataloader, post_process_class,
eval_class) eval_class, use_srn)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metirc.items(): for k, v in metirc.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册