diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py old mode 100755 new mode 100644 index 15488f5f654c1e380e4426ee7deb95138c688d7d..f0cbbd9278345fb786bbdcb8e15a9cc8da79d79b --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -13,6 +13,11 @@ # limitations under the License. +import distutils.util +import os + +import numpy as np + import paddle from paddle import framework from paddle.autograd import no_grad @@ -42,9 +47,131 @@ class HybridParallelClipGrad: self._clip = clip self._hcg = hcg self.not_sharding_stage1 = True + self._vpp_chunk_num = None + self._force_align_vpp_grad_sum_order = distutils.util.strtobool( + os.getenv('FLAGS_force_align_vpp_grad_sum_order', '1') + ) + + def _get_vpp_chunk_num(self, params_grads): + chunk_num = -1 + for p, g in params_grads: + if g is None: + continue + chunk_info = getattr(p, '_chunk_info', {}) + cur_chunk_num = chunk_info.get('chunk_num', -1) + if chunk_num < 0: + chunk_num = cur_chunk_num + else: + assert chunk_num == cur_chunk_num + return chunk_num + + @no_grad() + def _vpp_dygraph_clip(self, params_grads, chunk_num): + pp_group = self._hcg.get_pipe_parallel_group() + pp_rank = self._hcg.get_stage_id() + pp_size = self._hcg.get_pipe_parallel_world_size() + + if self._vpp_chunk_num is None: + all_chunk_nums = [] + paddle.distributed.all_gather_object( + all_chunk_nums, chunk_num, group=pp_group + ) + assert all([chunk_num == n for n in all_chunk_nums]) + self._vpp_chunk_num = chunk_num + else: + assert self._vpp_chunk_num == chunk_num + + sum_square_metas = [] + for p, g in params_grads: + if g is None: + continue + not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or ( + hasattr(p, 'is_firstly_shared') + and getattr(p, 'is_firstly_shared', True) + ) + + chunk_id = p._chunk_info['chunk_id'] + if not_shared_enable: + if g.type == core.VarDesc.VarType.SELECTED_ROWS: + merge_grad = clip.merge_selected_rows(g) + g = clip.get_tensor_from_selected_rows(merge_grad) + square = paddle.square(g) + sum_square = paddle.sum(square) + layer_id = chunk_id * pp_size + pp_rank + sum_square_metas.append( + [layer_id, p.is_distributed, sum_square.numpy()] + ) + + all_sum_square_metas = [] + paddle.distributed.all_gather_object( + all_sum_square_metas, + sum_square_metas, + group=pp_group, + ) + + # order: FP16, BF16, FP32 + sum_square_dist = [[], [], []] + sum_square_not_dist = [[], [], []] + + pp_stage = self._hcg.get_stage_id() + for i, metas in enumerate(all_sum_square_metas): + for layer_id, is_distributed, sum_square in metas: + rank = layer_id // chunk_num + assert rank < pp_size + if rank != pp_rank: + continue + if sum_square.dtype == np.float32: + idx = 2 + elif sum_square.dtype == np.float16: + idx = 0 + else: + assert ( + sum_square.dtype == np.uint16 + ), "The data type of grad must be FP32, FP16 or BF16, but got {}".format( + sum_square.dtype + ) + idx = 1 + + if is_distributed: + sum_square_dist[idx].append(sum_square) + else: + sum_square_not_dist[idx].append(sum_square) + + global_norm_var_dist = self._add_sum_squares(sum_square_dist) + global_norm_var_not_dist = self._add_sum_squares(sum_square_not_dist) + + return self._comm_and_clip( + params_grads, global_norm_var_dist, global_norm_var_not_dist + ) + + def _add_sum_squares(self, sum_squares): + norm_sum = None + for sq in sum_squares: + if len(sq) == 0: + continue + + sq = np.concatenate(sq, axis=0).flatten() + sq = paddle.to_tensor(sq) + sq = paddle.sum(sq) + if sq.dtype != paddle.float32: + sq = sq.astype(paddle.float32) + + if norm_sum is None: + norm_sum = sq + else: + norm_sum = norm_sum + sq + + if norm_sum is None: + norm_sum = paddle.to_tensor([0.0], dtype=paddle.float32) + + return norm_sum @no_grad() def _dygraph_clip(self, params_grads): + chunk_num = self._get_vpp_chunk_num(params_grads) + if chunk_num > 0 and self._force_align_vpp_grad_sum_order: + return self._vpp_dygraph_clip(params_grads, chunk_num) + sum_square_dist_fp16 = [] sum_square_dist_bf16 = [] sum_square_dist_fp32 = [] @@ -159,7 +286,13 @@ class HybridParallelClipGrad: + global_norm_not_dist_bf16 + global_norm_not_dist_fp32 ) + return self._comm_and_clip( + params_grads, global_norm_var_dist, global_norm_var_not_dist + ) + def _comm_and_clip( + self, params_grads, global_norm_var_dist, global_norm_var_not_dist + ): # add all reduce to get global norm of distributed params_and_grads if self._hcg.get_model_parallel_world_size() > 1: sharding_flag = False diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 5f1d32fc6399db597039f1b74a9714992a23c2a8..bb8937d79fb7d8ab035083228ef0767f50b1a351 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -638,6 +638,13 @@ class PipelineLayer(nn.Layer): logger.info(f"loss: {self._loss_fn.__class__.__name__}") def _build_layer_with_interleave(self): + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, + ) + + orig_rng_state = paddle.get_rng_state() + orig_rng_tracker = get_rng_state_tracker().get_states_tracker() + for i in range(len(self._start_poss)): start = self._start_poss[i] end = self._end_poss[i] @@ -648,10 +655,21 @@ class PipelineLayer(nn.Layer): self._model_chunks.append(chunk) self.add_sublayer(str(start), chunk) + paddle.set_rng_state(orig_rng_state) + get_rng_state_tracker().set_states_tracker(orig_rng_tracker) + def _build_layer(self): start = self._start_pos end = self._end_pos + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, + ) + + orig_rng_state = paddle.get_rng_state() + orig_rng_tracker = get_rng_state_tracker().get_states_tracker() self.run_function = self._build_layer_impl(start, end) + paddle.set_rng_state(orig_rng_state) + get_rng_state_tracker().set_states_tracker(orig_rng_tracker) def _build_layer_impl(self, start, end): if self._num_virtual_pipeline_stages > 1: @@ -661,13 +679,6 @@ class PipelineLayer(nn.Layer): # For 1f1b scheduler, just use run_function list run_function = self.run_function - from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( - get_rng_state_tracker, - ) - - orig_rng_state = paddle.get_rng_state() - orig_rng_tracker = get_rng_state_tracker().get_states_tracker() - for index, layer in enumerate(self._layers_desc[start:end]): layer_index = start + index @@ -722,8 +733,6 @@ class PipelineLayer(nn.Layer): else: run_function.append(layer) - paddle.set_rng_state(orig_rng_state) - get_rng_state_tracker().set_states_tracker(orig_rng_tracker) return run_function def forward_function(self, start, end): diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py old mode 100755 new mode 100644 index 7ca44f6e08c4ba3c43f8920d0efda7fef9c7c5ff..4a1b1ad72c05279fcf938b89c4d3c23015a50327 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -709,6 +709,13 @@ class PipelineParallelWithInterleave(PipelineParallel): assert len(self.model_chunks) == self.num_model_chunks self._virtual_pp_world_size = self.num_model_chunks self._virtual_pp_rank = 0 + self._assign_vpp_info(self.model_chunks) + + def _assign_vpp_info(self, chunks): + chunk_num = len(chunks) + for i, chunk in enumerate(chunks): + for p in chunk.parameters(): + p._chunk_info = {"chunk_id": i, "chunk_num": chunk_num} def _get_virtual_pp_rank(self, micro_step, forward): virtual_pp_stage = micro_step % (