From 0d57ca46ea06257447cc2a82839d64d94fc5e421 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Sat, 10 Feb 2018 23:31:12 +0000 Subject: [PATCH] nccl pass parallel_do test --- paddle/operators/nccl_op.cc | 21 +++++++++- paddle/operators/nccl_op.cu.cc | 8 ++++ paddle/operators/parallel_do_op.cc | 24 ++++++++++- python/paddle/v2/fluid/backward.py | 41 +++++++++++-------- python/paddle/v2/fluid/layers/control_flow.py | 6 ++- .../paddle/v2/fluid/tests/test_parallel_op.py | 33 +++++++++------ 6 files changed, 99 insertions(+), 34 deletions(-) diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index 83ac67f353d..a906223f38c 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/framework/op_registry.h" #include "paddle/operators/nccl/nccl_gpu_common.h" @@ -49,6 +50,22 @@ class NCCLInitOp : public framework::OperatorBase { } }; +class NCCLInitOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto out_var_name = op_desc.Output("Communicator").front(); + auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarDesc::NCCL_COM; + out_var.SetType(var_type); + } +}; + +class NCCLInitOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override {} +}; + class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { public: NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker) @@ -214,7 +231,9 @@ Bcast the tensors. namespace ops = paddle::operators; REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp, - paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker); + paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker, + ops::NCCLInitOpVarTypeInference, + ops::NCCLInitOpShapeInference); REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp, ops::NCCLAllReduceOpMaker); diff --git a/paddle/operators/nccl_op.cu.cc b/paddle/operators/nccl_op.cu.cc index 1b986a13650..b6db63ac6ae 100644 --- a/paddle/operators/nccl_op.cu.cc +++ b/paddle/operators/nccl_op.cu.cc @@ -47,8 +47,11 @@ class NCCLAllReduceKernel : public framework::OpKernel { auto ins = ctx.MultiInput("X"); auto outs = ctx.MultiOutput("Out"); + LOG(INFO) << "------------------"; std::string reduction = ctx.Attr("reduction"); + LOG(INFO) << "------------------"; ncclRedOp_t reduction_op_ = ncclSum; + LOG(INFO) << "------------------"; if (reduction == "ncclMin") { reduction_op_ = ncclMin; @@ -62,14 +65,19 @@ class NCCLAllReduceKernel : public framework::OpKernel { PADDLE_THROW("Invalid reduction. default ncclSum."); } + LOG(INFO) << "------------------"; auto* comm = ctx.Input("Communicator"); + LOG(INFO) << "------------------"; auto stream = ctx.cuda_device_context().stream(); + LOG(INFO) << "------------------"; // device id int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); + LOG(INFO) << "------------------"; int idx = comm->GetCommId(gpu_id); + LOG(INFO) << "------------------"; for (size_t i = 0; i < ins.size(); ++i) { VLOG(1) << "gpu : " << " invoke allreduce. send " << ins[i]->numel() << " recv " diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 89045923f9f..950a95ae360 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -30,6 +30,7 @@ static constexpr char kOutputs[] = "outputs"; static constexpr char kParallelScopes[] = "parallel_scopes"; static constexpr char kParallelBlock[] = "sub_block"; +static constexpr char kUseNCCL[] = "use_nccl"; using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; @@ -159,6 +160,7 @@ class ParallelDoOp : public framework::OperatorBase { } WaitOnPlaces(places); + // PADDLE_ENFORCE_EQ(places.size(), sub_scopes.size()); std::vector> workers; workers.reserve(places.size()); for (size_t place_idx = 0; place_idx < sub_scopes.size(); ++place_idx) { @@ -202,6 +204,8 @@ class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker { AddOutput(kOutputs, "").AsDuplicable(); AddOutput(kParallelScopes, ""); AddAttr(kParallelBlock, ""); + AddAttr(kUseNCCL, "true if we use nccl on backward") + .SetDefault(false); AddComment(R"DOC( ParallelDo Operator. )DOC"); @@ -223,20 +227,22 @@ class ParallelDoGradOp : public framework::OperatorBase { auto &sub_scopes = scope.FindVar(Input(kParallelScopes)) ->Get>(); - auto &places = scope.FindVar(Input(kPlaces))->Get(); + // PADDLE_ENFORCE_EQ(places.size(), sub_scopes.size()); // feed output@grad SplitTensorAndMoveTensorToScopes( scope, const_cast *>(&sub_scopes), places, Inputs(framework::GradVarName(kOutputs))); WaitOnPlaces(places); + LOG(INFO) << "places " << places.size(); // exe run std::vector> workers; for (size_t i = 0; i < sub_scopes.size(); ++i) { auto &place = places[i]; auto *cur_scope = sub_scopes[i]; + LOG(INFO) << place; // execute workers.emplace_back(framework::Async([program, cur_scope, place, block] { @@ -245,12 +251,26 @@ class ParallelDoGradOp : public framework::OperatorBase { false /*create_local_scope*/); })); } + LOG(INFO) << "places " << places.size(); for (auto &worker : workers) { worker.wait(); } WaitOnPlaces(places); - AccumulateGrad(scope, place, sub_scopes, places); + // NCCL allreduce op will be added by backward, + // so no need to explicitly accumulate grad + if (!(Attr(kUseNCCL))) { + AccumulateGrad(scope, place, sub_scopes, places); + } else { + for (auto &place : places) { + PADDLE_ENFORCE(platform::is_gpu_place(place), + "NCCL only supports cuda place"); + } + } + for (auto &s : Outputs(framework::GradVarName(kParameters))) { + CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s)); + } + WaitOnPlaces(places); } void AccumulateGrad(const framework::Scope &scope, diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index 40c54bf220f..28768ef07fc 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -218,7 +218,7 @@ def _callback_lookup_(op): :param op: :return: callback function """ - if op.type == 'parallel_do': + if op.type == 'parallel_do' and op.attr('use_nccl'): param_names = set(op.input('parameters')) param_grad_names = [n + "@GRAD" for n in param_names] @@ -229,18 +229,25 @@ def _callback_lookup_(op): def __call__(self, block, context): if not self.has_inserted_nccl_init: - global_block = block.program.global_block() - op_desc = global_block.desc.append_op() - var_desc = global_block.desc.var('nccl_com') - var_desc.set_type(core.VarDesc.VarType.NCCL_COM) - self.nccl_com = global_block.create_var( - name='nccl_com', type=core.VarDesc.VarType.NCCL_COM) - framework.Operator( - global_block, - type='ncclInit', - desc=op_desc, - inputs={}, - outputs={'Communicator': [self.nccl_com]}) + # global_block = block.program.global_block() + # op_desc = global_block.desc.append_op() + # var_desc = global_block.desc.var('nccl_com__do_not_change_') + # var_desc.set_type(core.VarDesc.VarType.NCCL_COM) + # self.nccl_com = global_block.create_var( + # name='nccl_com', type=core.VarDesc.VarType.NCCL_COM) + # framework.Operator( + # global_block, + # type='ncclInit', + # desc=op_desc, + # inputs={}, + # outputs={'Communicator': [self.nccl_com]}) + op_desc = _create_op_desc_( + "ncclInit", {}, + {"Communicator": ['nccl_com__do_not_change_']}, {}) + # block.desc.append_op().copy_from(op_desc) + print(serialize_op_decs(op_desc)) + block.program.global_block().desc.append_op().copy_from( + op_desc) self.has_inserted_nccl_init = True current_op_desc = context["__current_op_desc__"] @@ -263,7 +270,8 @@ def _callback_lookup_(op): op_desc = _create_op_desc_( "ncclAllReduce", { "X": [o_argu], - "Communicator": ['nccl_com_0'] + "Communicator": + ['nccl_com__do_not_change_'] }, {"Out": [allreduce_out_name]}, {"reduction": "ncclSum"}) block.desc.append_op().copy_from(op_desc) @@ -375,10 +383,11 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): continue grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block) # infer_shape and infer_type - if op_desc.type() == 'ncclInit': - continue op_desc.infer_var_type(block.desc) op_desc.infer_shape(block.desc) + # ncclInit dones't need to set data_type + if op_desc.type() == 'ncclInit': + continue for arg in op_desc.output_arg_names(): if arg in new_vars: _infer_var_data_type_(arg, block) diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index 71a9459d556..5c9c2470661 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -237,12 +237,13 @@ class ParallelDo(object): ParallelDo class is used to create a ParallelDo. """ - def __init__(self, places, name=None): + def __init__(self, places, use_nccl=False, name=None): self.helper = LayerHelper("parallel_do", name=name) self.inputs = [] self.places = places self.outputs = [] self.status = StaticRNN.BEFORE_RNN_BLOCK + self.use_nccl = use_nccl def do(self): return BlockGuardWithCompletion(self) @@ -325,7 +326,8 @@ class ParallelDo(object): }, outputs={'outputs': outputs, 'parallel_scopes': [step_scope]}, - attrs={'sub_block': current_block}) + attrs={'sub_block': current_block, + 'use_nccl': self.use_nccl}) class BlockGuardWithCompletion(BlockGuard): diff --git a/python/paddle/v2/fluid/tests/test_parallel_op.py b/python/paddle/v2/fluid/tests/test_parallel_op.py index 367cc8b1aaf..8452d6835fa 100644 --- a/python/paddle/v2/fluid/tests/test_parallel_op.py +++ b/python/paddle/v2/fluid/tests/test_parallel_op.py @@ -67,12 +67,25 @@ class BaseParallelForTest(unittest.TestCase): fetch=fetch, place=gpu, use_parallel=True) + result_gpu_nccl = self._run_test_impl_( + callback=callback, + feed=feed, + fetch=fetch, + place=gpu, + use_parallel=True, + use_nccl=True) self._assert_same_(fetch, result_cpu, result_cpu_parallel, - result_gpu, result_gpu_parallel) + result_gpu, result_gpu_parallel, result_gpu_nccl) else: self._assert_same_(fetch, result_cpu, result_cpu_parallel) - def _run_test_impl_(self, callback, feed, fetch, place, use_parallel=False): + def _run_test_impl_(self, + callback, + feed, + fetch, + place, + use_parallel=False, + use_nccl=False): """ Run a single test, returns the fetch values Args: @@ -96,7 +109,7 @@ class BaseParallelForTest(unittest.TestCase): # Automatically insert parallel do if use_parallel = True if use_parallel: places = fluid.layers.get_places() - pd = fluid.layers.ParallelDo(places) + pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl) data = next(generator) if isinstance(data, fluid.Variable): @@ -137,7 +150,9 @@ class BaseParallelForTest(unittest.TestCase): """ def _impl_(a, b, fetch_id, item_id): - item_str = ['CPU', 'ParallelCPU', 'GPU', 'ParallelGPU'] + item_str = [ + 'CPU', 'ParallelCPU', 'GPU', 'ParallelGPU', 'ParallelGPUNCCL' + ] flag = numpy.allclose(a, b, rtol=0.1) self.assertTrue(flag, "The {0} are different in {1}".format( fetch[fetch_id], item_str[item_id])) @@ -157,18 +172,10 @@ class ParallelOpTest(BaseParallelForTest): loss = fluid.layers.mean(x=hidden) yield loss - def test_simple_fc(self): - self.run_test( - callback=self.__network__, - feed={ - 'img': numpy.random.random(size=(51, 784)).astype('float32') - }, - fetch=['fc1.w@GRAD']) - def test_fc_with_tiny_data(self): self.run_test( callback=self.__network__, - feed={'img': numpy.random.random(size=(1, 784)).astype('float32')}, + feed={'img': numpy.random.random(size=(8, 784)).astype('float32')}, fetch=['fc1.w@GRAD']) -- GitLab