未验证 提交 a24d1868 编写于 作者: D danleifeng 提交者: GitHub

fix nccl init failed in parallel dygraph mode (#28497)

上级 93c39779
......@@ -49,16 +49,20 @@ void NCCLParallelContext::RecvNCCLID(const std::string &ep,
address.sin_port = htons(port);
int try_times = 0;
int retry_time = 0;
while (true) {
if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) {
retry_time = 3 * (try_times + 1);
LOG(WARNING) << "Socket bind worker " << ep
<< (try_times < 5 ? " failed, try again after 3 seconds."
: " failed, try again after 3 seconds. "
"Bind on endpoint %s failed. "
"Please confirm whether the "
"communication port or GPU card is "
"occupied.");
std::this_thread::sleep_for(std::chrono::seconds(3));
<< (try_times < 9
? " failed, try again after " +
std::to_string(retry_time) + " seconds."
: " failed, try again after " +
std::to_string(retry_time) +
" seconds. Bind on endpoint " + ep +
" failed. Please confirm whether the "
"communication port or GPU card is occupied.");
std::this_thread::sleep_for(std::chrono::seconds(retry_time));
++try_times;
continue;
}
......@@ -129,16 +133,20 @@ void NCCLParallelContext::SendNCCLID(const std::string &ep,
}
int try_times = 0;
int retry_time = 0;
while (true) {
if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) {
retry_time = 3 * (try_times + 1);
LOG(WARNING)
<< "Socket connect worker " << ep
<< (try_times < 5
? " failed, try again after 3 seconds."
: " failed, try again after 3 seconds. Maybe that "
"some process is occupied the GPUs of this node "
"now, and you should kill those process manually.");
std::this_thread::sleep_for(std::chrono::seconds(3));
<< (try_times < 9
? " failed, try again after " + std::to_string(retry_time) +
" seconds."
: " failed, try again after " + std::to_string(retry_time) +
" seconds. Maybe that some process is occupied the "
"GPUs of this node now, and you should kill those "
"process manually.");
std::this_thread::sleep_for(std::chrono::seconds(retry_time));
++try_times;
continue;
}
......
......@@ -125,7 +125,7 @@ def init_parallel_env():
if ParallelEnv().world_size < 2:
return
# 3: init gloo context
# 3: init gloo context (step 1: httpsever start)
ep_rank_0 = ParallelEnv().trainer_endpoints[0].split(":")
ep_rank = ParallelEnv().trainer_endpoints[ParallelEnv().rank].split(":")
manager = Manager()
......@@ -138,22 +138,6 @@ def init_parallel_env():
http_server.daemon = True
http_server_d["running"] = True
http_server.start()
wait_server_ready([ParallelEnv().trainer_endpoints[0]])
gloo_strategy = core.GlooParallelStrategy()
gloo_strategy.rank = ParallelEnv().rank
gloo_strategy.rank_num = ParallelEnv().world_size
gloo_strategy.ip_address = ep_rank_0[0]
gloo_strategy.ip_port = int(ep_rank_0[1])
default_init_timeout_seconds = 3600
default_run_timeout_seconds = 9999999
gloo_strategy.init_seconds = default_init_timeout_seconds
gloo_strategy.run_seconds = default_run_timeout_seconds
gloo = core.GlooParallelContext(gloo_strategy)
gloo.init()
if ParallelEnv().rank == 0:
http_server_d["running"] = False
http_server.join()
# 4. init NCCL ParallelStrategy
strategy = ParallelStrategy()
......@@ -165,7 +149,7 @@ def init_parallel_env():
strategy.current_endpoint = ParallelEnv().current_endpoint
# NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode,
# the dygraph mode will be set to default mode,
# users will not call `dygraph.guard` or `enable_dygraph`
# directly, if they want to switch default place,
# they need to call a function to change default place,
......@@ -177,6 +161,27 @@ def init_parallel_env():
parallel_helper._set_parallel_ctx(core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx()
# 5: init gloo context (step 2: gloo init)
# dividing init_gloo into two part beacause nccl and gloo
# are separately looking for free ports which sometimes
# leads to port-conflict.
wait_server_ready([ParallelEnv().trainer_endpoints[0]])
gloo_strategy = core.GlooParallelStrategy()
gloo_strategy.rank = ParallelEnv().rank
gloo_strategy.rank_num = ParallelEnv().world_size
gloo_strategy.ip_address = ep_rank_0[0]
gloo_strategy.ip_port = int(ep_rank_0[1])
default_init_timeout_seconds = 3600
default_run_timeout_seconds = 9999999
gloo_strategy.init_seconds = default_init_timeout_seconds
gloo_strategy.run_seconds = default_run_timeout_seconds
gloo = core.GlooParallelContext(gloo_strategy)
gloo.init()
if ParallelEnv().rank == 0:
http_server_d["running"] = False
http_server.join()
def get_rank():
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册