diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 5c426bc677e1fdba3b34c04cf6b4e390f66c688a..4ea3da949db574e5b41e89f6d632797592182cb8 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -451,6 +451,21 @@ void Reducer::PrepareDeps(const std::unordered_set &init_nodes) { PADDLE_ENFORCE_NOT_NULL( grad_pending_node, platform::errors::NotFound("Grad pending node should not be null")); + // py_layer is not supported in DataParallel + auto begin = grad_pending_node->begin(); + auto end = grad_pending_node->end(); + for (auto op_base = begin; op_base != end; op_base++) { + PADDLE_ENFORCE_EQ( + op_base->Type() != "py_layer", true, + platform::errors::PreconditionNotMet( + "Note: Currently PyLayer is not supported in DataParallel. For " + "using PyLayer in a DataParallel model, you can skip gradient " + "synchronization among multiple cards by 'no_sync', and " + "manually implement 'all_reduce' before model optimization. " + "There is an example showing specific implemetation processing " + "in offical docs: https://www.paddlepaddle.org.cn/documentation" + "/docs/api/paddle/DataParallel_cn.html")); + } ++node_deps_[grad_pending_node.get()]; if (visited.count(grad_pending_node.get()) == 0) { visited.insert(grad_pending_node.get()); diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index 81bed60050de2992d1ad35190477a743a39b8f9f..94dd29f74695d973cdccc4f959bb21d723b72f3e 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -19,6 +19,7 @@ import warnings from paddle import framework import paddle from paddle.fluid import core +import paddle.distributed as dist from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, build_groups from collections import OrderedDict from .log_util import logger @@ -44,8 +45,9 @@ def _apply_collective_grads(parameters, comm_group): for coalesced_grad, _, _ in coalesced_grads_and_vars: # need to div nranks - div_factor = paddle.to_tensor( - comm_group.nranks, dtype=coalesced_grad.dtype) + nranks = dist.get_world_size( + ) if comm_group is None else comm_group.nranks + div_factor = paddle.to_tensor(nranks, dtype=coalesced_grad.dtype) paddle.fluid.framework._dygraph_tracer().trace_op( type="elementwise_div", inputs={'X': coalesced_grad, @@ -115,7 +117,7 @@ def broadcast_dp_parameters(model, hcg): def fused_allreduce_gradients(parameter_list, hcg): - data_parallel_group = hcg.get_data_parallel_group() + data_parallel_group = None if hcg is None else hcg.get_data_parallel_group() logger.debug("dp start fuse allreduce gradients") with framework.no_grad(): _apply_collective_grads(parameter_list, data_parallel_group) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 15f4eece4487c763cfea5c41b88f0ec9f440619f..e4525a8d17992a1e2a1c3250734f9a8181518a49 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -426,8 +426,10 @@ class DataParallel(layers.Layer): Layer: The data paralleled module. Examples: + .. code-block:: python - + :name: dp-example + # required: distributed import paddle import paddle.nn as nn @@ -471,6 +473,72 @@ class DataParallel(layers.Layer): dist.spawn(train, nprocs=2) # 2. start by ``paddle.distributed.launch`` # train() + + + .. note:: + ``PyLayer`` is not supported in DataParallel. To solve problems of this kind, + it's recommended to skip gradient synchronization among multiple cards by 'no_sync', + and manually implement 'all_reduce' before model optimization. There is an example + showing specific implemetation processing. + + Examples: + + .. code-block:: python + :name: dp-pylayer-example + + # required: distributed + import numpy + import paddle + import paddle.distributed as dist + from paddle.autograd import PyLayer + from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + y = paddle.tanh(x) + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + + class SimpleNet(paddle.nn.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.linear = paddle.nn.Linear(2, 2) + + def forward(self, inputs): + inputs = cus_tanh.apply(inputs) + return self.linear(inputs) + + if __name__ == '__main__': + dist.init_parallel_env() + + model = SimpleNet() + model = paddle.DataParallel(model) + opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) + + for step in range(10): + x_data = numpy.random.randn(2, 2).astype(numpy.float32) + x = paddle.to_tensor(x_data) + x.stop_gradient = False + + # step 1 : skip gradient synchronization by 'no_sync' + with model.no_sync(): + y_pred = model(x) + loss = y_pred.mean() + loss.backward() + + # step 2 : fuse + allreduce manually before optimization + fused_allreduce_gradients(list(model.parameters()), None) + + opt.step() + opt.clear_grad() + """ def __init__(self, diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_with_pylayer.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_with_pylayer.py new file mode 100644 index 0000000000000000000000000000000000000000..f623ba36dcab56deeed32ac3f9e9723c530d16ed --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_with_pylayer.py @@ -0,0 +1,123 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import paddle +import numpy as np +import paddle.distributed as dist +from paddle.fluid.dygraph.nn import Linear +from paddle.autograd import PyLayer +from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients + +batch = 5 +in_dim = 20 +out_dim = 10 + + +class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + y = paddle.tanh(x) + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + + +class SimpleNet(paddle.nn.Layer): + def __init__(self, train_id, model_id): + super(SimpleNet, self).__init__() + self.w = self.create_parameter(shape=[in_dim, batch], dtype="float32") + self.linear = paddle.nn.Linear(in_dim, out_dim) + self.tanh = paddle.tanh + + self.trainer_id = train_id + self.model_id = model_id + + def forward(self, inputs): + if self.model_id == 0: + inputs = cus_tanh.apply(inputs) + else: + inputs = self.tanh(inputs) + + inputs = paddle.matmul(self.w, inputs) + return self.linear(inputs) + + +class TestDistTraning(unittest.TestCase): + def test_multiple_gpus(self): + self.trainer_id = dist.get_rank() + dist.init_parallel_env() + + model_a = SimpleNet(self.trainer_id, 0) + model_b = SimpleNet(self.trainer_id, 1) + + state_dict = model_a.state_dict() + model_b.set_state_dict(state_dict) + + model_a = paddle.DataParallel(model_a) + model_b = paddle.DataParallel(model_b) + + for step in range(10): + x_data = np.random.randn(batch, in_dim).astype(np.float32) + x = paddle.to_tensor(x_data) + x.stop_gradient = False + + with model_a.no_sync(): + y_pred_a = model_a(x) + loss_a = y_pred_a.mean() + loss_a.backward() + fused_allreduce_gradients(list(model_a.parameters()), None) + + y_pred_b = model_b(x) + loss_b = y_pred_b.mean() + loss_b.backward() + + self.check_gradient(model_a.parameters()) + self.check_gradient(model_b.parameters()) + + self.check_acc(model_a._layers.w.grad, model_b._layers.w.grad) + + model_a.clear_gradients() + model_b.clear_gradients() + + def check_acc(self, grad, acc_grad): + grad = grad.numpy() if grad is not None else None + acc_grad = acc_grad.numpy() if acc_grad is not None else None + return np.testing.assert_allclose(grad, acc_grad, rtol=1e-6) + + def broadcast_param(self, param, root): + paddle.distributed.broadcast(param, root) + return param + + def check_gradient(self, params): + other_param = [] + for param in params: + if param.trainable and (param._grad_ivar() is not None): + grad = param._grad_ivar() + other_grad = self.broadcast_param(grad.clone(), root=1) + if self.trainer_id == 0: + np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py index d15e55eb0fa1460b60b1b06582ad049875c7e54e..c97cd56e8a7a401758ceb9e1332db368d61dc1ff 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py @@ -130,5 +130,10 @@ class TestDataParallelGradientCheck(TestMultipleGpus): self.run_mnist_2gpu('parallel_dygraph_gradient_check.py') +class TestDataParallelWithPyLayer(TestMultipleGpus): + def test_parallel_dygraph_dataparallel_with_pylayer(self): + self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py') + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_gradient_check.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_gradient_check.py index f3fc13f3eea1bf5fb3367683d48222d757460e4c..d5eebf01adb7cc8c456cb0115373a1fb75109ce0 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_gradient_check.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_gradient_check.py @@ -20,7 +20,7 @@ import paddle.fluid as fluid from test_parallel_dygraph_dataparallel import TestMultipleGpus -class TestModelParallelLayer(TestMultipleGpus): +class TestDataParallelLayer(TestMultipleGpus): def test_parallel_dygraph_dataparallel_no_sync(self): self.run_mnist_2gpu('parallel_dygraph_no_sync_gradient_check.py')