diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index c8fd31fcbffe680da36d03276ec0d4c1095030bc..9ffec11354d8ae90e188bc666a1ddc1cd97924f9 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -100,7 +100,19 @@ void NCCLParallelContext::SendNCCLID(const std::string &ep, serv_addr.sin_family = AF_INET; serv_addr.sin_port = htons(port); - if (inet_pton(AF_INET, host.c_str(), &serv_addr.sin_addr) <= 0) { + char *ip = NULL; + struct hostent *hp; + if ((hp = gethostbyname(host.c_str())) == NULL) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Fail to get host by name %s.", host)); + } + int i = 0; + while (hp->h_addr_list[i] != NULL) { + ip = inet_ntoa(*(struct in_addr *)hp->h_addr_list[i]); + VLOG(3) << "gethostbyname host:" << host << " ->ip: " << ip; + break; + } + if (inet_pton(AF_INET, ip, &serv_addr.sin_addr) <= 0) { PADDLE_THROW(platform::errors::Unavailable("Open address %s failed.", ep)); } diff --git a/paddle/fluid/imperative/nccl_context.h b/paddle/fluid/imperative/nccl_context.h index ac36ed77b482fad29c9003fb190f7323ef0c5a8f..cbd169f8da77edbbb8093e1a2f1cd9a74e6dda96 100644 --- a/paddle/fluid/imperative/nccl_context.h +++ b/paddle/fluid/imperative/nccl_context.h @@ -16,6 +16,7 @@ // network header files #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #include +#include #include #include #include diff --git a/paddle/fluid/imperative/tests/nccl_context_test.cc b/paddle/fluid/imperative/tests/nccl_context_test.cc index 93ea988d638e4b67dde4707ca51f8ff1088d9059..e0d6950a97e30c587f0a63140ae588f66e3833aa 100644 --- a/paddle/fluid/imperative/tests/nccl_context_test.cc +++ b/paddle/fluid/imperative/tests/nccl_context_test.cc @@ -20,7 +20,7 @@ namespace imperative = paddle::imperative; namespace platform = paddle::platform; imperative::ParallelStrategy GetStrategy(int local_rank) { - std::vector eps = {"127.0.0.1:9866", "127.0.0.1:9867"}; + std::vector eps = {"127.0.0.1:9866", "localhost:9867"}; imperative::ParallelStrategy strategy; strategy.trainer_endpoints_ = eps; strategy.current_endpoint_ = eps[local_rank];