diff --git a/tools/program.py b/tools/program.py index f3ba49450a21f600589b6888710a2420ccdaa321..2daf309a55245b92856c1352121c638374699461 100755 --- a/tools/program.py +++ b/tools/program.py @@ -251,7 +251,7 @@ def train(config, min_average_window=10000, max_average_window=15625) Model_Average.apply() - cur_metirc = eval(model, valid_dataloader, post_process_class, + cur_metric = eval(model, valid_dataloader, post_process_class, # 原cur_metirc修改为 cur_metric eval_class) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))