未验证 提交 796fde66 编写于 作者: W wangguanzhong 提交者: GitHub

fix training on CPU (#3418)

上级 adf7eca6
......@@ -99,7 +99,7 @@ def main():
device_id = int(env['FLAGS_selected_gpus'])
else:
device_id = 0
place = fluid.CUDAPlace(device_id)
place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
lr_builder = create('LearningRate')
......@@ -164,8 +164,8 @@ def main():
# local execution scopes can be deleted after each iteration.
exec_strategy.num_iteration_per_drop_scope = 1
if FLAGS.dist:
dist_utils.prepare_for_multi_process(
exe, build_strategy, startup_prog, train_prog)
dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog,
train_prog)
exec_strategy.num_threads = 1
exe.run(startup_prog)
......@@ -187,10 +187,8 @@ def main():
elif cfg.pretrain_weights:
checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights)
train_reader = create_reader(
train_feed,
(cfg.max_iters - start_iter) * devices_num,
FLAGS.dataset_dir)
train_reader = create_reader(train_feed, (cfg.max_iters - start_iter) *
devices_num, FLAGS.dataset_dir)
train_pyreader.decorate_sample_list_generator(train_reader, place)
# whether output bbox is normalized in model output layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册