From 861a12ff432abe1725894a90c1e0e54ca9e3ade0 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Mon, 17 Apr 2023 11:08:54 +0800 Subject: [PATCH] [Dygraph] Support delaying div loss by accumulate_steps in PipelineLayer (#52848) (#52972) --- .../framework/distributed_strategy.proto | 6 + .../fleet/base/distributed_strategy.py | 6 + .../hybrid_parallel_optimizer.py | 40 +-- .../fleet/meta_parallel/pipeline_parallel.py | 62 ++++- .../fleet/meta_parallel/pp_utils/utils.py | 113 ++++++++ .../distributed/fleet/utils/__init__.py | 1 + .../fleet/utils/hybrid_parallel_util.py | 12 +- .../fleet/utils/mix_precision_utils.py | 254 ++++++++++++++++++ .../unittests/hybrid_parallel_pp_alexnet.py | 47 ++++ 9 files changed, 519 insertions(+), 22 deletions(-) create mode 100644 python/paddle/distributed/fleet/utils/mix_precision_utils.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index de2e38c2f11..d0e494ad494 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -57,12 +57,18 @@ message MpConfig { optional bool sync_moment= 3 [ default = false ]; } +message PpConfig { + optional bool dp_comm_overlap = 1 [ default = false ]; + optional bool delay_scale_loss = 2 [ default = false ]; +} + message HybridConfig { optional int32 dp_degree = 1 [ default = -1 ]; optional int32 mp_degree = 2 [ default = 1 ]; optional int32 pp_degree = 3 [ default = 1 ]; optional int32 sharding_degree = 4 [ default = 1 ]; optional MpConfig mp_configs = 5; + optional PpConfig pp_configs = 6; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 86292a2d90e..14e2fc09d33 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -1702,6 +1702,12 @@ class DistributedStrategy: self.strategy.hybrid_configs.mp_configs, configs["mp_configs"] ) configs.pop("mp_configs") + if "pp_configs" in configs: + assign_configs_value( + self.strategy.hybrid_configs.pp_configs, configs["pp_configs"] + ) + configs.pop("pp_configs") + assign_configs_value(self.strategy.hybrid_configs, configs) @property diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index acd34f1b1d5..405ef5492af 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -26,6 +26,7 @@ from ...utils.hybrid_parallel_util import ( sharding_reduce_gradients, ) from ...utils.log_util import logger +from ...utils.mix_precision_utils import MixPrecisionOptimizer __all__ = [] @@ -260,38 +261,41 @@ class HybridParallelOptimizer: "or Sharding, the grad clip of original optimizer will be changed." ) - if self._sharding_enable: - # change sharding inner_optimizer's _grad_clip - self._inner_opt._inner_optimizer._grad_clip = ( - HybridParallelClipGrad(self._inner_opt._grad_clip, hcg) - ) - elif ( - self._inner_opt._parameter_list - and not isinstance(self._inner_opt._parameter_list[0], dict) + inner_opt = ( + self._inner_opt._inner_optimizer + if self._sharding_enable + else self._inner_opt + ) + + if isinstance(inner_opt, MixPrecisionOptimizer): + inner_opt = inner_opt._inner_opt + + if ( + inner_opt._parameter_list + and not isinstance(inner_opt._parameter_list[0], dict) and len( [ p - for p in self._inner_opt._parameter_list + for p in inner_opt._parameter_list if hasattr(p, "main_grad") ] ) > 0 ): - - self._inner_opt._inner_opt._grad_clip = HybridParallelClipGrad( - self._inner_opt._inner_opt._grad_clip, hcg + inner_opt._grad_clip = HybridParallelClipGrad( + inner_opt._grad_clip, hcg ) else: - self._inner_opt._grad_clip = HybridParallelClipGrad( - self._inner_opt._grad_clip, hcg + inner_opt._grad_clip = HybridParallelClipGrad( + inner_opt._grad_clip, hcg ) - if self._inner_opt._parameter_list and isinstance( - self._inner_opt._parameter_list[0], dict + if inner_opt._parameter_list and isinstance( + inner_opt._parameter_list[0], dict ): - for item in self._inner_opt._param_groups: + for item in inner_opt._param_groups: if "grad_clip" in item.keys(): item["grad_clip"] = HybridParallelClipGrad( - self._inner_opt._grad_clip, hcg + inner_opt._grad_clip, hcg ) def _filter_fn(self, param): diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 09ccc198ad6..ab1cf9701dd 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -24,6 +24,7 @@ from ..utils.log_util import logger from .meta_parallel_base import MetaParallelBase from .parallel_layers.pp_layers import PipelineLayer from .pp_utils import p2p_communication as p2p +from .pp_utils.utils import FusedAllReduceBuffer, assign_group_by_size __all__ = [] @@ -59,12 +60,21 @@ class PipelineParallel(MetaParallelBase): self.num_stages = self._hcg.get_pipe_parallel_world_size() self.stage_id = self._hcg.get_stage_id() self.pp_group = self._hcg.get_pipe_parallel_group() + self.dp_group = self._hcg.get_data_parallel_group() self._virtual_pp_world_size = None self._virtual_pp_rank = None self._real_pp_world_size = self.num_stages self._real_pp_rank = self.stage_id + self._delay_scale_loss = self._strategy.hybrid_configs[ + "pp_configs" + ].delay_scale_loss + self._dp_comm_overlap = self._strategy.hybrid_configs[ + "pp_configs" + ].dp_comm_overlap + self._dp_comm_buffers = [] + p2p.initialize_p2p_groups( hcg, self._using_cache, self._enable_partial_send_recv ) @@ -92,6 +102,11 @@ class PipelineParallel(MetaParallelBase): logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) + if self._dp_comm_overlap: + self.register_allreduce_overlap_hook( + self._layers, self.dp_group, self.accumulate_steps + ) + def is_pipeline_first_stage(self, ignore_virtual=False): if not ignore_virtual: if self._virtual_pp_world_size is not None: @@ -114,6 +129,27 @@ class PipelineParallel(MetaParallelBase): def set_virtual_pipeline_rank(self, rank): self._virtual_pp_rank = rank + def bw_hook_func(self, buffer, param): + @paddle.autograd.no_grad() + def fused_allreduce(*_): + buffer.add_grad(param) + + return fused_allreduce + + def register_allreduce_overlap_hook(self, model, comm_group, acc_steps): + parameter_list = [p for p in model.parameters() if not p.stop_gradient] + if len(parameter_list) < 1: + return + + var_groups = assign_group_by_size(parameter_list) + for group_idx, parameters in var_groups.items(): + buffer = FusedAllReduceBuffer( + group_idx, parameters, comm_group, acc_steps + ) + self._dp_comm_buffers.append(buffer) + for param in parameters: + param._register_backward_hook(self.bw_hook_func(buffer, param)) + def forward_backward_pipeline(self, data, scaler=None): # use the 1f1b scheduling strategy. # this strategy is inspired by: @@ -192,6 +228,11 @@ class PipelineParallel(MetaParallelBase): ) p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage()) + if self._dp_comm_overlap: + assert len(self._dp_comm_buffers) > 0 + for buffer in self._dp_comm_buffers: + buffer.scale_and_split_grads() + self._layers.allreduce_shared_weight_gradients() with paddle.amp.auto_cast(enable=False): train_loss = self._broadcast_final_loss() @@ -310,7 +351,7 @@ class PipelineParallel(MetaParallelBase): ), "Currently, loss_fn should obtain Paddle.Tensor dtype" with paddle.amp.auto_cast(enable=False): - if self.accumulate_steps > 1: + if self.accumulate_steps > 1 and not self._delay_scale_loss: output_tensor = output_tensor / self.accumulate_steps if self.total_loss is None: @@ -413,7 +454,11 @@ class PipelineParallel(MetaParallelBase): assert ( self.total_loss is not None ), "train_batch() in last stage should obtain vaild loss" - loss = self.total_loss.detach() + loss = ( + self.total_loss.detach() + if not self._delay_scale_loss + else self.total_loss / self.accumulate_steps + ) is_fp32 = ( paddle.full([], 1, 'int64') if loss.dtype == paddle.float32 @@ -447,6 +492,14 @@ class PipelineParallel(MetaParallelBase): return loss def _optimizer_step(self): + if self._delay_scale_loss: + for p in self._layers.parameters(): + if hasattr(p, "main_grad") and p.main_grad is not None: + assert p.grad is None + p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps) + elif p.grad is not None: + p.grad = p.grad.scale(1.0 / self.accumulate_steps) + if self.scaler: self.scaler.step(self.optimizer) self.scaler.update() @@ -746,6 +799,11 @@ class PipelineParallelWithInterleave(PipelineParallel): ) ) + if self._dp_comm_overlap: + assert len(self._dp_comm_buffers) > 0 + for buffer in self._dp_comm_buffers: + buffer.scale_and_split_grads() + self._layers.allreduce_shared_weight_gradients() if compute_loss: diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index f6b447db919..c34ec8f45e1 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -12,8 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict + +import numpy as np + import paddle from paddle import _legacy_C_ops +from paddle.distributed.parallel import _split_tensors +from paddle.fluid import core __all__ = [] @@ -105,3 +111,110 @@ def _all_gather(tensor, group=None, use_calc_stream=True): 'nranks', nranks, ) + + +class FusedAllReduceBuffer: + def __init__(self, id, params, comm_group, acc_steps=1): + self._id = id + self._params = params + self._acc_steps = acc_steps + self._comm_group = comm_group + + self._tasks = [] + self._grads = [] + self._params_step_dict = {} + self._params_checked_in = 0 + self._coalesced_grads_and_grad_vars = [] + + self._init_step_dict() + + def _init_step_dict(self): + for p in self._params: + self._params_step_dict[p.name] = 0 + + def _reset_params_checked_in(self): + self._tasks.clear() + self._grads.clear() + self._init_step_dict() + self._params_checked_in = 0 + self._coalesced_grads_and_grad_vars.clear() + + @property + def _all_params_checked_in(self): + return ( + len(self._params) == self._params_checked_in + and len(self._params_step_dict) == 0 + ) + + def add_grad(self, param): + assert param.name in self._params_step_dict + + if self._params_step_dict[param.name] == 0: + if getattr(param, "main_grad", None) is not None: + assert param.grad is None + self._grads.append(param.main_grad) + else: + self._grads.append(param.grad) + + self._params_step_dict[param.name] += 1 + + if self._params_step_dict[param.name] == self._acc_steps: + self._params_checked_in += 1 + self._params_step_dict.pop(param.name) + + if self._all_params_checked_in: + self._fused_allreduce_grads() + + def _fused_allreduce_grads(self): + assert self._all_params_checked_in + flattened_vars = [] + g_var_shapes = [] + + for g_var in self._grads: + g_var_shapes.append(g_var.shape) + flattened_vars.append( + paddle.reshape(x=g_var, shape=[np.prod(g_var.shape)]) + ) + + coalesced_grad = paddle.concat(flattened_vars) + self._coalesced_grads_and_grad_vars.append( + [coalesced_grad, self._grads, g_var_shapes] + ) + + for coalesced_grad, _, _ in self._coalesced_grads_and_grad_vars: + self._tasks.append( + paddle.distributed.all_reduce( + coalesced_grad, group=self._comm_group, sync_op=False + ) + ) + + def scale_and_split_grads(self): + for task in self._tasks: + task.wait() + + scale_factor = 1.0 / self._comm_group.nranks + for coalesced_grad, _, _ in self._coalesced_grads_and_grad_vars: + coalesced_grad.scale_(scale_factor) + + _split_tensors(self._coalesced_grads_and_grad_vars) + self._reset_params_checked_in() + + +def assign_group_by_size(parameters, group_size=128 * 1024 * 1024): + + group_idx = 0 + memory_counter = 0 + var_groups = OrderedDict() + dtype = parameters[0].dtype + + for var in parameters: + bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype) + if memory_counter < group_size and dtype == var.dtype: + memory_counter += bytes + else: + memory_counter = bytes + dtype = var.dtype + group_idx += 1 + var_groups.setdefault(group_idx, []).append(var) + + return var_groups diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index c5453fa61e6..25b1c153165 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -22,6 +22,7 @@ import paddle from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 from . import tensor_parallel_utils # noqa: F401 +from . import mix_precision_utils # noqa: F401 __all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index f6a0a27ce35..fc7b463bd81 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -232,10 +232,18 @@ def sharding_reduce_gradients(parameter_list, hcg): sharding_nrank = hcg.get_sharding_parallel_group().nranks for param in parameter_list: + g_var = None if param.trainable and (param._grad_ivar() is not None): - param.grad.scale_(1.0 / sharding_nrank) + g_var = param._grad_ivar() + if param.trainable and hasattr(param, "main_grad"): + assert ( + param._grad_ivar() is None + ), "param.grad should be None when using main_grad" + g_var = param.main_grad + if g_var is not None: + g_var.scale_(1.0 / sharding_nrank) paddle.distributed.all_reduce( - param.grad, + g_var, group=hcg.get_sharding_parallel_group(), sync_op=True, ) diff --git a/python/paddle/distributed/fleet/utils/mix_precision_utils.py b/python/paddle/distributed/fleet/utils/mix_precision_utils.py new file mode 100644 index 00000000000..1ee26bce1fb --- /dev/null +++ b/python/paddle/distributed/fleet/utils/mix_precision_utils.py @@ -0,0 +1,254 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import defaultdict +from types import MethodType + +import numpy as np + +import paddle +from paddle import _legacy_C_ops, nn +from paddle.distributed import fleet +from paddle.fluid import framework +from paddle.fluid.dygraph import base as imperative_base +from paddle.fluid.dygraph import to_variable +from paddle.framework import core + + +class MixPrecisionLayer(nn.Layer): + def __init__(self, layers, dtype="float16"): + super().__init__(layers.full_name() + "_mix_precision") + + self._layers = layers + self._dtype = dtype + + assert self._dtype in ["float16", "bfloat16"] + + for param in self._layers.parameters(): + if not hasattr(param, "main_grad"): + param.main_grad = None + param._register_grad_hook(self._update_main_grad_hook(param)) + + def _update_main_grad_hook(self, param): + """Create the update_main_grad hook for backprop.""" + + # Hook used for back-prop and grad-merge. + @paddle.autograd.no_grad() + def param_hook(tmp_grad): + assert ( + param.grad is None + ), "In main_grad node, param.grad should be None, but find param[{}] has grad.".format( + param.name + ) + if param.main_grad is None: + param.main_grad = core.eager.Tensor( + value=tmp_grad.cast(paddle.float32).value(), + place=tmp_grad.place, + name="main_grad@" + param.name, + ) + else: + param.main_grad.add_(tmp_grad.cast(paddle.float32)) + + tmp_grad._clear_data() + return None + + return param_hook + + def forward(self, *inputs, **kwargs): + outputs = self._layers(*inputs, **kwargs) + + return outputs + + def state_dict( + self, + destination=None, + include_sublayers=True, + structured_name_prefix="", + ): + + return self._layers.state_dict( + destination=destination, + include_sublayers=include_sublayers, + structured_name_prefix=structured_name_prefix, + ) + + @framework.deprecate_stat_dict + def set_state_dict(self, state_dict, use_structured_name=True): + + self._layers.set_state_dict( + state_dict, use_structured_name=use_structured_name + ) + + +class MixPrecisionOptimizer: + def __init__(self, optimizer): + self._inner_opt = optimizer + self._parameter_list = self._obtain_optimizer_parameters_list() + + def _obtain_optimizer_parameters_list(self): + if getattr(self._inner_opt, '_param_groups', None) and isinstance( + self._inner_opt._param_groups[0], dict + ): + parameters_list = [] + for group in self._inner_opt._param_groups: + for param in group['params']: + parameters_list.append(param) + else: + parameters_list = list(self._inner_opt._parameter_list) + + return parameters_list + + @imperative_base.no_grad + @framework.dygraph_only + def step(self): + + if not isinstance(self._parameter_list[0], dict): + params_grads = [] + for param in self._parameter_list: + if param.stop_gradient: + continue + grad_var = param.main_grad + if framework.in_dygraph_mode(): + if ( + hasattr(grad_var, "is_selected_rows") + and grad_var.is_selected_rows() + and self._inner_opt.regularization is not None + ): + raise RuntimeError( + "AdamW don't support weight_decay with sparse parameters, please set it to None." + ) + else: + if ( + hasattr(grad_var, "_is_sparse") + and grad_var._is_sparse() + and self._inner_opt.regularization is not None + ): + raise RuntimeError( + "AdamW don't support weight_decay with sparse parameters, please set it to None." + ) + params_grads.append((param, grad_var)) + + optimize_ops = self._inner_opt._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads + ) + else: + # optimize parameters in groups + for param_group in self._inner_opt._param_groups: + params_grads = defaultdict(lambda: []) + for param in param_group['params']: + if param.stop_gradient: + continue + grad_var = param.main_grad + if framework.in_dygraph_mode(): + if ( + hasattr(grad_var, "is_selected_rows") + and grad_var.is_selected_rows() + and self._inner_opt.regularization is not None + ): + raise RuntimeError( + "AdamW don't support weight_decay with sparse parameters, please set it to None." + ) + else: + if ( + hasattr(grad_var, "_is_sparse") + and grad_var._is_sparse() + and self._inner_opt.regularization is not None + ): + raise RuntimeError( + "AdamW don't support weight_decay with sparse parameters, please set it to None." + ) + params_grads['params'].append((param, grad_var)) + params_grads.update( + {k: v for k, v in param_group.items() if k != 'params'} + ) + self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads + ) + + @framework.dygraph_only + def clear_grad(self, set_to_zero=True): + + param_list = [] + if self._parameter_list is None or not isinstance( + self._parameter_list[0], dict + ): + for p in self._parameter_list: + if not p.stop_gradient: + param_list.append(p) + else: + for param_group in self._param_groups: + for p in param_group['params']: + if not p.stop_gradient: + param_list.append(p) + + for p in param_list: + if hasattr(p, "main_grad") and p.main_grad is not None: + if set_to_zero: + p.main_grad.zero_() + else: + p.main_grad._clear() + p.main_grad = None + elif not hasattr(p, "main_grad"): + p.clear_gradient(set_to_zero) + + def __getattr__(self, item): + return getattr(self._inner_opt, item) + + +def unscale_method(self, optimizer): + if not self._enable: + return + param_grads = [] + if getattr(optimizer, '_param_groups', None) and isinstance( + optimizer._param_groups[0], dict + ): + for group in optimizer._param_groups: + for param in group['params']: + if param.main_grad is not None: + assert param.main_grad.dtype == core.VarDesc.VarType.FP32 + param_grads.append(param.main_grad) + else: + for param in optimizer._parameter_list: + if param.main_grad is not None: + assert param.main_grad.dtype == core.VarDesc.VarType.FP32 + param_grads.append(param.main_grad) + + temp_found_inf = to_variable(np.array([0]).astype(np.bool_)) + if len(param_grads): + _legacy_C_ops.check_finite_and_unscale( + param_grads, + self._scale, + param_grads, + temp_found_inf, + ) + + self._found_inf = 1 if temp_found_inf else 0 + + hcg = fleet.get_hybrid_communicate_group() + if hcg is not None and hcg.nranks > hcg.get_data_parallel_world_size(): + is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") + paddle.distributed.all_reduce( + is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None + ) + self._found_inf = is_found_inf.numpy()[0] + + +class MixPrecisionScaler: + def __init__(self, scaler): + self._inner_scaler = scaler + self._inner_scaler._unscale = MethodType(unscale_method, scaler) + + def __getattr__(self, item): + return getattr(self._inner_scaler, item) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py index 7972ec62fa4..1c3eac9cec4 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py @@ -21,6 +21,10 @@ from hybrid_parallel_pp_layer import AlexNet, AlexNetPipeDesc import paddle import paddle.distributed as dist from paddle.distributed import fleet +from paddle.distributed.fleet.utils.mix_precision_utils import ( + MixPrecisionLayer, + MixPrecisionOptimizer, +) def set_random_seed(seed, dp_id, rank_id): @@ -60,6 +64,9 @@ class TestDistPPTraning(unittest.TestCase): ) return scheduler, optimizer + def wrapper_mix_precision(self, model, optimizer): + return model, optimizer + def test_pp_model(self): hcg = fleet.get_hybrid_communicate_group() word_size = hcg.get_model_parallel_world_size() @@ -81,6 +88,7 @@ class TestDistPPTraning(unittest.TestCase): # construct model b model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size) scheduler_b, optimizer_b = self.build_optimizer(model_b) + model_b, optimizer_b = self.wrapper_mix_precision(model_b, optimizer_b) model_b = fleet.distributed_model(model_b) optimizer_b = fleet.distributed_optimizer(optimizer_b) @@ -125,5 +133,44 @@ class TestDistPPTraning(unittest.TestCase): ) +class TestDistPPDelayScaleLoss(TestDistPPTraning): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + "pp_configs": { + "delay_scale_loss": True, + }, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size, + } + fleet.init(is_collective=True, strategy=strategy) + + +class TestDistPPMainGrad(TestDistPPTraning): + def wrapper_mix_precision(self, model, optimizer): + model = MixPrecisionLayer(model, dtype="float16") + optimizer = MixPrecisionOptimizer(optimizer) + return model._layers, optimizer + + def build_optimizer(self, model): + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True + ) + optimizer = paddle.optimizer.SGD( + learning_rate=scheduler, + parameters=model.parameters(), + grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0), + ) + return scheduler, optimizer + + if __name__ == "__main__": unittest.main() -- GitLab