From f83561e4a3223e20c140bf17eb9883f0cf396fee Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Wed, 25 Sep 2019 20:38:08 +0800 Subject: [PATCH] fix training on CPU (#3418) --- tools/train.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tools/train.py b/tools/train.py index 11167a327..6dc0f1ede 100644 --- a/tools/train.py +++ b/tools/train.py @@ -99,7 +99,7 @@ def main(): device_id = int(env['FLAGS_selected_gpus']) else: device_id = 0 - place = fluid.CUDAPlace(device_id) + place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) lr_builder = create('LearningRate') @@ -164,8 +164,8 @@ def main(): # local execution scopes can be deleted after each iteration. exec_strategy.num_iteration_per_drop_scope = 1 if FLAGS.dist: - dist_utils.prepare_for_multi_process( - exe, build_strategy, startup_prog, train_prog) + dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog, + train_prog) exec_strategy.num_threads = 1 exe.run(startup_prog) @@ -187,10 +187,8 @@ def main(): elif cfg.pretrain_weights: checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights) - train_reader = create_reader( - train_feed, - (cfg.max_iters - start_iter) * devices_num, - FLAGS.dataset_dir) + train_reader = create_reader(train_feed, (cfg.max_iters - start_iter) * + devices_num, FLAGS.dataset_dir) train_pyreader.decorate_sample_list_generator(train_reader, place) # whether output bbox is normalized in model output layer -- GitLab