diff --git a/python/paddle/distributed/fleet/meta_optimizers/model_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/model_parallel_optimizer.py index af84682f45041170cf1d8660112b4581a6d6ac6b..1511769350477094f858688f41ddc31593e4b39a 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/model_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/model_parallel_optimizer.py @@ -22,9 +22,10 @@ from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_u class ModelParallelHelper(object): - def __init__(self, role_maker, wait_port=True): + def __init__(self, role_maker, wait_port=True, megatron_dp=False): self.wait_port = wait_port self.role_maker = role_maker + self.megatron_dp = megatron_dp def update_startup_program(self, startup_program=None, @@ -48,24 +49,29 @@ class ModelParallelHelper(object): mp_endpoints, mp_rank, 0, self.wait_port) self._broadcast_params(0, broadcast_distributed_weight=False) - mp_num = len(endpoints) // inner_parallelism - if mp_num == 1: return - # Create rings for gpus as the same model parallel part - eps = [] - dp_rank = rank // inner_parallelism - dp_id = rank % inner_parallelism - #if dp_rank == 1: dp_rank =0 - #if dp_rank == 0: dp_rank =1 - ring_id = 1 - for idx, ep in enumerate(endpoints): - if idx % inner_parallelism == dp_id: - eps.append(ep) - #ep = eps.pop(0) - #eps.insert(1, ep) - print("data parallel eps:{}, rank{}".format(eps, dp_rank)) - self._init_communicator(self.startup_program, current_endpoint, eps, - dp_rank, ring_id, self.wait_port) - self._broadcast_params(ring_id, broadcast_distributed_weight=True) + print("megatron group size: {}".format(inner_parallelism)) + print("megatron rank: {}".format(mp_rank)) + print("megatron endpoints: {}".format(mp_endpoints)) + + if self.megatron_dp: + mp_num = len(endpoints) // inner_parallelism + if mp_num == 1: return + # Create rings for gpus as the same model parallel part + eps = [] + dp_rank = rank // inner_parallelism + dp_id = rank % inner_parallelism + #if dp_rank == 1: dp_rank =0 + #if dp_rank == 0: dp_rank =1 + ring_id = 1 + for idx, ep in enumerate(endpoints): + if idx % inner_parallelism == dp_id: + eps.append(ep) + #ep = eps.pop(0) + #eps.insert(1, ep) + print("data parallel eps:{}, rank{}".format(eps, dp_rank)) + self._init_communicator(self.startup_program, current_endpoint, eps, + dp_rank, ring_id, self.wait_port) + self._broadcast_params(ring_id, broadcast_distributed_weight=True) def _init_communicator(self, program, current_endpoint, endpoints, rank, ring_id, wait_port): @@ -129,9 +135,14 @@ class ModelParallelOptimizer(MetaOptimizerBase): def __init__(self, optimizer): super(ModelParallelOptimizer, self).__init__(optimizer) self.inner_opt = optimizer - # we do not allow meta optimizer to be inner optimizer currently - self.meta_optimizers_white_list = [] + self.meta_optimizers_white_list = [ + "RecomputeOptimizer", + "AMPOptimizer", + "LarsOptimizer", + "LambOptimizer", + ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] + self.megatron_dp = False def _set_basic_info(self, loss, role_maker, user_defined_optimizer, user_defined_strategy): @@ -156,6 +167,10 @@ class ModelParallelOptimizer(MetaOptimizerBase): dist_strategy.model_parallel = True dist_strategy.model_parallel_configs = {"parallelism": 1, } + # the following function will be used by AMP if both Megatron and AMP are turn on together. + def apply_gradients(self, params_grads): + return self.minimize_impl(params_grads=params_grads) + def minimize_impl(self, loss, startup_program=None, @@ -167,6 +182,8 @@ class ModelParallelOptimizer(MetaOptimizerBase): if startup_program is None: self.startup_program = fluid.default_startup_program() + # (TODO) check the order of metaoptimizer + # (TODO) check the params_grads optimize_ops, params_grads = self.inner_opt.minimize( loss, self.startup_program, parameter_list, no_grad_set) @@ -179,10 +196,12 @@ class ModelParallelOptimizer(MetaOptimizerBase): self.inner_parallelism) assert self.nranks % self.inner_parallelism == 0 - # data parallelism - dp_parallelism = self.nranks // self.inner_parallelism - self._transpile_main_program(loss, dp_parallelism) + if self.megatron_dp: + # data parallelism + dp_parallelism = self.nranks // self.inner_parallelism + + self._transpile_main_program(loss, dp_parallelism) return optimize_ops, params_grads def _transpile_main_program(self, loss, dp_parallelism): diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py index 03b36262a4fb1e095eb17fa57bf27b5c9f3cf74c..c2177548005e1b2c27a7f6814e3a325cd299a645 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -73,7 +73,7 @@ class FP16Utils(object): @staticmethod def prune_fp16(block, shard, reduced_grads_to_param, ring_id): """ - 1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard + 1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard 2. revise amp inifine grad checking for sharding """ # remove cast @@ -103,6 +103,7 @@ class FP16Utils(object): op._rename_input(inf_var_name, inf_var_name + "@sharding") if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: reversed_x = [] + reversed_x_paramname = [] for input_name in op.desc.input('X'): param_name = input_name.strip("@GRAD") if param_name not in shard.global_params: @@ -111,12 +112,26 @@ class FP16Utils(object): "be grads, but {} is not a grad".format(input_name)) if shard.has_param(param_name): reversed_x.append(input_name) + reversed_x_paramname.append(param_name) op.desc.set_input('X', reversed_x) op.desc.set_output('Out', reversed_x) + + # the grad checking should take the all and only param in the current shard + to_check_param = set(reversed_x_paramname) + should_check_param = set(shard.global_params).intersection( + set([ + param + for param, worker_idx in shard.global_param2device. + items() if worker_idx == shard.worker_idx + ])) + assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format( + should_check_param - to_check_param, + to_check_param - should_check_param) + if update_loss_scaling_op_idx == -1: return inf_var = block.var(inf_var_name) - inf_var_fp32 = block.create_var( + inf_var_int32 = block.create_var( name=inf_var_name + "@cast_int32", shape=inf_var.shape, dtype=core.VarDesc.VarType.INT32) @@ -128,32 +143,36 @@ class FP16Utils(object): update_loss_scaling_op_idx, type='cast', inputs={'X': inf_var}, - outputs={'Out': inf_var_fp32}, + outputs={'Out': inf_var_int32}, attrs={ "in_dtype": inf_var.dtype, - "out_dtype": inf_var_fp32.dtype, + "out_dtype": inf_var_int32.dtype, OP_ROLE_KEY: OpRole.Optimize }) - insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, - [inf_var_fp32]) + # this allreduce communication should not overlap with calc + # insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, + # [inf_var_int32]) block._insert_op_without_sync( - update_loss_scaling_op_idx + 2, + update_loss_scaling_op_idx + 1, type='c_allreduce_max', - inputs={'X': inf_var_fp32}, - outputs={'Out': inf_var_fp32}, - attrs={'ring_id': ring_id, - OP_ROLE_KEY: OpRole.Optimize}) + inputs={'X': inf_var_int32}, + outputs={'Out': inf_var_int32}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize + }) - comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, - ring_id, [inf_var_fp32]) + # comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, + # ring_id, [inf_var_int32]) block._insert_op_without_sync( - update_loss_scaling_op_idx + 3 + comm_op_num, + update_loss_scaling_op_idx + 2, type='cast', - inputs={'X': inf_var_fp32}, + inputs={'X': inf_var_int32}, outputs={'Out': inf_var_sharding}, attrs={ - "in_dtype": inf_var_fp32.dtype, + "in_dtype": inf_var_int32.dtype, "out_dtype": inf_var_sharding.dtype, OP_ROLE_KEY: OpRole.Optimize })