提交 637477f7 编写于 作者: D Dang Qingqing

Follow comments.

上级 9a2f88be
......@@ -18,6 +18,7 @@ add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('num_passes', int, 25, "Epoch number.")
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_nccl', bool, False, "Whether use NCCL.")
add_arg('dataset', str, 'pascalvoc', "coco or pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
......@@ -57,7 +58,7 @@ def parallel_do(args,
if args.parallel:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
pd = fluid.layers.ParallelDo(places, use_nccl=args.use_nccl)
with pd.do():
image_ = pd.read_input(image)
gt_box_ = pd.read_input(gt_box)
......@@ -224,7 +225,8 @@ def parallel_exe(args,
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=loss.name)
train_exe = fluid.ParallelExecutor(use_cuda=args.use_gpu,
loss_name=loss.name)
train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册