From 59d8b8cb43381e5c88a8745e9f3fae7f94ceecf1 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Mon, 25 Oct 2021 11:02:19 +0800 Subject: [PATCH] [HybridParallel]fix bug of check_inf in fleet_base.py (#36651) * fix bug of check_inf * fix allreduce --- python/paddle/distributed/fleet/base/fleet_base.py | 8 ++++---- .../distributed/fleet/utils/hybrid_parallel_util.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 571199b99b0..aea7ad07102 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1586,16 +1586,16 @@ class Fleet(object): _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale, param_grads_fp32, temp_found_inf_fp32) + self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 + is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") # TODO(shenliang03) Since dp allreduce in the optimizer is # after the gradscaler, check_finite needs to synchronize global # information. In the future, we should use check_group to speed. paddle.distributed.all_reduce( - paddle.to_tensor( - [self._found_inf], dtype="int32"), - op=paddle.distributed.ReduceOp.MAX, - group=None) + is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None) + self._found_inf = is_found_inf.numpy()[0] # Only tensor_parallel and pipeline_parallel need to modify scaler if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL, diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index 0f5c24f022e..75aa9766e7b 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -47,6 +47,7 @@ def _apply_collective_grads(parameters, comm_group): nranks = paddle.distributed.get_world_size( ) if comm_group is None else comm_group.nranks div_factor = paddle.to_tensor(nranks, dtype=coalesced_grad.dtype) + paddle.distributed.all_reduce(coalesced_grad, group=comm_group) paddle.fluid.framework._dygraph_tracer().trace_op( type="elementwise_div", inputs={'X': coalesced_grad, @@ -54,8 +55,6 @@ def _apply_collective_grads(parameters, comm_group): outputs={'Out': coalesced_grad}, attrs={'axis': -1}) - paddle.distributed.all_reduce(coalesced_grad, group=comm_group) - _split_tensors(coalesced_grads_and_vars) -- GitLab