From d3105dbf505cd7aee39ee0488bf7c9d8c2cd048c Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Sun, 7 Feb 2021 18:11:45 +0800 Subject: [PATCH] update --- .../sharding/gradient_clip_helper.py | 53 +++++++++++++------ .../fleet/meta_optimizers/sharding/utils.py | 14 ++--- 2 files changed, 39 insertions(+), 28 deletions(-) mode change 100755 => 100644 python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index c6aee792fcf..3b0cfe21a79 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -16,8 +16,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole class GradientClipHelper(object): - def __init__(self, sharding_ring_id): - self.sharding_ring_id = sharding_ring_id + def __init__(self, mp_ring_id): + self.mp_ring_id = mp_ring_id def _is_gradient_clip_op(self, op): return op.desc.has_attr("op_namescope") \ @@ -31,6 +31,7 @@ class GradientClipHelper(object): """ deperated_vars = set() deperate_op_idx = set() + reversed_x_paramname = [] for idx, op in enumerate(block.ops): if not self._is_gradient_clip_op(op): continue @@ -44,6 +45,8 @@ class GradientClipHelper(object): if shard.is_param(param_name) and \ not shard.has_param(param_name): deperate_op = True + elif shard.is_param(param_name): + reversed_x_paramname.append(param_name) if deperate_op: deperate_op_idx.add(idx) @@ -65,31 +68,47 @@ class GradientClipHelper(object): for input_name in op.desc.input_arg_names(): if input_name not in deperated_vars: reversed_inputs.append(input_name) + op.desc.set_input("X", reversed_inputs) assert (len(op.desc.output_arg_names()) == 1) sum_res = op.desc.output_arg_names()[0] - block._insert_op_without_sync( - idx + 1, - type='c_sync_comm_stream', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={'ring_id': 0, - OP_ROLE_KEY: OpRole.Optimize}) + + # this allreduce should not overlap with calc and should be scheduled in calc stream + # block._insert_op_without_sync( + # idx + 1, + # type='c_sync_comm_stream', + # inputs={'X': sum_res}, + # outputs={'Out': sum_res}, + # attrs={'ring_id': 0, + # OP_ROLE_KEY: OpRole.Optimize}) block._insert_op_without_sync( idx + 1, type='c_allreduce_sum', inputs={'X': sum_res}, outputs={'Out': sum_res}, attrs={ - 'ring_id': self.sharding_ring_id, - OP_ROLE_KEY: OpRole.Optimize + 'ring_id': self.mp_ring_id, + 'op_namescope': "/gradient_clip_model_parallelism", + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, }) - block._insert_op_without_sync( - idx + 1, - type='c_sync_calc_stream', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={OP_ROLE_KEY: OpRole.Optimize}) + # block._insert_op_without_sync( + # idx + 1, + # type='c_sync_calc_stream', + # inputs={'X': sum_res}, + # outputs={'Out': sum_res}, + # attrs={OP_ROLE_KEY: OpRole.Optimize}) + + # the grad sum here should take the all and only param in the current shard + to_check_param = set(reversed_x_paramname) + should_check_param = set(shard.global_params).intersection( + set([ + param for param, worker_idx in shard.global_param2device.items() + if worker_idx == shard.worker_idx + ])) + assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format( + should_check_param - to_check_param, + to_check_param - should_check_param) for var_name in deperated_vars: block._remove_var(var_name, sync=False) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py old mode 100755 new mode 100644 index 1691bf7387a..eb5767ec4d3 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -227,14 +227,9 @@ def get_valid_op_role(block, insert_idx): return OpRole.Forward or OpRole.Backward """ op_role = block.ops[insert_idx].attr('op_role') - #if (insert_idx >= len(block.ops)) or ( - # op_role in [int(OpRole.Backward), int(OpRole.Optimize)]): - # return OpRole.Backward - #if op_role in [int(OpRole.Forward), int(OpRole.Loss)]: - # return OpRole.Forward - if insert_idx >= len(block.ops): return OpRole.Optimize - if op_role == int(OpRole.Backward): return OpRole.Backward - if op_role == int(OpRole.Optimize): return OpRole.Optimize + if (insert_idx >= len(block.ops)) or ( + op_role in [int(OpRole.Backward), int(OpRole.Optimize)]): + return OpRole.Backward if op_role in [int(OpRole.Forward), int(OpRole.Loss)]: return OpRole.Forward @@ -485,9 +480,6 @@ def save_persistables(exe, dirname, main_program, filename=None): This function handles the model saving for sharding training. """ - if main_program._pipeline_opt: - main_program = main_program._pipeline_opt['section_program']['program'] - def is_opt_vars(var): # NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer # now only Momentum and adam are compatible with sharding -- GitLab