提交 67231825 编写于 作者: W WenmuZhou

删除eval多余的参数

上级 4eba6c0d
...@@ -231,7 +231,7 @@ def train(config, ...@@ -231,7 +231,7 @@ def train(config,
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
cur_metirc = eval(model, valid_dataloader, post_process_class, cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class, logger, print_batch_step) eval_class)
cur_metirc_str = 'cur metirc, {}'.format(', '.join( cur_metirc_str = 'cur metirc, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metirc.items()])) ['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
logger.info(cur_metirc_str) logger.info(cur_metirc_str)
...@@ -293,8 +293,7 @@ def train(config, ...@@ -293,8 +293,7 @@ def train(config,
return return
def eval(model, valid_dataloader, post_process_class, eval_class, logger, def eval(model, valid_dataloader, post_process_class, eval_class):
print_batch_step):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -315,9 +314,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger, ...@@ -315,9 +314,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
eval_class(post_result, batch) eval_class(post_result, batch)
pbar.update(1) pbar.update(1)
total_frame += len(images) total_frame += len(images)
# if idx % print_batch_step == 0 and dist.get_rank() == 0:
# logger.info('tackling images for eval: {}/{}'.format(
# idx, len(valid_dataloader)))
# Get final metirc,eg. acc or hmean # Get final metirc,eg. acc or hmean
metirc = eval_class.get_metric() metirc = eval_class.get_metric()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册