未验证 提交 467c7169 编写于 作者: W WangXi 提交者: GitHub

gen nccl id use socket (#29431)

上级 d72604cd
...@@ -28,11 +28,13 @@ foreach(src ${OPS}) ...@@ -28,11 +28,13 @@ foreach(src ${OPS})
set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS}) set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS})
endforeach() endforeach()
register_operators(EXCLUDES c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) register_operators(EXCLUDES c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
if(WITH_NCCL) if(WITH_NCCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper)
op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common) cc_library(gen_nccl_id_op_helper SRCS gen_nccl_id_op_helper.cc)
op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common gen_nccl_id_op_helper)
op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common gen_nccl_id_op_helper)
endif() endif()
if(WITH_GLOO) if(WITH_GLOO)
......
...@@ -21,14 +21,12 @@ limitations under the License. */ ...@@ -21,14 +21,12 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h" #include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,80 +40,23 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -42,80 +40,23 @@ class CGenNCCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
// put nccl id in CPUPlace
auto& dev_ctx = *pool.Get(platform::CPUPlace());
int rank = Attr<int>("rank"); int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope(); framework::Scope& local_scope = scope.NewScope();
std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
};
if (rank == 0) { if (rank == 0) {
GenerateAndSend(&local_scope, dev_ctx); std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
SendBroadCastNCCLID(endpoint_list, 1, func, local_scope);
} else { } else {
GetIdByServer(&local_scope, dev_ctx); std::string endpoint = Attr<std::string>("endpoint");
RecvBroadCastNCCLID(endpoint, 1, func, local_scope);
} }
scope.DeleteScope(&local_scope); scope.DeleteScope(&local_scope);
} }
private:
void GenerateAndSend(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const {
std::string var_name = Output("Out");
auto var = scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Output can not be Null"));
auto id = var->GetMutable<ncclUniqueId>();
PADDLE_ENFORCE_EQ(platform::dynload::ncclGetUniqueId(id), 0,
platform::errors::InvalidArgument(
"ncclGetUniqueId failed with id %s", id));
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep;
client->AsyncSendVar(ep, dev_ctx, *scope, var_name);
}
client->Wait();
for (auto& ep : endpoint_list) {
client->AsyncSendBatchBarrier(ep);
}
client->Wait();
VLOG(3) << "sending completed...";
}
void GetIdByServer(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const {
std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1));
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get());
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor);
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0...";
rpc_service->WaitBarrier(distributed::kRequestSend);
VLOG(3) << "got nccl id and stop server...";
rpc_service->ShutDown();
VLOG(3) << "rpc server stopped";
server_thread.join();
}
}; };
class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ostream>
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h"
namespace paddle {
namespace operators {
class GenNCCLIdOp : public framework::OperatorBase {
public:
GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
std::vector<std::string> trainers =
Attr<std::vector<std::string>>("trainers");
int trainer_id = Attr<int>("trainer_id");
std::string endpoint = trainers[trainer_id];
PADDLE_ENFORCE_GE(trainer_id, 0, platform::errors::InvalidArgument(
"trainer_id %d is less than 0. Its "
"valid range is [0, trainer_size)"));
PADDLE_ENFORCE_LT(
trainer_id, static_cast<int>(trainers.size()),
platform::errors::OutOfRange("trainer_id %d is out of range. Its valid "
"range is [0, trainer_size)",
trainer_id));
int nccl_comm_num = Attr<int>("nccl_comm_num");
int use_hierarchical_allreduce = Attr<bool>("use_hierarchical_allreduce");
int inter_nranks = Attr<int>("hierarchical_allreduce_inter_nranks");
int inter_trainer_id = -1;
int exter_trainer_id = -1;
if (use_hierarchical_allreduce) {
PADDLE_ENFORCE_GT(
trainers.size(), 1,
platform::errors::PreconditionNotMet(
"The number of collective trainers %llu <= 1", trainers.size()));
PADDLE_ENFORCE_GT(
inter_nranks, 1,
platform::errors::PreconditionNotMet(
"inter_nranks %d <= 1 while in hierarchical allreduce mode",
inter_nranks));
PADDLE_ENFORCE_EQ(
trainers.size() % inter_nranks, 0,
platform::errors::PreconditionNotMet(
"The number of trainers %llu mod inter_nranks %d is not equal 0",
trainers.size(), inter_nranks));
inter_trainer_id = trainer_id % inter_nranks;
if (trainer_id % inter_nranks == 0) {
exter_trainer_id = trainer_id / inter_nranks;
}
}
std::ostringstream ss;
for (size_t i = 0; i < trainers.size(); i++) {
ss << trainers[i] << ",";
}
VLOG(1) << "trainer_id:" << trainer_id
<< ", use_hierarchical_allreduce:" << use_hierarchical_allreduce
<< ", nccl_comm_num:" << nccl_comm_num
<< ", inter_nranks:" << inter_nranks
<< ", inter_trainer_id:" << inter_trainer_id
<< ", exter_trainer_id:" << exter_trainer_id
<< ", trainers:" << ss.str();
int server_fd = -1;
/// 1. init flat
std::function<std::string(size_t)> func = platform::GetFlatNCCLVarName;
if (trainer_id == 0) {
// server endpoints
std::vector<std::string> flat_endpoints;
flat_endpoints.insert(flat_endpoints.begin(), trainers.begin() + 1,
trainers.end());
SendBroadCastNCCLID(flat_endpoints, nccl_comm_num, func, scope);
} else {
server_fd = CreateListenSocket(endpoint);
RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope);
}
/// 2. hierarchical inter ncclid
func = platform::GetHierarchicalInterNCCLVarName;
if (inter_trainer_id == 0) {
std::ostringstream ss;
ss << endpoint;
std::vector<std::string> inter_endpoints;
for (int i = trainer_id + 1; i < trainer_id + inter_nranks &&
i < static_cast<int>(trainers.size());
i++) {
ss << ",";
inter_endpoints.push_back(trainers[i]);
ss << trainers[i];
}
VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str();
SendBroadCastNCCLID(inter_endpoints, nccl_comm_num, func, scope);
} else if (inter_trainer_id > 0) {
VLOG(1) << "Hierarchical inter ring";
RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope);
}
/// 3. hierarchical exter ncclid
func = platform::GetHierarchicalExterNCCLVarName;
if (exter_trainer_id == 0) {
std::ostringstream ss;
std::vector<std::string> exter_endpoints;
ss << endpoint;
for (size_t i = inter_nranks; i < trainers.size(); i += inter_nranks) {
ss << ",";
exter_endpoints.push_back(trainers[i]);
ss << trainers[i];
}
VLOG(1) << "Hierarchical exter ring endpoints:" << ss.str();
SendBroadCastNCCLID(exter_endpoints, nccl_comm_num, func, scope);
} else if (exter_trainer_id > 0) {
VLOG(1) << "Hierarchical exter ring";
RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope);
}
// close socket server
if (trainer_id != 0) {
CloseSocket(server_fd);
}
}
};
class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
GenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::vector<std::string>>(
"trainers",
"['trainer0_ip:port', 'trainer1_ip:port', ...] "
"list of all trainer endpoints")
.SetDefault({});
AddAttr<int>("trainer_id",
"(int) "
"The index of the trainer in distributed training.");
AddAttr<int>("nccl_comm_num",
"(int default 1) "
"The number of nccl communicator num.")
.SetDefault(1);
AddAttr<bool>("use_hierarchical_allreduce",
"(bool default false) "
"Wheter to use hierarchical allreduce.")
.SetDefault(false);
AddAttr<int>("hierarchical_allreduce_inter_nranks",
"(int default 1) "
"Wheter to use hierarchical allreduce.")
.SetDefault(-1);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h"
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <algorithm>
#include <ostream>
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace operators {
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_";
// Check system calls, such as socket, bind.
#define CHECK_SYS_CALL(call, name) \
do { \
int retval; \
CHECK_SYS_CALL_VAL(call, name, retval); \
} while (false)
#define CHECK_SYS_CALL_VAL(call, name, retval) \
do { \
RETRY_SYS_CALL_VAL(call, name, retval); \
if (retval == -1) { \
PADDLE_THROW(platform::errors::Unavailable("Call to %s failed: %s", \
name, strerror(errno))); \
} \
} while (false)
#define RETRY_SYS_CALL_VAL(call, name, retval) \
do { \
retval = (call); \
if (retval == -1 && \
(errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \
LOG(WARNING) << "Call " << name << " returned " << strerror(errno) \
<< " retry"; \
} else { \
break; \
} \
} while (true)
static int SocketSend(int fd, const char* buffer, int size) {
int offset = 0;
int bytes = 0;
while (offset < size) {
bytes = send(fd, buffer + offset, size - offset, 0);
if (bytes == -1) {
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
// send failed
return -1;
} else {
bytes = 0;
}
}
offset += bytes;
}
return offset;
}
static int SocketRecv(int fd, char* buffer, int size) {
int offset = 0;
int bytes = 0;
while (offset < size) {
bytes = recv(fd, buffer + offset, size - offset, 0);
if (bytes == 0) {
// closed by client, maybe probing alive client
return 0;
}
if (bytes == -1) {
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
return -1;
} else {
bytes = 0;
}
}
offset += bytes;
}
return offset;
}
static void BindOrConnectFailed(int timeout, int* try_times, int* total_time,
const char* op, const std::string& ep) {
PADDLE_ENFORCE_LT(
*total_time, timeout,
platform::errors::Unavailable("%s addr=%s timeout, failed reason: %s", op,
ep.c_str(), strerror(errno)));
++(*try_times);
int retry_time = std::min(*try_times * 500, 3000); // max 3 seconds
*total_time += retry_time;
LOG(WARNING) << op << " addr=" << ep << " failed " << *try_times
<< " times with reason: " << strerror(errno) << " retry after "
<< retry_time / 1000.0 << " seconds";
std::this_thread::sleep_for(std::chrono::milliseconds(retry_time));
}
int CreateListenSocket(const std::string& ep) {
auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ(
addr.size(), 2UL,
platform::errors::InvalidArgument(
"The endpoint should contain host and port, but got %s.", ep));
std::string host = addr[0];
int port = std::stoi(addr[1]);
// creating socket fd
int server_fd = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", server_fd);
// NOTE. Solutions to `Address already in use`.
// 1. Reuse addr&port. Otherwise, once the server closes the socket
// before client, the server will enter TIME-WAIT status. If we bind port
// again, the error `Address already in use` will appear.
// 2. Or we can close the client first to ensure that the server does
// not enter the TIME-WAIT state. But this is obviously not as convenient
// as the reuse method.
int opt = 1;
#if defined(SO_REUSEPORT)
// since Linux kernel 3.9
CHECK_SYS_CALL(setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT,
&opt, sizeof(opt)),
"setsockopt");
#else
CHECK_SYS_CALL(
setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)),
"setsockopt");
#endif
struct sockaddr_in address;
address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY;
address.sin_port = htons(port);
// TODO(wangxi) Set from env, default 900s=15min
int timeout = 900 * 1000;
int try_times = 0;
int total_time = 0;
while (true) {
int ret_val = -1;
RETRY_SYS_CALL_VAL(
bind(server_fd, (struct sockaddr*)&address, sizeof(address)), "bind",
ret_val);
if (ret_val == -1) {
BindOrConnectFailed(timeout, &try_times, &total_time, "bind", ep);
continue;
}
break;
}
CHECK_SYS_CALL(listen(server_fd, 3), "listen");
LOG(INFO) << "Server listening on: " << ep << " successful.";
return server_fd;
}
void CloseSocket(int fd) { CHECK_SYS_CALL(close(fd), "close"); }
static int SocketAccept(int server_fd, const char* head) {
struct sockaddr_in client_addr;
socklen_t addr_length = sizeof(client_addr);
char buffer[1024] = {0};
int conn = -1;
while (true) {
CHECK_SYS_CALL_VAL(
accept(server_fd, reinterpret_cast<struct sockaddr*>(&client_addr),
&addr_length),
"accept", conn);
int ret_val = SocketRecv(conn, buffer, strlen(head));
if (ret_val > 0 && strncmp(buffer, head, strlen(head)) == 0) {
break; // accept client
} else {
VLOG(3) << "socket read failed with ret_val=" << ret_val;
CloseSocket(conn);
}
}
return conn;
}
static int ConnectAddr(const std::string& ep, const char* head) {
auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ(
addr.size(), 2UL,
platform::errors::InvalidArgument(
"The endpoint should contain host and port, but got %s.", ep));
std::string host = addr[0];
int port = std::stoi(addr[1]);
int sock = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock);
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port);
char* ip = NULL;
struct hostent* hp = NULL;
hp = gethostbyname(host.c_str());
PADDLE_ENFORCE_NOT_NULL(hp, platform::errors::InvalidArgument(
"Fail to get host by name %s.", host));
int i = 0;
while (hp->h_addr_list[i] != NULL) {
ip = inet_ntoa(*(struct in_addr*)hp->h_addr_list[i]);
VLOG(3) << "gethostbyname host:" << host << " ->ip: " << ip;
break;
}
PADDLE_ENFORCE_GT(inet_pton(AF_INET, ip, &server_addr.sin_addr), 0,
platform::errors::Unavailable("Open address %s failed: %s",
ep, strerror(errno)));
// TODO(wangxi) Set from env, default 900s=15min
int timeout = 900 * 1000;
int try_times = 0;
int total_time = 0;
while (true) {
int ret_val = -1;
RETRY_SYS_CALL_VAL(
connect(sock, (struct sockaddr*)&server_addr, sizeof(server_addr)),
"connect", ret_val);
if (ret_val == -1) {
BindOrConnectFailed(timeout, &try_times, &total_time, "connect", ep);
continue;
}
CHECK_SYS_CALL(SocketSend(sock, head, strlen(head)), "send");
break;
}
return sock;
}
static void RecvNCCLID(int conn, ncclUniqueId* nccl_id) {
char buffer[1024] = {0};
static_assert(NCCL_UNIQUE_ID_BYTES <= 1024,
"nccl id bytes must <= buffer size");
CHECK_SYS_CALL(SocketRecv(conn, buffer, NCCL_UNIQUE_ID_BYTES), "recv ncc id");
memcpy(nccl_id, buffer, NCCL_UNIQUE_ID_BYTES);
}
static void SendNCCLID(int conn, ncclUniqueId* nccl_id) {
char buffer[1024] = {0};
memcpy(buffer, nccl_id, NCCL_UNIQUE_ID_BYTES);
CHECK_SYS_CALL(SocketSend(conn, buffer, NCCL_UNIQUE_ID_BYTES),
"send nccl id");
}
void SendBroadCastNCCLID(std::vector<std::string> servers, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
// connect with server
std::vector<int> connects;
for (auto server : servers) {
VLOG(3) << "connecting endpoint: " << server;
int conn = ConnectAddr(server, COMM_HEAD);
connects.push_back(conn);
}
VLOG(3) << "connecting completed...";
for (int i = 0; i < nccl_comm_num; ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto nccl_id = var->GetMutable<ncclUniqueId>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(nccl_id));
int j = 0;
for (auto conn : connects) {
VLOG(3) << "sending nccl_id_var: " << var_name << " to " << servers[j]
<< " nccl_comm_no: " << i;
SendNCCLID(conn, nccl_id);
++j;
}
VLOG(3) << "sending completed...";
}
// close client
for (auto conn : connects) {
CloseSocket(conn);
}
}
void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
int server = CreateListenSocket(endpoint);
RecvBroadCastNCCLID(server, endpoint, nccl_comm_num, func, scope);
CloseSocket(server);
}
void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
int client = SocketAccept(server_fd, COMM_HEAD);
for (int i = 0; i < nccl_comm_num; ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto nccl_id = var->GetMutable<ncclUniqueId>();
VLOG(3) << "trainer: " << endpoint << " receiving nccl_id_var: " << var_name
<< " from trainer 0, nccl_comm_no: " << i;
RecvNCCLID(client, nccl_id);
}
VLOG(3) << "receiving completed...";
CloseSocket(client);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <string>
#include <vector>
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
int CreateListenSocket(const std::string& ep);
void CloseSocket(int fd);
void SendBroadCastNCCLID(std::vector<std::string> servers, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope);
// server listen on endpoint, then recv nccl id
void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope);
// recv nccl id from socket
void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope);
} // namespace operators
} // namespace paddle
...@@ -32,7 +32,6 @@ register_operators(EXCLUDES gen_nccl_id_op DEPS ${DISTRIBUTE_DEPS}) ...@@ -32,7 +32,6 @@ register_operators(EXCLUDES gen_nccl_id_op DEPS ${DISTRIBUTE_DEPS})
if(WITH_NCCL) if(WITH_NCCL)
set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} nccl_common) set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} nccl_common)
op_library(gen_nccl_id_op DEPS ${DISTRIBUTE_DEPS} nccl_common)
endif() endif()
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${DISTRIBUTE_DEPS} PARENT_SCOPE) set(OPERATOR_DEPS ${OPERATOR_DEPS} ${DISTRIBUTE_DEPS} PARENT_SCOPE)
......
...@@ -18,6 +18,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer) ...@@ -18,6 +18,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer)
list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
list(APPEND DIST_TEST_OPS test_listen_and_serv_op) list(APPEND DIST_TEST_OPS test_listen_and_serv_op)
list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer) list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer)
list(APPEND DIST_TEST_OPS test_gen_nccl_id_op)
set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS})
#remove distribute unittests. #remove distribute unittests.
list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) list(APPEND MIXED_DIST_TEST_OPS test_dgc_op)
......
...@@ -945,7 +945,7 @@ class TestDistBase(unittest.TestCase): ...@@ -945,7 +945,7 @@ class TestDistBase(unittest.TestCase):
tr_cmd += " --use_cuda" tr_cmd += " --use_cuda"
env.update({ env.update({
"FLAGS_selected_gpus": "{}".format(0), "FLAGS_selected_gpus": "{}".format(0),
"CUDA_VISIBLE_DEVICES": "{}".format(trainer_id % 2), "CUDA_VISIBLE_DEVICES": "{}".format(trainer_id),
"PADDLE_TRAINERS_NUM": "{}".format(trainer_num), "PADDLE_TRAINERS_NUM": "{}".format(trainer_num),
"PADDLE_TRAINER_ID": "{}".format(trainer_id), "PADDLE_TRAINER_ID": "{}".format(trainer_id),
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
...@@ -960,7 +960,7 @@ class TestDistBase(unittest.TestCase): ...@@ -960,7 +960,7 @@ class TestDistBase(unittest.TestCase):
if self._pipeline_mode: if self._pipeline_mode:
tr_cmd += " --use_pipeline" tr_cmd += " --use_pipeline"
if self._mp_mode: if self._mp_mode:
env = {"FLAGS_selected_gpus": "{}".format(trainer_id % 2)} env = {"FLAGS_selected_gpus": "{}".format(trainer_id)}
if self._nccl_comm_num > 1: if self._nccl_comm_num > 1:
tr_cmd += " --nccl_comm_num {}".format(self._nccl_comm_num) tr_cmd += " --nccl_comm_num {}".format(self._nccl_comm_num)
...@@ -992,6 +992,7 @@ class TestDistBase(unittest.TestCase): ...@@ -992,6 +992,7 @@ class TestDistBase(unittest.TestCase):
global DIST_UT_PORT global DIST_UT_PORT
if DIST_UT_PORT == 0: if DIST_UT_PORT == 0:
# NOTE(wangxi). hallreduce test must use 4cards after nccl>=2.7
for i in range(0, 4): for i in range(0, 4):
self._ps_endpoints += "127.0.0.1:%s," % ( self._ps_endpoints += "127.0.0.1:%s," % (
self._find_free_port()) self._find_free_port())
...@@ -1110,7 +1111,8 @@ class TestDistBase(unittest.TestCase): ...@@ -1110,7 +1111,8 @@ class TestDistBase(unittest.TestCase):
required_envs["GLOG_vmodule"] = \ required_envs["GLOG_vmodule"] = \
"fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \ "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \
"alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \ "alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \
"sparse_all_reduce_op_handle=10,gen_nccl_id_op=10,nccl_helper=10,grpc_client=10,grpc_server=10,request_handler_impl=10" "sparse_all_reduce_op_handle=10,gen_nccl_id_op=10,gen_nccl_id_op_help=10,nccl_helper=10,grpc_client=10," \
"grpc_server=10,request_handler_impl=10"
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
required_envs.update(need_envs) required_envs.update(need_envs)
......
...@@ -29,6 +29,7 @@ class TestDistMnistNCCL2HAllreduce(TestDistBase): ...@@ -29,6 +29,7 @@ class TestDistMnistNCCL2HAllreduce(TestDistBase):
self._use_reduce = False self._use_reduce = False
self._use_reader_alloc = False self._use_reader_alloc = False
self._nccl2_mode = True self._nccl2_mode = True
# NOTE(wangxi). hallreduce test must use 4cards after nccl>=2.7
self._use_hallreduce = True self._use_hallreduce = True
def test_dist_train(self): def test_dist_train(self):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
from launch_function_helper import wait, _find_free_port
from multiprocessing import Pool, Process
os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10")
import paddle
from paddle.fluid import core
paddle.enable_static()
def run_gen_ncc_id(attr):
nccl_comm_num = attr['nccl_comm_num']
use_hallreduce = attr['use_hierarchical_allreduce']
startup_program = paddle.static.default_startup_program()
main_program = paddle.static.default_main_program()
with paddle.static.program_guard(main_program, startup_program):
nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
for i in range(1, nccl_comm_num):
startup_program.global_block().create_var(
name="NCCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
if use_hallreduce:
for i in range(0, nccl_comm_num):
startup_program.global_block().create_var(
name="Hierarchical_inter_NCCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
startup_program.global_block().create_var(
name="Hierarchical_exter_NCCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
startup_program.global_block().append_op(
type="gen_nccl_id",
inputs={},
outputs={"NCCLID": nccl_id_var},
attrs=attr)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_program)
class TestGenNcclIdOp(unittest.TestCase):
def setUp(self):
try:
self._dist_ut_port_0 = int(os.environ["PADDLE_DIST_UT_PORT"])
except Exception as e:
self._dist_ut_port_0 = _find_free_port(set())
def gen_nccl_id(self, nranks=2):
nccl_comm_num = 1
if nranks == 2:
use_hallreduce = False
hallreduce_inter_nranks = -1
elif nranks == 4:
use_hallreduce = True
hallreduce_inter_nranks = 2
port = self._dist_ut_port_0
trainers = []
for i in range(nranks):
trainers.append('127.0.0.1:{}'.format(port + i))
attr = {
"trainers": trainers,
"trainer_id": 0,
"nccl_comm_num": nccl_comm_num,
"use_hierarchical_allreduce": use_hallreduce,
"hierarchical_allreduce_inter_nranks": hallreduce_inter_nranks,
}
procs = []
for i in range(nranks):
attr['trainer_id'] = i
p = Process(target=run_gen_ncc_id, args=(attr, ))
p.start()
procs.append(p)
wait(procs, timeout=120)
def test_flat(self):
print(">>> test gen flat nccl id")
self.gen_nccl_id(2)
print("<<< end test gen flat nccl id")
def test_hierarchical(self):
print(">>> test gen hierarchical nccl id")
self.gen_nccl_id(4)
print("<<< end test gen hierarchical nccl id")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册