提交 3ec289a6 编写于 作者: W WangXi 提交者: gongweibao

fix sync_batch_norm hang in fleet (#21838)

上级 a0b53376
......@@ -337,6 +337,17 @@ class CollectiveOptimizer(DistributedOptimizer):
"with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0"
)
# NOTE. open sync_batch_norm will hang when use multi num_threads
sync_batch_norm = self._strategy.sync_batch_norm
if sync_batch_norm is not None and sync_batch_norm is True:
self._strategy.nccl_comm_num = 1
self._strategy.use_hierarchical_allreduce = False
exec_strategy.num_threads = 1
logging.warn(
"use sync_batch_norm will hang when set num_threads > 1, so "
"set num_threads=1, nccl_comm_num=1, use_hierarchical_allreduce=False."
)
if self.print_config:
print("node_num:", node_num, "num_threads:",
exec_strategy.num_threads, "use_hierarchical_allreduce:",
......
......@@ -128,11 +128,20 @@ function(bash_test_modules TARGET_NAME)
set(timeout ${bash_test_modules_TIMEOUT})
endif()
if(WITH_COVERAGE)
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS}
WITH_COVERAGE=ON COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else()
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS}
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
if (bash_test_modules_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
......
......@@ -22,7 +22,15 @@ rm -f ${name}_*.log
# start the unit test
run_time=$(( $TEST_TIMEOUT - 10 ))
echo "run_time: ${run_time}"
timeout -s SIGKILL ${run_time} python -u ${name}.py > ${name}_run.log 2>&1
if [[ ${WITH_COVERAGE} == "ON" ]]; then
PYTHON_EXEC="python -u -m coverage run --branch -p "
else
PYTHON_EXEC="python -u "
fi
timeout -s SIGKILL ${run_time} ${PYTHON_EXEC} ${name}.py > ${name}_run.log 2>&1
exit_code=$?
if [[ $exit_code -eq 0 ]]; then
exit 0
......
......@@ -137,6 +137,8 @@ class TestDistRunnerBase(object):
dist_strategy.use_local_sgd = True
if args.ut4grad_allreduce:
dist_strategy._ut4grad_allreduce = True
if args.sync_batch_norm:
dist_strategy.sync_batch_norm = True
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
......@@ -487,6 +489,7 @@ def runtime_main(test_class):
required=False,
type=bool,
default=False)
parser.add_argument('--sync_batch_norm', action='store_true')
args = parser.parse_args()
......@@ -837,6 +840,8 @@ class TestDistBase(unittest.TestCase):
tr_cmd += " --use_local_sgd"
if self._ut4grad_allreduce:
tr_cmd += " --ut4grad_allreduce"
if hasattr(self, '_sync_batch_norm') and self._sync_batch_norm:
tr_cmd += " --sync_batch_norm"
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
......
......@@ -24,6 +24,7 @@ class TestDistMnistNCCL2FleetApi(TestDistBase):
self._use_reader_alloc = False
self._nccl2_mode = True
self._gpu_fleet_api = True
self._sync_batch_norm = True
def test_dist_train(self):
import paddle.fluid as fluid
......@@ -31,5 +32,30 @@ class TestDistMnistNCCL2FleetApi(TestDistBase):
self.check_with_place("dist_mnist.py", delta=1e-5)
class FleetCollectiveTest(unittest.TestCase):
def test_open_sync_batch_norm(self):
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
optimizer = fluid.optimizer.AdamOptimizer()
role = role_maker.UserDefinedCollectiveRoleMaker(0, ['127.0.0.1:6170'])
fleet.init(role)
dist_strategy = DistributedStrategy()
dist_strategy.sync_batch_norm = True
dist_optimizer = fleet.distributed_optimizer(
optimizer, strategy=dist_strategy)
dist_optimizer.minimize(loss)
self.assertEqual(dist_strategy.exec_strategy.num_threads, 1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册