From f29fb396df2f354cc677e2483a98f76cd2c6f4be Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Wed, 21 Oct 2020 15:16:11 +0800 Subject: [PATCH] dygraph nccl init support host domain name (#28107) * nccl init support hostname and ip; test=develop --- paddle/fluid/imperative/nccl_context.cc | 14 +++++++++++++- paddle/fluid/imperative/nccl_context.h | 1 + paddle/fluid/imperative/tests/nccl_context_test.cc | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index c8fd31fcbff..9ffec11354d 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -100,7 +100,19 @@ void NCCLParallelContext::SendNCCLID(const std::string &ep, serv_addr.sin_family = AF_INET; serv_addr.sin_port = htons(port); - if (inet_pton(AF_INET, host.c_str(), &serv_addr.sin_addr) <= 0) { + char *ip = NULL; + struct hostent *hp; + if ((hp = gethostbyname(host.c_str())) == NULL) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Fail to get host by name %s.", host)); + } + int i = 0; + while (hp->h_addr_list[i] != NULL) { + ip = inet_ntoa(*(struct in_addr *)hp->h_addr_list[i]); + VLOG(3) << "gethostbyname host:" << host << " ->ip: " << ip; + break; + } + if (inet_pton(AF_INET, ip, &serv_addr.sin_addr) <= 0) { PADDLE_THROW(platform::errors::Unavailable("Open address %s failed.", ep)); } diff --git a/paddle/fluid/imperative/nccl_context.h b/paddle/fluid/imperative/nccl_context.h index ac36ed77b48..cbd169f8da7 100644 --- a/paddle/fluid/imperative/nccl_context.h +++ b/paddle/fluid/imperative/nccl_context.h @@ -16,6 +16,7 @@ // network header files #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #include +#include #include #include #include diff --git a/paddle/fluid/imperative/tests/nccl_context_test.cc b/paddle/fluid/imperative/tests/nccl_context_test.cc index 93ea988d638..e0d6950a97e 100644 --- a/paddle/fluid/imperative/tests/nccl_context_test.cc +++ b/paddle/fluid/imperative/tests/nccl_context_test.cc @@ -20,7 +20,7 @@ namespace imperative = paddle::imperative; namespace platform = paddle::platform; imperative::ParallelStrategy GetStrategy(int local_rank) { - std::vector eps = {"127.0.0.1:9866", "127.0.0.1:9867"}; + std::vector eps = {"127.0.0.1:9866", "localhost:9867"}; imperative::ParallelStrategy strategy; strategy.trainer_endpoints_ = eps; strategy.current_endpoint_ = eps[local_rank]; -- GitLab