From 09fd94e781d8ddc5e6cfa3b9e652401f0499b526 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Tue, 26 Jan 2021 15:16:02 +0800 Subject: [PATCH] fix typo --- tools/program.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tools/program.py b/tools/program.py index cbca715a..fb9e3802 100755 --- a/tools/program.py +++ b/tools/program.py @@ -212,7 +212,7 @@ def train(config, stats['lr'] = lr train_stats.update(stats) - if cal_metric_during_train: # onlt rec and cls need + if cal_metric_during_train: # only rec and cls need batch = [item.numpy() for item in batch] post_result = post_process_class(preds, batch[1]) eval_class(post_result, batch) @@ -238,21 +238,21 @@ def train(config, # eval if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: - cur_metirc = eval(model, valid_dataloader, post_process_class, + cur_metric = eval(model, valid_dataloader, post_process_class, eval_class) - cur_metirc_str = 'cur metirc, {}'.format(', '.join( - ['{}: {}'.format(k, v) for k, v in cur_metirc.items()])) - logger.info(cur_metirc_str) + cur_metric_str = 'cur metric, {}'.format(', '.join( + ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) + logger.info(cur_metric_str) # logger metric if vdl_writer is not None: - for k, v in cur_metirc.items(): + for k, v in cur_metric.items(): if isinstance(v, (float, int)): vdl_writer.add_scalar('EVAL/{}'.format(k), - cur_metirc[k], global_step) - if cur_metirc[main_indicator] >= best_model_dict[ + cur_metric[k], global_step) + if cur_metric[main_indicator] >= best_model_dict[ main_indicator]: - best_model_dict.update(cur_metirc) + best_model_dict.update(cur_metric) best_model_dict['best_epoch'] = epoch save_model( model, @@ -263,7 +263,7 @@ def train(config, prefix='best_accuracy', best_model_dict=best_model_dict, epoch=epoch) - best_str = 'best metirc, {}'.format(', '.join([ + best_str = 'best metric, {}'.format(', '.join([ '{}: {}'.format(k, v) for k, v in best_model_dict.items() ])) logger.info(best_str) @@ -294,7 +294,7 @@ def train(config, prefix='iter_epoch_{}'.format(epoch), best_model_dict=best_model_dict, epoch=epoch) - best_str = 'best metirc, {}'.format(', '.join( + best_str = 'best metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) logger.info(best_str) if dist.get_rank() == 0 and vdl_writer is not None: @@ -323,13 +323,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class): eval_class(post_result, batch) pbar.update(1) total_frame += len(images) - # Get final metirc,eg. acc or hmean - metirc = eval_class.get_metric() + # Get final metric,eg. acc or hmean + metric = eval_class.get_metric() pbar.close() model.train() - metirc['fps'] = total_frame / total_time - return metirc + metric['fps'] = total_frame / total_time + return metric def preprocess(is_train=False): -- GitLab