diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 784004269d797bc847f7e1b2f4d3f0787f897174..6a96a102e1433ddc6eb9e74128edf824a75cf36c 100644 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -74,3 +74,4 @@ state_dict = fleet.state_dict set_state_dict = fleet.set_state_dict shrink = fleet.shrink get_hybrid_communicate_group = fleet.get_hybrid_communicate_group +distributed_scaler = fleet.distributed_scaler diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 5e17794dfeac1255c05eefc280bdd07943690418..2a4977847b1821417ab6ed8cecfa5de5611b7470 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -30,6 +30,7 @@ from . import topology as tp from .topology import ParallelMode from ..meta_parallel import ModelParallel from ..meta_optimizers import HybridParallelOptimizer +from ..meta_optimizers import HybridParallelGradScaler def _inited_runtime_handler_(func): @@ -1333,3 +1334,7 @@ class Fleet(object): fleet.util._set_strategy(context["valid_strategy"]) return optimize_ops, params_grads + + @dygraph_only + def distributed_scaler(self, scaler): + return HybridParallelGradScaler(scaler, self._hcg) diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index d26dee331ccf5de2e015a1ab1eb6a9f5803f97dc..da0fe8aee60e8b3fa82b8165ed0a0f2234e9c3a9 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -19,6 +19,8 @@ import collections import numpy as np from itertools import product from functools import reduce +from ..utils.log_util import logger + __all__ = ['CommunicateTopology', 'HybridCommunicateGroup'] _HYBRID_PARALLEL_GROUP = None @@ -129,12 +131,17 @@ class HybridCommunicateGroup(object): # create comm group for model parallel self._mp_group, self._mp_comm_group = self._set_comm_group("model") + + # create global group for check inf_nan / clip global norm + self._check_group, self._check_comm_group = self._set_check_group( + "data") + debug_str = "HybridParallelInfo: rank_id: %d, dp_degree: %d, " \ "mp_degree: %d, pp_degree: %d\n" % (self.global_rank, self._dp_degree, self._mp_degree,self._pp_degree) - debug_str += "dp_group: %s, mp_group: %s" % (self._dp_group, - self._mp_group) - print(debug_str, file=sys.stderr) + debug_str += "dp_group: %s, mp_group: %s, check/clip group: %s" % ( + self._dp_group, self._mp_group, self._check_group) + logger.info(debug_str) global _HYBRID_PARALLEL_GROUP _HYBRID_PARALLEL_GROUP = self @@ -168,6 +175,22 @@ class HybridCommunicateGroup(object): return parallel_group, parallel_comm_group + def _set_check_group(self, parallel_method="data"): + parallel_group = [] + parallel_comm_group = None + parallel_size = self._topo.get_dim(parallel_method) + for idx in range(parallel_size): + parallel_groups = self._topo.get_axis_list(parallel_method, idx) + comm_group = paddle.distributed.new_group(ranks=parallel_groups) + if self.global_rank in parallel_groups: + parallel_group = parallel_groups + parallel_comm_group = comm_group + + assert len(parallel_group) > 0 + assert parallel_comm_group is not None + + return parallel_group, parallel_comm_group + def topology(self): return self._topo @@ -205,3 +228,7 @@ class HybridCommunicateGroup(object): def get_model_parallel_group_src_rank(self): return self._mp_comm_group.ranks[0] + + # check parallel group + def get_check_parallel_group(self): + return self._check_comm_group diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index 8dd57c87ef896caaae494d28ec593cdc9f60a51b..3be8a479491dc04042e2d333163ff38a0ac67f3b 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -26,3 +26,4 @@ from .lamb_optimizer import LambOptimizer from .fp16_allreduce_optimizer import FP16AllReduceOptimizer from .sharding_optimizer import ShardingOptimizer from .dygraph_optimizer import HybridParallelOptimizer +from .dygraph_optimizer import HybridParallelGradScaler diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py index a2a3bb8d17201c62305938aef8b23e39500ba21a..4e41723cb622dce26f511fe5dc051a59b5f3eb7a 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py @@ -11,3 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and from .hybrid_parallel_optimizer import HybridParallelOptimizer +from .hybrid_parallel_gradscaler import HybridParallelGradScaler diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..97b0e7306c13d5b28a11f32440362615a921d394 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py @@ -0,0 +1,77 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +from paddle.optimizer import Optimizer +from ...base.topology import ParallelMode +from paddle.fluid.dygraph import base as imperative_base +from paddle.fluid import framework +from paddle.fluid.framework import Variable +import types +from paddle.fluid import core +import paddle + + +class HybridParallelGradScaler: + def __init__(self, scaler, hcg): + self._scaler = scaler + self._hcg = hcg + self._is_mp = ( + self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL) + + def scale(self, var): + return self._scaler.scale(var) + + def minimize(self, optimizer, *args, **kwargs): + if not self._enable: + return optimizer.minimize(*args, **kwargs) + + # unscale the grad + self._unscale(optimizer) + + optimize_ops, params_grads = (None, None) + + if self._found_inf: + self._cache_founf_inf = True + else: + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = False + + if self._use_dynamic_loss_scaling: + self._update() + + return optimize_ops, params_grads + + @imperative_base.no_grad + def _unscale(self, optimizer): + if not self._enable: + return + param_grads = [ + param._grad_ivar() for param in optimizer._parameter_list + if param._grad_ivar() is not None + ] + core.ops.check_finite_and_unscale(param_grads, self._scale, param_grads, + self._found_inf) + # allreduce_max found_inf in check_group + if self._is_mp: + self._found_inf = paddle.cast(self._found_inf, dtype="int64") + paddle.distributed.all_reduce( + self._found_inf, + op=paddle.distributed.ReduceOp.MAX, + group=self._hcg.get_check_parallel_group()) + self._found_inf = paddle.cast(self._found_inf, dtype="bool") + + def __getattr__(self, item): + return getattr(self._scaler, item) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index b1cf98b4b1d2fde34aff21753444ab89bd5ed225..52e87173684a342bb9d7c15e22012e307d474a98 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -12,15 +12,77 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function +import sys from paddle.optimizer import Optimizer +from paddle.fluid.clip import ClipGradByGlobalNorm from ...utils.hybrid_parallel_util import fused_allreduce_gradients from ...base.topology import ParallelMode from paddle.fluid.dygraph import base as imperative_base from paddle.fluid import framework from paddle.fluid.framework import Variable +from ...utils.log_util import logger + + +class HybridParallelClipGrad: + def __init__(self, clip, hcg): + self._clip = clip + self._hcg = hcg + + @imperative_base.no_grad + def _dygraph_clip(self, params_grads): + params_and_grads = [] + sum_square_list = [] + for p, g in params_grads: + if g is None: + continue + if getattr(p, 'need_clip', True) is False: + continue + merge_grad = g + if g.type == core.VarDesc.VarType.SELECTED_ROWS: + merge_grad = layers.merge_selected_rows(g) + merge_grad = layers.get_tensor_from_selected_rows(merge_grad) + square = layers.square(merge_grad) + sum_square = layers.reduce_sum(square) + sum_square_list.append(sum_square) + + # all parameters have been filterd out + if len(sum_square_list) == 0: + return params_grads + + global_norm_var = layers.concat(sum_square_list) + global_norm_var = layers.reduce_sum(global_norm_var) + # add all reduce to get global norm in world size + paddle.distributed.all_reduce(global_norm_var, + self._hcg.get_check_parallel_group()) + global_norm_var = layers.sqrt(global_norm_var) + + max_global_norm = layers.fill_constant( + shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) + clip_var = layers.elementwise_div( + x=max_global_norm, + y=layers.elementwise_max( + x=global_norm_var, y=max_global_norm)) + for p, g in params_grads: + if g is None: + continue + if getattr(p, 'need_clip', True) is False: + params_and_grads.append((p, g)) + continue + new_grad = layers.elementwise_mul(x=g, y=clip_var) + params_and_grads.append((p, new_grad)) + + return params_and_grads + + def __getattr__(self, item): + return getattr(self._clip, item) + + def __call__(self, params_grads): + return self._clip(params_grads) class HybridParallelOptimizer: + # adapter wrapper for optimizer def __init__(self, optimizer, hcg, strategy): self._inner_opt = optimizer self._strategy = strategy @@ -29,6 +91,13 @@ class HybridParallelOptimizer: self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL) self._need_dp = (self._hcg.get_data_parallel_world_size() > 1) + if isinstance(self._inner_opt._grad_clip, + ClipGradByGlobalNorm) and self._is_mp: + logger.warning("using ClipGradByGlobalNorm in ModelParallel, the origin " \ + "optmizer'grad clip will be changed.") + self._inner_opt._grad_clip = HybridParallelClipGrad( + self._inner_opt._grad_clip, hcg) + @imperative_base.no_grad @framework.dygraph_only def step(self): diff --git a/python/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py b/python/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py index 5cf1242a37ad52582a89d9f860ca521c2bbff61a..6c8bf68fd1fb631a8b8b713f755635eb6dd07309 100644 --- a/python/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py +++ b/python/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py @@ -13,7 +13,6 @@ # limitations under the License. from paddle.fluid.dygraph.layers import Layer -import logging class MetaParallelBase(Layer): diff --git a/python/paddle/distributed/fleet/meta_parallel/model_parallel.py b/python/paddle/distributed/fleet/meta_parallel/model_parallel.py index 62f5266250f60dcddc0198988ff008a3fe74c06b..ebf26498d93243168b874323a38171b7db3df3be 100644 --- a/python/paddle/distributed/fleet/meta_parallel/model_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/model_parallel.py @@ -15,6 +15,7 @@ from paddle.fluid.dygraph.layers import Layer from .meta_parallel_base import MetaParallelBase from ..utils.hybrid_parallel_util import * +from ..utils.log_util import logger class ModelParallel(MetaParallelBase): @@ -22,8 +23,14 @@ class ModelParallel(MetaParallelBase): super(ModelParallel, self).__init__(layers, hcg, **kwargs) def _prepare_for_model(self): + logger.info("start broadcast mp parameters") broadcast_mp_parameters(self._layers, self._hcg) + + logger.info("start broadcast mp parameters") broadcast_dp_parameters(self._layers, self._hcg) + logger.info("mp's parameters is ready") + def _pre_forward(self, *inputs, **kwargs): + logger.debug("mp start broadcast input data") return broadcast_input_data(self._hcg, *inputs, **kwargs) diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index a866d5be64891780e062d4dd1f5034972b8f24c0..1f4222d478cd90ffeddd2de8f84a4ea21993c3b6 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -19,8 +19,9 @@ import warnings from paddle import framework import paddle from paddle.fluid import core -from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, construct_groups +from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, build_groups from collections import OrderedDict +from .log_util import logger def _apply_collective_grads(parameters, comm_group): @@ -37,7 +38,7 @@ def _apply_collective_grads(parameters, comm_group): assert g_var not in grad_var_set grad_var_set.add(g_var) - coalesced_grads_and_vars = construct_groups(grad_vars, 128 * 1024 * 1024) + coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024) for coalesced_grad, _, _ in coalesced_grads_and_vars: # need to div nranks @@ -60,7 +61,7 @@ def broadcast_input_data(hcg, *inputs, **kwargs): group=model_parallel_group, use_calc_stream=True) else: - print("it doesn't support data type {}".format(type(input_))) + logger.error("it doesn't support data type {}".format(type(input_))) for k, v in kwargs.items(): if isinstance(v, core.VarBase): @@ -72,7 +73,7 @@ def broadcast_input_data(hcg, *inputs, **kwargs): use_calc_stream=True) kwargs[k] = v else: - print("it doesn't support data type {}".format(type(v))) + logger.error("it doesn't support data type {}".format(type(v))) return inputs, kwargs @@ -92,5 +93,6 @@ def broadcast_dp_parameters(model, hcg): def fused_allreduce_gradients(parameter_list, hcg): data_parallel_group = hcg.get_data_parallel_group() + logger.debug("dp start fuse allreduce gradients") with framework.no_grad(): _apply_collective_grads(parameter_list, data_parallel_group) diff --git a/python/paddle/distributed/fleet/utils/log_util.py b/python/paddle/distributed/fleet/utils/log_util.py new file mode 100644 index 0000000000000000000000000000000000000000..906891961dd455f7dae41ad03432a2bc385da240 --- /dev/null +++ b/python/paddle/distributed/fleet/utils/log_util.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys + + +class LoggerFactory: + @staticmethod + def build_logger(name=None, level=logging.INFO): + assert name is not None, "name for logger should not be None" + + formatter = logging.Formatter( + "%(asctime)s-%(levelname)s: " + "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s") + + _logger = logging.getLogger(name) + _logger.setLevel(level) + _logger.propagate = False + handler = logging.StreamHandler(stream=sys.stderr) + handler.setFormatter(formatter) + handler.setLevel(level) + _logger.addHandler(handler) + return _logger + + +logger = LoggerFactory.build_logger(name="HybridParallel", level=logging.INFO) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index d06a5c890feeb824db9003c5804c9931252448c6..ca5e5606e432b00f4a04206f5db605b25cded0c0 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -323,7 +323,7 @@ def scale_loss(loss): @imperative_base.no_grad @framework.dygraph_only -def construct_groups(vars, group_size): +def build_groups(vars, group_size): group_idx = 0 memory_counter = 0 var_groups = OrderedDict() @@ -334,7 +334,7 @@ def construct_groups(vars, group_size): if memory_counter < group_size and dtype == var.dtype: memory_counter += bytes else: - memory_counter = 0 + memory_counter = bytes dtype = var.dtype group_idx += 1 var_groups.setdefault(group_idx, []).append(var) @@ -361,7 +361,7 @@ def sync_params_buffers(model, return # group size is 128M - coalesced_vars = construct_groups(model_vars, 128 * 1024 * 1024) + coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024) for coalesced_var, _, _ in coalesced_vars: paddle.distributed.broadcast( diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index c3ec312331b8f4c1bb5955b24848832ad1491e65..843eaedd69084fd24c7979e5bbdd4a9d421ae1ae 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -852,7 +852,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120) - set_tests_properties(test_parallel_dygraph_hybrid_parallel PROPERTIES TIMEOUT 120 LABELS "RUN_TYPE=DIST") + set_tests_properties(test_parallel_dygraph_hybrid_parallel PROPERTIES TIMEOUT 200 LABELS "RUN_TYPE=DIST") if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..248c271eec6a1bd139e0727b7a824aaa2f4269bf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle +import numpy as np +from hybrid_parallel_mp_model import TestDistMPTraning +import paddle.distributed.fleet as fleet +import unittest + + +class TestMPClipGrad(TestDistMPTraning): + def build_optimizer(self, model): + grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0) + scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=0.001, gamma=0.999, verbose=True) + optimizer = paddle.optimizer.SGD(scheduler, + grad_clip=grad_clip, + parameters=model.parameters()) + return optimizer + + def train_batch(self, batch, model, optimizer, is_mp): + scaler = paddle.amp.GradScaler(init_loss_scaling=5160) + if is_mp: + scaler = fleet.distributed_scaler(scaler) + with paddle.amp.auto_cast(): + output = model(batch) + loss = output.mean() + + scaled = scaler.scale(loss) # scale the loss + scaled.backward() # do backward + + scaler.minimize(optimizer, scaled) # update parameters + optimizer.clear_grad() + return scaled + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_clip_grad.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_clip_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..ad95aceaa2cf9ce47d42d8bda664837f5859e681 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_clip_grad.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle +import numpy as np +from hybrid_parallel_mp_model import TestDistMPTraning +import unittest +import logging + +#log = logging.getLogger("HybridParallel") +#log.setLevel(logging.WARNING) + + +class TestMPClipGrad(TestDistMPTraning): + def build_optimizer(self, model): + grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0) + scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=0.001, gamma=0.999, verbose=True) + optimizer = paddle.optimizer.SGD(scheduler, + grad_clip=grad_clip, + parameters=model.parameters()) + return optimizer + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py index ed5b9060e5eba9fc5b555b689746014968bca552..dfbef998a2f07ab697d27b19197b3cb65cb41205 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py @@ -173,9 +173,9 @@ class TestDistTraning(unittest.TestCase): self.word_size = self.hcg.get_model_parallel_world_size() self.rank_id = self.hcg.get_model_parallel_rank() - input_size_per_card = 17 + input_size_per_card = 11 input_size = input_size_per_card * self.model_parallel_size - output_size_per_card = 13 + output_size_per_card = 10 output_size = output_size_per_card * self.model_parallel_size batch_size = 4 diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_model.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_model.py index 0336f9220ab8c1e0f3814755ace52dcf0c707456..767bf5d57e74aff64d13170267785c6a8ed4347b 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_model.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_model.py @@ -21,7 +21,6 @@ import random import paddle.distributed as dist import paddle.fluid as fluid import paddle.distributed.fleet as fleet -import paddle.fluid.generator as generator from paddle.io import DataLoader, Dataset import unittest @@ -143,7 +142,7 @@ class TrainDataset(Dataset): return np_input_data -class TestDistTraning(unittest.TestCase): +class TestDistMPTraning(unittest.TestCase): def setUp(self): strategy = fleet.DistributedStrategy() self.model_parallel_size = 2 @@ -155,7 +154,20 @@ class TestDistTraning(unittest.TestCase): } fleet.init(is_collective=True, strategy=strategy) - def test_mp_model(self): + def train_batch(self, batch, model, optimizer, is_mp): + output = model(batch) + loss = output.mean() + loss.backward() # do backward + optimizer.step() # update parameters + optimizer.clear_grad() + return loss + + def build_optimizer(self, model): + optimizer = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model.parameters()) + return optimizer + + def build_model_optimizer(self): hcg = fleet.get_hybrid_communicate_group() word_size = hcg.get_model_parallel_world_size() mp_id = hcg.get_model_parallel_rank() @@ -182,31 +194,29 @@ class TestDistTraning(unittest.TestCase): model_a = SimpleMPNet(vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2, mp_id) - optimizer_a = paddle.optimizer.SGD(learning_rate=0.001, - parameters=model_a.parameters()) + optimizer_a = self.build_optimizer(model_a) model_a = fleet.distributed_model(model_a) optimizer_a = fleet.distributed_optimizer(optimizer_a) model_b = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2) - optimizer_b = paddle.optimizer.SGD(learning_rate=0.001, - parameters=model_b.parameters()) + optimizer_b = self.build_optimizer(model_b) + + return model_a, optimizer_a, model_b, optimizer_b, train_data_loader + + def test_mp_model(self): + model_a, optimizer_a, model_b, optimizer_b, train_data_loader = self.build_model_optimizer( + ) for step, batch in enumerate(train_data_loader): if step > 5: return - output_a = model_a(batch) - loss_a = output_a.mean() - loss_a.backward() - optimizer_a.step() - optimizer_a.clear_grad() - - output_b = model_b(batch) - loss_b = output_b.mean() - loss_b.backward() - optimizer_b.step() - optimizer_b.clear_grad() - np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy()) + + loss_a = self.train_batch(batch, model_a, optimizer_a, True) + loss_b = self.train_batch(batch, model_b, optimizer_b, False) + + np.testing.assert_allclose( + loss_a.numpy(), loss_b.numpy(), rtol=1e-5) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py index c3cb26c078e2dd1739a7e7a3b5abc1d38a6cef98..ac37edc266f2ca22f2af7ddb8ddac1f1e8787494 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py @@ -30,6 +30,12 @@ class TestHybridParallel(TestMultipleGpus): def test_hybrid_parallel_mp_model(self): self.run_mnist_2gpu('hybrid_parallel_mp_model.py') + def test_hybrid_parallel_mp_amp(self): + self.run_mnist_2gpu('hybrid_parallel_mp_amp.py') + + def test_hybrid_parallel_mp_clip_grad(self): + self.run_mnist_2gpu('hybrid_parallel_mp_clip_grad.py') + if __name__ == "__main__": unittest.main()