From 6382b62f6bc3533bf29db043d209cc2840992056 Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Wed, 20 Mar 2019 08:00:29 +0800 Subject: [PATCH] Collective ops (#15572) * wip allreduce in op * wip * wip * wip * wip adding test * wip for conflict with mp mode * fix tests test=develop * fix cpu build test=develop * fix travis clang format test=develop * fix cpu build test=develop * update api.spec test=develop * delete comment test=develop * fix cpplint test=develop * fix test=develop * follow comment test=develop * add file test=develop * fix build test=develop * update test=develop * to be compatible with sync_bn, and fix mp mode in develop test=develop --- .../details/multi_devices_graph_pass.cc | 26 +++- .../details/multi_devices_graph_pass.h | 6 +- paddle/fluid/framework/parallel_executor.cc | 23 ++- .../operators/distributed_ops/allreduce_op.cc | 143 ++++++++++++++++++ paddle/fluid/platform/device_context.h | 1 + paddle/fluid/platform/nccl_helper.h | 9 +- python/paddle/fluid/layers/__init__.py | 1 + python/paddle/fluid/layers/collective.py | 47 ++++++ python/paddle/fluid/layers/control_flow.py | 18 +-- .../tests/unittests/dist_allreduce_op.py | 120 +++++++++++++++ .../tests/unittests/test_dist_allreduce_op.py | 35 +++++ .../fluid/tests/unittests/test_dist_base.py | 89 +++++++---- .../fluid/tests/unittests/test_dist_mnist.py | 2 +- 13 files changed, 458 insertions(+), 62 deletions(-) create mode 100644 paddle/fluid/operators/distributed_ops/allreduce_op.cc create mode 100644 python/paddle/fluid/layers/collective.py create mode 100644 python/paddle/fluid/tests/unittests/dist_allreduce_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_dist_allreduce_op.py diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 478d2ffbcf..ed2e0bc65e 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -166,7 +166,6 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( result.Set(kGraphOps, new GraphOps); bool is_forwarding = true; - bool insert_collection_ops = NeedCollectiveOps(); for (ir::Node *node : sorted_ops) { if (DealWithSpecialOp(&result, node)) { @@ -185,8 +184,8 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( CreateComputationalOps(&result, node, places_.size()); } - // Insert collection ops - if (!is_forwarding && insert_collection_ops) { + // Insert collective ops if nranks > 1 + if (!is_forwarding && Get(kNRanks) > 1) { try { bool is_bk_op = static_cast(boost::get(node->Op()->GetAttr( @@ -205,8 +204,9 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( auto &p_name = backward_vars[i]; auto &g_name = backward_vars[i + 1]; VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; - - InsertCollectiveOp(&result, p_name, g_name); + if (NeedCollectiveForGrad(g_name, sorted_ops)) { + InsertCollectiveOp(&result, p_name, g_name); + } } } catch (boost::bad_get e) { } @@ -271,8 +271,20 @@ bool MultiDevSSAGraphBuilderBase::UseGPU() const { return use_gpu; } -bool MultiDevSSAGraphBuilderBase::NeedCollectiveOps() const { - return Get(kNRanks) > 1; +bool MultiDevSSAGraphBuilderBase::NeedCollectiveForGrad( + const std::string &grad_name, std::vector ops) const { + // if we have allreduce_op for current gradient variable in the graph, + // then we don't need to add allreduce_op_handle for this gradient + // NOTE: This is for the case that all gradients should add collective ops + for (auto *node : ops) { + if (node->Op()->Type() != "allreduce") continue; + for (auto in_name : node->Op()->InputArgumentNames()) { + if (in_name == grad_name) { + return false; + } + } + } + return true; } void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result, diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 6d4386538e..9018243dd7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -14,7 +14,10 @@ #pragma once +#include #include +#include +#include #include #include @@ -55,7 +58,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { bool UseGPU() const; - bool NeedCollectiveOps() const; + bool NeedCollectiveForGrad(const std::string &grad_name, + std::vector ops) const; bool IsScaleLossOp(ir::Node *node) const; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 56f108cea2..20a8c47d5d 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -254,18 +254,29 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, member_->places_, nccl_id, build_strategy.num_trainers_, build_strategy.trainer_id_)); - std::unique_ptr dev_nccl_ctxs; - dev_nccl_ctxs.reset(new platform::NCCLContextMap(member_->places_)); - // Initialize device context's nccl comm - // Note, more than one ParallelExecutor with same place, the nccl comm will + // Initialize device context's nccl comm, will be used by normal + // Operators like sync_batch_norm, and collective ops. + // NOTE: more than one ParallelExecutor with same place, the nccl comm will // be rewrite and there will be some problem. + // NOTE: NCCL group-calls and non-group-calls can not use the same + // NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use + // same communicators. + std::unique_ptr dev_nccl_ctxs; + if (nccl_id == nullptr) { + dev_nccl_ctxs.reset(new platform::NCCLContextMap(member_->places_)); + } for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) { - auto &nccl_ctx = dev_nccl_ctxs->at(dev_id); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto *dev_ctx = static_cast( pool.Get(member_->places_[dev_id])); - dev_ctx->set_nccl_comm(nccl_ctx.comm()); + if (nccl_id != nullptr) { + auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[dev_id]); + dev_ctx->set_nccl_comm(nccl_ctx.comm()); + } else { + auto &nccl_ctx = dev_nccl_ctxs->at(member_->places_[dev_id]); + dev_ctx->set_nccl_comm(nccl_ctx.comm()); + } } #else PADDLE_THROW("Not compiled with CUDA"); diff --git a/paddle/fluid/operators/distributed_ops/allreduce_op.cc b/paddle/fluid/operators/distributed_ops/allreduce_op.cc new file mode 100644 index 0000000000..0fbc27515c --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/allreduce_op.cc @@ -0,0 +1,143 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +struct MutableDataFunctor { + MutableDataFunctor(void** data, framework::LoDTensor* tensor, + const platform::Place& place) + : data_(data), tensor_(tensor), place_(place) {} + + template + void apply() { + *data_ = tensor_->mutable_data(place_); + } + + void** data_; + framework::LoDTensor* tensor_; + platform::Place place_; +}; + +class AllReduceOp : public framework::OperatorBase { + using OperatorBase::OperatorBase; + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + PADDLE_ENFORCE(is_gpu_place(place), + "AllReduce op can run on gpu place only for now."); +#ifdef PADDLE_WITH_CUDA + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* ctx = pool.Get(place); + auto in_names = Inputs("X"); + auto out_names = Outputs("Out"); + PADDLE_ENFORCE_EQ(in_names.size(), 1, "Only support one input"); + PADDLE_ENFORCE_EQ(out_names.size(), 1, "Only support one output"); + + auto* in = scope.FindVar(in_names[0]); + auto* out = scope.FindVar(out_names[0]); + + PADDLE_ENFORCE(in->IsType() || + out->IsType(), + "Only support allreduce LoDTensors"); + + int dtype = -1; + auto in_tensor = in->Get(); + dtype = platform::ToNCCLDataType(in_tensor.type()); + + int64_t numel = in_tensor.numel(); + auto* sendbuff = in_tensor.data(); + auto* out_tensor = out->GetMutable(); + out_tensor->Resize(in_tensor.dims()); + void* recvbuff = nullptr; + framework::VisitDataType(in_tensor.type(), + MutableDataFunctor(&recvbuff, out_tensor, place)); + + auto cuda_ctx = static_cast(ctx); + auto* comm = cuda_ctx->nccl_comm(); + // FIXME(typhoonzero): should use nccl stream here. + auto stream = cuda_ctx->stream(); + + int reduce_type = Attr("reduce_type"); + ncclRedOp_t red_type = ncclSum; + switch (reduce_type) { + case 0: + red_type = ncclSum; + break; + case 1: + red_type = ncclProd; + break; + case 2: + red_type = ncclMax; + break; + case 3: + red_type = ncclMin; + break; + } + + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + sendbuff, recvbuff, numel, static_cast(dtype), red_type, + comm, stream)); +#endif + } +}; + +class AllReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor), tensor to be allreduced."); + AddOutput("Out", "(Tensor) the result of allreduced."); + AddAttr("reduce_type", "(int) determin the reduce type.") + .SetDefault(0); + AddComment(R"DOC( +***AllReduce Operator*** + +Call NCCL AllReduce internally. Note that this op must be used when one +thread is managing one GPU device. + +For speed reasons, reduce_type should be an integer: + +0: sum +1: prod +2: max +3: min + +If input and output are the same variable, in-place allreduce will be used. +)DOC"); + } +}; + +class AllReduceOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(allreduce, ops::AllReduceOp, + paddle::framework::EmptyGradOpMaker, ops::AllReduceOpMaker, + ops::AllReduceOpShapeInference); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 1eb8d9691a..19ce769ab3 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" +#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/gpu_info.h" #endif diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 0428c40f98..b8b14b3d15 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -22,6 +22,7 @@ #include #include #include + #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/enforce.h" @@ -79,7 +80,6 @@ struct NCCLContext { : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {} cudaStream_t stream() const { return ctx_->stream(); } - ncclComm_t comm() const { return comm_; } int device_id() const { @@ -105,9 +105,6 @@ struct NCCLContextMap { order_.size(), contexts_.size(), "NCCL Context Map does not support contain two or more same device"); - if (places.size() <= 1 && num_trainers == 1) { - return; - } std::unique_ptr comms(new ncclComm_t[order_.size()]); // if num_trainers == 1, should create a new nccl id for local comms. if (num_trainers == 1 && nccl_id == nullptr) { @@ -127,8 +124,8 @@ struct NCCLContextMap { } else { rank = trainer_id; } - VLOG(30) << "init nccl rank: " << rank << " nranks: " << nranks - << "gpu id: " << gpu_id; + VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks + << " gpu id: " << gpu_id; PADDLE_ENFORCE(cudaSetDevice(gpu_id)); PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( comms.get() + i, nranks, *nccl_id, rank)); diff --git a/python/paddle/fluid/layers/__init__.py b/python/paddle/fluid/layers/__init__.py index a2a808777d..31effea378 100644 --- a/python/paddle/fluid/layers/__init__.py +++ b/python/paddle/fluid/layers/__init__.py @@ -33,6 +33,7 @@ from .detection import * from . import metric_op from .metric_op import * from .learning_rate_scheduler import * +from .collective import * __all__ = [] __all__ += nn.__all__ diff --git a/python/paddle/fluid/layers/collective.py b/python/paddle/fluid/layers/collective.py new file mode 100644 index 0000000000..a9bce77b9d --- /dev/null +++ b/python/paddle/fluid/layers/collective.py @@ -0,0 +1,47 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from ..layer_helper import LayerHelper, unique_name + + +def _allreduce(x, out=None, reduce_type="sum"): + helper = LayerHelper("allreduce", **locals()) + # Convert string reduce type to op int type + red_typ_int = 0 + if reduce_type == "sum": + red_typ_int = 0 + elif reduce_type == "prod": + red_typ_int = 1 + elif reduce_type == "max": + red_typ_int = 2 + elif reduce_type == "min": + red_typ_int = 3 + else: + raise TypeError("reduce type can only be [sum|prod|max|min]") + + if out is None: + out = helper.create_variable( + name=unique_name.generate(".".join([x.name, 'tmp'])), + shape=x.shape, + dtype=x.dtype, + type=x.type, + persistable=x.persistable, + stop_gradient=True) + helper.append_op( + type='allreduce', + inputs={'X': [x]}, + outputs={'Out': [out]}, + attrs={"reduce_type": red_typ_int}) + return out diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index e6e13b0342..3277766171 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -28,21 +28,9 @@ import six from functools import reduce __all__ = [ - 'While', - 'Switch', - 'increment', - 'array_write', - 'create_array', - 'less_than', - 'equal', - 'array_read', - 'array_length', - 'IfElse', - 'DynamicRNN', - 'StaticRNN', - 'reorder_lod_tensor_by_rank', - 'Print', - 'is_empty', + 'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than', + 'equal', 'array_read', 'array_length', 'IfElse', 'DynamicRNN', 'StaticRNN', + 'reorder_lod_tensor_by_rank', 'Print', 'is_empty' ] diff --git a/python/paddle/fluid/tests/unittests/dist_allreduce_op.py b/python/paddle/fluid/tests/unittests/dist_allreduce_op.py new file mode 100644 index 0000000000..88a3cd14c4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_allreduce_op.py @@ -0,0 +1,120 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main + +DTYPE = "float32" +paddle.dataset.mnist.fetch() + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +def cnn_model(data): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=data, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu", + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.01))) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu", + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.01))) + + SIZE = 10 + input_shape = conv_pool_2.shape + param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE] + scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5 + + predict = fluid.layers.fc( + input=conv_pool_2, + size=SIZE, + act="softmax", + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01))) + return predict + + +class TestDistMnist2x2(TestDistRunnerBase): + def get_model(self, batch_size=2, single_device=False): + # Input data + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = cnn_model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Evaluator + batch_size_tensor = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size_tensor) + + inference_program = fluid.default_main_program().clone() + + # Reader + train_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + + # Optimization + # TODO(typhoonzero): fix distributed adam optimizer + # opt = fluid.optimizer.AdamOptimizer( + # learning_rate=0.001, beta1=0.9, beta2=0.999) + opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9) + if single_device: + opt.minimize(avg_cost) + else: + # multi device or distributed multi device + params_grads = opt.backward(avg_cost) + data_parallel_param_grads = [] + for p, g in params_grads: + # NOTE: scale will be done on loss scale in multi_devices_graph_pass using nranks. + grad_reduce = fluid.layers.collective._allreduce(g) + data_parallel_param_grads.append([p, grad_reduce]) + opt.apply_gradients(data_parallel_param_grads) + + return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict + + +if __name__ == "__main__": + runtime_main(TestDistMnist2x2) diff --git a/python/paddle/fluid/tests/unittests/test_dist_allreduce_op.py b/python/paddle/fluid/tests/unittests/test_dist_allreduce_op.py new file mode 100644 index 0000000000..fbeff20c63 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_allreduce_op.py @@ -0,0 +1,35 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase + + +class TestDistMnistNCCL2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + self._use_reader_alloc = False + self._nccl2_mode = True + self._nccl2_reduce_layer = True + + def test_dist_train(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place("dist_allreduce_op.py", delta=1e-5) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index f4d14d4024..d98b839e9b 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -33,7 +33,10 @@ DEFAULT_BATCH_SIZE = 2 class TestDistRunnerBase(object): - def get_model(self, batch_size=DEFAULT_BATCH_SIZE, lr=0.1): + def get_model(self, + batch_size=DEFAULT_BATCH_SIZE, + lr=0.1, + single_device=False): raise NotImplementedError( "get_model should be implemented by child classes.") @@ -76,8 +79,12 @@ class TestDistRunnerBase(object): def run_trainer(self, args): self.lr = args.lr - test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ - self.get_model(batch_size=args.batch_size) + if args.nccl2_reduce_layer_local_run: + test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ + self.get_model(batch_size=args.batch_size, single_device=True) + else: + test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ + self.get_model(batch_size=args.batch_size) if args.mem_opt: fluid.memory_optimize(fluid.default_main_program(), skip_grads=True) @@ -87,7 +94,7 @@ class TestDistRunnerBase(object): args.endpoints, args.trainers, args.sync_mode, args.dc_asgd) trainer_prog = t.get_trainer_program() - elif args.update_method == "nccl2": + elif args.update_method == "nccl2" or args.update_method == "nccl2_reduce_layer": # transpile for nccl2 config = fluid.DistributeTranspilerConfig() config.mode = "nccl2" @@ -110,9 +117,9 @@ class TestDistRunnerBase(object): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - strategy = fluid.ExecutionStrategy() - strategy.num_threads = 1 - strategy.allow_op_delay = False + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = 1 + exec_strategy.allow_op_delay = False build_stra = fluid.BuildStrategy() # FIXME force disable enable_inplace and memory_optimize @@ -124,23 +131,25 @@ class TestDistRunnerBase(object): else: build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce + pass_builder = None if args.batch_merge_repeat > 1: pass_builder = build_stra._finalize_strategy_and_create_passes() mypass = pass_builder.insert_pass( len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass") mypass.set("num_repeats", args.batch_merge_repeat) - if args.update_method == "nccl2": + if args.update_method == "nccl2" or args.update_method == "nccl2_reduce_layer": build_stra.num_trainers = len(args.endpoints.split(",")) build_stra.trainer_id = args.trainer_id else: + # case args.update_method == "nccl2_reduce_layer": build_stra.num_trainers = 1 build_stra.trainer_id = 0 binary = compiler.CompiledProgram(trainer_prog).with_data_parallel( loss_name=avg_cost.name, build_strategy=build_stra, - exec_strategy=strategy) + exec_strategy=exec_strategy) feed_var_list = [ var for var in trainer_prog.global_block().vars.values() @@ -182,7 +191,7 @@ def runtime_main(test_class): '--update_method', type=str, default="local", - choices=["pserver", "nccl2", "local"]) + choices=["pserver", "nccl2", "local", "nccl2_reduce_layer"]) parser.add_argument('--trainer_id', type=int, required=False, default=0) parser.add_argument('--trainers', type=int, required=False, default=1) parser.add_argument( @@ -198,6 +207,11 @@ def runtime_main(test_class): parser.add_argument('--lr', required=False, type=float, default=0.001) parser.add_argument( '--batch_merge_repeat', required=False, type=int, default=1) + parser.add_argument( + '--nccl2_reduce_layer_local_run', + required=False, + type=bool, + default=False) args = parser.parse_args() @@ -242,6 +256,11 @@ class TestDistBase(unittest.TestCase): self._dc_asgd = False # must use with async mode self._use_reader_alloc = True self._nccl2_mode = False + # FIXME(typhoonzero): I added this stupid argument to enable + # testing allreduce layers, which users can call layers.allreduce + # to accumulate tensors at anywhere. Find a better way to do this + # test, reduce check this argument everywhere. + self._nccl2_reduce_layer = False self._lr = 0.001 self._setup_config() self._after_setup_config() @@ -307,10 +326,16 @@ class TestDistBase(unittest.TestCase): cmd += " --batch_size %d" % batch_size if batch_merge_repeat > 1: cmd += " --batch_merge_repeat %d" % batch_merge_repeat + if self._nccl2_reduce_layer: + cmd += " --nccl2_reduce_layer_local_run 1" if self.__use_cuda: cmd += " --use_cuda" - env_local = {"CUDA_VISIBLE_DEVICES": "0"} + env_local = { + "CUDA_VISIBLE_DEVICES": "0", + "PADDLE_TRAINERS_NUM": "1", + "PADDLE_TRAINER_ID": "0" + } else: env_local = {'CPU_NUM': '1'} @@ -427,29 +452,30 @@ class TestDistBase(unittest.TestCase): sys.stderr.write("ps1 stderr: %s\n" % fn.read()) # print log - if stat0 == 0: - sys.stderr.write('trainer 0 stdout: %s\n' % pickle.loads(tr0_out)) with open("/tmp/tr0_err.log", "r") as fn: sys.stderr.write('trainer 0 stderr: %s\n' % fn.read()) - if stat1 == 0: - sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out)) with open("/tmp/tr1_err.log", "r") as fn: sys.stderr.write('trainer 1 stderr: %s\n' % fn.read()) return pickle.loads(tr0_out), pickle.loads(tr1_out) - def _run_cluster_nccl2(self, model, envs, check_error_log): + def _run_cluster_nccl2(self, model, envs, nccl2_reduce_layer, + check_error_log): # NOTE: we reuse ps_endpoints as nccl2 worker endpoints worker_endpoints = self._ps_endpoints.split(",") w0_ep, w1_ep = worker_endpoints + if nccl2_reduce_layer: + update_method = "nccl2_reduce_layer" + else: + update_method = "nccl2" - tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f" + tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method %s --lr %f" tr0_cmd = tr_cmd % \ (self._python_interp, model, self._ps_endpoints, - 0, w0_ep, self._lr) + 0, w0_ep, update_method, self._lr) tr1_cmd = tr_cmd % \ (self._python_interp, model, self._ps_endpoints, - 1, w1_ep, self._lr) + 1, w1_ep, update_method, self._lr) if self._mem_opt: tr0_cmd += " --mem_opt" @@ -463,8 +489,17 @@ class TestDistBase(unittest.TestCase): if self.__use_cuda: tr0_cmd += " --use_cuda" tr1_cmd += " --use_cuda" - env0 = {"CUDA_VISIBLE_DEVICES": "0"} - env1 = {"CUDA_VISIBLE_DEVICES": "1"} + env0 = { + "CUDA_VISIBLE_DEVICES": "0", + # for test nccl2 layer + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ID": "0" + } + env1 = { + "CUDA_VISIBLE_DEVICES": "1", + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ID": "1" + } else: env0 = {'CPU_NUM': '1'} env1 = {'CPU_NUM': '1'} @@ -498,8 +533,6 @@ class TestDistBase(unittest.TestCase): # print log sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err) sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err) - sys.stderr.write('trainer 0 stdout: %s\n' % tr0_out) - sys.stderr.write('trainer 1 stdout: %s\n' % tr1_out) return pickle.loads(tr0_out), pickle.loads(tr1_out) @@ -528,10 +561,14 @@ class TestDistBase(unittest.TestCase): local_losses\ = self._run_local(model_file, required_envs, - check_error_log) + check_error_log) if self._nccl2_mode: - tr0_losses, tr1_losses = self._run_cluster_nccl2( - model_file, required_envs, check_error_log) + if self._nccl2_reduce_layer: + tr0_losses, tr1_losses = self._run_cluster_nccl2( + model_file, required_envs, True, check_error_log) + else: + tr0_losses, tr1_losses = self._run_cluster_nccl2( + model_file, required_envs, False, check_error_log) else: tr0_losses, tr1_losses = self._run_cluster( model_file, required_envs, check_error_log) diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index 49a2ca40e3..030860ec79 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -- GitLab