提交 f83561e4 编写于 作者: W wangguanzhong 提交者: GitHub

fix training on CPU (#3418)

上级 1a13bc56
...@@ -99,7 +99,7 @@ def main(): ...@@ -99,7 +99,7 @@ def main():
device_id = int(env['FLAGS_selected_gpus']) device_id = int(env['FLAGS_selected_gpus'])
else: else:
device_id = 0 device_id = 0
place = fluid.CUDAPlace(device_id) place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
lr_builder = create('LearningRate') lr_builder = create('LearningRate')
...@@ -164,8 +164,8 @@ def main(): ...@@ -164,8 +164,8 @@ def main():
# local execution scopes can be deleted after each iteration. # local execution scopes can be deleted after each iteration.
exec_strategy.num_iteration_per_drop_scope = 1 exec_strategy.num_iteration_per_drop_scope = 1
if FLAGS.dist: if FLAGS.dist:
dist_utils.prepare_for_multi_process( dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog,
exe, build_strategy, startup_prog, train_prog) train_prog)
exec_strategy.num_threads = 1 exec_strategy.num_threads = 1
exe.run(startup_prog) exe.run(startup_prog)
...@@ -187,10 +187,8 @@ def main(): ...@@ -187,10 +187,8 @@ def main():
elif cfg.pretrain_weights: elif cfg.pretrain_weights:
checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights) checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights)
train_reader = create_reader( train_reader = create_reader(train_feed, (cfg.max_iters - start_iter) *
train_feed, devices_num, FLAGS.dataset_dir)
(cfg.max_iters - start_iter) * devices_num,
FLAGS.dataset_dir)
train_pyreader.decorate_sample_list_generator(train_reader, place) train_pyreader.decorate_sample_list_generator(train_reader, place)
# whether output bbox is normalized in model output layer # whether output bbox is normalized in model output layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册