diff --git a/tools/train.py b/tools/train.py index b6215931ca146fa9b4e6d5f45be54c3cf133e32f..52062f33853bff574791c580cbfacae2a974bf77 100644 --- a/tools/train.py +++ b/tools/train.py @@ -103,10 +103,6 @@ def main(): optimizer = optim_builder(lr) optimizer.minimize(loss) - train_reader = create_reader(train_feed, cfg.max_iters * devices_num, - FLAGS.dataset_dir) - train_pyreader.decorate_sample_list_generator(train_reader, place) - # parse train fetches train_keys, train_values, _ = parse_fetches(train_fetches) train_values.append(lr) @@ -163,6 +159,13 @@ def main(): elif cfg.pretrain_weights: checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights) + train_reader = create_reader( + train_feed, + (cfg.max_iters - start_iter) * devices_num, + FLAGS.dataset_dir) + train_pyreader.decorate_sample_list_generator(train_reader, place) + + # whether output bbox is normalized in model output layer is_bbox_normalized = False if hasattr(model, 'is_bbox_normalized') and \