From dec53a9c795f77095d36c2de195b2d59e7231281 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 29 Sep 2020 15:47:37 +0800 Subject: [PATCH] Remove DataParallel.scale_loss & apply_collective_grads (#27603) * remove data parallel scale loss & apply collective_grads * move apply in minimize * fix failed unittests --- python/paddle/fluid/dygraph/parallel.py | 339 +++++++----------- .../fluid/dygraph/varbase_patch_methods.py | 12 +- python/paddle/fluid/optimizer.py | 12 +- .../fluid/tests/unittests/test_dist_base.py | 16 - ...test_imperative_parallel_coalesce_split.py | 7 +- python/paddle/optimizer/adam.py | 7 +- python/paddle/optimizer/adamw.py | 11 +- python/paddle/optimizer/optimizer.py | 18 +- 8 files changed, 178 insertions(+), 244 deletions(-) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 30918113be..d810709e67 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -217,6 +217,121 @@ class ParallelEnv(object): Env = ParallelEnv +def _build_default_parallel_strategy(): + strategy = ParallelStrategy() + strategy.nranks = ParallelEnv().nranks + strategy.local_rank = ParallelEnv().local_rank + strategy.trainer_endpoints = ParallelEnv().trainer_endpoints + strategy.current_endpoint = ParallelEnv().current_endpoint + return strategy + + +def _coalesce_tensors(var_groups): + from ..layers import nn + coalesced_grads_and_grad_vars = [] + for group_id, grad_vars in var_groups.items(): + flattened_vars = [] + g_var_shapes = [] + for g_var in grad_vars: + g_var_shapes.append(g_var.shape) + flattened_vars.append( + nn.reshape( + x=g_var, shape=[np.prod(g_var.shape)])) + coalesced_grad = nn.concat(flattened_vars) + coalesced_grads_and_grad_vars.append( + [coalesced_grad, grad_vars, g_var_shapes]) + return coalesced_grads_and_grad_vars + + +@framework.dygraph_only +def _reshape_inplace(x, shape): + x_shape = framework._varbase_creator(dtype=x.dtype) + framework._dygraph_tracer().trace_op( + type="reshape2", + inputs={'X': x}, + outputs={'Out': x, + 'XShape': x_shape}, + attrs={'shape': shape}) + + +@framework.dygraph_only +def _split_tensors(coalesced_grads_and_grad_vars): + for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars: + grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes] + framework._dygraph_tracer().trace_op( + type='split', + inputs={'X': coalesced_grad}, + outputs={'Out': origin_grad_vars}, + attrs={'sections': grad_var_len, + 'axis': 0}) + for g_var, g_shape in zip(origin_grad_vars, grad_shapes): + _reshape_inplace(x=g_var, shape=g_shape) + assert g_var.shape == g_shape + + +def scale_loss(loss): + if not ParallelEnv().world_size > 1: + return loss + + loss_scale = to_variable( + np.array([ParallelEnv().world_size]).astype("float32")) + loss_scale.stop_gradient = True + scaled_loss = loss / loss_scale + return scaled_loss + + +@no_grad +def apply_collective_grads(parameters): + if not ParallelEnv().world_size > 1: + return + + grad_var_set = set() + grad_vars = [] + sparse_grad_vars = [] + strategy = _build_default_parallel_strategy() + for param in parameters: + # NOTE(zcd): The grad_ivar maybe no generated. + if param.trainable and (param._grad_ivar() is not None): + g_var = param._grad_ivar() + if g_var._is_sparse(): + sparse_grad_vars.append(g_var) + continue + grad_vars.append(g_var) + assert g_var not in grad_var_set + grad_var_set.add(g_var) + + if sparse_grad_vars: + sparse_grad_vars.sort(key=lambda x: x.name) + for grad_var in sparse_grad_vars: + grad_var._allreduce(strategy) + + # FIXME(zcd): the type of the var should be LoDTensor, i.e + # the gradients should be dense, otherwise, the following + # logic should be updated. + # 128 MB as a group + mega_bytes = 128 * 1024 * 1024 + group_idx = 0 + memory_counter = 0 + grad_var_groups = OrderedDict() + dtype = grad_vars[0].dtype + for g_var in grad_vars: + # NOTE: the dtype of the same group should be the same. + bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype) + if memory_counter < mega_bytes and dtype == g_var.dtype: + memory_counter += bytes + else: + memory_counter = bytes + group_idx += 1 + grad_var_groups.setdefault(group_idx, []).append(g_var) + + coalesced_grads_and_vars = _coalesce_tensors(grad_var_groups) + + for coalesced_grad, _, _ in coalesced_grads_and_vars: + coalesced_grad._allreduce(strategy) + + _split_tensors(coalesced_grads_and_vars) + + class DataParallel(layers.Layer): """ Run the dygraph module with data parallelism. @@ -309,232 +424,28 @@ class DataParallel(layers.Layer): if strategy is not None: self._strategy = strategy else: - self._strategy = ParallelStrategy() - self._strategy.nranks = ParallelEnv().nranks - self._strategy.local_rank = ParallelEnv().local_rank - self._strategy.trainer_endpoints = ParallelEnv().trainer_endpoints - self._strategy.current_endpoint = ParallelEnv().current_endpoint + self._strategy = _build_default_parallel_strategy() def forward(self, *inputs, **kwargs): return self._layers(*inputs, **kwargs) + @deprecated( + since="2.0.0", reason="This method does not need to be called anymore.") def scale_loss(self, loss): """ - Scale the loss. In data parallel mode, the loss should be scale with - the number of trainers. If not in data parallel mode, return the loss - directly. - - Args: - loss(Variable): The loss of the current Model. - - Returns: - Variable: the scaled loss. - - Examples: - .. code-block:: python - - import paddle - import paddle.nn as nn - import paddle.optimizer as opt - import paddle.distributed as dist - - class LinearNet(nn.Layer): - def __init__(self): - super(LinearNet, self).__init__() - self._linear1 = nn.Linear(10, 10) - self._linear2 = nn.Linear(10, 1) - - def forward(self, x): - return self._linear2(self._linear1(x)) - - def train(): - # 1. enable dynamic mode - paddle.disable_static() - - # 2. initialize parallel environment - dist.init_parallel_env() - - # 3. create data parallel layer & optimizer - layer = LinearNet() - dp_layer = paddle.DataParallel(layer) - - loss_fn = nn.MSELoss() - adam = opt.Adam( - learning_rate=0.001, parameters=dp_layer.parameters()) - - # 4. run layer - inputs = paddle.randn([10, 10], 'float32') - outputs = dp_layer(inputs) - labels = paddle.randn([10, 1], 'float32') - loss = loss_fn(outputs, labels) - - loss = dp_layer.scale_loss(loss) - loss.backward() - dp_layer.apply_collective_grads() - - adam.step() - adam.clear_grad() - - if __name__ == '__main__': - # 1. start by ``paddle.distributed.spawn`` (default) - dist.spawn(train, nprocs=2) - # 2. start by ``paddle.distributed.launch`` - # train() + Deprecated method, now ``scale_loss`` is an empty method, + keep this method just for compatibility. """ - if not self._is_data_parallel_mode(): - return loss - - loss_scale = to_variable( - np.array([self._strategy.nranks]).astype("float32")) - loss_scale.stop_gradient = True - loss = loss / loss_scale return loss - def _coalesce_tensors(self, var_groups): - from ..layers import nn - coalesced_grads_and_grad_vars = [] - for group_id, grad_vars in var_groups.items(): - flattened_vars = [] - g_var_shapes = [] - for g_var in grad_vars: - g_var_shapes.append(g_var.shape) - flattened_vars.append( - nn.reshape( - x=g_var, shape=[np.prod(g_var.shape)])) - coalesced_grad = nn.concat(flattened_vars) - coalesced_grads_and_grad_vars.append( - [coalesced_grad, grad_vars, g_var_shapes]) - return coalesced_grads_and_grad_vars - - def _reshape_inplace(self, x, shape): - x_shape = self._helper.create_variable_for_type_inference(dtype=x.dtype) - self._helper.append_op( - type="reshape2", - inputs={'X': x}, - attrs={'shape': shape}, - outputs={'Out': x, - 'XShape': x_shape}) - - def _split_tensors(self, coalesced_grads_and_grad_vars): - from ..layers import nn - for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars: - grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes] - self._helper.main_program.current_block().append_op( - type='split', - inputs={'X': coalesced_grad}, - outputs={'Out': origin_grad_vars}, - attrs={'sections': grad_var_len, - 'axis': 0}) - for g_var, g_shape in zip(origin_grad_vars, grad_shapes): - self._reshape_inplace(x=g_var, shape=g_shape) - assert g_var.shape == g_shape - - @no_grad + @deprecated( + since="2.0.0", reason="This method does not need to be called anymore.") def apply_collective_grads(self): """ - AllReduce the Parameters' gradient. - - Examples: - .. code-block:: python - - import paddle - import paddle.nn as nn - import paddle.optimizer as opt - import paddle.distributed as dist - - class LinearNet(nn.Layer): - def __init__(self): - super(LinearNet, self).__init__() - self._linear1 = nn.Linear(10, 10) - self._linear2 = nn.Linear(10, 1) - - def forward(self, x): - return self._linear2(self._linear1(x)) - - def train(): - # 1. enable dynamic mode - paddle.disable_static() - - # 2. initialize parallel environment - dist.init_parallel_env() - - # 3. create data parallel layer & optimizer - layer = LinearNet() - dp_layer = paddle.DataParallel(layer) - - loss_fn = nn.MSELoss() - adam = opt.Adam( - learning_rate=0.001, parameters=dp_layer.parameters()) - - # 4. run layer - inputs = paddle.randn([10, 10], 'float32') - outputs = dp_layer(inputs) - labels = paddle.randn([10, 1], 'float32') - loss = loss_fn(outputs, labels) - - loss = dp_layer.scale_loss(loss) - loss.backward() - dp_layer.apply_collective_grads() - - adam.step() - adam.clear_grad() - - if __name__ == '__main__': - # 1. start by ``paddle.distributed.spawn`` (default) - dist.spawn(train, nprocs=2) - # 2. start by ``paddle.distributed.launch`` - # train() + Deprecated method, now ``apply_collective_grads`` is an empty method, + keep this method just for compatibility. """ - if not self._is_data_parallel_mode(): - return - - grad_var_set = set() - grad_vars = [] - sparse_grad_vars = [] - for param in self._layers.parameters(): - # NOTE(zcd): The grad_ivar maybe no generated. - if param.trainable and (param._grad_ivar() is not None): - g_var = param._grad_ivar() - if g_var._is_sparse(): - sparse_grad_vars.append(g_var) - continue - grad_vars.append(g_var) - assert g_var not in grad_var_set - grad_var_set.add(g_var) - - if sparse_grad_vars: - sparse_grad_vars.sort(key=lambda x: x.name) - for grad_var in sparse_grad_vars: - grad_var._allreduce(self._strategy) - - # FIXME(zcd): the type of the var should be LoDTensor, i.e - # the gradients should be dense, otherwise, the following - # logic should be updated. - # 128 MB as a group - mega_bytes = 128 * 1024 * 1024 - group_idx = 0 - memory_counter = 0 - grad_var_groups = OrderedDict() - dtype = grad_vars[0].dtype - for g_var in grad_vars: - # Note: the dtype of the same group should be the same. - bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype) - if memory_counter < mega_bytes and dtype == g_var.dtype: - memory_counter += bytes - else: - memory_counter = bytes - group_idx += 1 - grad_var_groups.setdefault(group_idx, []).append(g_var) - - coalesced_grads_and_vars = self._coalesce_tensors(grad_var_groups) - - for coalesced_grad, _, _ in coalesced_grads_and_vars: - coalesced_grad._allreduce(self._strategy) - - self._split_tensors(coalesced_grads_and_vars) - - def _is_data_parallel_mode(self): - return self._strategy.nranks > 1 + return def state_dict(self, destination=None, diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 7cb1784339..6ac13923a2 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -13,12 +13,15 @@ # limitations under the License. import inspect +import numpy as np + +import paddle from .. import framework from .. import core from ..framework import Variable, Parameter, ParamBase from .base import switch_to_static_graph -import numpy as np from .math_op_patch import monkey_patch_math_varbase +from .parallel import scale_loss def monkey_patch_varbase(): @@ -165,7 +168,12 @@ def monkey_patch_varbase(): """ if framework.in_dygraph_mode(): - self._run_backward(framework._dygraph_tracer(), retain_graph) + if paddle.distributed.get_world_size() > 1: + scaled_loss = scale_loss(self) + scaled_loss._run_backward(framework._dygraph_tracer(), + retain_graph) + else: + self._run_backward(framework._dygraph_tracer(), retain_graph) else: raise ValueError( "Variable.backward() is only available in DyGraph mode") diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 0dd1694c86..761f6409fe 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -19,8 +19,10 @@ import six import logging from collections import defaultdict +import paddle from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard +from paddle.fluid.dygraph.parallel import apply_collective_grads from . import framework from . import layers @@ -40,7 +42,6 @@ from paddle.fluid.layers import tensor from functools import reduce from .wrapped_decorator import signature_safe_contextmanager from .. import compat as cpt -import paddle __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'Dpsgd', 'DecayedAdagrad', @@ -771,8 +772,14 @@ class Optimizer(object): self._dtype = loss.dtype if framework.in_dygraph_mode(): + parameter_list = parameter_list if parameter_list \ + else self._parameter_list + + if paddle.distributed.get_world_size() > 1: + apply_collective_grads(parameter_list) + params_grads = [] - for param in self._parameter_list: + for param in parameter_list: if not param.trainable: continue if param._grad_ivar() is not None: @@ -939,6 +946,7 @@ class Optimizer(object): parameter_list = parameter_list if parameter_list \ else self._parameter_list + params_grads = self.backward( loss, startup_program=startup_program, diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 166a44fb2d..10e154044f 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -435,13 +435,7 @@ class TestParallelDyGraphRunnerBase(object): "loss at step %d: %f" % (step_id, loss.numpy())) out_losses.append(loss.numpy()) - # FIXME(Yancey1989): scale the loss inplace - if args.update_method == "nccl2": - loss = model.scale_loss(loss) - loss.backward() - if args.update_method == "nccl2": - model.apply_collective_grads() opt.minimize(loss) model.clear_gradients() @@ -477,12 +471,7 @@ class TestParallelDyGraphRunnerBase(object): loss = self.run_one_loop(model, opt, data) out_losses.append(loss.numpy()) - if args.update_method == "nccl2": - loss = model.scale_loss(loss) - loss.backward() - if args.update_method == "nccl2": - model.apply_collective_grads() opt.minimize(loss) model.clear_gradients() @@ -521,12 +510,7 @@ class TestParallelDyGraphRunnerBase(object): loss = self.run_one_loop(model, opt, data) out_losses.append(loss.numpy()) - if args.update_method == "nccl2": - loss = model.scale_loss(loss) - loss.backward() - if args.update_method == "nccl2": - model.apply_collective_grads() opt.step() opt.clear_grad() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_parallel_coalesce_split.py b/python/paddle/fluid/tests/unittests/test_imperative_parallel_coalesce_split.py index e5c32d0003..480df7482e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_parallel_coalesce_split.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_parallel_coalesce_split.py @@ -22,6 +22,7 @@ import paddle.fluid as fluid from paddle.fluid import core from paddle.fluid.dygraph.parallel import DataParallel from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.dygraph.parallel import _coalesce_tensors, _split_tensors, _reshape_inplace class MyLayer(fluid.Layer): @@ -57,8 +58,8 @@ class TestImperativeParallelCoalesceSplit(unittest.TestCase): orig_var_shapes.append(var.shape) # execute interface - coalesced_vars = test_layer._coalesce_tensors(var_groups) - test_layer._split_tensors(coalesced_vars) + coalesced_vars = _coalesce_tensors(var_groups) + _split_tensors(coalesced_vars) # compare for orig_var_shape, var in zip(orig_var_shapes, vars): @@ -74,7 +75,7 @@ class TestImperativeParallelCoalesceSplit(unittest.TestCase): new_shape = [5, 10] x_data = np.random.random(ori_shape).astype("float32") x = to_variable(x_data) - test_layer._reshape_inplace(x, new_shape) + _reshape_inplace(x, new_shape) self.assertEqual(x.shape, new_shape) diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 24cebf8e6e..9cbb45ce60 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -17,6 +17,9 @@ from ..fluid import core from ..fluid import framework from ..fluid.framework import Variable +import paddle +from paddle.fluid.dygraph.parallel import apply_collective_grads + __all__ = ["Adam"] @@ -276,7 +279,9 @@ class Adam(Optimizer): adam.step() adam.clear_grad() """ - parameter_list = self._parameter_list + if paddle.distributed.get_world_size() > 1: + apply_collective_grads(self._parameter_list) + self._dtype = None params_grads = [] for param in self._parameter_list: diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index edaca7e830..0b04f03eb1 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -16,6 +16,8 @@ from .optimizer import Optimizer from .adam import Adam from ..fluid import framework import paddle +from paddle.fluid.dygraph.parallel import apply_collective_grads + __all__ = ['AdamW'] @@ -184,6 +186,9 @@ class AdamW(Adam): startup_program=None, parameters=None, no_grad_set=None): + parameters = parameters if parameters \ + else self._parameter_list + params_grads = self.backward( loss=loss, startup_program=startup_program, @@ -206,7 +211,9 @@ class AdamW(Adam): @framework.dygraph_only def step(self): - parameter_list = self._parameter_list + if paddle.distributed.get_world_size() > 1: + apply_collective_grads(self._parameter_list) + self._dtype = None params_grads = [] for param in self._parameter_list: @@ -224,7 +231,7 @@ class AdamW(Adam): updated_param = paddle.fluid.layers.elementwise_sub( x=param, y=scaled_param) param.set_value(updated_param.numpy()) - optimize_ops = self._apply_optimize( + self._apply_optimize( loss=None, startup_program=None, params_grads=params_grads) def __str__(self): diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 1bd9a1f144..15519cdd30 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -19,9 +19,10 @@ import six import logging from collections import defaultdict +import paddle from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard -import paddle +from paddle.fluid.dygraph.parallel import apply_collective_grads from ..fluid import framework from ..fluid import layers @@ -675,8 +676,14 @@ class Optimizer(object): self._dtype = loss.dtype if framework.in_dygraph_mode(): + parameter_list = parameters if parameters \ + else self._parameter_list + + if paddle.distributed.get_world_size() > 1: + apply_collective_grads(parameter_list) + params_grads = [] - for param in self._parameter_list: + for param in parameter_list: if not param.trainable: continue if param._grad_ivar() is not None: @@ -871,6 +878,7 @@ class Optimizer(object): parameter_list = parameters if parameters \ else self._parameter_list + params_grads = self.backward( loss, startup_program=startup_program, @@ -907,7 +915,9 @@ class Optimizer(object): adam.step() adam.clear_grad() """ - parameter_list = self._parameter_list + if paddle.distributed.get_world_size() > 1: + apply_collective_grads(self._parameter_list) + self._dtype = None params_grads = [] for param in self._parameter_list: @@ -917,5 +927,5 @@ class Optimizer(object): grad_var = param._grad_ivar() params_grads.append((param, grad_var)) - optimize_ops = self._apply_optimize( + self._apply_optimize( loss=None, startup_program=None, params_grads=params_grads) -- GitLab