From b5f4d5ed0e24f5dae293991bf6c267dc695c8f78 Mon Sep 17 00:00:00 2001 From: chengduo Date: Fri, 24 May 2019 13:20:41 +0800 Subject: [PATCH] Add broadcast operators (#17503) * This PR adds broadcast for multi-process. And it could be used in dynamic graph to broadcast parameters. --- .../operators/distributed_ops/broadcast_op.cc | 76 ++++++++++++ .../distributed_ops/broadcast_op.cu.cc | 81 +++++++++++++ python/paddle/fluid/dygraph/layers.py | 3 + python/paddle/fluid/dygraph/parallel.py | 110 +++++++++++++++--- .../paddle/fluid/dygraph/parallel_helper.py | 43 +++++++ python/paddle/fluid/layers/collective.py | 11 ++ 6 files changed, 309 insertions(+), 15 deletions(-) create mode 100644 paddle/fluid/operators/distributed_ops/broadcast_op.cc create mode 100644 paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc create mode 100644 python/paddle/fluid/dygraph/parallel_helper.py diff --git a/paddle/fluid/operators/distributed_ops/broadcast_op.cc b/paddle/fluid/operators/distributed_ops/broadcast_op.cc new file mode 100644 index 000000000..6ae98af1e --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/broadcast_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class BroadcastOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of BroadcastOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Output) of ConvOp should not be null."); + } +}; + +class BroadcastOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor), tensor to be broadcast."); + AddOutput("Out", "(Tensor) the result of broadcast."); + AddAttr( + "sync_mode", + "(bool) whether to synchronize the CUDA stream after nccl call.") + .SetDefault(false); + AddAttr("root", "(int).").SetDefault(0).EqualGreaterThan(0); + AddComment(R"DOC( +***Broadcast Operator*** + +Call NCCL Broadcast internally. Note that this op must be used when one +thread is managing one GPU device. +)DOC"); + } +}; + +template +class BroadcastOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW("Broadcast op can run on gpu place only for now."); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(broadcast, ops::BroadcastOp, + ops::BroadcastOpMaker); + +REGISTER_OP_CPU_KERNEL(broadcast, ops::BroadcastOpKernel, + ops::BroadcastOpKernel, + ops::BroadcastOpKernel, + ops::BroadcastOpKernel, + ops::BroadcastOpKernel); diff --git a/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc new file mode 100644 index 000000000..c9b40e686 --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc @@ -0,0 +1,81 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +namespace paddle { +namespace operators { + +template +class NCCLBroadcastOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "The place of ExecutionContext should be CUDAPlace."); + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + int dev_id = boost::get(ctx.GetPlace()).device; + int root_dev_id = ctx.Attr("root"); + + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + PADDLE_ENFORCE(out->IsInitialized(), + "Currently, the output of broadcast op must be initialized, " + "because this op can only be an In-Place operation."); + void* send_recv_buffer = out->mutable_data(ctx.GetPlace()); + PADDLE_ENFORCE_EQ( + send_recv_buffer, in->data(), + "Currently, the broadcast op can only be an In-Place operation."); + + auto& dev_ctx = ctx.template device_context(); + auto comm = dev_ctx.nccl_comm(); + auto stream = dev_ctx.stream(); + + PADDLE_ENFORCE(platform::dynload::ncclBcast( + send_recv_buffer, static_cast(in->numel()), + platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream)); + + VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << ", (" << in->numel() << ")" + << " From " << root_dev_id << " to " << dev_id; + + if (ctx.Attr("sync_mode")) { + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + } +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(broadcast, ops::NCCLBroadcastOpKernel, + ops::NCCLBroadcastOpKernel, + ops::NCCLBroadcastOpKernel, + ops::NCCLBroadcastOpKernel, + ops::NCCLBroadcastOpKernel); diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 7ddf94146..54b34919e 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -18,6 +18,7 @@ import sys import numpy as np import collections import six +from . import parallel_helper from .. import unique_name from paddle.fluid import core from .layer_object_helper import LayerObjectHelper @@ -154,6 +155,8 @@ class Layer(core.Layer): def __call__(self, *inputs): if not self._built: self.build_once(*inputs) + if parallel_helper._is_data_parallel_mode(): + parallel_helper._broadcast_parameters(self._parameters.values()) outputs = self.forward(*inputs) self._built = True diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 1378f9140..37716cea1 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -17,8 +17,8 @@ import numpy as np from .. import core from . import layers +from . import parallel_helper from .. import framework - from ..layers import collective from . import to_variable @@ -26,24 +26,29 @@ __all__ = ["prepare_context"] ParallelStrategy = core.ParallelStrategy -__parallel_ctx__clz__ = None - -def prepare_context(parallel_strategy): - global __parallel_ctx__clz__ - assert __parallel_ctx__clz__ is None, "ParallelContext can only be initialized once." - assert framework.in_dygraph_mode( - ) is True, "dygraph.parallel.prepare_context should be used with dygrahp mode." +def prepare_context(strategy=None): + if strategy is None: + strategy = ParallelStrategy() + strategy.nranks = Env().nranks + strategy.local_rank = Env().local_rank + strategy.trainer_endpoints = Env().trainer_endpoints + strategy.current_endpoint = Env().current_endpoint + if strategy.nranks < 2: + return + assert framework.in_dygraph_mode() is True,\ + "dygraph.parallel.prepare_context should be used with dygrahp mode." place = framework._current_expected_place() - assert place is not None, "dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard." - + assert place is not None, \ + "dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard." if isinstance(place, core.CUDAPlace): - __parallel_ctx__clz__ = core.NCCLParallelContext(parallel_strategy, - place) + parallel_helper._set_parallel_ctx( + core.NCCLParallelContext(strategy, place)) else: # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation assert ("Only support CUDAPlace for now.") - __parallel_ctx__clz__.init() + parallel_helper._init_parallel_ctx() + return strategy class Env(object): @@ -77,9 +82,65 @@ class Env(object): class DataParallel(layers.Layer): + """ + Runs the module with data parallelism. + + Currently, DataParallel only supports to run the dynamic graph + with multi-process. The usage is: + `python -m paddle.distributed.launch --gpus 2 dynamic_graph_test.py`. + And the content of `dynamic_graph_test.py` is the code of examples. + + Examples: + .. code-block:: python + + import numpy as np + import paddle.fluid as fluid + import paddle.fluid.dygraph as dygraph + from paddle.fluid.optimizer import AdamOptimizer + from paddle.fluid.dygraph.nn import FC + from paddle.fluid.dygraph.base import to_variable + + place = fluid.CUDAPlace(0) + with fluid.dygraph.guard(place=place): + + # prepare the data parallel context + strategy=dygraph.parallel.prepare_context() + + fc_layer = FC("FC", 10, act="softmax") + adam = fluid.optimizer.AdamOptimizer() + + # make the module become the data parallelism module + fc_layer = dygraph.parallel.DataParallel(fc_layer, strategy) + + x_data = np.random.random(size=[10, 1]).astype(np.float32) + data = to_variable(x_data) + + hidden = fc_layer(data) + avg_loss = fluid.layers.mean(hidden) + + # scale the loss according to the number of trainers. + avg_loss = fc_layer.scale_loss(avg_loss) + + avg_loss.backward() + + # collect the gradients of trainers. + fc_layer.apply_collective_grads() + + adam.minimize(avg_loss) + fc_layer.clear_gradients() + + Args: + layers(Layer): The module that should be executed by data parallel. + strategy(ParallelStrategy): The strategy of data parallelism. + + Returns: + Layer: The data paralleled module. + """ + def __init__(self, layers, strategy): super(DataParallel, self).__init__(layers.full_name() + "_data_parallel") + self._layers = layers self._strategy = strategy @@ -87,8 +148,20 @@ class DataParallel(layers.Layer): return self._layers(*inputs, **kwargs) def scale_loss(self, loss): - if self._strategy.nranks < 2: + """ + Scale the loss. In data parallel mode, the loss should be scale with + the number of trainers. If not in data parallel mode, return the loss + directly. + + Args: + loss(Layer): The loss of the current Model. + + Returns: + Layer: the scaled loss. + """ + if not self._is_data_parallel_mode(): return loss + loss_scale = to_variable( np.array([self._strategy.nranks]).astype("float32")) loss_scale.stop_gradient = True @@ -96,10 +169,14 @@ class DataParallel(layers.Layer): return loss def apply_collective_grads(self): - if self._strategy.nranks < 2: + """ + AllReduce the Parameters' gradient. + """ + if not self._is_data_parallel_mode(): return for param in self._layers.parameters(): + # NOTE(zcd): The grad_ivar maybe no generated. if param.trainable and param._ivar._grad_ivar(): g_var = framework.Variable( block=self._helper.main_program.current_block(), @@ -107,3 +184,6 @@ class DataParallel(layers.Layer): stop_gradient=True, ivar=param._ivar._grad_ivar()) collective._allreduce(g_var, g_var, sync_mode=True) + + def _is_data_parallel_mode(self): + return self._strategy.nranks > 1 diff --git a/python/paddle/fluid/dygraph/parallel_helper.py b/python/paddle/fluid/dygraph/parallel_helper.py new file mode 100644 index 000000000..7932c327e --- /dev/null +++ b/python/paddle/fluid/dygraph/parallel_helper.py @@ -0,0 +1,43 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except jin compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from ..layers import collective + +__parallel_ctx__clz__ = None + + +def _is_data_parallel_mode(): + global __parallel_ctx__clz__ + return __parallel_ctx__clz__ is not None and int( + os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1 + + +def _set_parallel_ctx(nccl_parallel_context): + global __parallel_ctx__clz__ + assert __parallel_ctx__clz__ is None, \ + "ParallelContext can only be initialized once." + __parallel_ctx__clz__ = nccl_parallel_context + + +def _init_parallel_ctx(): + global __parallel_ctx__clz__ + assert __parallel_ctx__clz__ is not None, \ + "ParallelContext should be initialized." + __parallel_ctx__clz__.init() + + +def _broadcast_parameters(parameters): + for param in parameters: + if param.trainable: + collective._broadcast(param, 0, sync_mode=True) diff --git a/python/paddle/fluid/layers/collective.py b/python/paddle/fluid/layers/collective.py index 97c290f5a..4fa0d1eb2 100644 --- a/python/paddle/fluid/layers/collective.py +++ b/python/paddle/fluid/layers/collective.py @@ -46,3 +46,14 @@ def _allreduce(x, out=None, reduce_type="sum", sync_mode=False): attrs={"reduce_type": red_typ_int, "sync_mode": sync_mode}) return out + + +def _broadcast(x, root, sync_mode=False): + helper = LayerHelper("broadcast", **locals()) + helper.append_op( + type='broadcast', + inputs={'X': [x]}, + outputs={'Out': [x]}, + attrs={"sync_mode": sync_mode, + "root": root}) + return x -- GitLab