diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index a0e18eb16b60165b1082cba92b7b84da90f3d169..52ef843aa0d751b0b981a3d93e81af6d6f121275 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -14,7 +14,7 @@ import paddle from paddle.fluid import core, unique_name from functools import reduce -from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op +from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY import re @@ -431,15 +431,19 @@ def insert_reduce_ops(block, reduce_vars, shard, op_role=OpRole.Backward, - use_calc_stream=False): + use_calc_stream=False, + rank=None): """ _add_allreduce_ops """ + grad_in_this_device = [] for var in reduce_vars: root_id = get_grad_device(var, shard) assert root_id >= 0, "root id should be a positive int, but now root id is {}".format( root_id) + if rank is not None and rank == root_id: + grad_in_this_device.append(var) block._insert_op_without_sync( insert_idx, type='c_reduce_sum', @@ -451,16 +455,23 @@ def insert_reduce_ops(block, 'use_calc_stream': use_calc_stream, OP_ROLE_KEY: op_role }) - return + + return grad_in_this_device def get_grad_device(grad_name, shard): assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format( grad_name) base_name = None - # mind the traversal order + # NOTE: mind the traversal order possible_suffixes = [ - '.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD' + # sharding gm + '.cast_fp16@GRAD@MERGED', + '.cast_fp16@GRAD', + # pipeline + '@GRAD@MERGED@FP16', + '@GRAD@MERGED', + '@GRAD', ] for suffix in possible_suffixes: if suffix in grad_name: @@ -487,6 +498,15 @@ def get_first_check_finite_and_unscale_op_idx(block, raise_error=True): return -1 +def get_first_optimize_op_idx(block): + first_opt_op_idx = None + for index, op in reversed(tuple(enumerate(block.ops))): + if is_backward_op(op) and first_opt_op_idx is None: + first_opt_op_idx = index + 1 + break + return first_opt_op_idx + + def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): """ _add_broadcast_ops @@ -672,23 +692,6 @@ def save_persistables(exe, dirname, main_program, filename=None): return -def get_grad_device(grad_name, shard): - assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format( - grad_name) - base_name = None - # mind the traversal order - possible_suffixes = ['.cast_fp16@GRAD', '@GRAD'] - for suffix in possible_suffixes: - if suffix in grad_name: - base_name = re.sub(suffix, '', grad_name) - break - - assert base_name in shard.global_param2device, "[{}] should be a param variable.".format( - base_name) - - return shard.global_param2device[base_name] - - def append_naive_sync(block, sync_var, ring_id): # NOTE (JZ-LIANG) update this to use barrier sync for more elegent logic # sync within global diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index a5df9486da46563298bd022c6701b05088a608d4..93901b38873b95691696f08cf3ef30e935142c25 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -294,6 +294,8 @@ class ShardingOptimizer(MetaOptimizerBase): if self.pp_degree == 1: return strategy = self.user_defined_strategy + fp16_allreduce = strategy.fp16_allreduce + main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() @@ -317,33 +319,44 @@ class ShardingOptimizer(MetaOptimizerBase): main_block._remove_op(idx) accumulated_grad_names = self._pp_optimizer._accumulate_gradients( - main_block) - # accumulated_grad_names = sorted(accumulated_grad_names) + main_block, fp16_allreduce=fp16_allreduce) + + len_of_ops = len(main_block.ops) + first_optimize_op_index = get_first_optimize_op_idx(main_block) + if self.pp_allreduce_in_optimize: - print("persistable FP32 grad: ") - print(accumulated_grad_names) - first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( - main_block, raise_error=strategy.amp) - insert_reduce_ops( + logger.info("Pipeline Persistable grad is {}".format( + accumulated_grad_names)) + # FIXME(wangxi): accumulated_grad get from pipeline is not + # include sharding's param@BroadCast grad when + # pp_allreduce_in_optimize + accumulated_grad_names = insert_reduce_ops( main_block, first_optimize_op_index, self.sharding_ring_id, accumulated_grad_names, self._shard, core.op_proto_and_checker_maker.OpRole.Optimize, - use_calc_stream=True) + use_calc_stream=True, + rank=self.sharding_rank) + + logger.info("PP-Sharding grad is {}".format(accumulated_grad_names)) + first_optimize_op_index += (len(main_block.ops) - len_of_ops) + len_of_ops = len(main_block.ops) + if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp": - first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( - main_block, raise_error=strategy.amp) - if first_optimize_op_index >= 0: - insert_allreduce_ops( - main_block, - first_optimize_op_index, - self.dp_ring_id, - accumulated_grad_names, - core.op_proto_and_checker_maker.OpRole.Optimize, - use_calc_stream=True, - user_defined_strategy=strategy) + insert_allreduce_ops( + main_block, + first_optimize_op_index, + self.dp_ring_id, + accumulated_grad_names, + core.op_proto_and_checker_maker.OpRole.Optimize, + use_calc_stream=True, + user_defined_strategy=strategy) + first_optimize_op_index += (len(main_block.ops) - len_of_ops) + len_of_ops = len(main_block.ops) + + # FIXME(wangxi): if fp16_allreduce, put cast fp16->fp32 to there? def _adapt_amp_clip_without_sharding(self): if self.sharding_degree > 1: return diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index ab3dbad1ef326dabb3578d79207853cf96028003..7ad94f4be3eb2fce210937e70d01e0615cfd880b 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4528,7 +4528,7 @@ class PipelineOptimizer(object): op._rename_input(old_name, new_name) op._rename_output(old_name, new_name) - def _create_var(self, block, ref_var, name): + def _create_var(self, block, ref_var, name, dtype=None): """ Create a new var for block, which has the same type, shape and dtype as ref_var, then rename it with the @@ -4537,7 +4537,7 @@ class PipelineOptimizer(object): new_var = block.create_var( name=name, shape=ref_var.shape, - dtype=ref_var.dtype, + dtype=ref_var.dtype if dtype is None else dtype, type=ref_var.type, lod_level=ref_var.lod_level, persistable=ref_var.persistable, @@ -5044,7 +5044,10 @@ class PipelineOptimizer(object): new_grad_name = name + "@MERGED" self._rename_arg(op, name, new_grad_name) - def _accumulate_gradients(self, block, pp_allreduce_in_optimize=False): + def _accumulate_gradients(self, + block, + pp_allreduce_in_optimize=False, + fp16_allreduce=False): """ Create a new merged gradient for each parameter and accumulate the corresponding gradient to it. @@ -5052,6 +5055,9 @@ class PipelineOptimizer(object): merged_gradient_names = [] first_opt_op_idx = None + merged_suffix = '@MERGED@FP16' if fp16_allreduce else '@MERGED' + dtype = paddle.float16 if fp16_allreduce else None + for index, op in reversed(tuple(enumerate(list(block.ops)))): # remove the cast op of fp16 grad to fp32 grad if self._is_optimize_op(op) and op.type == 'cast': @@ -5062,12 +5068,10 @@ class PipelineOptimizer(object): block._remove_op(index) continue - if self._is_backward_op(op) and not first_opt_op_idx: + if self._is_backward_op(op) and first_opt_op_idx is None: first_opt_op_idx = index + 1 # no optimize phase if first_opt_op_idx == len(block.ops): return - if block.ops[first_opt_op_idx].type == "c_sync_comm_stream": - first_opt_op_idx += 1 if self._is_backward_op(op) and ( self._op_role_var_key in op.attr_names): @@ -5079,12 +5083,14 @@ class PipelineOptimizer(object): param_name = op_role_var[i] if not block.has_var(param_name): continue if '@BroadCast' in param_name: continue + param_grad_name = param_name + core.grad_var_suffix() - merged_param_grad_name = param_grad_name + '@MERGED' + merged_param_grad_name = param_grad_name + merged_suffix if not block.has_var(merged_param_grad_name): self._create_var(block, block.vars[param_name], - merged_param_grad_name) + merged_param_grad_name, dtype) assert block.has_var(merged_param_grad_name) + param_grad_var = block.var(param_grad_name) merged_param_grad_var = block.var(merged_param_grad_name) merged_param_grad_var.persistable = True @@ -5103,22 +5109,18 @@ class PipelineOptimizer(object): offset += 1 grad_name = op_role_var[i + 1] grad_var = block.vars[grad_name] - if not 'cast_fp16' in grad_name: - block._insert_op( - index=first_opt_op_idx + offset, - type='sum', - inputs={'X': [grad_var, merged_param_grad_var]}, - outputs={'Out': merged_param_grad_var}, - attrs={ - self._op_role_key: self._op_role.Backward, - }) - offset += 1 - merged_gradient_names.append(merged_param_grad_name) - else: - # cast gradient to fp32 to accumulate to merged gradient + + is_fp16_grad = 'cast_fp16' in grad_name + need_cast = (is_fp16_grad is not fp16_allreduce) + + if need_cast: + # if fp16_allreduce: + # cast grad to fp16 to accumulate to merged gradient + # else: + # cast grad to fp32 to accumulate to merged gradient cast_grad_var_name = param_grad_name + '@TMP' - cast_grad_var = self._create_var(block, param_grad_var, - cast_grad_var_name) + cast_grad_var = self._create_var( + block, param_grad_var, cast_grad_var_name, dtype) cast_grad_var.persistable = False block._insert_op( index=first_opt_op_idx + offset, @@ -5131,18 +5133,52 @@ class PipelineOptimizer(object): self._op_role_key: self._op_role.Backward, }) offset += 1 - block._insert_op( - index=first_opt_op_idx + offset, - type='sum', - inputs={ - 'X': [merged_param_grad_var, cast_grad_var] - }, - outputs={'Out': merged_param_grad_var}, - attrs={ - self._op_role_key: self._op_role.Backward, - }) - offset += 1 - merged_gradient_names.append(merged_param_grad_name) + grad_var = cast_grad_var + + block._insert_op( + index=first_opt_op_idx + offset, + type='sum', + inputs={'X': [merged_param_grad_var, grad_var]}, + outputs={'Out': merged_param_grad_var}, + attrs={self._op_role_key: self._op_role.Backward, }) + offset += 1 + merged_gradient_names.append(merged_param_grad_name) + + if not fp16_allreduce: return merged_gradient_names + + first_opt_op_idx = None + for index, op in reversed(tuple(enumerate(list(block.ops)))): + if self._is_backward_op(op) and first_opt_op_idx is None: + first_opt_op_idx = index + 1 + break + assert first_opt_op_idx is not None + + # insert cast op from fp16->fp32 + # FIXME(wangxi): maybe put in sharding is better, for some grad + # is not in sharding device. + for fp16_grad_name in merged_gradient_names: + grad_name = fp16_grad_name.replace('@FP16', '') + param_name = fp16_grad_name.replace('@GRAD@MERGED@FP16', '') + + if not block.has_var(grad_name): + self._create_var(block, block.vars[param_name], grad_name) + assert block.has_var(grad_name) + + fp16_grad_var = block.var(fp16_grad_name) + grad_var = block.var(grad_name) + grad_var.persistable = False + + block._insert_op( + index=first_opt_op_idx, + type='cast', + inputs={'X': fp16_grad_var}, + outputs={'Out': grad_var}, + attrs={ + 'in_dtype': fp16_grad_var.dtype, + 'out_dtype': grad_var.dtype, + self._op_role_key: self._op_role.Optimize, + }) + return merged_gradient_names def _add_sub_blocks(self, main_block, program_list): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index b7cf9dfaec5760e37eadf3d84a439617c5436e8a..d70a58c7d8ab41b0b1907c8543841aaf343911d0 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -552,9 +552,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', - 'c_sync_comm_stream', 'fill_constant', 'sum', 'fill_constant', + 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', - 'fill_constant', 'sum', 'momentum', 'momentum', 'momentum', + 'c_sync_comm_stream', 'momentum', 'momentum', 'momentum', 'momentum', 'momentum' ]) @@ -694,6 +694,171 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002']) + def test_hybrid_with_pp_dp_amp_fp16allreduce(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.pp_net(train_prog, startup_prog) + strategy.amp = True + strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], } + strategy.sharding = True + strategy.sharding_configs = { + "sharding_degree": 1, + "mp_degree": 1, + "pp_degree": 2, + "dp_degree": 2, + } + strategy.pipeline = True + strategy.pipeline_configs = { + "schedule_mode": "1F1B", + "micro_batch_size": 2, + "accumulate_steps": 4, + } + strategy.fp16_allreduce = True + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + train_prog = train_prog._pipeline_opt['section_program'] + startup_prog = startup_prog._pipeline_opt['startup_program'] + + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check program + startup_prog_op_types = [op.type for op in startup_prog_ops] + main_prog_op_types = [op.type for op in main_prog_ops] + + # ring: mp, pp_group, pp_pair, pp_pair + self.assertEqual(startup_prog_op_types, [ + 'uniform_random', 'fill_constant', 'uniform_random', + 'fill_constant', 'uniform_random', 'fill_constant', + 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream' + ]) + + self.assertEqual(main_prog_op_types, [ + 'recv_v2', 'cast', 'mul', 'cast', 'elementwise_add', 'tanh', 'cast', + 'mul', 'cast', 'elementwise_add', 'tanh', 'cast', 'mul', 'cast', + 'elementwise_add', 'tanh', 'cast', 'mul', 'cast', 'elementwise_add', + 'softmax', 'cross_entropy2', 'mean', 'elementwise_mul', + 'fill_constant', 'scale', 'scale', 'elementwise_mul_grad', + 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'cast', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', + 'fill_constant', 'cast', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', + 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum', + 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', + 'c_sync_comm_stream', 'check_finite_and_unscale', 'cast', + 'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum', + 'momentum', 'momentum', 'momentum', 'momentum', 'momentum', + 'momentum', 'momentum' + ]) + + # amp check_finite_and_unscale, allreduce(pp) + self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1) + + # should has ring id for pp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(self.pp_pair_ring_id, created_ring_ids) + self.assertIn(self.dp_ring_id, created_ring_ids) + + # check correctness of pp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_0": + pp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of dp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_3": + dp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + + def test_hybrid_with_sharding_pp_amp_fp16allreduce_in_optimize(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.pp_net(train_prog, startup_prog) + strategy.amp = True + strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], } + strategy.sharding = True + strategy.sharding_configs = { + "segment_broadcast_MB": 0.1, + "sharding_degree": 2, + "mp_degree": 1, + "pp_degree": 2, + "dp_degree": 1, + 'pp_allreduce_in_optimize': True, + } + strategy.pipeline = True + strategy.pipeline_configs = { + "schedule_mode": "1F1B", + "micro_batch_size": 2, + "accumulate_steps": 4, + } + strategy.fp16_allreduce = True + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + train_prog = train_prog._pipeline_opt['section_program'] + startup_prog = startup_prog._pipeline_opt['startup_program'] + + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check program + startup_prog_op_types = [op.type for op in startup_prog_ops] + main_prog_op_types = [op.type for op in main_prog_ops] + + # ring: sharding, pp_group, pp_pair, pp_pair + self.assertEqual(startup_prog_op_types, [ + 'fill_constant', 'uniform_random', 'fill_constant', + 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', + 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', + 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init' + ]) + + # FIXME(wangxi): some bug in sharding+pp with pp_allreduce_in_optimize + # self.assertEqual(main_prog_op_types, []) + + # amp check_finite_and_unscale, allreduce(pp) + self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 2) + + # should has ring id for pp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(self.sharding_ring_id, created_ring_ids) + self.assertIn(self.pp_pair_ring_id, created_ring_ids) + + # check correctness of sharding group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_0": + sharding_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of pp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_1": + pp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002']) + if __name__ == "__main__": unittest.main()