提交 3ec289a6 编写于 作者: W WangXi 提交者: gongweibao

fix sync_batch_norm hang in fleet (#21838)

上级 a0b53376
...@@ -337,6 +337,17 @@ class CollectiveOptimizer(DistributedOptimizer): ...@@ -337,6 +337,17 @@ class CollectiveOptimizer(DistributedOptimizer):
"with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0" "with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0"
) )
# NOTE. open sync_batch_norm will hang when use multi num_threads
sync_batch_norm = self._strategy.sync_batch_norm
if sync_batch_norm is not None and sync_batch_norm is True:
self._strategy.nccl_comm_num = 1
self._strategy.use_hierarchical_allreduce = False
exec_strategy.num_threads = 1
logging.warn(
"use sync_batch_norm will hang when set num_threads > 1, so "
"set num_threads=1, nccl_comm_num=1, use_hierarchical_allreduce=False."
)
if self.print_config: if self.print_config:
print("node_num:", node_num, "num_threads:", print("node_num:", node_num, "num_threads:",
exec_strategy.num_threads, "use_hierarchical_allreduce:", exec_strategy.num_threads, "use_hierarchical_allreduce:",
......
...@@ -128,11 +128,20 @@ function(bash_test_modules TARGET_NAME) ...@@ -128,11 +128,20 @@ function(bash_test_modules TARGET_NAME)
set(timeout ${bash_test_modules_TIMEOUT}) set(timeout ${bash_test_modules_TIMEOUT})
endif() endif()
add_test(NAME ${TARGET_NAME} if(WITH_COVERAGE)
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python add_test(NAME ${TARGET_NAME}
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS} COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES} TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WITH_COVERAGE=ON COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else()
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS}
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
if (bash_test_modules_SERIAL) if (bash_test_modules_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
......
...@@ -22,7 +22,15 @@ rm -f ${name}_*.log ...@@ -22,7 +22,15 @@ rm -f ${name}_*.log
# start the unit test # start the unit test
run_time=$(( $TEST_TIMEOUT - 10 )) run_time=$(( $TEST_TIMEOUT - 10 ))
echo "run_time: ${run_time}" echo "run_time: ${run_time}"
timeout -s SIGKILL ${run_time} python -u ${name}.py > ${name}_run.log 2>&1
if [[ ${WITH_COVERAGE} == "ON" ]]; then
PYTHON_EXEC="python -u -m coverage run --branch -p "
else
PYTHON_EXEC="python -u "
fi
timeout -s SIGKILL ${run_time} ${PYTHON_EXEC} ${name}.py > ${name}_run.log 2>&1
exit_code=$? exit_code=$?
if [[ $exit_code -eq 0 ]]; then if [[ $exit_code -eq 0 ]]; then
exit 0 exit 0
......
...@@ -137,6 +137,8 @@ class TestDistRunnerBase(object): ...@@ -137,6 +137,8 @@ class TestDistRunnerBase(object):
dist_strategy.use_local_sgd = True dist_strategy.use_local_sgd = True
if args.ut4grad_allreduce: if args.ut4grad_allreduce:
dist_strategy._ut4grad_allreduce = True dist_strategy._ut4grad_allreduce = True
if args.sync_batch_norm:
dist_strategy.sync_batch_norm = True
role = role_maker.PaddleCloudRoleMaker(is_collective=True) role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role) fleet.init(role)
...@@ -487,6 +489,7 @@ def runtime_main(test_class): ...@@ -487,6 +489,7 @@ def runtime_main(test_class):
required=False, required=False,
type=bool, type=bool,
default=False) default=False)
parser.add_argument('--sync_batch_norm', action='store_true')
args = parser.parse_args() args = parser.parse_args()
...@@ -837,6 +840,8 @@ class TestDistBase(unittest.TestCase): ...@@ -837,6 +840,8 @@ class TestDistBase(unittest.TestCase):
tr_cmd += " --use_local_sgd" tr_cmd += " --use_local_sgd"
if self._ut4grad_allreduce: if self._ut4grad_allreduce:
tr_cmd += " --ut4grad_allreduce" tr_cmd += " --ut4grad_allreduce"
if hasattr(self, '_sync_batch_norm') and self._sync_batch_norm:
tr_cmd += " --sync_batch_norm"
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
......
...@@ -24,6 +24,7 @@ class TestDistMnistNCCL2FleetApi(TestDistBase): ...@@ -24,6 +24,7 @@ class TestDistMnistNCCL2FleetApi(TestDistBase):
self._use_reader_alloc = False self._use_reader_alloc = False
self._nccl2_mode = True self._nccl2_mode = True
self._gpu_fleet_api = True self._gpu_fleet_api = True
self._sync_batch_norm = True
def test_dist_train(self): def test_dist_train(self):
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -31,5 +32,30 @@ class TestDistMnistNCCL2FleetApi(TestDistBase): ...@@ -31,5 +32,30 @@ class TestDistMnistNCCL2FleetApi(TestDistBase):
self.check_with_place("dist_mnist.py", delta=1e-5) self.check_with_place("dist_mnist.py", delta=1e-5)
class FleetCollectiveTest(unittest.TestCase):
def test_open_sync_batch_norm(self):
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
optimizer = fluid.optimizer.AdamOptimizer()
role = role_maker.UserDefinedCollectiveRoleMaker(0, ['127.0.0.1:6170'])
fleet.init(role)
dist_strategy = DistributedStrategy()
dist_strategy.sync_batch_norm = True
dist_optimizer = fleet.distributed_optimizer(
optimizer, strategy=dist_strategy)
dist_optimizer.minimize(loss)
self.assertEqual(dist_strategy.exec_strategy.num_threads, 1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册