diff --git a/tools/program.py b/tools/program.py index 6a51e5c37175f0f45e87571122adc2aba04d491c..ff8743f15e2925a92fff76e7430e761d7baa720e 100755 --- a/tools/program.py +++ b/tools/program.py @@ -256,15 +256,15 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): t2 = time.time() train_batch_elapse = t2 - t1 train_stats.update(stats) - if train_batch_id > start_eval_step and (train_batch_id -start_eval_step) \ + if train_batch_id > 0 and train_batch_id \ % print_batch_step == 0: logs = train_stats.log() strs = 'epoch: {}, iter: {}, {}, time: {:.3f}'.format( epoch, train_batch_id, logs, train_batch_elapse) logger.info(strs) - if train_batch_id > 0 and\ - train_batch_id % eval_batch_step == 0: + if train_batch_id > start_eval_step and\ + (train_batch_id - start_eval_step) % eval_batch_step == 0: metrics = eval_det_run(exe, config, eval_info_dict, "eval") hmean = metrics['hmean'] if hmean >= best_eval_hmean: