提交 d9320dcd 编写于 作者: T typhoonzero

complete code

上级 7237323c
...@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay, Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
bool use_default_grad_scale) bool use_default_grad_scale, size_t num_trainers, size_t trainer_id)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
...@@ -80,7 +80,13 @@ ParallelExecutor::ParallelExecutor( ...@@ -80,7 +80,13 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_)); auto *nccl_id_var = scope->FindVar("NCCLID");
ncclUniqueId *nccl_id = nullptr;
if (nccl_id_var != nullptr) {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
}
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id, num_trainers, trainer_id));
#endif #endif
if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 && if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 &&
local_scopes.empty()) { // Is CUDA local_scopes.empty()) { // Is CUDA
......
...@@ -40,7 +40,8 @@ class ParallelExecutor { ...@@ -40,7 +40,8 @@ class ParallelExecutor {
const ProgramDesc& main_program, const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope, const std::string& loss_var_name, Scope* scope,
const std::vector<Scope*>& local_scopes, const std::vector<Scope*>& local_scopes,
bool allow_op_delay, bool use_default_grad_scale); bool allow_op_delay, bool use_default_grad_scale,
size_t num_trainers = 0, size_t trainer_id = 0);
~ParallelExecutor(); ~ParallelExecutor();
......
...@@ -32,6 +32,7 @@ service SendRecvService { ...@@ -32,6 +32,7 @@ service SendRecvService {
enum VarType { enum VarType {
LOD_TENSOR = 0; LOD_TENSOR = 0;
SELECTED_ROWS = 1; SELECTED_ROWS = 1;
NCCL_ID = 2;
} }
// NOTICE(gongwb):don't modify this proto if you are not // NOTICE(gongwb):don't modify this proto if you are not
......
...@@ -43,13 +43,16 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -43,13 +43,16 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void* buf = buffer.get(); void* buf = buffer.get();
void* payload = nullptr; void* payload = nullptr;
size_t payload_size; size_t payload_size = 0;
ProtoEncodeHelper e(static_cast<char*>(buf), 1024); ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteString(VarMsg::kVarnameFieldNumber, name); e.WriteString(VarMsg::kVarnameFieldNumber, name);
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 0); e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 1); e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
} else if (var->IsType<ncclUniqueId>()) {
// NOTE: sendrecv only support RAW type for NCCL_ID
e.WriteUint64(VarMsg::kTypeFieldNumber, 2);
} }
if (!out_name.empty()) { if (!out_name.empty()) {
...@@ -139,11 +142,27 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -139,11 +142,27 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
payload_size = tensor->numel() * framework::SizeOfType(tensor->type()); payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break; } break;
case framework::proto::VarType_Type_RAW: {
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES);
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>();
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES));
} break;
default: default:
PADDLE_THROW("Serialize does not support type: %s", PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name()); typeid(var->Type()).name());
break; break;
} }
if (framework::ToVarType(var->Type()) == framework::proto::VarType_Type_RAW) {
// for serialize NCCL_ID
::grpc::Slice slices(e.size());
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
::grpc::ByteBuffer tmp(&slices, 1);
msg->Swap(&tmp);
return;
}
// steal reference of tensor data // steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer int num_slices = 2; // only SelectedRows have rows buffer
......
...@@ -367,9 +367,18 @@ int VariableResponse::Parse(Source* source) { ...@@ -367,9 +367,18 @@ int VariableResponse::Parse(Source* source) {
} }
case sendrecv::VariableMessage::kSerializedFieldNumber: { case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) && meta_.type() == sendrecv::LOD_TENSOR ||
meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "", meta_.varname() != "",
"meta info should be got first!"); "meta info should be got first!");
if (meta_.type() == sendrecv::NCCL_ID) {
auto* var = scope_->FindVar(meta_.varname());
if (var != nullptr) {
ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
memcpy(id->internal, meta_.serialized().c_str(),
meta_.serialized().size());
}
}
int length = 0; int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED || if (wt != WIRETYPE_LENGTH_DELIMITED ||
......
...@@ -54,7 +54,7 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -54,7 +54,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
auto var = scope->FindVar("NCCLID"); auto var = scope->FindVar("NCCLID");
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
auto id = var->GetMutable<ncclUniqueId>(); auto id = var->GetMutable<ncclUniqueId>();
ncclGetUniqueId(id); platform::dynload::ncclGetUniqueId(id);
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
...@@ -120,4 +120,4 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser ...@@ -120,4 +120,4 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(gen_nccl_id_op, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker); REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker);
...@@ -73,7 +73,9 @@ struct NCCLContextMap { ...@@ -73,7 +73,9 @@ struct NCCLContextMap {
std::unordered_map<int, NCCLContext> contexts_; std::unordered_map<int, NCCLContext> contexts_;
std::vector<int> order_; std::vector<int> order_;
explicit NCCLContextMap(const std::vector<platform::Place> &places) { explicit NCCLContextMap(const std::vector<platform::Place> &places,
ncclUniqueId *nccl_id = nullptr,
size_t node_count = 0, size_t trainer_id = 0) {
PADDLE_ENFORCE(!places.empty()); PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size()); order_.reserve(places.size());
for (auto &p : places) { for (auto &p : places) {
...@@ -85,18 +87,36 @@ struct NCCLContextMap { ...@@ -85,18 +87,36 @@ struct NCCLContextMap {
order_.size(), contexts_.size(), order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device"); "NCCL Context Map does not support contain two or more same device");
if (places.size() > 1) { if (places.size() <= 1) {
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]); return;
}
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
// if pass nccl_id here, can assume we are doing multi node training
if (nccl_id == nullptr) {
{ {
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex()); std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
comms.get(), static_cast<int>(order_.size()), order_.data())); comms.get(), static_cast<int>(order_.size()), order_.data()));
} }
int i = 0; } else {
for (auto &dev_id : order_) { PADDLE_ENFORCE_GT(node_count, 0);
contexts_.at(dev_id).comm_ = comms[i++]; PADDLE_ENFORCE_EQ(node_count % places.size(), 0,
"must have same number of GPUs on each node");
{
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
int nranks = node_count * order_.size();
for (auto &gpu_id : order_) {
int rank = trainer_id * order_.size() + gpu_id;
PADDLE_ENFORCE(cudaSetDevice(gpu_id));
PADDLE_ENFORCE(
ncclCommInitRank(comms.get() + gpu_id, nranks, *nccl_id, rank));
}
} }
} }
int i = 0;
for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++];
}
} }
NCCLContextMap(const NCCLContextMap &other) = delete; NCCLContextMap(const NCCLContextMap &other) = delete;
......
...@@ -502,11 +502,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -502,11 +502,13 @@ All parameter, weight, gradient are variables in Paddle.
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, std::vector<Scope *> &local_scopes, Scope *scope, std::vector<Scope *> &local_scopes,
bool allow_op_delay, bool use_default_grad_scale) { bool allow_op_delay, bool use_default_grad_scale,
size_t num_trainers, size_t trainer_id) {
new (&self) ParallelExecutor( new (&self) ParallelExecutor(
num_threads, use_event, places, params, bcast_vars, num_threads, use_event, places, params, bcast_vars,
main_program, loss_var_name, scope, local_scopes, main_program, loss_var_name, scope, local_scopes,
allow_op_delay, use_default_grad_scale); allow_op_delay, use_default_grad_scale, num_trainers,
trainer_id);
}) })
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册