diff --git a/tools/eval.py b/tools/eval.py index 39a26ffefff46a9a3fe1465e874d501a334921c7..28247bc57450aaf067fcb405674098eacb990166 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -54,8 +54,7 @@ def main(): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - use_srn = config['Architecture']['algorithm'] == "SRN" - use_sar = config['Architecture']['algorithm'] == "SAR" + extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"] if "model_type" in config['Architecture'].keys(): model_type = config['Architecture']['model_type'] else: @@ -72,7 +71,7 @@ def main(): # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, model_type, use_srn, use_sar) + eval_class, model_type, extra_input) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v))