diff --git a/tools/program.py b/tools/program.py index 2daf309a55245b92856c1352121c638374699461..a24d6ca7e627c6571592f7d4417ddfd661cfb920 100755 --- a/tools/program.py +++ b/tools/program.py @@ -251,7 +251,7 @@ def train(config, min_average_window=10000, max_average_window=15625) Model_Average.apply() - cur_metric = eval(model, valid_dataloader, post_process_class, # 原cur_metirc修改为 cur_metric + cur_metric = eval(model, valid_dataloader, post_process_class, eval_class) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))