提交 637477f7 编写于 作者: D Dang Qingqing

Follow comments.

上级 9a2f88be
...@@ -18,6 +18,7 @@ add_arg('batch_size', int, 32, "Minibatch size.") ...@@ -18,6 +18,7 @@ add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('num_passes', int, 25, "Epoch number.") add_arg('num_passes', int, 25, "Epoch number.")
add_arg('parallel', bool, True, "Whether use parallel training.") add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether use GPU.") add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_nccl', bool, False, "Whether use NCCL.")
add_arg('dataset', str, 'pascalvoc', "coco or pascalvoc.") add_arg('dataset', str, 'pascalvoc', "coco or pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.") add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.") add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
...@@ -57,7 +58,7 @@ def parallel_do(args, ...@@ -57,7 +58,7 @@ def parallel_do(args,
if args.parallel: if args.parallel:
places = fluid.layers.get_places() places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places) pd = fluid.layers.ParallelDo(places, use_nccl=args.use_nccl)
with pd.do(): with pd.do():
image_ = pd.read_input(image) image_ = pd.read_input(image)
gt_box_ = pd.read_input(gt_box) gt_box_ = pd.read_input(gt_box)
...@@ -224,7 +225,8 @@ def parallel_exe(args, ...@@ -224,7 +225,8 @@ def parallel_exe(args,
return os.path.exists(os.path.join(pretrained_model, var.name)) return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=loss.name) train_exe = fluid.ParallelExecutor(use_cuda=args.use_gpu,
loss_name=loss.name)
train_reader = paddle.batch( train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size) reader.train(data_args, train_file_list), batch_size=batch_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册