From 3ec289a6a33c5392b914cc256736dcb00b2cecce Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 31 Dec 2019 11:10:51 +0800 Subject: [PATCH] fix sync_batch_norm hang in fleet (#21838) --- .../incubate/fleet/collective/__init__.py | 11 ++++++++ .../fluid/tests/unittests/CMakeLists.txt | 19 ++++++++++---- .../paddle/fluid/tests/unittests/dist_test.sh | 10 ++++++- .../fluid/tests/unittests/test_dist_base.py | 5 ++++ .../unittests/test_dist_mnist_fleetapi.py | 26 +++++++++++++++++++ 5 files changed, 65 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/incubate/fleet/collective/__init__.py b/python/paddle/fluid/incubate/fleet/collective/__init__.py index 481747e6039..e33662cf082 100644 --- a/python/paddle/fluid/incubate/fleet/collective/__init__.py +++ b/python/paddle/fluid/incubate/fleet/collective/__init__.py @@ -337,6 +337,17 @@ class CollectiveOptimizer(DistributedOptimizer): "with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0" ) + # NOTE. open sync_batch_norm will hang when use multi num_threads + sync_batch_norm = self._strategy.sync_batch_norm + if sync_batch_norm is not None and sync_batch_norm is True: + self._strategy.nccl_comm_num = 1 + self._strategy.use_hierarchical_allreduce = False + exec_strategy.num_threads = 1 + logging.warn( + "use sync_batch_norm will hang when set num_threads > 1, so " + "set num_threads=1, nccl_comm_num=1, use_hierarchical_allreduce=False." + ) + if self.print_config: print("node_num:", node_num, "num_threads:", exec_strategy.num_threads, "use_hierarchical_allreduce:", diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 25d73df503f..9f26e695506 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -128,11 +128,20 @@ function(bash_test_modules TARGET_NAME) set(timeout ${bash_test_modules_TIMEOUT}) endif() - add_test(NAME ${TARGET_NAME} - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python - TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS} - bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES} - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + if(WITH_COVERAGE) + add_test(NAME ${TARGET_NAME} + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python + TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS} + WITH_COVERAGE=ON COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data + bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + else() + add_test(NAME ${TARGET_NAME} + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python + TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS} + bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + endif() if (bash_test_modules_SERIAL) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) diff --git a/python/paddle/fluid/tests/unittests/dist_test.sh b/python/paddle/fluid/tests/unittests/dist_test.sh index b185ab54a95..42566f63b68 100644 --- a/python/paddle/fluid/tests/unittests/dist_test.sh +++ b/python/paddle/fluid/tests/unittests/dist_test.sh @@ -22,7 +22,15 @@ rm -f ${name}_*.log # start the unit test run_time=$(( $TEST_TIMEOUT - 10 )) echo "run_time: ${run_time}" -timeout -s SIGKILL ${run_time} python -u ${name}.py > ${name}_run.log 2>&1 + +if [[ ${WITH_COVERAGE} == "ON" ]]; then + PYTHON_EXEC="python -u -m coverage run --branch -p " +else + PYTHON_EXEC="python -u " +fi + +timeout -s SIGKILL ${run_time} ${PYTHON_EXEC} ${name}.py > ${name}_run.log 2>&1 + exit_code=$? if [[ $exit_code -eq 0 ]]; then exit 0 diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index f94c6c1184d..4288a6c52af 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -137,6 +137,8 @@ class TestDistRunnerBase(object): dist_strategy.use_local_sgd = True if args.ut4grad_allreduce: dist_strategy._ut4grad_allreduce = True + if args.sync_batch_norm: + dist_strategy.sync_batch_norm = True role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) @@ -487,6 +489,7 @@ def runtime_main(test_class): required=False, type=bool, default=False) + parser.add_argument('--sync_batch_norm', action='store_true') args = parser.parse_args() @@ -837,6 +840,8 @@ class TestDistBase(unittest.TestCase): tr_cmd += " --use_local_sgd" if self._ut4grad_allreduce: tr_cmd += " --ut4grad_allreduce" + if hasattr(self, '_sync_batch_norm') and self._sync_batch_norm: + tr_cmd += " --sync_batch_norm" if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py index 30f8592e1da..bc86cba80b8 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py @@ -24,6 +24,7 @@ class TestDistMnistNCCL2FleetApi(TestDistBase): self._use_reader_alloc = False self._nccl2_mode = True self._gpu_fleet_api = True + self._sync_batch_norm = True def test_dist_train(self): import paddle.fluid as fluid @@ -31,5 +32,30 @@ class TestDistMnistNCCL2FleetApi(TestDistBase): self.check_with_place("dist_mnist.py", delta=1e-5) +class FleetCollectiveTest(unittest.TestCase): + def test_open_sync_batch_norm(self): + import paddle.fluid as fluid + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy + + data = fluid.layers.data(name='X', shape=[1], dtype='float32') + hidden = fluid.layers.fc(input=data, size=10) + loss = fluid.layers.mean(hidden) + + optimizer = fluid.optimizer.AdamOptimizer() + + role = role_maker.UserDefinedCollectiveRoleMaker(0, ['127.0.0.1:6170']) + fleet.init(role) + + dist_strategy = DistributedStrategy() + dist_strategy.sync_batch_norm = True + + dist_optimizer = fleet.distributed_optimizer( + optimizer, strategy=dist_strategy) + dist_optimizer.minimize(loss) + + self.assertEqual(dist_strategy.exec_strategy.num_threads, 1) + + if __name__ == "__main__": unittest.main() -- GitLab