未验证 提交 6ab935f8 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #10349 from typhoonzero/gen_nccl_id_op

[Feature] NCCL2 distributed training
...@@ -58,7 +58,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -58,7 +58,8 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay, Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
bool use_default_grad_scale, bool balance_parameter_opt_between_cards) bool use_default_grad_scale, bool balance_parameter_opt_between_cards,
size_t num_trainers, size_t trainer_id)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
...@@ -80,7 +81,13 @@ ParallelExecutor::ParallelExecutor( ...@@ -80,7 +81,13 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_)); auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
ncclUniqueId *nccl_id = nullptr;
if (nccl_id_var != nullptr) {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
}
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id, num_trainers, trainer_id));
#endif #endif
if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 && if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 &&
local_scopes.empty()) { // Is CUDA local_scopes.empty()) { // Is CUDA
......
...@@ -41,7 +41,8 @@ class ParallelExecutor { ...@@ -41,7 +41,8 @@ class ParallelExecutor {
const std::string& loss_var_name, Scope* scope, const std::string& loss_var_name, Scope* scope,
const std::vector<Scope*>& local_scopes, const std::vector<Scope*>& local_scopes,
bool allow_op_delay, bool use_default_grad_scale, bool allow_op_delay, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards); bool balance_parameter_opt_between_cards,
size_t num_trainers = 1, size_t trainer_id = 0);
~ParallelExecutor(); ~ParallelExecutor();
......
...@@ -186,6 +186,11 @@ endif() ...@@ -186,6 +186,11 @@ endif()
add_subdirectory(detail) add_subdirectory(detail)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
if(WITH_GPU)
op_library(gen_nccl_id_op DEPS nccl_common)
else()
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif()
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(send_op DEPS ${DISTRIBUTE_DEPS}) op_library(send_op DEPS ${DISTRIBUTE_DEPS})
...@@ -202,8 +207,9 @@ if(WITH_DISTRIBUTE) ...@@ -202,8 +207,9 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor) cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor)
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor)
else() else()
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op) set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op)
endif() endif()
op_library(cross_entropy_op DEPS cross_entropy) op_library(cross_entropy_op DEPS cross_entropy)
......
...@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
// stub context // stub context
SendProcessor* s = new SendProcessor(ch); SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out); s->Prepare(var_h, time_out);
s->response_call_back_ = NULL; s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
......
...@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); ...@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor { class BaseProcessor {
public: public:
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { context_ = NULL; } explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
context_ = nullptr;
}
virtual ~BaseProcessor() {} virtual ~BaseProcessor() {}
...@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor { ...@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
::grpc::GenericStub stub_g_; ::grpc::GenericStub stub_g_;
::grpc::ByteBuffer reply_; ::grpc::ByteBuffer reply_;
RequestSendCallBack response_call_back_ = NULL; RequestSendCallBack response_call_back_ = nullptr;
}; };
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
......
...@@ -47,6 +47,7 @@ class AsyncGRPCServer final { ...@@ -47,6 +47,7 @@ class AsyncGRPCServer final {
explicit AsyncGRPCServer(const std::string &address, bool sync_mode) explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode), ready_(0) {} : address_(address), sync_mode_(sync_mode), ready_(0) {}
~AsyncGRPCServer() {}
void WaitServerReady(); void WaitServerReady();
void RunSyncUpdate(); void RunSyncUpdate();
......
...@@ -32,6 +32,7 @@ service SendRecvService { ...@@ -32,6 +32,7 @@ service SendRecvService {
enum VarType { enum VarType {
LOD_TENSOR = 0; LOD_TENSOR = 0;
SELECTED_ROWS = 1; SELECTED_ROWS = 1;
NCCL_ID = 2;
} }
// NOTICE(gongwb):don't modify this proto if you are not // NOTICE(gongwb):don't modify this proto if you are not
......
...@@ -14,6 +14,9 @@ limitations under the License. */ ...@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h> #include <sys/time.h>
#include <thread> // NOLINT #include <thread> // NOLINT
...@@ -129,6 +132,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -129,6 +132,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
request.set_type(::sendrecv::SELECTED_ROWS); request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size); GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
request.set_type(::sendrecv::NCCL_ID);
#endif
} else { } else {
PADDLE_THROW("Serialize does not support type: %s", PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name()); typeid(var->Type()).name());
...@@ -149,6 +156,24 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -149,6 +156,24 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void* buf = buffer.get(); void* buf = buffer.get();
ProtoEncodeHelper e(static_cast<char*>(buf), 1024); ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size())); e.WriteRawBytes(std::string(header.data(), header.size()));
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
#ifdef PADDLE_WITH_CUDA
if (var->IsType<ncclUniqueId>()) {
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
// for serialize NCCL_ID
::grpc::Slice slices(e.size());
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
::grpc::ByteBuffer tmp(&slices, 1);
msg->Swap(&tmp);
return;
}
#endif
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data // steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h"
...@@ -368,7 +371,8 @@ int VariableResponse::Parse(Source* source) { ...@@ -368,7 +371,8 @@ int VariableResponse::Parse(Source* source) {
} }
case sendrecv::VariableMessage::kSerializedFieldNumber: { case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) && meta_.type() == sendrecv::LOD_TENSOR ||
meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "", meta_.varname() != "",
"meta info should be got first!"); "meta info should be got first!");
...@@ -378,6 +382,22 @@ int VariableResponse::Parse(Source* source) { ...@@ -378,6 +382,22 @@ int VariableResponse::Parse(Source* source) {
return tag; return tag;
} }
if (meta_.type() == sendrecv::NCCL_ID) {
#ifdef PADDLE_WITH_CUDA
auto* var = scope_->FindVar(meta_.varname());
if (var != nullptr) {
ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
num_bytes)) {
return tag;
}
}
break;
#else
PADDLE_THROW("Not compiled with CUDA!");
#endif
}
framework::DDim dims = GetDims(meta_.dims()); framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) { if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0, PADDLE_ENFORCE(meta_.lod_size() >= 0,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <nccl.h>
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace paddle {
namespace operators {
class GenNCCLIdOp : public framework::OperatorBase {
public:
GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
// put nccl id in CPUPlace
auto& dev_ctx = *pool.Get(platform::CPUPlace());
int trainer_id = Attr<int>("trainer_id");
framework::Scope& local_scope = scope.NewScope();
if (trainer_id == 0) {
GenerateAndSend(&local_scope, dev_ctx);
} else {
GetIdByServer(&local_scope, dev_ctx);
}
}
private:
void GenerateAndSend(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const {
auto var = scope->FindVar(NCCL_ID_VARNAME);
PADDLE_ENFORCE_NOT_NULL(var);
auto id = var->GetMutable<ncclUniqueId>();
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient client;
for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep;
client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
}
client.Wait();
VLOG(3) << "sending completed...";
}
void GetIdByServer(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const {
std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
detail::AsyncGRPCServer rpc_service(endpoint, true);
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_service.SetScope(scope);
rpc_service.SetDevCtx(&dev_ctx);
rpc_service.SetProgram(&empty_program);
rpc_service.SetExecutor(&executor);
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service));
rpc_service.SetCond(0);
VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service.Get();
VLOG(3) << "got nccl id and stop server...";
rpc_service.ShutDown();
VLOG(3) << "rpc server stopped";
server_thread.join();
}
};
class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
GenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::string>("endpoint",
"(string), e.g. 127.0.0.1:6175 "
"current listen endpoint");
AddAttr<std::vector<std::string>>(
"endpoint_list",
"['trainer1_ip:port', 'trainer2_ip:port', ...] "
"list of trainer endpoints start from trainer 1")
.SetDefault({});
AddAttr<int>("trainer_id",
"(int default 0) "
"The index of the trainer in distributed training.")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/printf.h"
USE_NO_KERNEL_OP(listen_and_serv);
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
namespace detail = paddle::operators::detail;
namespace string = paddle::string;
std::unique_ptr<detail::AsyncGRPCServer> rpc_service;
void StartServer(std::atomic<bool>* initialized) {
f::Scope scope;
p::CPUPlace place;
scope.Var(NCCL_ID_VARNAME);
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace());
rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true));
f::ProgramDesc empty_program;
f::Executor executor(dev_ctx.GetPlace());
rpc_service->SetScope(&scope);
rpc_service->SetDevCtx(&dev_ctx);
rpc_service->SetProgram(&empty_program);
rpc_service->SetExecutor(&executor);
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get()));
*initialized = true;
rpc_service->SetCond(0);
auto recv = rpc_service->Get();
LOG(INFO) << "got nccl id and stop server...";
rpc_service->ShutDown();
server_thread.join();
}
TEST(SendNcclId, Normal) {
std::atomic<bool> initialized{false};
std::thread server_thread(StartServer, &initialized);
while (!initialized) {
}
// wait server to start
// sleep(2);
rpc_service->WaitServerReady();
f::Scope scope;
p::CPUPlace place;
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace());
auto var = scope.Var(NCCL_ID_VARNAME);
// var->SetType(f::proto::VarType_Type_RAW);
auto id = var->GetMutable<ncclUniqueId>();
p::dynload::ncclGetUniqueId(id);
int port = rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient client;
client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client.Wait();
server_thread.join();
auto* ptr = rpc_service.release();
delete ptr;
}
...@@ -14,12 +14,15 @@ ...@@ -14,12 +14,15 @@
#pragma once #pragma once
#include <stdio.h>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#define NCCL_ID_VARNAME "NCCLID"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -73,7 +76,9 @@ struct NCCLContextMap { ...@@ -73,7 +76,9 @@ struct NCCLContextMap {
std::unordered_map<int, NCCLContext> contexts_; std::unordered_map<int, NCCLContext> contexts_;
std::vector<int> order_; std::vector<int> order_;
explicit NCCLContextMap(const std::vector<platform::Place> &places) { explicit NCCLContextMap(const std::vector<platform::Place> &places,
ncclUniqueId *nccl_id = nullptr,
size_t num_trainers = 1, size_t trainer_id = 0) {
PADDLE_ENFORCE(!places.empty()); PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size()); order_.reserve(places.size());
for (auto &p : places) { for (auto &p : places) {
...@@ -85,19 +90,35 @@ struct NCCLContextMap { ...@@ -85,19 +90,35 @@ struct NCCLContextMap {
order_.size(), contexts_.size(), order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device"); "NCCL Context Map does not support contain two or more same device");
if (places.size() > 1) { if (places.size() <= 1) {
return;
}
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]); std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
{ // if pass nccl_id here, can assume we are doing multi node training
if (nccl_id == nullptr) {
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex()); std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
comms.get(), static_cast<int>(order_.size()), order_.data())); comms.get(), static_cast<int>(order_.size()), order_.data()));
} else {
PADDLE_ENFORCE_GT(num_trainers, 1);
// TODO(wuyi): need to ensure each node have same number of GPUs
{
int nranks = num_trainers * order_.size();
NCCLGroupGuard gurad;
for (auto &gpu_id : order_) {
int rank = trainer_id * order_.size() + gpu_id;
VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks;
PADDLE_ENFORCE(cudaSetDevice(gpu_id));
PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
comms.get() + gpu_id, nranks, *nccl_id, rank));
}
}
} }
int i = 0; int i = 0;
for (auto &dev_id : order_) { for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++]; contexts_.at(dev_id).comm_ = comms[i++];
} }
} }
}
NCCLContextMap(const NCCLContextMap &other) = delete; NCCLContextMap(const NCCLContextMap &other) = delete;
NCCLContextMap &operator=(const NCCLContextMap &other) = delete; NCCLContextMap &operator=(const NCCLContextMap &other) = delete;
......
...@@ -503,12 +503,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -503,12 +503,13 @@ All parameter, weight, gradient are variables in Paddle.
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, std::vector<Scope *> &local_scopes, Scope *scope, std::vector<Scope *> &local_scopes,
bool allow_op_delay, bool use_default_grad_scale, bool allow_op_delay, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards) { bool balance_parameter_opt_between_cards, size_t num_trainers,
size_t trainer_id) {
new (&self) ParallelExecutor( new (&self) ParallelExecutor(
num_threads, use_event, places, params, bcast_vars, num_threads, use_event, places, params, bcast_vars,
main_program, loss_var_name, scope, local_scopes, main_program, loss_var_name, scope, local_scopes,
allow_op_delay, use_default_grad_scale, allow_op_delay, use_default_grad_scale,
balance_parameter_opt_between_cards); balance_parameter_opt_between_cards, num_trainers, trainer_id);
}) })
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
......
...@@ -489,7 +489,7 @@ class Operator(object): ...@@ -489,7 +489,7 @@ class Operator(object):
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv', 'listen_and_serv', 'parallel_do', 'save_combine', 'recv', 'listen_and_serv', 'parallel_do', 'save_combine',
'load_combine', 'ncclInit', 'channel_create', 'channel_close', 'load_combine', 'ncclInit', 'channel_create', 'channel_close',
'channel_send', 'channel_recv', 'select' 'channel_send', 'channel_recv', 'select', 'gen_nccl_id'
} }
if type not in no_kernel_op_set: if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc) self.desc.infer_var_type(self.block.desc)
......
...@@ -31,7 +31,9 @@ class ParallelExecutor(object): ...@@ -31,7 +31,9 @@ class ParallelExecutor(object):
allow_op_delay=False, allow_op_delay=False,
share_vars_from=None, share_vars_from=None,
use_default_grad_scale=True, use_default_grad_scale=True,
balance_parameter_opt_between_cards=False): balance_parameter_opt_between_cards=False,
num_trainers=1,
trainer_id=0):
""" """
ParallelExecutor can run program in parallel. ParallelExecutor can run program in parallel.
...@@ -55,6 +57,11 @@ class ParallelExecutor(object): ...@@ -55,6 +57,11 @@ class ParallelExecutor(object):
balance_parameter_opt_between_cards(bool, default True): Whether balance_parameter_opt_between_cards(bool, default True): Whether
updating different gradients on different cards. Currently, it updating different gradients on different cards. Currently, it
is not recommended. is not recommended.
num_trainers(int, default 1): If greater than 1, NCCL will be
initialized with multpile rank of nodes, each node should have
same number of GPUs. Distributed training will be enabled then.
trainer_id(int, default 0): Must use together with num_trainers.
trainer_id is the "rank" of current node starts from 0.
Returns: Returns:
A ParallelExecutor object. A ParallelExecutor object.
...@@ -134,8 +141,9 @@ class ParallelExecutor(object): ...@@ -134,8 +141,9 @@ class ParallelExecutor(object):
local_scopes, local_scopes,
allow_op_delay, allow_op_delay,
use_default_grad_scale, use_default_grad_scale,
balance_parameter_opt_between_cards) balance_parameter_opt_between_cards,
num_trainers,
trainer_id)
self.scope = scope self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None): def run(self, fetch_list, feed=None, feed_dict=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册