From 672318256cc0d1a4cb2d245950de1f368a2bc7e6 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Mon, 9 Nov 2020 13:28:46 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4eval=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E7=9A=84=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/program.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tools/program.py b/tools/program.py index 41acb866..8bae0fd5 100755 --- a/tools/program.py +++ b/tools/program.py @@ -231,7 +231,7 @@ def train(config, if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: cur_metirc = eval(model, valid_dataloader, post_process_class, - eval_class, logger, print_batch_step) + eval_class) cur_metirc_str = 'cur metirc, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metirc.items()])) logger.info(cur_metirc_str) @@ -293,8 +293,7 @@ def train(config, return -def eval(model, valid_dataloader, post_process_class, eval_class, logger, - print_batch_step): +def eval(model, valid_dataloader, post_process_class, eval_class): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -315,9 +314,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger, eval_class(post_result, batch) pbar.update(1) total_frame += len(images) - # if idx % print_batch_step == 0 and dist.get_rank() == 0: - # logger.info('tackling images for eval: {}/{}'.format( - # idx, len(valid_dataloader))) # Get final metirc,eg. acc or hmean metirc = eval_class.get_metric() -- GitLab