diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index 9c04b95a732e8ca6e0574c4b8a44f95070e83830..f3ac17cc46cd21e47c2779b6d143fdb93cd76699 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -775,6 +775,13 @@ void EagerReducer::ProcessUnusedDenseVars() { continue; } + // NOTE(haohongxiang): Calling SetFakeEmpty here is to make sure that + // gradient accumulation can continue normally after clear_gradients() + // especiall in cases including complex control flow. + std::static_pointer_cast( + GetGradNodeFromTensor(&tensors_[var_index])) + ->SetFakeEmpty(false); + Tensor grad_value(std::make_shared(src_tensor)); auto dest_var_base = tensors_[var_index]; diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index e2f7af769d39e94ebf34b3ad5dab2fd1fc950136..161f4d3262ab173dfa3380b2b4195a3f27579960 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -43,12 +43,11 @@ def _apply_collective_grads(parameters, comm_group): coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024) + nranks = paddle.distributed.get_world_size( + ) if comm_group is None else comm_group.nranks for coalesced_grad, _, _ in coalesced_grads_and_vars: # need to div nranks - nranks = paddle.distributed.get_world_size( - ) if comm_group is None else comm_group.nranks div_factor = paddle.to_tensor(nranks, dtype=coalesced_grad.dtype) - paddle.distributed.all_reduce(coalesced_grad, group=comm_group) paddle.fluid.framework._dygraph_tracer().trace_op( type="elementwise_div", inputs={ @@ -57,6 +56,7 @@ def _apply_collective_grads(parameters, comm_group): }, outputs={'Out': coalesced_grad}, attrs={'axis': -1}) + paddle.distributed.all_reduce(coalesced_grad, group=comm_group) _split_tensors(coalesced_grads_and_vars) @@ -76,10 +76,11 @@ def _apply_collective_grads_eager(parameters, comm_group): coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024) - div_factor = 1.0 / comm_group.nranks + nranks = paddle.distributed.get_world_size( + ) if comm_group is None else comm_group.nranks for coalesced_grad, _, _ in coalesced_grads_and_vars: # need to div nranks - coalesced_grad.scale_(div_factor) + coalesced_grad.scale_(1.0 / nranks) paddle.distributed.all_reduce(coalesced_grad, group=comm_group) _split_tensors(coalesced_grads_and_vars) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 214c68c250ea98a672b8b4b27e15ecd8ed4f49c6..6710ddb97dc24f7246117f15af2e2518c49e5d2e 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1507,7 +1507,7 @@ if(WITH_DISTRIBUTE 350) set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 300) set_tests_properties(test_parallel_dygraph_no_sync_gradient_check - PROPERTIES TIMEOUT 30) + PROPERTIES TIMEOUT 60) set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 500) set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py index 930bf5345fcae3466c5a2efdc4bc953d94cec97b..1e8aae7226a7e864c1e27834cc2d56276a696b8d 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py @@ -200,7 +200,8 @@ class TestMultipleWithGloo(unittest.TestCase): class TestDataParallelGradientCheck(TestMultipleGpus): def test_multiple_gpus_dynamic(self): - self.run_mnist_2gpu('parallel_dygraph_gradient_check.py') + self.run_mnist_2gpu('parallel_dygraph_gradient_check.py', + eager_mode=False) class TestDataParallelWithPyLayer(TestMultipleGpus): @@ -218,4 +219,5 @@ class TestGradientCheckInEagerMode(TestMultipleGpus): if __name__ == "__main__": + os.environ["FLAGS_enable_eager_mode"] = "1" unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_gradient_check.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_gradient_check.py index fad9e902cc91eaf0dbb7c2baaf6b3a41edb87a04..d6a48b504a2dc60e4ee154e3d59d571e167d4eed 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_gradient_check.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_gradient_check.py @@ -14,6 +14,7 @@ from __future__ import print_function +import os import unittest import paddle.fluid as fluid @@ -24,7 +25,10 @@ class TestDataParallelLayer(TestMultipleGpus): def test_parallel_dygraph_dataparallel_no_sync(self): self.run_mnist_2gpu('parallel_dygraph_no_sync_gradient_check.py') + self.run_mnist_2gpu('parallel_dygraph_no_sync_gradient_check.py', + eager_mode=False) if __name__ == "__main__": + os.environ["FLAGS_enable_eager_mode"] = "1" unittest.main()