From 5b357e021d33e02e16d5eb96d2f6d644c9c78277 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Tue, 26 Oct 2021 17:31:25 +0800 Subject: [PATCH] [cherry-pick]Support FP16 in HybridParallel and Fix bugs in HybridOptimizer (#36707) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer (#36237) * fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer * update * update * fix bugs in mp_layers、pp_layers and HybridParallelClipGrad (#36144) * fix calling bug of HybridParallelClipGrad * fix bugs of HybridParallelClipGrad * add unittest of pp with HybridParallelClipGrad * fix bugs in mp_layers.py * update * fix bugs in pp_layers.py * update * [HybridParallel]Rebuild code for pipeline (#36396) * add no_sync for parameters sync * add pipeline for moe * [HybridParallel]Support fp16 in dygraph hybrid parallel (#36420) * [HybridParallel]Support fp16 in dygraph hybrid parallel * update * update * update for recompute * add unittest of pp+fp16 * add unittest of recompute+fp16 * update * modify ut * modify ut of cond (#36475) * fix bugs of ClipGradByGlobalNorm in HybridParallel (#36555) * fix bugs of ClipGradByGlobalNorm * add unittests * add unittests * [HybridParallel]fix bug of check_inf in fleet_base.py (#36651) * fix bug of check_inf * fix allreduce * support ClipGradByGlobalNorm in sharding (#36012) * support ClipGradByGlobalNorm in sharding * support ClipGradByGlobalNorm in sharding * test=allcase * Update test_linalg_cond.py * Update hybrid_parallel_util.py * Update hybrid_parallel_util.py Co-authored-by: ShenLiang <1422485404@qq.com> Co-authored-by: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> --- .../distributed/fleet/base/fleet_base.py | 40 ++++- .../dygraph_optimizer/__init__.py | 1 + .../hybrid_parallel_optimizer.py | 120 ++++++++++++--- .../parallel_layers/mp_layers.py | 8 +- .../parallel_layers/pp_layers.py | 7 + .../fleet/meta_parallel/pipeline_parallel.py | 86 +++++++---- .../fleet/meta_parallel/pp_utils/utils.py | 13 +- .../distributed/fleet/utils/recompute.py | 15 +- python/paddle/fluid/dygraph/parallel.py | 10 +- python/paddle/fluid/framework.py | 2 +- .../unittests/hybrid_parallel_mp_fp16.py | 59 ++++++++ .../unittests/hybrid_parallel_pp_alexnet.py | 17 ++- .../tests/unittests/hybrid_parallel_pp_amp.py | 4 + .../unittests/hybrid_parallel_pp_clip_grad.py | 35 +++++ .../unittests/hybrid_parallel_pp_fp16.py | 142 ++++++++++++++++++ .../hybrid_parallel_sharding_model.py | 19 ++- .../tests/unittests/test_dygraph_recompute.py | 38 ++++- ...test_parallel_dygraph_pipeline_parallel.py | 8 +- .../test_parallel_dygraph_tensor_parallel.py | 3 + 19 files changed, 533 insertions(+), 94 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/hybrid_parallel_mp_fp16.py create mode 100644 python/paddle/fluid/tests/unittests/hybrid_parallel_pp_clip_grad.py create mode 100644 python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 687295b1f2c..c930e1c06ae 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -35,6 +35,8 @@ from ..meta_parallel import TensorParallel, model_parallel_random_seed from ..meta_parallel import PipelineParallel, ShardingParallel from ..meta_optimizers import HybridParallelOptimizer from paddle import _C_ops +from paddle.fluid import core +from paddle.fluid.dygraph import to_variable __all__ = [] @@ -1547,26 +1549,52 @@ class Fleet(object): if getattr(optimizer, '_param_groups', None) and isinstance( optimizer._param_groups[0], dict): param_grads = [] + param_grads_fp16 = [] + param_grads_fp32 = [] for group in optimizer._param_groups: for param in group['params']: if param._grad_ivar() is not None: param_grads.append(param._grad_ivar()) + if param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP16: + param_grads_fp16.append(param._grad_ivar()) + else: + param_grads_fp32.append(param._grad_ivar()) else: param_grads = [ param._grad_ivar() for param in optimizer._parameter_list if param._grad_ivar() is not None ] - _C_ops.check_finite_and_unscale(param_grads, self._scale, - param_grads, self._found_inf) - - self._found_inf = paddle.cast(self._found_inf, dtype="int32") + param_grads_fp16 = [ + param._grad_ivar() for param in optimizer._parameter_list + if (param._grad_ivar() is not None) and (param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP16) + ] + param_grads_fp32 = [ + param._grad_ivar() for param in optimizer._parameter_list + if (param._grad_ivar() is not None) and (param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP32) + ] + temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) + temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) + if len(param_grads_fp16): + _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale, + param_grads_fp16, + temp_found_inf_fp16) + if len(param_grads_fp32): + _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale, + param_grads_fp32, + temp_found_inf_fp32) + + self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 + is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") # TODO(shenliang03) Since dp allreduce in the optimizer is # after the gradscaler, check_finite needs to synchronize global # information. In the future, we should use check_group to speed. paddle.distributed.all_reduce( - self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None) - self._found_inf = paddle.cast(self._found_inf, dtype="bool") + is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None) + self._found_inf = is_found_inf.numpy()[0] # Only tensor_parallel and pipeline_parallel need to modify scaler if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL, diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py index f0f26bd2e0d..28260d7aa18 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and from .hybrid_parallel_optimizer import HybridParallelOptimizer from .hybrid_parallel_gradscaler import HybridParallelGradScaler +from .dygraph_sharding_optimizer import DygraphShardingOptimizer __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 581fbc5153a..e7108b3f4f3 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -50,7 +50,12 @@ class HybridParallelClipGrad: @imperative_base.no_grad def _dygraph_clip(self, params_grads): params_and_grads = [] - sum_square_list = [] + + sum_square_dist_fp16 = [] + sum_square_dist_fp32 = [] + sum_square_not_dist_fp16 = [] + sum_square_not_dist_fp32 = [] + for p, g in params_grads: if g is None: continue @@ -62,32 +67,98 @@ class HybridParallelClipGrad: merge_grad = layers.get_tensor_from_selected_rows(merge_grad) square = layers.square(merge_grad) sum_square = layers.reduce_sum(square) - sum_square_list.append(sum_square) - # all parameters have been filterd out - if len(sum_square_list) == 0: - return params_grads - - global_norm_var = layers.concat(sum_square_list) - global_norm_var = layers.reduce_sum(global_norm_var) - # add all reduce to get global norm in world size - paddle.distributed.all_reduce(global_norm_var, - self._hcg.get_check_parallel_group()) - global_norm_var = layers.sqrt(global_norm_var) + not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or ( + hasattr(p, 'is_firstly_shared') and + getattr(p, 'is_firstly_shared', True)) + + if not_shared_enable: + if p.is_distributed: + if p.dtype == paddle.float16: + sum_square_dist_fp16.append(sum_square) + elif p.dtype == paddle.float32: + sum_square_dist_fp32.append(sum_square) + else: + if p.dtype == paddle.float16: + sum_square_not_dist_fp16.append(sum_square) + elif p.dtype == paddle.float32: + sum_square_not_dist_fp32.append(sum_square) + + # global norm of distributed FP16 params_and_grads + if len(sum_square_dist_fp16) == 0: + global_norm_dist_fp16 = paddle.to_tensor([0.], dtype=paddle.float32) + else: + global_norm_dist_fp16 = layers.concat(sum_square_dist_fp16) + global_norm_dist_fp16 = layers.reduce_sum(global_norm_dist_fp16) + global_norm_dist_fp16 = paddle.cast( + global_norm_dist_fp16, dtype=paddle.float32) + + # global norm of non-distributed FP16 params_and_grads + if len(sum_square_not_dist_fp16) == 0: + global_norm_not_dist_fp16 = paddle.to_tensor( + [0.], dtype=paddle.float32) + else: + global_norm_not_dist_fp16 = layers.concat(sum_square_not_dist_fp16) + global_norm_not_dist_fp16 = layers.reduce_sum( + global_norm_not_dist_fp16) + global_norm_not_dist_fp16 = paddle.cast( + global_norm_not_dist_fp16, dtype=paddle.float32) + + # global norm of distributed FP32 params_and_grads + global_norm_dist_fp32 = layers.concat(sum_square_dist_fp32) if len( + sum_square_dist_fp32) != 0 else paddle.to_tensor( + [0.], dtype=paddle.float32) + global_norm_dist_fp32 = layers.reduce_sum(global_norm_dist_fp32) + + # global norm of non-distributed FP32 params_and_grads + global_norm_not_dist_fp32 = layers.concat( + sum_square_not_dist_fp32) if len( + sum_square_not_dist_fp32) != 0 else paddle.to_tensor( + [0.], dtype=paddle.float32) + global_norm_not_dist_fp32 = layers.reduce_sum(global_norm_not_dist_fp32) + + global_norm_var_dist = global_norm_dist_fp16 + global_norm_dist_fp32 + global_norm_var_not_dist = global_norm_not_dist_fp16 + global_norm_not_dist_fp32 + + # add all reduce to get global norm of distributed params_and_grads + if self._hcg.get_model_parallel_world_size() > 1: + paddle.distributed.all_reduce( + global_norm_var_dist, + group=self._hcg.get_check_parallel_group()) + + # add all reduce to get global norm of non-distributed params_and_grads in groups of pp + if self._hcg.get_pipe_parallel_world_size() > 1: + paddle.distributed.all_reduce( + global_norm_var_not_dist, + group=self._hcg.get_pipe_parallel_group()) + + # In Sharding mode, param and grad is mapping different rank in optimizer. + # ClipGradByGlobalNorm need allreduce to get globol norm + if self._hcg.get_sharding_parallel_world_size() > 1: + paddle.distributed.all_reduce( + global_norm_var_not_dist, + group=self._hcg.get_sharding_parallel_group()) + + global_norm_var_fp32 = layers.sqrt(global_norm_var_dist + + global_norm_var_not_dist) max_global_norm = layers.fill_constant( - shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) + shape=[1], dtype=global_norm_var_fp32.dtype, value=self.clip_norm) clip_var = layers.elementwise_div( x=max_global_norm, y=layers.elementwise_max( - x=global_norm_var, y=max_global_norm)) + x=global_norm_var_fp32, y=max_global_norm)) + clip_var_fp16 = paddle.cast(clip_var, paddle.float16) for p, g in params_grads: if g is None: continue if getattr(p, 'need_clip', True) is False: params_and_grads.append((p, g)) continue - new_grad = layers.elementwise_mul(x=g, y=clip_var) + if p.dtype == paddle.float16: + new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16) + else: + new_grad = layers.elementwise_mul(x=g, y=clip_var) params_and_grads.append((p, new_grad)) return params_and_grads @@ -96,7 +167,7 @@ class HybridParallelClipGrad: return getattr(self._clip, item) def __call__(self, params_grads): - return self._clip(params_grads) + return self._dygraph_clip(params_grads) class HybridParallelOptimizer: @@ -112,7 +183,7 @@ class HybridParallelOptimizer: self._need_dp = (self._hcg.get_data_parallel_world_size() > 1) # NOTE(shenliang03): Because of the pure DataParallel mode, the gradient synchronization - # is achieved through reducer, so there is no need to call fuse_allreduce in oprimizer. + # is achieved through reducer, so there is no need to call fuse_allreduce in optimizer. self._dp_enable = not self._use_dp_mode and self._need_dp self._sharding_enable = ( @@ -120,11 +191,16 @@ class HybridParallelOptimizer: if isinstance(self._inner_opt._grad_clip, ClipGradByGlobalNorm) and not self._use_dp_mode: - logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \ - "optmizer'grad clip will be changed.") - - self._inner_opt._grad_clip = HybridParallelClipGrad( - self._inner_opt._grad_clip, hcg) + logger.warning("While using ClipGradByGlobalNorm in TensorParallel, PipelineParallel " \ + "or Sharding, the grad clip of original optimizer will be changed.") + + if self._sharding_enable: + # change sharding inner_optimizer's _grad_clip + self._inner_opt._inner_optimizer._grad_clip = HybridParallelClipGrad( + self._inner_opt._grad_clip, hcg) + else: + self._inner_opt._grad_clip = HybridParallelClipGrad( + self._inner_opt._grad_clip, hcg) @imperative_base.no_grad @framework.dygraph_only diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py index 2555d73462b..2ce8cf7bdeb 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py @@ -70,7 +70,7 @@ class VocabParallelEmbedding(Layer): dtype=self._dtype, is_bias=False) - self.weight.is_distributed = True + self.weight.is_distributed = True if self.is_mp else False def forward(self, x): if self.is_mp: @@ -135,7 +135,7 @@ class ColumnParallelLinear(Layer): dtype=self._dtype, is_bias=False) - self.weight.is_distributed = True + self.weight.is_distributed = True if self.is_mp else False if has_bias: # initialize bias to zero like Megatron @@ -144,7 +144,7 @@ class ColumnParallelLinear(Layer): attr=paddle.nn.initializer.Constant(value=0.0), dtype=self._dtype, is_bias=True) - self.bias.is_distributed = True + self.bias.is_distributed = True if self.is_mp else False else: self.bias = None @@ -212,7 +212,7 @@ class RowParallelLinear(Layer): dtype=self._dtype, is_bias=False) - self.weight.is_distributed = True + self.weight.is_distributed = True if self.is_mp else False if has_bias: self.bias = self.create_parameter( diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index db6fc964895..9920bbd400c 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -261,6 +261,10 @@ class PipelineLayer(Layer): src=min(comm['ranks']), group=comm['group']) + for param in comm['layer'].parameters(): + if self.global_rank != min(comm['ranks']): + setattr(param, 'is_firstly_shared', False) + def allreduce_shared_weight_gradients(self): for key, comm in self.shared_comm.items(): param = getattr(self.shared_layers[key], comm['weight_attr']) @@ -316,6 +320,9 @@ class PipelineLayer(Layer): self.shared_layers[layer.layer_name] = layer.build_layer() self.shared_weight_attrs[ layer.layer_name] = layer.shared_weight_attr + for param in self.shared_layers[ + layer.layer_name].parameters(): + setattr(param, "is_firstly_shared", True) if layer.forward_func is None: self.run_function.append(self.shared_layers[ diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 431bc6d7bc3..7c7637a90fe 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -77,26 +77,15 @@ class PipelineParallel(MetaParallelBase): logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) - def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): - assert isinstance(optimizer, HybridParallelOptimizer), ( - 'optimizer should be HybridParallelOptimizer subclass.') - - assert fluid.framework._dygraph_tracer()._has_grad, ( - 'Please enable the generation of gradients.') - - if self.is_first_stage or self.is_last_stage: - assert data is not None, ( - "For the first and the last stage, the data must be set.") - else: - data = None + def forward_backward_pipeline(self, data, scaler=None): + # use the 1f1b scheduling strategy. + # this strategy is inspired by: + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler self.scaler = scaler - self.data = data - self._compute_loss = True - self._layers.train() + # store data for train + self.data = data # store total loss of entire batch self.total_loss = None @@ -104,10 +93,6 @@ class PipelineParallel(MetaParallelBase): # store data id for micro_batch self.micro_batch_id = 0 - # Next, use the 1f1b scheduling strategy. - # this strategy is inspired by: - # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py - startup_steps = (self.num_stages - self.stage_id - 1) startup_steps = min(startup_steps, self.accumulate_steps) steady_steps = self.accumulate_steps - startup_steps @@ -160,12 +145,36 @@ class PipelineParallel(MetaParallelBase): p2p.send_backward(input_tensor_grad) self._layers.allreduce_shared_weight_gradients() + with paddle.amp.auto_cast(enable=False): + train_loss = self._broadcast_final_loss() + return train_loss - self.train_loss = self._broadcast_final_loss() + def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): + assert isinstance(optimizer, HybridParallelOptimizer), ( + 'optimizer should be HybridParallelOptimizer subclass.') + + assert fluid.framework._dygraph_tracer()._has_grad, ( + 'Please enable the generation of gradients.') + + if self.is_first_stage or self.is_last_stage: + assert data is not None, ( + "For the first and the last stage, the data must be set.") + else: + data = None + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + self._layers.train() + + # 1f1b for pipeline + train_loss = self.forward_backward_pipeline(data, scaler) # optimizer - self._optimizer_step() - return self.train_loss + with paddle.amp.auto_cast(enable=False): + self._optimizer_step() + + return train_loss def eval_batch(self, data, compute_loss=False): self._layers.eval() @@ -233,12 +242,13 @@ class PipelineParallel(MetaParallelBase): output_tensor, paddle.Tensor ), "Currently, loss_fn should obtain Paddle.Tensor dtype" - if self.accumulate_steps > 1: - output_tensor = output_tensor / self.accumulate_steps + with paddle.amp.auto_cast(enable=False): + if self.accumulate_steps > 1: + output_tensor = output_tensor / self.accumulate_steps - if self.total_loss is None: - self.total_loss = paddle.zeros_like(output_tensor) - self.total_loss += output_tensor.detach() + if self.total_loss is None: + self.total_loss = paddle.zeros_like(output_tensor) + self.total_loss += output_tensor.detach() self.micro_batch_id += 1 return output_tensor @@ -312,13 +322,29 @@ class PipelineParallel(MetaParallelBase): if self.is_last_stage: assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" loss = self.total_loss.detach() + is_fp32 = paddle.to_tensor( + 1) if loss.dtype == paddle.float32 else paddle.to_tensor(0) + paddle.distributed.broadcast( + is_fp32, + src=self.global_rank, + use_calc_stream=True, + group=self.pp_group) paddle.distributed.broadcast( loss, src=self.global_rank, use_calc_stream=True, group=self.pp_group) else: - loss = paddle.zeros(shape=[1], dtype="float32") + is_fp32 = paddle.to_tensor(1) + paddle.distributed.broadcast( + is_fp32, + src=self._hcg.get_rank_from_stage(self.num_stages - 1), + use_calc_stream=True, + group=self.pp_group) + loss = paddle.zeros( + shape=[1], + dtype="float32") if is_fp32.numpy()[0] else paddle.zeros( + shape=[1], dtype="float16") paddle.distributed.broadcast( loss, src=self._hcg.get_rank_from_stage(self.num_stages - 1), diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 08266096548..7224ba6dedd 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -198,11 +198,14 @@ class _HPRecomputeFunction(PyLayer): # TODO support AMP tracer = framework._dygraph_tracer() - if tracer._amp_level == core.AmpLevel.O0: - ctx.is_fw_autocast = False + ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + ctx.amp_level = 'O2' + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + ctx.amp_level = 'O1' else: - ctx.is_fw_autocast = True - ctx.amp_mode = 'O1' + raise ValueError("unsupported amp level: {}".format( + tracer._amp_level)) ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): @@ -263,7 +266,7 @@ class _HPRecomputeFunction(PyLayer): enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, custom_black_list=ctx.amp_black_list, - level=ctx.amp_mode): + level=ctx.amp_level): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 56a64049b16..2d1db5db945 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -98,11 +98,14 @@ class RecomputeFunction(PyLayer): # TODO support AMP tracer = framework._dygraph_tracer() - if tracer._amp_level == core.AmpLevel.O0: - ctx.is_fw_autocast = False + ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + ctx.amp_level = 'O2' + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + ctx.amp_level = 'O1' else: - ctx.is_fw_autocast = True - ctx.amp_mode = 'O1' + raise ValueError("unsupported amp level: {}".format( + tracer._amp_level)) ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): @@ -133,7 +136,7 @@ class RecomputeFunction(PyLayer): enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, custom_black_list=ctx.amp_black_list, - level=ctx.amp_mode): + level=ctx.amp_level): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) else: @@ -141,7 +144,7 @@ class RecomputeFunction(PyLayer): enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, custom_black_list=ctx.amp_black_list, - level=ctx.amp_mode): + level=ctx.amp_level): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index e4525a8d179..7dd8d38aa70 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -354,9 +354,15 @@ def sync_params_buffers(model, if not isinstance(param, core.VarBase): raise TypeError("The data type of '%s' must be Varbase" % param.name) + # is_distributed param not need to sync when in mp mode - if is_model_parallel and isinstance(param, ParamBase): - if param.is_distributed: + if isinstance(param, ParamBase): + if is_model_parallel and param.is_distributed: + continue + + # NOTE(shenliang03): Support situations that do not require synchronization parameters, + # such as moe's expert parameters + if getattr(param, "no_sync", False): continue model_vars.append(param.detach()) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b6241f6e529..d93b407c1f3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -5995,7 +5995,7 @@ class ParamBase(core.VarBase): self.need_clip = kwargs.get('need_clip', True) - self.is_distributed = False + self.is_distributed = kwargs.get('is_distributed', False) # self.block = default_main_program().global_block() @property diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_fp16.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_fp16.py new file mode 100644 index 00000000000..3e5eedbec9a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_fp16.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle +import numpy as np +from hybrid_parallel_mp_model import TestDistMPTraning +import paddle.distributed.fleet as fleet +import unittest + + +class TestMPFP16(TestDistMPTraning): + def build_optimizer(self, model): + grad_clip = paddle.nn.ClipGradByGlobalNorm(1.0) + scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=0.001, gamma=0.999, verbose=True) + optimizer = paddle.optimizer.SGD(scheduler, + grad_clip=grad_clip, + parameters=model.parameters()) + + model, optimizer = paddle.amp.decorate( + models=model, + optimizers=optimizer, + level='O2', + save_dtype='float32') + + return optimizer + + def train_batch(self, batch, model, optimizer, is_mp): + scaler = paddle.amp.GradScaler(init_loss_scaling=5160) + if is_mp: + scaler = fleet.distributed_scaler(scaler) + with paddle.amp.auto_cast(enable=True, level="O2"): + output = model(batch) + loss = output.mean() + + scaled = scaler.scale(loss) + scaled.backward() + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad() + return scaled + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py index 912849ffbeb..71e873b0e2f 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py @@ -53,6 +53,13 @@ class TestDistPPTraning(unittest.TestCase): } fleet.init(is_collective=True, strategy=strategy) + def build_optimizer(self, model): + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer = paddle.optimizer.SGD(learning_rate=scheduler, + parameters=model.parameters()) + return scheduler, optimizer + def test_pp_model(self): hcg = fleet.get_hybrid_communicate_group() word_size = hcg.get_model_parallel_world_size() @@ -63,10 +70,7 @@ class TestDistPPTraning(unittest.TestCase): #construct model a model_a = AlexNet(10) - scheduler_a = paddle.optimizer.lr.PiecewiseDecay( - boundaries=[2], values=[0.001, 0.002], verbose=True) - optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a, - parameters=model_a.parameters()) + scheduler_a, optimizer_a = self.build_optimizer(model_a) param_len = len(model_a.parameters()) @@ -76,10 +80,7 @@ class TestDistPPTraning(unittest.TestCase): # construct model b model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size) - scheduler_b = paddle.optimizer.lr.PiecewiseDecay( - boundaries=[2], values=[0.001, 0.002], verbose=True) - optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b, - parameters=model_b.parameters()) + scheduler_b, optimizer_b = self.build_optimizer(model_b) model_b = fleet.distributed_model(model_b) optimizer_b = fleet.distributed_optimizer(optimizer_b) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py index 33a04a5e7e1..84d11670027 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py @@ -61,11 +61,14 @@ class TestDistPPTraning(unittest.TestCase): rank_id = dist.get_rank() set_random_seed(1024, dp_id, rank_id) + grad_clip = paddle.nn.ClipGradByGlobalNorm(1.0) + #construct model a model_a = AlexNet(10) scheduler_a = paddle.optimizer.lr.PiecewiseDecay( boundaries=[2], values=[0.001, 0.002], verbose=True) optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a, + grad_clip=grad_clip, parameters=model_a.parameters()) scaler_a = paddle.amp.GradScaler(init_loss_scaling=2**5) @@ -80,6 +83,7 @@ class TestDistPPTraning(unittest.TestCase): scheduler_b = paddle.optimizer.lr.PiecewiseDecay( boundaries=[2], values=[0.001, 0.002], verbose=True) optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b, + grad_clip=grad_clip, parameters=model_b.parameters()) model_b = fleet.distributed_model(model_b) optimizer_b = fleet.distributed_optimizer(optimizer_b) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_clip_grad.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_clip_grad.py new file mode 100644 index 00000000000..de980f3c3f7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_clip_grad.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle +import unittest +from hybrid_parallel_pp_alexnet import TestDistPPTraning + + +class TestPPClipGrad(TestDistPPTraning): + def build_optimizer(self, model): + grad_clip = paddle.nn.ClipGradByGlobalNorm(0.5) + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer = paddle.optimizer.SGD(learning_rate=scheduler, + grad_clip=grad_clip, + parameters=model.parameters()) + return scheduler, optimizer + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py new file mode 100644 index 00000000000..9042cdba976 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from hybrid_parallel_pp_layer import AlexNetPipeDesc, AlexNet + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + + +batch_size = 4 +micro_batch_size = 2 + + +class TestDistPPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + grad_clip = paddle.nn.ClipGradByGlobalNorm(1.0) + + #construct model a + model_a = AlexNet(10) + scheduler_a = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a, + grad_clip=grad_clip, + parameters=model_a.parameters()) + + scaler_a = paddle.amp.GradScaler(init_loss_scaling=2**5) + + # construct model b + model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size) + scheduler_b = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b, + grad_clip=grad_clip, + parameters=model_b.parameters()) + + param_len = len(model_a.parameters()) + parameters = [] + for param in model_a.parameters(): + parameters.append(param.numpy()) + + for idx, param in enumerate(model_b.parameters()): + param.set_value(parameters[idx + pp_id * (param_len // 2)]) + + model_a, optimizer_a = paddle.amp.decorate( + models=model_a, + optimizers=optimizer_a, + level='O2', + save_dtype='float32') + model_b, optimizer_b = paddle.amp.decorate( + models=model_b, + optimizers=optimizer_b, + level='O2', + save_dtype='float32') + + model_b = fleet.distributed_model(model_b) + optimizer_b = fleet.distributed_optimizer(optimizer_b) + scaler_b = paddle.amp.GradScaler(init_loss_scaling=2**5) + scaler_b = fleet.distributed_scaler(scaler_b) + + # construct reader + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=batch_size, drop_last=True) + + for step_id, data in enumerate(train_reader()): + x_data = np.array([x[0] for x in data]).astype('float32').reshape( + batch_size, 1, 28, 28) + y_data = np.array([x[1] for x in data]).astype('int64').reshape( + batch_size, 1) + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) + img.stop_gradient = True + label.stop_gradient = True + + if step_id >= 5: + return True + + with paddle.amp.auto_cast(enable=True, level='O2'): + loss_a = model_a(img, label) + scaler_a.scale(loss_a).backward() + with paddle.amp.auto_cast(enable=False): + scaler_a.minimize(optimizer_a, loss_a) + optimizer_a.clear_grad() + scheduler_a.step() + + loss_b = model_b.train_batch( + [img, label], optimizer_b, scheduler_b, scaler=scaler_b) + + print("loss: ", loss_a.numpy(), loss_b.numpy()) + np.testing.assert_allclose( + loss_a.numpy(), loss_b.numpy(), rtol=5e-3) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py index 2995e4dbf84..8cb1166cd0d 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py @@ -183,21 +183,23 @@ class TestDistMPTraning(unittest.TestCase): strategy=None, is_sharding=True, Optimizer="adam"): - + clip = paddle.nn.ClipGradByGlobalNorm(0.5) if Optimizer == "adam": if is_sharding: optimizer = DygraphShardingOptimizer( hcg=fleet.get_hybrid_communicate_group(), user_defined_strategy=strategy, params=model.parameters(), - inner_optimizer_class=paddle.optimizer.Adam, + inner_optimizer_class=paddle.optimizer.AdamW, learning_rate=0.001, - weight_decay=0.00001, ) + weight_decay=0.00001, + grad_clip=clip) else: - optimizer = paddle.optimizer.Adam( + optimizer = paddle.optimizer.AdamW( parameters=model.parameters(), learning_rate=0.001, - weight_decay=0.00001, ) + weight_decay=0.00001, + grad_clip=clip) else: if is_sharding: optimizer = DygraphShardingOptimizer( @@ -205,10 +207,13 @@ class TestDistMPTraning(unittest.TestCase): user_defined_strategy=strategy, params=model.parameters(), inner_optimizer_class=paddle.optimizer.Momentum, - learning_rate=0.001, ) + learning_rate=0.001, + grad_clip=clip) else: optimizer = paddle.optimizer.Momentum( - learning_rate=0.001, parameters=model.parameters()) + learning_rate=0.001, + parameters=model.parameters(), + grad_clip=clip) return optimizer def build_model_optimizer(self, Optimizer="adam"): diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py index 332603b8129..4a4bcd2b816 100755 --- a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py @@ -92,7 +92,10 @@ class Naive_fc_net(paddle.nn.Layer): return inputs -def run_model(recompute_block=[], recompute_kwargs={}, enable_autocast=False): +def run_model(recompute_block=[], + recompute_kwargs={}, + enable_autocast=False, + pure_fp16=False): gen = paddle.seed(10) gen.manual_seed(10) np.random.seed(10) @@ -118,7 +121,8 @@ def run_model(recompute_block=[], recompute_kwargs={}, enable_autocast=False): x_data = np.random.randn(batch_size, input_size).astype(np.float32) x = paddle.to_tensor(x_data) # x.stop_gradient = False - with paddle.amp.auto_cast(True): + level = 'O2' if pure_fp16 else 'O1' + with paddle.amp.auto_cast(True, level=level): y_pred = model(x) loss = y_pred.mean() if enable_autocast: @@ -196,6 +200,36 @@ class TestPyLayer(unittest.TestCase): recompute_block=[1, 3], enable_autocast=True) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_fp16(self): + def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): + self.assertEqual(loss_ref, loss) + self.assertEqual(param_ref, param) + self.assertEqual(grad_ref, grad) + + # without recompute + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], enable_autocast=True, pure_fp16=True) + + # recompute second block + loss, param, grad = run_model( + recompute_block=[1], enable_autocast=True, pure_fp16=True) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute fourth block + loss, param, grad = run_model( + recompute_block=[3], enable_autocast=True, pure_fp16=True) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute second to fourth block + loss, param, grad = run_model( + recompute_block=[1, 2, 3], enable_autocast=True, pure_fp16=True) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute second & fourth block + loss, param, grad = run_model( + recompute_block=[1, 3], enable_autocast=True, pure_fp16=True) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py index 7a4f7f9fbd6..71c254dabb9 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py @@ -30,9 +30,12 @@ class TestHybridPipeParallel(TestMultipleGpus): def test_hybrid_parallel_shared_weight(self): self.run_mnist_2gpu('hybrid_parallel_shared_weight.py') - def test_pipeline_parallel(self): + def test_pipeline_parallel_amp(self): self.run_mnist_2gpu('hybrid_parallel_pp_amp.py') + def test_pipeline_parallel_fp16(self): + self.run_mnist_2gpu('hybrid_parallel_pp_fp16.py') + def test_hybrid_parallel_transformer(self): self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py') @@ -42,6 +45,9 @@ class TestHybridPipeParallel(TestMultipleGpus): def test_hybrid_parallel_recompute(self): self.run_mnist_2gpu('hybrid_parallel_pp_recompute.py') + def test_hybrid_parallel_pp_clip_grad(self): + self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py') + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_tensor_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_tensor_parallel.py index 4b9d6764bbb..3705deb5ad8 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_tensor_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_tensor_parallel.py @@ -30,6 +30,9 @@ class TestHybridParallel(TestMultipleGpus): def test_hybrid_parallel_mp_amp(self): self.run_mnist_2gpu('hybrid_parallel_mp_amp.py') + def test_hybrid_parallel_mp_fp16(self): + self.run_mnist_2gpu('hybrid_parallel_mp_fp16.py') + def test_hybrid_parallel_mp_clip_grad(self): self.run_mnist_2gpu('hybrid_parallel_mp_clip_grad.py') -- GitLab