From e222bc9e64b603e351d055a7592db0cd3cc3bb6b Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Tue, 7 Jul 2020 02:35:17 +0000 Subject: [PATCH] set evaluation interval --- doc/doc_ch/config.md | 2 +- doc/doc_en/config_en.md | 2 +- tools/program.py | 18 ++++++++++++++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/doc/doc_ch/config.md b/doc/doc_ch/config.md index ae16263e..56e5c7f5 100644 --- a/doc/doc_ch/config.md +++ b/doc/doc_ch/config.md @@ -22,7 +22,7 @@ | print_batch_step | 设置打印log间隔 | 10 | \ | | save_model_dir | 设置模型保存路径 | output/{算法名称} | \ | | save_epoch_step | 设置模型保存间隔 | 3 | \ | -| eval_batch_step | 设置模型评估间隔 | 2000 | \ | +| eval_batch_step | 设置模型评估间隔 | 2000 或 [1000, 2000] | 2000 表示每2000次迭代评估一次,[1000, 2000]表示从1000次迭代开始,每2000次评估一次 | |train_batch_size_per_card | 设置训练时单卡batch size | 256 | \ | | test_batch_size_per_card | 设置评估时单卡batch size | 256 | \ | | image_shape | 设置输入图片尺寸 | [3, 32, 100] | \ | diff --git a/doc/doc_en/config_en.md b/doc/doc_en/config_en.md index b9ad0394..932de26c 100644 --- a/doc/doc_en/config_en.md +++ b/doc/doc_en/config_en.md @@ -22,7 +22,7 @@ Take `rec_chinese_lite_train.yml` as an example | print_batch_step | Set print log interval | 10 | \ | | save_model_dir | Set model save path | output/{model_name} | \ | | save_epoch_step | Set model save interval | 3 | \ | -| eval_batch_step | Set the model evaluation interval | 2000 | \ | +| eval_batch_step | Set the model evaluation interval |2000 or [1000, 2000] |runing evaluation every 2000 iters or evaluation is run every 2000 iterations after the 1000th iteration | |train_batch_size_per_card | Set the batch size during training | 256 | \ | | test_batch_size_per_card | Set the batch size during testing | 256 | \ | | image_shape | Set input image size | [3, 32, 100] | \ | diff --git a/tools/program.py b/tools/program.py index 3c71065a..57447caa 100755 --- a/tools/program.py +++ b/tools/program.py @@ -219,6 +219,13 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): epoch_num = config['Global']['epoch_num'] print_batch_step = config['Global']['print_batch_step'] eval_batch_step = config['Global']['eval_batch_step'] + start_eval_step = 0 + if type(eval_batch_step) == list and len(eval_batch_step) >= 2: + start_eval_step = eval_batch_step[0] + eval_batch_step = eval_batch_step[1] + logger.info( + "During the training process, after the {}th iteration, an evaluation is run every {} iterations". + format(start_eval_step, eval_batch_step)) save_epoch_step = config['Global']['save_epoch_step'] save_model_dir = config['Global']['save_model_dir'] if not os.path.exists(save_model_dir): @@ -246,7 +253,7 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): t2 = time.time() train_batch_elapse = t2 - t1 train_stats.update(stats) - if train_batch_id > 0 and train_batch_id \ + if train_batch_id > start_eval_step and train_batch_id \ % print_batch_step == 0: logs = train_stats.log() strs = 'epoch: {}, iter: {}, {}, time: {:.3f}'.format( @@ -286,6 +293,13 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): epoch_num = config['Global']['epoch_num'] print_batch_step = config['Global']['print_batch_step'] eval_batch_step = config['Global']['eval_batch_step'] + start_eval_step = 0 + if type(eval_batch_step) == list and len(eval_batch_step) >= 2: + start_eval_step = eval_batch_step[0] + eval_batch_step = eval_batch_step[1] + logger.info( + "During the training process, after the {}th iteration, an evaluation is run every {} iterations". + format(start_eval_step, eval_batch_step)) save_epoch_step = config['Global']['save_epoch_step'] save_model_dir = config['Global']['save_model_dir'] if not os.path.exists(save_model_dir): @@ -324,7 +338,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): train_batch_elapse = t2 - t1 stats = {'loss': loss, 'acc': acc} train_stats.update(stats) - if train_batch_id > 0 and train_batch_id \ + if train_batch_id > start_eval_step and train_batch_id \ % print_batch_step == 0: logs = train_stats.log() strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format( -- GitLab