提交 f5840d89 编写于 作者: T typhoonzero

follow comments

上级 04bde96e
......@@ -80,7 +80,7 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA
auto *nccl_id_var = scope->FindVar("NCCLID");
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
ncclUniqueId *nccl_id = nullptr;
if (nccl_id_var != nullptr) {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
......
......@@ -162,8 +162,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<ncclUniqueId>()) {
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES);
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>();
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES));
ncclUniqueId& uid = var->Get<ncclUniqueId>();
e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
// for serialize NCCL_ID
::grpc::Slice slices(e.size());
......
......@@ -52,17 +52,17 @@ class GenNCCLIdOp : public framework::OperatorBase {
private:
void GenerateAndSend(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const {
auto var = scope->FindVar("NCCLID");
auto var = scope->FindVar(NCCL_ID_VARNAME);
PADDLE_ENFORCE_NOT_NULL(var);
auto id = var->GetMutable<ncclUniqueId>();
platform::dynload::ncclGetUniqueId(id);
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient client;
for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep;
client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID");
client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
}
client.Wait();
VLOG(3) << "sending completed...";
......@@ -71,6 +71,9 @@ class GenNCCLIdOp : public framework::OperatorBase {
void GetIdByServer(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const {
std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
rpc_service_ = new detail::AsyncGRPCServer(endpoint, true);
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
......
......@@ -39,7 +39,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service;
void StartServer() {
f::Scope scope;
p::CPUPlace place;
scope.Var("NCCLID");
scope.Var(NCCL_ID_VARNAME);
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace());
......@@ -71,7 +71,7 @@ TEST(SendNcclId, Normal) {
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace());
auto var = scope.Var("NCCLID");
auto var = scope.Var(NCCL_ID_VARNAME);
// var->SetType(f::proto::VarType_Type_RAW);
auto id = var->GetMutable<ncclUniqueId>();
p::dynload::ncclGetUniqueId(id);
......@@ -80,7 +80,7 @@ TEST(SendNcclId, Normal) {
std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient client;
client.AsyncSendVariable(ep, dev_ctx, scope, "NCCLID");
client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client.Wait();
server_thread.join();
auto* ptr = rpc_service.release();
......
......@@ -21,6 +21,8 @@
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
#define NCCL_ID_VARNAME "NCCLID"
namespace paddle {
namespace platform {
......@@ -76,7 +78,7 @@ struct NCCLContextMap {
explicit NCCLContextMap(const std::vector<platform::Place> &places,
ncclUniqueId *nccl_id = nullptr,
size_t node_count = 0, size_t trainer_id = 0) {
size_t num_trainers = 0, size_t trainer_id = 0) {
PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size());
for (auto &p : places) {
......@@ -94,16 +96,14 @@ struct NCCLContextMap {
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
// if pass nccl_id here, can assume we are doing multi node training
if (nccl_id == nullptr) {
{
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
comms.get(), static_cast<int>(order_.size()), order_.data()));
}
} else {
PADDLE_ENFORCE_GT(node_count, 0);
PADDLE_ENFORCE_GT(num_trainers, 0);
// TODO(wuyi): need to ensure each node have same number of GPUs
{
int nranks = node_count * order_.size();
int nranks = num_trainers * order_.size();
NCCLGroupGuard gurad;
for (auto &gpu_id : order_) {
int rank = trainer_id * order_.size() + gpu_id;
......
......@@ -31,7 +31,7 @@ class ParallelExecutor(object):
allow_op_delay=False,
share_vars_from=None,
use_default_grad_scale=True,
num_nodes=0,
num_trainers=0,
trainer_id=0):
"""
ParallelExecutor can run program in parallel.
......@@ -53,10 +53,10 @@ class ParallelExecutor(object):
gradients of each device and scaled gradients would be
aggregated. Otherwise, a customized scale value should be fed
to the network.
num_nodes(int, default 0): If greater than 0, NCCL will be
num_trainers(int, default 0): If greater than 0, NCCL will be
initialized with multpile rank of nodes, each node should have
same number of GPUs. Distributed training will be enabled then.
trainer_id(int, default 0): Must use together with num_nodes.
trainer_id(int, default 0): Must use together with num_trainers.
trainer_id is the "rank" of current node starts from 0.
Returns:
......@@ -137,7 +137,7 @@ class ParallelExecutor(object):
local_scopes,
allow_op_delay,
use_default_grad_scale,
num_nodes,
num_trainers,
trainer_id)
self.scope = scope
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册