diff --git a/tools/eval.py b/tools/eval.py index 16cfe532aae49ce98bc9503ca73e009bf206caa7..4afed469c875ef8d2200cdbfd89e5a8af4c6b7c3 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -47,6 +47,7 @@ def main(): config['Architecture']["Head"]['out_channels'] = len( getattr(post_process_class, 'character')) model = build_model(config['Architecture']) + use_srn = config['Architecture']['algorithm'] == "SRN" best_model_dict = init_model(config, model, logger) if len(best_model_dict): @@ -59,7 +60,7 @@ def main(): # start eval metirc = program.eval(model, valid_dataloader, post_process_class, - eval_class) + eval_class, use_srn) logger.info('metric eval ***************') for k, v in metirc.items(): logger.info('{}:{}'.format(k, v))