提交 560f2f49 编写于 作者: T tink2123

fix eval

上级 e885b57e
......@@ -54,8 +54,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
use_sar = config['Architecture']['algorithm'] == "SAR"
extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"]
if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type']
else:
......@@ -72,7 +71,7 @@ def main():
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn, use_sar)
eval_class, model_type, extra_input)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册