From 467c71696356d6cf793000b4edadc4297ee7278a Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 14 Dec 2020 20:28:13 +0800 Subject: [PATCH] gen nccl id use socket (#29431) --- .../fluid/operators/collective/CMakeLists.txt | 6 +- .../operators/collective/c_gen_nccl_id_op.cc | 81 +--- .../operators/collective/gen_nccl_id_op.cc | 201 ++++++++++ .../collective/gen_nccl_id_op_helper.cc | 351 ++++++++++++++++++ .../collective/gen_nccl_id_op_helper.h | 48 +++ .../operators/distributed_ops/CMakeLists.txt | 1 - .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_dist_base.py | 8 +- .../unittests/test_dist_mnist_hallreduce.py | 1 + .../tests/unittests/test_gen_nccl_id_op.py | 118 ++++++ 10 files changed, 740 insertions(+), 76 deletions(-) create mode 100644 paddle/fluid/operators/collective/gen_nccl_id_op.cc create mode 100644 paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc create mode 100644 paddle/fluid/operators/collective/gen_nccl_id_op_helper.h create mode 100644 python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py diff --git a/paddle/fluid/operators/collective/CMakeLists.txt b/paddle/fluid/operators/collective/CMakeLists.txt index 686b3039d4d..395b54c8b6c 100644 --- a/paddle/fluid/operators/collective/CMakeLists.txt +++ b/paddle/fluid/operators/collective/CMakeLists.txt @@ -28,11 +28,13 @@ foreach(src ${OPS}) set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS}) endforeach() -register_operators(EXCLUDES c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) +register_operators(EXCLUDES c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) if(WITH_NCCL) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper) - op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common) + cc_library(gen_nccl_id_op_helper SRCS gen_nccl_id_op_helper.cc) + op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common gen_nccl_id_op_helper) + op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common gen_nccl_id_op_helper) endif() if(WITH_GLOO) diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index ed478b1f0a0..93a6b50c4db 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -21,14 +21,12 @@ limitations under the License. */ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/var_type_traits.h" -#include "paddle/fluid/operators/distributed/distributed.h" -#include "paddle/fluid/operators/distributed/request_handler.h" -#include "paddle/fluid/operators/distributed/request_handler_impl.h" -#include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" +#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" + namespace paddle { namespace operators { @@ -42,80 +40,23 @@ class CGenNCCLIdOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - // put nccl id in CPUPlace - auto& dev_ctx = *pool.Get(platform::CPUPlace()); int rank = Attr("rank"); framework::Scope& local_scope = scope.NewScope(); + std::function func = [&](size_t i) -> std::string { + return Output("Out"); + }; + if (rank == 0) { - GenerateAndSend(&local_scope, dev_ctx); + std::vector endpoint_list = + Attr>("other_endpoints"); + SendBroadCastNCCLID(endpoint_list, 1, func, local_scope); } else { - GetIdByServer(&local_scope, dev_ctx); + std::string endpoint = Attr("endpoint"); + RecvBroadCastNCCLID(endpoint, 1, func, local_scope); } scope.DeleteScope(&local_scope); } - - private: - void GenerateAndSend(framework::Scope* scope, - const platform::DeviceContext& dev_ctx) const { - std::string var_name = Output("Out"); - auto var = scope->FindVar(var_name); - PADDLE_ENFORCE_NOT_NULL( - var, platform::errors::InvalidArgument("Output can not be Null")); - auto id = var->GetMutable(); - PADDLE_ENFORCE_EQ(platform::dynload::ncclGetUniqueId(id), 0, - platform::errors::InvalidArgument( - "ncclGetUniqueId failed with id %s", id)); - - std::vector endpoint_list = - Attr>("other_endpoints"); - distributed::RPCClient* client = - distributed::RPCClient::GetInstance(0); - - for (auto& ep : endpoint_list) { - VLOG(3) << "sending nccl id to " << ep; - client->AsyncSendVar(ep, dev_ctx, *scope, var_name); - } - client->Wait(); - for (auto& ep : endpoint_list) { - client->AsyncSendBatchBarrier(ep); - } - client->Wait(); - VLOG(3) << "sending completed..."; - } - - void GetIdByServer(framework::Scope* scope, - const platform::DeviceContext& dev_ctx) const { - std::string endpoint = Attr("endpoint"); - // NOTE: Can not use unique_ptr here because the default - // deleter will call GRPC Server's base class's dtor and - // that will cause a wired crash. - distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync); - std::unique_ptr rpc_service( - new RPCSERVER_T(endpoint, 1)); - - rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); - rpc_h.SetRPCServer(rpc_service.get()); - - framework::ProgramDesc empty_program; - framework::Executor executor(dev_ctx.GetPlace()); - rpc_h.SetScope(scope); - rpc_h.SetDevCtx(&dev_ctx); - rpc_h.SetProgram(&empty_program); - rpc_h.SetExecutor(&executor); - - std::thread server_thread( - std::bind(&distributed::RPCServer::StartServer, rpc_service.get())); - - rpc_service->SetCond(distributed::kRequestSend); - VLOG(3) << "start getting nccl id from trainer 0..."; - rpc_service->WaitBarrier(distributed::kRequestSend); - VLOG(3) << "got nccl id and stop server..."; - rpc_service->ShutDown(); - VLOG(3) << "rpc server stopped"; - server_thread.join(); - } }; class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op.cc b/paddle/fluid/operators/collective/gen_nccl_id_op.cc new file mode 100644 index 00000000000..98b1df9efc9 --- /dev/null +++ b/paddle/fluid/operators/collective/gen_nccl_id_op.cc @@ -0,0 +1,201 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "glog/logging.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/var_type_traits.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" + +#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" + +namespace paddle { +namespace operators { + +class GenNCCLIdOp : public framework::OperatorBase { + public: + GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + std::vector trainers = + Attr>("trainers"); + int trainer_id = Attr("trainer_id"); + std::string endpoint = trainers[trainer_id]; + + PADDLE_ENFORCE_GE(trainer_id, 0, platform::errors::InvalidArgument( + "trainer_id %d is less than 0. Its " + "valid range is [0, trainer_size)")); + PADDLE_ENFORCE_LT( + trainer_id, static_cast(trainers.size()), + platform::errors::OutOfRange("trainer_id %d is out of range. Its valid " + "range is [0, trainer_size)", + trainer_id)); + + int nccl_comm_num = Attr("nccl_comm_num"); + int use_hierarchical_allreduce = Attr("use_hierarchical_allreduce"); + int inter_nranks = Attr("hierarchical_allreduce_inter_nranks"); + int inter_trainer_id = -1; + int exter_trainer_id = -1; + + if (use_hierarchical_allreduce) { + PADDLE_ENFORCE_GT( + trainers.size(), 1, + platform::errors::PreconditionNotMet( + "The number of collective trainers %llu <= 1", trainers.size())); + PADDLE_ENFORCE_GT( + inter_nranks, 1, + platform::errors::PreconditionNotMet( + "inter_nranks %d <= 1 while in hierarchical allreduce mode", + inter_nranks)); + PADDLE_ENFORCE_EQ( + trainers.size() % inter_nranks, 0, + platform::errors::PreconditionNotMet( + "The number of trainers %llu mod inter_nranks %d is not equal 0", + trainers.size(), inter_nranks)); + + inter_trainer_id = trainer_id % inter_nranks; + + if (trainer_id % inter_nranks == 0) { + exter_trainer_id = trainer_id / inter_nranks; + } + } + + std::ostringstream ss; + for (size_t i = 0; i < trainers.size(); i++) { + ss << trainers[i] << ","; + } + + VLOG(1) << "trainer_id:" << trainer_id + << ", use_hierarchical_allreduce:" << use_hierarchical_allreduce + << ", nccl_comm_num:" << nccl_comm_num + << ", inter_nranks:" << inter_nranks + << ", inter_trainer_id:" << inter_trainer_id + << ", exter_trainer_id:" << exter_trainer_id + << ", trainers:" << ss.str(); + + int server_fd = -1; + + /// 1. init flat + std::function func = platform::GetFlatNCCLVarName; + if (trainer_id == 0) { + // server endpoints + std::vector flat_endpoints; + flat_endpoints.insert(flat_endpoints.begin(), trainers.begin() + 1, + trainers.end()); + SendBroadCastNCCLID(flat_endpoints, nccl_comm_num, func, scope); + } else { + server_fd = CreateListenSocket(endpoint); + RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); + } + + /// 2. hierarchical inter ncclid + func = platform::GetHierarchicalInterNCCLVarName; + if (inter_trainer_id == 0) { + std::ostringstream ss; + ss << endpoint; + std::vector inter_endpoints; + for (int i = trainer_id + 1; i < trainer_id + inter_nranks && + i < static_cast(trainers.size()); + i++) { + ss << ","; + inter_endpoints.push_back(trainers[i]); + ss << trainers[i]; + } + VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str(); + + SendBroadCastNCCLID(inter_endpoints, nccl_comm_num, func, scope); + } else if (inter_trainer_id > 0) { + VLOG(1) << "Hierarchical inter ring"; + RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); + } + + /// 3. hierarchical exter ncclid + func = platform::GetHierarchicalExterNCCLVarName; + if (exter_trainer_id == 0) { + std::ostringstream ss; + std::vector exter_endpoints; + ss << endpoint; + for (size_t i = inter_nranks; i < trainers.size(); i += inter_nranks) { + ss << ","; + exter_endpoints.push_back(trainers[i]); + ss << trainers[i]; + } + VLOG(1) << "Hierarchical exter ring endpoints:" << ss.str(); + + SendBroadCastNCCLID(exter_endpoints, nccl_comm_num, func, scope); + } else if (exter_trainer_id > 0) { + VLOG(1) << "Hierarchical exter ring"; + RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); + } + + // close socket server + if (trainer_id != 0) { + CloseSocket(server_fd); + } + } +}; + +class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces."); + AddComment(R"DOC( +GenNCCLId operator + +For trainer 0: generate a new UniqueId and send it to all the other trainers. +For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. +)DOC"); + AddAttr>( + "trainers", + "['trainer0_ip:port', 'trainer1_ip:port', ...] " + "list of all trainer endpoints") + .SetDefault({}); + AddAttr("trainer_id", + "(int) " + "The index of the trainer in distributed training."); + AddAttr("nccl_comm_num", + "(int default 1) " + "The number of nccl communicator num.") + .SetDefault(1); + AddAttr("use_hierarchical_allreduce", + "(bool default false) " + "Wheter to use hierarchical allreduce.") + .SetDefault(false); + AddAttr("hierarchical_allreduce_inter_nranks", + "(int default 1) " + "Wheter to use hierarchical allreduce.") + .SetDefault(-1); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker); diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc new file mode 100644 index 00000000000..f448084019c --- /dev/null +++ b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc @@ -0,0 +1,351 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "glog/logging.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/var_type_traits.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" + +namespace paddle { +namespace operators { + +constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; + +// Check system calls, such as socket, bind. +#define CHECK_SYS_CALL(call, name) \ + do { \ + int retval; \ + CHECK_SYS_CALL_VAL(call, name, retval); \ + } while (false) + +#define CHECK_SYS_CALL_VAL(call, name, retval) \ + do { \ + RETRY_SYS_CALL_VAL(call, name, retval); \ + if (retval == -1) { \ + PADDLE_THROW(platform::errors::Unavailable("Call to %s failed: %s", \ + name, strerror(errno))); \ + } \ + } while (false) + +#define RETRY_SYS_CALL_VAL(call, name, retval) \ + do { \ + retval = (call); \ + if (retval == -1 && \ + (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \ + LOG(WARNING) << "Call " << name << " returned " << strerror(errno) \ + << " retry"; \ + } else { \ + break; \ + } \ + } while (true) + +static int SocketSend(int fd, const char* buffer, int size) { + int offset = 0; + int bytes = 0; + while (offset < size) { + bytes = send(fd, buffer + offset, size - offset, 0); + if (bytes == -1) { + if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { + // send failed + return -1; + } else { + bytes = 0; + } + } + offset += bytes; + } + return offset; +} + +static int SocketRecv(int fd, char* buffer, int size) { + int offset = 0; + int bytes = 0; + while (offset < size) { + bytes = recv(fd, buffer + offset, size - offset, 0); + if (bytes == 0) { + // closed by client, maybe probing alive client + return 0; + } + if (bytes == -1) { + if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { + return -1; + } else { + bytes = 0; + } + } + offset += bytes; + } + return offset; +} + +static void BindOrConnectFailed(int timeout, int* try_times, int* total_time, + const char* op, const std::string& ep) { + PADDLE_ENFORCE_LT( + *total_time, timeout, + platform::errors::Unavailable("%s addr=%s timeout, failed reason: %s", op, + ep.c_str(), strerror(errno))); + ++(*try_times); + int retry_time = std::min(*try_times * 500, 3000); // max 3 seconds + *total_time += retry_time; + + LOG(WARNING) << op << " addr=" << ep << " failed " << *try_times + << " times with reason: " << strerror(errno) << " retry after " + << retry_time / 1000.0 << " seconds"; + std::this_thread::sleep_for(std::chrono::milliseconds(retry_time)); +} + +int CreateListenSocket(const std::string& ep) { + auto addr = paddle::string::Split(ep, ':'); + PADDLE_ENFORCE_EQ( + addr.size(), 2UL, + platform::errors::InvalidArgument( + "The endpoint should contain host and port, but got %s.", ep)); + std::string host = addr[0]; + int port = std::stoi(addr[1]); + + // creating socket fd + int server_fd = -1; + CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", server_fd); + + // NOTE. Solutions to `Address already in use`. + // 1. Reuse addr&port. Otherwise, once the server closes the socket + // before client, the server will enter TIME-WAIT status. If we bind port + // again, the error `Address already in use` will appear. + // 2. Or we can close the client first to ensure that the server does + // not enter the TIME-WAIT state. But this is obviously not as convenient + // as the reuse method. + int opt = 1; +#if defined(SO_REUSEPORT) + // since Linux kernel 3.9 + CHECK_SYS_CALL(setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, + &opt, sizeof(opt)), + "setsockopt"); +#else + CHECK_SYS_CALL( + setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), + "setsockopt"); +#endif + + struct sockaddr_in address; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(port); + + // TODO(wangxi) Set from env, default 900s=15min + int timeout = 900 * 1000; + int try_times = 0; + int total_time = 0; + while (true) { + int ret_val = -1; + RETRY_SYS_CALL_VAL( + bind(server_fd, (struct sockaddr*)&address, sizeof(address)), "bind", + ret_val); + + if (ret_val == -1) { + BindOrConnectFailed(timeout, &try_times, &total_time, "bind", ep); + continue; + } + break; + } + + CHECK_SYS_CALL(listen(server_fd, 3), "listen"); + LOG(INFO) << "Server listening on: " << ep << " successful."; + return server_fd; +} + +void CloseSocket(int fd) { CHECK_SYS_CALL(close(fd), "close"); } + +static int SocketAccept(int server_fd, const char* head) { + struct sockaddr_in client_addr; + socklen_t addr_length = sizeof(client_addr); + char buffer[1024] = {0}; + int conn = -1; + + while (true) { + CHECK_SYS_CALL_VAL( + accept(server_fd, reinterpret_cast(&client_addr), + &addr_length), + "accept", conn); + + int ret_val = SocketRecv(conn, buffer, strlen(head)); + if (ret_val > 0 && strncmp(buffer, head, strlen(head)) == 0) { + break; // accept client + } else { + VLOG(3) << "socket read failed with ret_val=" << ret_val; + CloseSocket(conn); + } + } + return conn; +} + +static int ConnectAddr(const std::string& ep, const char* head) { + auto addr = paddle::string::Split(ep, ':'); + PADDLE_ENFORCE_EQ( + addr.size(), 2UL, + platform::errors::InvalidArgument( + "The endpoint should contain host and port, but got %s.", ep)); + std::string host = addr[0]; + int port = std::stoi(addr[1]); + + int sock = -1; + CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock); + + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(port); + + char* ip = NULL; + struct hostent* hp = NULL; + hp = gethostbyname(host.c_str()); + PADDLE_ENFORCE_NOT_NULL(hp, platform::errors::InvalidArgument( + "Fail to get host by name %s.", host)); + + int i = 0; + while (hp->h_addr_list[i] != NULL) { + ip = inet_ntoa(*(struct in_addr*)hp->h_addr_list[i]); + VLOG(3) << "gethostbyname host:" << host << " ->ip: " << ip; + break; + } + + PADDLE_ENFORCE_GT(inet_pton(AF_INET, ip, &server_addr.sin_addr), 0, + platform::errors::Unavailable("Open address %s failed: %s", + ep, strerror(errno))); + + // TODO(wangxi) Set from env, default 900s=15min + int timeout = 900 * 1000; + int try_times = 0; + int total_time = 0; + while (true) { + int ret_val = -1; + RETRY_SYS_CALL_VAL( + connect(sock, (struct sockaddr*)&server_addr, sizeof(server_addr)), + "connect", ret_val); + + if (ret_val == -1) { + BindOrConnectFailed(timeout, &try_times, &total_time, "connect", ep); + continue; + } + + CHECK_SYS_CALL(SocketSend(sock, head, strlen(head)), "send"); + break; + } + return sock; +} + +static void RecvNCCLID(int conn, ncclUniqueId* nccl_id) { + char buffer[1024] = {0}; + static_assert(NCCL_UNIQUE_ID_BYTES <= 1024, + "nccl id bytes must <= buffer size"); + + CHECK_SYS_CALL(SocketRecv(conn, buffer, NCCL_UNIQUE_ID_BYTES), "recv ncc id"); + memcpy(nccl_id, buffer, NCCL_UNIQUE_ID_BYTES); +} + +static void SendNCCLID(int conn, ncclUniqueId* nccl_id) { + char buffer[1024] = {0}; + memcpy(buffer, nccl_id, NCCL_UNIQUE_ID_BYTES); + + CHECK_SYS_CALL(SocketSend(conn, buffer, NCCL_UNIQUE_ID_BYTES), + "send nccl id"); +} + +void SendBroadCastNCCLID(std::vector servers, int nccl_comm_num, + std::function func, + const framework::Scope& scope) { + // connect with server + std::vector connects; + for (auto server : servers) { + VLOG(3) << "connecting endpoint: " << server; + int conn = ConnectAddr(server, COMM_HEAD); + connects.push_back(conn); + } + VLOG(3) << "connecting completed..."; + + for (int i = 0; i < nccl_comm_num; ++i) { + std::string var_name = func(i); + auto var = scope.FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("Variable with name %s is not found", + var_name.c_str())); + auto nccl_id = var->GetMutable(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(nccl_id)); + + int j = 0; + for (auto conn : connects) { + VLOG(3) << "sending nccl_id_var: " << var_name << " to " << servers[j] + << " nccl_comm_no: " << i; + SendNCCLID(conn, nccl_id); + ++j; + } + VLOG(3) << "sending completed..."; + } + + // close client + for (auto conn : connects) { + CloseSocket(conn); + } +} + +void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, + std::function func, + const framework::Scope& scope) { + int server = CreateListenSocket(endpoint); + RecvBroadCastNCCLID(server, endpoint, nccl_comm_num, func, scope); + CloseSocket(server); +} + +void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, + std::function func, + const framework::Scope& scope) { + int client = SocketAccept(server_fd, COMM_HEAD); + + for (int i = 0; i < nccl_comm_num; ++i) { + std::string var_name = func(i); + auto var = scope.FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("Variable with name %s is not found", + var_name.c_str())); + auto nccl_id = var->GetMutable(); + + VLOG(3) << "trainer: " << endpoint << " receiving nccl_id_var: " << var_name + << " from trainer 0, nccl_comm_no: " << i; + RecvNCCLID(client, nccl_id); + } + VLOG(3) << "receiving completed..."; + CloseSocket(client); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h new file mode 100644 index 00000000000..38751805191 --- /dev/null +++ b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { + +int CreateListenSocket(const std::string& ep); + +void CloseSocket(int fd); + +void SendBroadCastNCCLID(std::vector servers, int nccl_comm_num, + std::function func, + const framework::Scope& scope); + +// server listen on endpoint, then recv nccl id +void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, + std::function func, + const framework::Scope& scope); + +// recv nccl id from socket +void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, + std::function func, + const framework::Scope& scope); +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index 79f14d75d27..ec48a51baa2 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -32,7 +32,6 @@ register_operators(EXCLUDES gen_nccl_id_op DEPS ${DISTRIBUTE_DEPS}) if(WITH_NCCL) set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} nccl_common) - op_library(gen_nccl_id_op DEPS ${DISTRIBUTE_DEPS} nccl_common) endif() set(OPERATOR_DEPS ${OPERATOR_DEPS} ${DISTRIBUTE_DEPS} PARENT_SCOPE) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 9a17160ee03..b748e220826 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -18,6 +18,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer) list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) list(APPEND DIST_TEST_OPS test_listen_and_serv_op) list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer) +list(APPEND DIST_TEST_OPS test_gen_nccl_id_op) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 19d9031573d..29ac46e81d8 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -945,7 +945,7 @@ class TestDistBase(unittest.TestCase): tr_cmd += " --use_cuda" env.update({ "FLAGS_selected_gpus": "{}".format(0), - "CUDA_VISIBLE_DEVICES": "{}".format(trainer_id % 2), + "CUDA_VISIBLE_DEVICES": "{}".format(trainer_id), "PADDLE_TRAINERS_NUM": "{}".format(trainer_num), "PADDLE_TRAINER_ID": "{}".format(trainer_id), "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, @@ -960,7 +960,7 @@ class TestDistBase(unittest.TestCase): if self._pipeline_mode: tr_cmd += " --use_pipeline" if self._mp_mode: - env = {"FLAGS_selected_gpus": "{}".format(trainer_id % 2)} + env = {"FLAGS_selected_gpus": "{}".format(trainer_id)} if self._nccl_comm_num > 1: tr_cmd += " --nccl_comm_num {}".format(self._nccl_comm_num) @@ -992,6 +992,7 @@ class TestDistBase(unittest.TestCase): global DIST_UT_PORT if DIST_UT_PORT == 0: + # NOTE(wangxi). hallreduce test must use 4cards after nccl>=2.7 for i in range(0, 4): self._ps_endpoints += "127.0.0.1:%s," % ( self._find_free_port()) @@ -1110,7 +1111,8 @@ class TestDistBase(unittest.TestCase): required_envs["GLOG_vmodule"] = \ "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \ "alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \ - "sparse_all_reduce_op_handle=10,gen_nccl_id_op=10,nccl_helper=10,grpc_client=10,grpc_server=10,request_handler_impl=10" + "sparse_all_reduce_op_handle=10,gen_nccl_id_op=10,gen_nccl_id_op_help=10,nccl_helper=10,grpc_client=10," \ + "grpc_server=10,request_handler_impl=10" required_envs["GLOG_logtostderr"] = "1" required_envs.update(need_envs) diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_hallreduce.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_hallreduce.py index 356c5573f95..e1fbbebe171 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist_hallreduce.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_hallreduce.py @@ -29,6 +29,7 @@ class TestDistMnistNCCL2HAllreduce(TestDistBase): self._use_reduce = False self._use_reader_alloc = False self._nccl2_mode = True + # NOTE(wangxi). hallreduce test must use 4cards after nccl>=2.7 self._use_hallreduce = True def test_dist_train(self): diff --git a/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py b/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py new file mode 100644 index 00000000000..bd186e09006 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py @@ -0,0 +1,118 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +from launch_function_helper import wait, _find_free_port +from multiprocessing import Pool, Process + +os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10") + +import paddle +from paddle.fluid import core + +paddle.enable_static() + + +def run_gen_ncc_id(attr): + nccl_comm_num = attr['nccl_comm_num'] + use_hallreduce = attr['use_hierarchical_allreduce'] + + startup_program = paddle.static.default_startup_program() + main_program = paddle.static.default_main_program() + + with paddle.static.program_guard(main_program, startup_program): + nccl_id_var = startup_program.global_block().create_var( + name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) + + for i in range(1, nccl_comm_num): + startup_program.global_block().create_var( + name="NCCLID_{}".format(i), + persistable=True, + type=core.VarDesc.VarType.RAW) + + if use_hallreduce: + for i in range(0, nccl_comm_num): + startup_program.global_block().create_var( + name="Hierarchical_inter_NCCLID_{}".format(i), + persistable=True, + type=core.VarDesc.VarType.RAW) + startup_program.global_block().create_var( + name="Hierarchical_exter_NCCLID_{}".format(i), + persistable=True, + type=core.VarDesc.VarType.RAW) + + startup_program.global_block().append_op( + type="gen_nccl_id", + inputs={}, + outputs={"NCCLID": nccl_id_var}, + attrs=attr) + + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_program) + + +class TestGenNcclIdOp(unittest.TestCase): + def setUp(self): + try: + self._dist_ut_port_0 = int(os.environ["PADDLE_DIST_UT_PORT"]) + except Exception as e: + self._dist_ut_port_0 = _find_free_port(set()) + + def gen_nccl_id(self, nranks=2): + nccl_comm_num = 1 + if nranks == 2: + use_hallreduce = False + hallreduce_inter_nranks = -1 + elif nranks == 4: + use_hallreduce = True + hallreduce_inter_nranks = 2 + + port = self._dist_ut_port_0 + trainers = [] + for i in range(nranks): + trainers.append('127.0.0.1:{}'.format(port + i)) + + attr = { + "trainers": trainers, + "trainer_id": 0, + "nccl_comm_num": nccl_comm_num, + "use_hierarchical_allreduce": use_hallreduce, + "hierarchical_allreduce_inter_nranks": hallreduce_inter_nranks, + } + + procs = [] + for i in range(nranks): + attr['trainer_id'] = i + p = Process(target=run_gen_ncc_id, args=(attr, )) + p.start() + procs.append(p) + + wait(procs, timeout=120) + + def test_flat(self): + print(">>> test gen flat nccl id") + self.gen_nccl_id(2) + print("<<< end test gen flat nccl id") + + def test_hierarchical(self): + print(">>> test gen hierarchical nccl id") + self.gen_nccl_id(4) + print("<<< end test gen hierarchical nccl id") + + +if __name__ == "__main__": + unittest.main() -- GitLab