diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index eb5767ec4d343634a6e91916da75071cd155fa00..c110ba805575c6abda11b26dadfd8fc304d9eedb 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -227,12 +227,18 @@ def get_valid_op_role(block, insert_idx): return OpRole.Forward or OpRole.Backward """ op_role = block.ops[insert_idx].attr('op_role') - if (insert_idx >= len(block.ops)) or ( - op_role in [int(OpRole.Backward), int(OpRole.Optimize)]): - return OpRole.Backward + # if (insert_idx >= len(block.ops)) or ( + # op_role in [int(OpRole.Backward), int(OpRole.Optimize)]): + # return OpRole.Backward + # if op_role in [int(OpRole.Forward), int(OpRole.Loss)]: + # return OpRole.Forward + + # return get_valid_op_role(block, insert_idx + 1) + if insert_idx >= len(block.ops): return OpRole.Optimize + if op_role == int(OpRole.Backward): return OpRole.Backward + if op_role == int(OpRole.Optimize): return OpRole.Optimize if op_role in [int(OpRole.Forward), int(OpRole.Loss)]: return OpRole.Forward - return get_valid_op_role(block, insert_idx + 1) @@ -480,6 +486,9 @@ def save_persistables(exe, dirname, main_program, filename=None): This function handles the model saving for sharding training. """ + if main_program._pipeline_opt: + main_program = main_program._pipeline_opt['section_program']['program'] + def is_opt_vars(var): # NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer # now only Momentum and adam are compatible with sharding diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 89ee08126621340a3f9348453921053816d02ac7..b17084a979e5deeea8c44dae41379f261a2b562c 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -182,10 +182,17 @@ class ShardingOptimizer(MetaOptimizerBase): if not self._shard.has_param(param_name): main_block._remove_op(idx) - param_list = [] - for param_name, grad_name in params_grads: - if self._shard.has_param(param_name): - param_list.append(param_name) + for idx, op in reversed(list(enumerate(main_block.ops))): + if op.type != 'cast': continue + in_name = op.input_arg_names[0] + if in_name not in self._params: continue + #if self._shard.has_param(param_name): continue + if in_name not in main_block.vars: + main_block._remove_op(idx) + #param_list = [] + #for param_name, grad_name in params_grads: + # if self._shard.has_param(param_name): + # param_list.append(param_name) #pp_optimizer._clear_gradients(main_block, param_list) pp_optimizer._accumulate_gradients(main_block) #if not self._shard.has_param(param_name): continue @@ -359,13 +366,19 @@ class ShardingOptimizer(MetaOptimizerBase): # config sharding & dp groups self._init_comm() - # inner & outer model parallelism + # global + print("global_group_endpoints:", self.global_group_endpoints) + print("global_rank:", self.global_rank) + print("global_ring_id:", self.global_group_id) if self._as_outer_parallelism: self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, self.global_group_endpoints, self.global_rank, self.global_group_id, True) + print("mp_group_endpoints:", self.mp_group_endpoints) + print("mp_rank:", self.mp_rank) + print("mp_ring_id:", self.mp_group_id) if self._as_outer_parallelism: self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, @@ -387,6 +400,9 @@ class ShardingOptimizer(MetaOptimizerBase): self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True) # pp if self.use_pipeline: + print("pp_group_endpoints:", self.pp_group_endpoints) + print("pp_rank:", self.pp_rank) + print("pp_ring_id:", self.pp_ring_id) self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, False) @@ -660,6 +676,7 @@ class ShardingOptimizer(MetaOptimizerBase): fill_constant_vars) # step4: add `cast` ops + print("cast_ops:", cast_ops) insert_cast_ops(block, segment._end_idx, cast_ops) # step5: add broadcast ops