未验证 提交 8088395a 编写于 作者: G gongweibao 提交者: GitHub

Set unique port to every distribute test to avoid potential port conflicts (#20759)

上级 0687bcd6
...@@ -239,6 +239,10 @@ if(WITH_DISTRIBUTE) ...@@ -239,6 +239,10 @@ if(WITH_DISTRIBUTE)
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transformer") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transformer")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transpiler") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transpiler")
#not need
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_base")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_base")
py_test_modules(test_lookup_remote_table_op MODULES test_lookup_remote_table_op ENVS ${dist_ENVS}) py_test_modules(test_lookup_remote_table_op MODULES test_lookup_remote_table_op ENVS ${dist_ENVS})
py_test_modules(test_hsigmoid_remote_table_op MODULES test_hsigmoid_remote_table_op ENVS ${dist_ENVS}) py_test_modules(test_hsigmoid_remote_table_op MODULES test_hsigmoid_remote_table_op ENVS ${dist_ENVS})
py_test_modules(test_nce_remote_table_op MODULES test_nce_remote_table_op ENVS ${dist_ENVS}) py_test_modules(test_nce_remote_table_op MODULES test_nce_remote_table_op ENVS ${dist_ENVS})
...@@ -249,8 +253,11 @@ if(WITH_DISTRIBUTE) ...@@ -249,8 +253,11 @@ if(WITH_DISTRIBUTE)
bash_test_modules(test_listen_and_serv_op MODULES test_listen_and_serv.sh) bash_test_modules(test_listen_and_serv_op MODULES test_listen_and_serv.sh)
bash_test_modules(test_launch MODULES test_launch.sh) bash_test_modules(test_launch MODULES test_launch.sh)
set(dist_ut_port 1000)
foreach(TEST_OP ${DIST_TEST_OPS}) foreach(TEST_OP ${DIST_TEST_OPS})
bash_test_modules(${TEST_OP} MODULES dist_test.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE") message(STATUS "set dist_ut_port=${dist_ut_port} on ${TEST_OP}")
bash_test_modules(${TEST_OP} MODULES dist_test.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}")
MATH(EXPR dist_ut_port "${dist_ut_port}+50")
endforeach(TEST_OP) endforeach(TEST_OP)
endif(NOT APPLE) endif(NOT APPLE)
endif() endif()
......
...@@ -44,8 +44,9 @@ done ...@@ -44,8 +44,9 @@ done
#display system context #display system context
for i in {1..2}; do for i in {1..2}; do
sleep 2 sleep 3
ps -ef | grep -E "(test_|_test)" ps -aux
netstat -anlp
if hash "nvidia-smi" > /dev/null; then if hash "nvidia-smi" > /dev/null; then
nvidia-smi nvidia-smi
......
...@@ -36,6 +36,7 @@ import paddle.fluid.incubate.fleet.base.role_maker as role_maker ...@@ -36,6 +36,7 @@ import paddle.fluid.incubate.fleet.base.role_maker as role_maker
RUN_STEP = 5 RUN_STEP = 5
DEFAULT_BATCH_SIZE = 2 DEFAULT_BATCH_SIZE = 2
DIST_UT_PORT = 0
def print_to_out(out_losses): def print_to_out(out_losses):
...@@ -486,8 +487,6 @@ class TestDistBase(unittest.TestCase): ...@@ -486,8 +487,6 @@ class TestDistBase(unittest.TestCase):
self._trainers = 2 self._trainers = 2
self._pservers = 2 self._pservers = 2
self._port_set = set() self._port_set = set()
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable self._python_interp = sys.executable
self._sync_mode = True self._sync_mode = True
self._hogwild_mode = False self._hogwild_mode = False
...@@ -512,6 +511,20 @@ class TestDistBase(unittest.TestCase): ...@@ -512,6 +511,20 @@ class TestDistBase(unittest.TestCase):
self._ut4grad_allreduce = False self._ut4grad_allreduce = False
self._use_hallreduce = False self._use_hallreduce = False
self._setup_config() self._setup_config()
global DIST_UT_PORT
if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"):
DIST_UT_PORT = int(os.getenv("PADDLE_DIST_UT_PORT"))
if DIST_UT_PORT == 0:
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
else:
print("set begin_port:", DIST_UT_PORT)
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
DIST_UT_PORT, DIST_UT_PORT + 1)
DIST_UT_PORT += 2
self._after_setup_config() self._after_setup_config()
def _find_free_port(self): def _find_free_port(self):
...@@ -790,8 +803,16 @@ class TestDistBase(unittest.TestCase): ...@@ -790,8 +803,16 @@ class TestDistBase(unittest.TestCase):
check_error_log, log_name): check_error_log, log_name):
if self._use_hallreduce: if self._use_hallreduce:
self._ps_endpoints = "" self._ps_endpoints = ""
global DIST_UT_PORT
if DIST_UT_PORT == 0:
for i in range(0, 4):
self._ps_endpoints += "127.0.0.1:%s," % (
self._find_free_port())
else:
for i in range(0, 4): for i in range(0, 4):
self._ps_endpoints += "127.0.0.1:%s," % (self._find_free_port()) self._ps_endpoints += "127.0.0.1:%s," % (DIST_UT_PORT + i)
DIST_UT_PORT += 4
self._ps_endpoints = self._ps_endpoints[:-1] self._ps_endpoints = self._ps_endpoints[:-1]
# NOTE: we reuse ps_endpoints as nccl2 worker endpoints # NOTE: we reuse ps_endpoints as nccl2 worker endpoints
...@@ -858,7 +879,7 @@ class TestDistBase(unittest.TestCase): ...@@ -858,7 +879,7 @@ class TestDistBase(unittest.TestCase):
required_envs["GLOG_vmodule"] = \ required_envs["GLOG_vmodule"] = \
"fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \ "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \
"alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \ "alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \
"sparse_all_reduce_op_handle=10,gen_nccl_id_op=10" "sparse_all_reduce_op_handle=10,gen_nccl_id_op=10,nccl_helper=10,grpc_client=10,grpc_server=10"
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
required_envs.update(need_envs) required_envs.update(need_envs)
......
...@@ -40,6 +40,7 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo ...@@ -40,6 +40,7 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo
RUN_STEP = 5 RUN_STEP = 5
LEARNING_RATE = 0.01 LEARNING_RATE = 0.01
DIST_UT_PORT = 0
class FleetDistRunnerBase(object): class FleetDistRunnerBase(object):
...@@ -123,8 +124,20 @@ class TestFleetBase(unittest.TestCase): ...@@ -123,8 +124,20 @@ class TestFleetBase(unittest.TestCase):
self._trainers = 2 self._trainers = 2
self._pservers = 2 self._pservers = 2
self._port_set = set() self._port_set = set()
global DIST_UT_PORT
if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"):
DIST_UT_PORT = int(os.getenv("PADDLE_DIST_UT_PORT"))
if DIST_UT_PORT:
print("set begin_port:", DIST_UT_PORT)
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
DIST_UT_PORT, DIST_UT_PORT + 1)
DIST_UT_PORT += 2
else:
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port()) self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable self._python_interp = sys.executable
self._geo_sgd = False self._geo_sgd = False
self._geo_sgd_need_push_nums = 5 self._geo_sgd_need_push_nums = 5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册