diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 3580e85fc89c19bf57daf53774a88ebe02037c84..5d28c2d5cebd92fd724fd10204f7b347e5c426ca 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -142,32 +142,103 @@ class GradientClipHelper(object): return # TODO (JZ-LIANG) revise this for uniform mixed parallelism - def sync_global_norm(self, block, ring_ids): + def sync_global_norm(self, block, ring_ids, mp_rank): """ prune gradient_clip related ops for params that not belong to cur shard prune: square, reduce_sum, elementwise_mul keep: sum, sqrt, elementwise_max, elementwise_div """ - # FIXME(wangxi): mp should prune duplicated param_grads + is_clip_grad_by_global_norm = False + for idx, op in list(enumerate(block.ops)): + if not self._is_gradient_clip_op(op): + continue + if op.type == 'sum': + is_clip_grad_by_global_norm = True + break + if not is_clip_grad_by_global_norm: + # TODO(Yuang Liu): need some extra handles when clip_grad_norm for mp + return + + removed_op_idx = set() + removed_tmp_var = set() + for idx, op in list(enumerate(block.ops)): + if not self._is_gradient_clip_op(op): + continue + if op.type == 'sum': + break + for input_name in op.input_arg_names: + input_var = block.var(input_name) + # NOTE: when mp_degree > 1, some vars will be split into each mp rank. + # However, there still some vars such as Scale, Bias are not split. + # Those not be split vars should only be counted once during grad clip + # by global norm. Those vars either doesn't have is_distributed attr + # or the is_distributed attr has been set as False. + # Therefore, we prune those duplicated vars for grad clip. + if mp_rank >= 1 and (not (hasattr(input_var, 'is_distributed') + and input_var.is_distributed)): + removed_op_idx.add(idx) + for output_name in op.output_arg_names: + removed_tmp_var.add(output_name) + for idx, op in reversed(list(enumerate(block.ops))): if not self._is_gradient_clip_op(op): continue + if idx in removed_op_idx: + block._remove_op(idx, sync=False) - if op.type == "sum": - sum_res = op.desc.output_arg_names()[0] - for ring_id in ring_ids: - if ring_id == -1: continue + for var_name in removed_tmp_var: + block._remove_var(var_name, sync=False) - idx = idx + 1 - block._insert_op_without_sync( - idx, - type='c_allreduce_sum', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={ - 'ring_id': ring_id, - 'op_namescope': "/gradient_clip_model_parallelism", - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Optimize, - }) - return + for idx, op in list(enumerate(block.ops)): + if not self._is_gradient_clip_op(op): + continue + if op.type == 'sum': + # If mp_rank == 0, no extra handles, just allreduce + # If mp_rank >= 1, some extra handles is needed + sum_rst_var = block.var(op.output_arg_names[0]) + if mp_rank >= 1: + reserved_vars = [] + for input_name in op.input_arg_names: + if input_name not in removed_tmp_var: + reserved_vars.append(input_name) + + if len(reserved_vars) > 0: + op.desc.set_input("X", reserved_vars) + else: + # If all input of sum op should be removed, then remove the sum op. + # And set the output's value of sum to 0. + namescope = op.attr("op_namescope") + block._remove_op(idx, sync=False) + fill_constant_op = block._insert_op_without_sync( + idx, + type='fill_constant', + inputs={}, + outputs={'Out': sum_rst_var}, + attrs={ + 'shape': sum_rst_var.shape, + 'dtype': sum_rst_var.dtype, + 'value': 0.0, + OP_ROLE_KEY: OpRole.Optimize + }) + fill_constant_op._set_attr('op_namescope', namescope) + self._insert_allreduce(block, ring_ids, idx, sum_rst_var) + break + + @staticmethod + def _insert_allreduce(block, ring_ids, idx, var): + for ring_id in ring_ids: + if ring_id == -1: + continue + + idx = idx + 1 + block._insert_op_without_sync( + idx, + type='c_allreduce_sum', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + 'ring_id': ring_id, + 'op_namescope': "/gradient_clip_model_parallelism", + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, + }) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 1f96ab07d60a8478fb9efe4a368f8f7e6c66d6f3..f14f1e06624028e31f1bd87f973d7859fa27ec83 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -435,7 +435,6 @@ class ShardingOptimizer(MetaOptimizerBase): main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() - # FIXME(wangxi): mp should prune duplicated param_grads when calc # amp inf_var & clip global_norm_var rings = [self.mp_ring_id, self.pp_ring_id] @@ -446,7 +445,7 @@ class ShardingOptimizer(MetaOptimizerBase): gradientclip_helper = GradientClipHelper(None) gradientclip_helper.sync_global_norm( - main_block, [self.mp_ring_id, self.pp_ring_id]) + main_block, [self.mp_ring_id, self.pp_ring_id], self.mp_rank) def _insert_loss_grad_scale_op(self): main_block = self._main_program.global_block() diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index ed351dcbefdbcc982e54eaac81ce56f5bb888ba9..709b36ed8e32b25ac8913c97f3eb2f51ddae41ee 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4381,7 +4381,7 @@ class PipelineOptimizer(object): persistable=source_var.persistable) else: dest_var = block._clone_variable(source_var, False) - dest_var.stop_gradient = source_var.stop_gradient + self._clone_var_attr(dest_var, source_var) # When use with sharding, allreduce_sum and allreduce_max # used for global gradient clip and amp will be added by sharding. op_idx += 1 @@ -4547,9 +4547,14 @@ class PipelineOptimizer(object): persistable=ref_var.persistable, is_data=ref_var.is_data, need_check_feed=ref_var.desc.need_check_feed()) - new_var.stop_gradient = ref_var.stop_gradient + self._clone_var_attr(new_var, ref_var) return new_var + def _clone_var_attr(self, dest, src): + dest.stop_gradient = src.stop_gradient + if hasattr(src, 'is_distributed'): + dest.is_distributed = src.is_distributed + def _strip_grad_suffix(self, name): """ Strip the grad suffix from the given variable name @@ -5209,6 +5214,8 @@ class PipelineOptimizer(object): persistable=True, stop_gradient=False) real_param = main_block.var(param) + if hasattr(real_param, 'is_distributed'): + merged_grad_var.is_distributed = real_param.is_distributed tmp_size = self._get_var_size(real_grad) # two strategies for splitting the grad # 1. the current segment's size reach the user defined grad_size_in_MB diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 1dd368f0848c1dd36e2e0d687ec803b49af2b6e8..6b0a7b79c232cca3e69a037f9ece6ef0f168d083 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -658,6 +658,33 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'c_gen_nccl_id', 'c_comm_init' ]) + self.assertEqual(main_prog_op_types, [ + 'partial_recv', 'partial_allgather', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'softmax', 'cast', 'cross_entropy2', 'mean', + 'elementwise_mul', 'fill_constant', 'elementwise_mul_grad', + 'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream', + 'partial_send', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'c_sync_comm_stream', 'check_finite_and_unscale', + 'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast', + 'update_loss_scaling', 'fill_constant', 'c_allreduce_sum', + 'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max', + 'elementwise_div', 'elementwise_mul', 'elementwise_mul', + 'elementwise_mul', 'elementwise_mul', 'elementwise_mul', + 'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum', + 'momentum', 'momentum', 'momentum', 'momentum', 'momentum', + 'momentum', 'momentum' + ]) + # pp + mp, partial send recv self.assertIn('partial_recv', main_prog_op_types) self.assertIn('partial_allgather', main_prog_op_types)