未验证 提交 a24d1868 编写于 作者: D danleifeng 提交者: GitHub

fix nccl init failed in parallel dygraph mode (#28497)

上级 93c39779
...@@ -49,16 +49,20 @@ void NCCLParallelContext::RecvNCCLID(const std::string &ep, ...@@ -49,16 +49,20 @@ void NCCLParallelContext::RecvNCCLID(const std::string &ep,
address.sin_port = htons(port); address.sin_port = htons(port);
int try_times = 0; int try_times = 0;
int retry_time = 0;
while (true) { while (true) {
if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) { if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) {
retry_time = 3 * (try_times + 1);
LOG(WARNING) << "Socket bind worker " << ep LOG(WARNING) << "Socket bind worker " << ep
<< (try_times < 5 ? " failed, try again after 3 seconds." << (try_times < 9
: " failed, try again after 3 seconds. " ? " failed, try again after " +
"Bind on endpoint %s failed. " std::to_string(retry_time) + " seconds."
"Please confirm whether the " : " failed, try again after " +
"communication port or GPU card is " std::to_string(retry_time) +
"occupied."); " seconds. Bind on endpoint " + ep +
std::this_thread::sleep_for(std::chrono::seconds(3)); " failed. Please confirm whether the "
"communication port or GPU card is occupied.");
std::this_thread::sleep_for(std::chrono::seconds(retry_time));
++try_times; ++try_times;
continue; continue;
} }
...@@ -129,16 +133,20 @@ void NCCLParallelContext::SendNCCLID(const std::string &ep, ...@@ -129,16 +133,20 @@ void NCCLParallelContext::SendNCCLID(const std::string &ep,
} }
int try_times = 0; int try_times = 0;
int retry_time = 0;
while (true) { while (true) {
if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) { if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) {
retry_time = 3 * (try_times + 1);
LOG(WARNING) LOG(WARNING)
<< "Socket connect worker " << ep << "Socket connect worker " << ep
<< (try_times < 5 << (try_times < 9
? " failed, try again after 3 seconds." ? " failed, try again after " + std::to_string(retry_time) +
: " failed, try again after 3 seconds. Maybe that " " seconds."
"some process is occupied the GPUs of this node " : " failed, try again after " + std::to_string(retry_time) +
"now, and you should kill those process manually."); " seconds. Maybe that some process is occupied the "
std::this_thread::sleep_for(std::chrono::seconds(3)); "GPUs of this node now, and you should kill those "
"process manually.");
std::this_thread::sleep_for(std::chrono::seconds(retry_time));
++try_times; ++try_times;
continue; continue;
} }
......
...@@ -125,7 +125,7 @@ def init_parallel_env(): ...@@ -125,7 +125,7 @@ def init_parallel_env():
if ParallelEnv().world_size < 2: if ParallelEnv().world_size < 2:
return return
# 3: init gloo context # 3: init gloo context (step 1: httpsever start)
ep_rank_0 = ParallelEnv().trainer_endpoints[0].split(":") ep_rank_0 = ParallelEnv().trainer_endpoints[0].split(":")
ep_rank = ParallelEnv().trainer_endpoints[ParallelEnv().rank].split(":") ep_rank = ParallelEnv().trainer_endpoints[ParallelEnv().rank].split(":")
manager = Manager() manager = Manager()
...@@ -138,22 +138,6 @@ def init_parallel_env(): ...@@ -138,22 +138,6 @@ def init_parallel_env():
http_server.daemon = True http_server.daemon = True
http_server_d["running"] = True http_server_d["running"] = True
http_server.start() http_server.start()
wait_server_ready([ParallelEnv().trainer_endpoints[0]])
gloo_strategy = core.GlooParallelStrategy()
gloo_strategy.rank = ParallelEnv().rank
gloo_strategy.rank_num = ParallelEnv().world_size
gloo_strategy.ip_address = ep_rank_0[0]
gloo_strategy.ip_port = int(ep_rank_0[1])
default_init_timeout_seconds = 3600
default_run_timeout_seconds = 9999999
gloo_strategy.init_seconds = default_init_timeout_seconds
gloo_strategy.run_seconds = default_run_timeout_seconds
gloo = core.GlooParallelContext(gloo_strategy)
gloo.init()
if ParallelEnv().rank == 0:
http_server_d["running"] = False
http_server.join()
# 4. init NCCL ParallelStrategy # 4. init NCCL ParallelStrategy
strategy = ParallelStrategy() strategy = ParallelStrategy()
...@@ -177,6 +161,27 @@ def init_parallel_env(): ...@@ -177,6 +161,27 @@ def init_parallel_env():
parallel_helper._set_parallel_ctx(core.NCCLParallelContext(strategy, place)) parallel_helper._set_parallel_ctx(core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx() parallel_helper._init_parallel_ctx()
# 5: init gloo context (step 2: gloo init)
# dividing init_gloo into two part beacause nccl and gloo
# are separately looking for free ports which sometimes
# leads to port-conflict.
wait_server_ready([ParallelEnv().trainer_endpoints[0]])
gloo_strategy = core.GlooParallelStrategy()
gloo_strategy.rank = ParallelEnv().rank
gloo_strategy.rank_num = ParallelEnv().world_size
gloo_strategy.ip_address = ep_rank_0[0]
gloo_strategy.ip_port = int(ep_rank_0[1])
default_init_timeout_seconds = 3600
default_run_timeout_seconds = 9999999
gloo_strategy.init_seconds = default_init_timeout_seconds
gloo_strategy.run_seconds = default_run_timeout_seconds
gloo = core.GlooParallelContext(gloo_strategy)
gloo.init()
if ParallelEnv().rank == 0:
http_server_d["running"] = False
http_server.join()
def get_rank(): def get_rank():
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册