From 637477f7b861854c5918402c79bc94580675a49b Mon Sep 17 00:00:00 2001 From: Dang Qingqing Date: Thu, 12 Apr 2018 11:24:19 +0800 Subject: [PATCH] Follow comments. --- fluid/object_detection/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fluid/object_detection/train.py b/fluid/object_detection/train.py index 8610500e..0f2856ca 100644 --- a/fluid/object_detection/train.py +++ b/fluid/object_detection/train.py @@ -18,6 +18,7 @@ add_arg('batch_size', int, 32, "Minibatch size.") add_arg('num_passes', int, 25, "Epoch number.") add_arg('parallel', bool, True, "Whether use parallel training.") add_arg('use_gpu', bool, True, "Whether use GPU.") +add_arg('use_nccl', bool, False, "Whether use NCCL.") add_arg('dataset', str, 'pascalvoc', "coco or pascalvoc.") add_arg('model_save_dir', str, 'model', "The path to save model.") add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.") @@ -57,7 +58,7 @@ def parallel_do(args, if args.parallel: places = fluid.layers.get_places() - pd = fluid.layers.ParallelDo(places) + pd = fluid.layers.ParallelDo(places, use_nccl=args.use_nccl) with pd.do(): image_ = pd.read_input(image) gt_box_ = pd.read_input(gt_box) @@ -224,7 +225,8 @@ def parallel_exe(args, return os.path.exists(os.path.join(pretrained_model, var.name)) fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) - train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=loss.name) + train_exe = fluid.ParallelExecutor(use_cuda=args.use_gpu, + loss_name=loss.name) train_reader = paddle.batch( reader.train(data_args, train_file_list), batch_size=batch_size) -- GitLab