提交 f5840d89 编写于 作者: T typhoonzero

follow comments

上级 04bde96e
...@@ -80,7 +80,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -80,7 +80,7 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *nccl_id_var = scope->FindVar("NCCLID"); auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
ncclUniqueId *nccl_id = nullptr; ncclUniqueId *nccl_id = nullptr;
if (nccl_id_var != nullptr) { if (nccl_id_var != nullptr) {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>(); nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
......
...@@ -187,7 +187,7 @@ if(WITH_DISTRIBUTE) ...@@ -187,7 +187,7 @@ if(WITH_DISTRIBUTE)
if(WITH_GPU) if(WITH_GPU)
op_library(gen_nccl_id_op DEPS nccl_common) op_library(gen_nccl_id_op DEPS nccl_common)
else() else()
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif() endif()
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......
...@@ -162,8 +162,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -162,8 +162,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<ncclUniqueId>()) { if (var->IsType<ncclUniqueId>()) {
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES); NCCL_UNIQUE_ID_BYTES);
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>(); ncclUniqueId& uid = var->Get<ncclUniqueId>();
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES)); e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
// for serialize NCCL_ID // for serialize NCCL_ID
::grpc::Slice slices(e.size()); ::grpc::Slice slices(e.size());
......
...@@ -52,17 +52,17 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -52,17 +52,17 @@ class GenNCCLIdOp : public framework::OperatorBase {
private: private:
void GenerateAndSend(framework::Scope* scope, void GenerateAndSend(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
auto var = scope->FindVar("NCCLID"); auto var = scope->FindVar(NCCL_ID_VARNAME);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
auto id = var->GetMutable<ncclUniqueId>(); auto id = var->GetMutable<ncclUniqueId>();
platform::dynload::ncclGetUniqueId(id); PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient client; detail::RPCClient client;
for (auto& ep : endpoint_list) { for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep; VLOG(3) << "sending nccl id to " << ep;
client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID"); client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
} }
client.Wait(); client.Wait();
VLOG(3) << "sending completed..."; VLOG(3) << "sending completed...";
...@@ -71,6 +71,9 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -71,6 +71,9 @@ class GenNCCLIdOp : public framework::OperatorBase {
void GetIdByServer(framework::Scope* scope, void GetIdByServer(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
rpc_service_ = new detail::AsyncGRPCServer(endpoint, true); rpc_service_ = new detail::AsyncGRPCServer(endpoint, true);
framework::ProgramDesc empty_program; framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace()); framework::Executor executor(dev_ctx.GetPlace());
......
...@@ -39,7 +39,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service; ...@@ -39,7 +39,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service;
void StartServer() { void StartServer() {
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
scope.Var("NCCLID"); scope.Var(NCCL_ID_VARNAME);
p::DeviceContextPool& pool = p::DeviceContextPool::Instance(); p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace()); auto& dev_ctx = *pool.Get(p::CPUPlace());
...@@ -71,7 +71,7 @@ TEST(SendNcclId, Normal) { ...@@ -71,7 +71,7 @@ TEST(SendNcclId, Normal) {
p::DeviceContextPool& pool = p::DeviceContextPool::Instance(); p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace()); auto& dev_ctx = *pool.Get(p::CPUPlace());
auto var = scope.Var("NCCLID"); auto var = scope.Var(NCCL_ID_VARNAME);
// var->SetType(f::proto::VarType_Type_RAW); // var->SetType(f::proto::VarType_Type_RAW);
auto id = var->GetMutable<ncclUniqueId>(); auto id = var->GetMutable<ncclUniqueId>();
p::dynload::ncclGetUniqueId(id); p::dynload::ncclGetUniqueId(id);
...@@ -80,7 +80,7 @@ TEST(SendNcclId, Normal) { ...@@ -80,7 +80,7 @@ TEST(SendNcclId, Normal) {
std::string ep = string::Sprintf("127.0.0.1:%d", port); std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient client; detail::RPCClient client;
client.AsyncSendVariable(ep, dev_ctx, scope, "NCCLID"); client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client.Wait(); client.Wait();
server_thread.join(); server_thread.join();
auto* ptr = rpc_service.release(); auto* ptr = rpc_service.release();
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#define NCCL_ID_VARNAME "NCCLID"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -76,7 +78,7 @@ struct NCCLContextMap { ...@@ -76,7 +78,7 @@ struct NCCLContextMap {
explicit NCCLContextMap(const std::vector<platform::Place> &places, explicit NCCLContextMap(const std::vector<platform::Place> &places,
ncclUniqueId *nccl_id = nullptr, ncclUniqueId *nccl_id = nullptr,
size_t node_count = 0, size_t trainer_id = 0) { size_t num_trainers = 0, size_t trainer_id = 0) {
PADDLE_ENFORCE(!places.empty()); PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size()); order_.reserve(places.size());
for (auto &p : places) { for (auto &p : places) {
...@@ -94,16 +96,14 @@ struct NCCLContextMap { ...@@ -94,16 +96,14 @@ struct NCCLContextMap {
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]); std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
// if pass nccl_id here, can assume we are doing multi node training // if pass nccl_id here, can assume we are doing multi node training
if (nccl_id == nullptr) { if (nccl_id == nullptr) {
{ std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex()); PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( comms.get(), static_cast<int>(order_.size()), order_.data()));
comms.get(), static_cast<int>(order_.size()), order_.data()));
}
} else { } else {
PADDLE_ENFORCE_GT(node_count, 0); PADDLE_ENFORCE_GT(num_trainers, 0);
// TODO(wuyi): need to ensure each node have same number of GPUs // TODO(wuyi): need to ensure each node have same number of GPUs
{ {
int nranks = node_count * order_.size(); int nranks = num_trainers * order_.size();
NCCLGroupGuard gurad; NCCLGroupGuard gurad;
for (auto &gpu_id : order_) { for (auto &gpu_id : order_) {
int rank = trainer_id * order_.size() + gpu_id; int rank = trainer_id * order_.size() + gpu_id;
......
...@@ -31,7 +31,7 @@ class ParallelExecutor(object): ...@@ -31,7 +31,7 @@ class ParallelExecutor(object):
allow_op_delay=False, allow_op_delay=False,
share_vars_from=None, share_vars_from=None,
use_default_grad_scale=True, use_default_grad_scale=True,
num_nodes=0, num_trainers=0,
trainer_id=0): trainer_id=0):
""" """
ParallelExecutor can run program in parallel. ParallelExecutor can run program in parallel.
...@@ -53,10 +53,10 @@ class ParallelExecutor(object): ...@@ -53,10 +53,10 @@ class ParallelExecutor(object):
gradients of each device and scaled gradients would be gradients of each device and scaled gradients would be
aggregated. Otherwise, a customized scale value should be fed aggregated. Otherwise, a customized scale value should be fed
to the network. to the network.
num_nodes(int, default 0): If greater than 0, NCCL will be num_trainers(int, default 0): If greater than 0, NCCL will be
initialized with multpile rank of nodes, each node should have initialized with multpile rank of nodes, each node should have
same number of GPUs. Distributed training will be enabled then. same number of GPUs. Distributed training will be enabled then.
trainer_id(int, default 0): Must use together with num_nodes. trainer_id(int, default 0): Must use together with num_trainers.
trainer_id is the "rank" of current node starts from 0. trainer_id is the "rank" of current node starts from 0.
Returns: Returns:
...@@ -137,7 +137,7 @@ class ParallelExecutor(object): ...@@ -137,7 +137,7 @@ class ParallelExecutor(object):
local_scopes, local_scopes,
allow_op_delay, allow_op_delay,
use_default_grad_scale, use_default_grad_scale,
num_nodes, num_trainers,
trainer_id) trainer_id)
self.scope = scope self.scope = scope
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册