diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 4796ad2f1f3f1d98766abf2a2587476b48e0dd43..961a789dc081a0c04a32741526a9426a53d72e18 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -51,7 +51,8 @@ class GradientClipHelper(object): if deperate_op: deperate_op_idx.add(idx) for output_name in op.desc.output_arg_names(): - deperated_vars.add(output_name) + if output_name not in op.desc.input_arg_names(): + deperated_vars.add(output_name) if not deperated_vars: # got no gradient_clip op @@ -111,7 +112,6 @@ class GradientClipHelper(object): to_check_param - should_check_param) for var_name in deperated_vars: - if block.has_var(var_name): - block._remove_var(var_name, sync=False) + block._remove_var(var_name, sync=False) block._sync_with_cpp() return