提交 7b0c0273 编写于 作者: T typhoonzero

update by comments

上级 928418a9
...@@ -42,7 +42,7 @@ class ParallelExecutor { ...@@ -42,7 +42,7 @@ class ParallelExecutor {
const std::vector<Scope*>& local_scopes, const std::vector<Scope*>& local_scopes,
bool allow_op_delay, bool use_default_grad_scale, bool allow_op_delay, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards, bool balance_parameter_opt_between_cards,
size_t num_trainers = 0, size_t trainer_id = 0); size_t num_trainers = 1, size_t trainer_id = 0);
~ParallelExecutor(); ~ParallelExecutor();
......
...@@ -75,29 +75,29 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -75,29 +75,29 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default // NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and // deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash. // that will cause a wired crash.
rpc_service_ = new detail::AsyncGRPCServer(endpoint, true);
detail::AsyncGRPCServer rpc_service(endpoint, true);
framework::ProgramDesc empty_program; framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace()); framework::Executor executor(dev_ctx.GetPlace());
rpc_service_->SetScope(scope); rpc_service.SetScope(scope);
rpc_service_->SetDevCtx(&dev_ctx); rpc_service.SetDevCtx(&dev_ctx);
rpc_service_->SetProgram(&empty_program); rpc_service.SetProgram(&empty_program);
rpc_service_->SetExecutor(&executor); rpc_service.SetExecutor(&executor);
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_)); std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service));
rpc_service_->SetCond(0); rpc_service.SetCond(0);
VLOG(3) << "start getting nccl id from trainer 0..."; VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service_->Get(); auto recv = rpc_service.Get();
VLOG(3) << "got nccl id and stop server..."; VLOG(3) << "got nccl id and stop server...";
rpc_service_->ShutDown(); rpc_service.ShutDown();
VLOG(3) << "rpc server stopped"; VLOG(3) << "rpc server stopped";
// TODO(wuyi): reinit nccl communicators // TODO(wuyi): reinit nccl communicators
server_thread.join(); server_thread.join();
delete rpc_service_;
} }
protected: // protected:
mutable detail::AsyncGRPCServer* rpc_service_ = nullptr; // mutable detail::AsyncGRPCServer* rpc_service_ = nullptr;
}; };
class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -78,7 +78,7 @@ struct NCCLContextMap { ...@@ -78,7 +78,7 @@ struct NCCLContextMap {
explicit NCCLContextMap(const std::vector<platform::Place> &places, explicit NCCLContextMap(const std::vector<platform::Place> &places,
ncclUniqueId *nccl_id = nullptr, ncclUniqueId *nccl_id = nullptr,
size_t num_trainers = 0, size_t trainer_id = 0) { size_t num_trainers = 1, size_t trainer_id = 0) {
PADDLE_ENFORCE(!places.empty()); PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size()); order_.reserve(places.size());
for (auto &p : places) { for (auto &p : places) {
...@@ -100,7 +100,7 @@ struct NCCLContextMap { ...@@ -100,7 +100,7 @@ struct NCCLContextMap {
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
comms.get(), static_cast<int>(order_.size()), order_.data())); comms.get(), static_cast<int>(order_.size()), order_.data()));
} else { } else {
PADDLE_ENFORCE_GT(num_trainers, 0); PADDLE_ENFORCE_GT(num_trainers, 1);
// TODO(wuyi): need to ensure each node have same number of GPUs // TODO(wuyi): need to ensure each node have same number of GPUs
{ {
int nranks = num_trainers * order_.size(); int nranks = num_trainers * order_.size();
......
...@@ -32,7 +32,7 @@ class ParallelExecutor(object): ...@@ -32,7 +32,7 @@ class ParallelExecutor(object):
share_vars_from=None, share_vars_from=None,
use_default_grad_scale=True, use_default_grad_scale=True,
balance_parameter_opt_between_cards=False, balance_parameter_opt_between_cards=False,
num_trainers=0, num_trainers=1,
trainer_id=0): trainer_id=0):
""" """
ParallelExecutor can run program in parallel. ParallelExecutor can run program in parallel.
...@@ -57,7 +57,7 @@ class ParallelExecutor(object): ...@@ -57,7 +57,7 @@ class ParallelExecutor(object):
balance_parameter_opt_between_cards(bool, default True): Whether balance_parameter_opt_between_cards(bool, default True): Whether
updating different gradients on different cards. Currently, it updating different gradients on different cards. Currently, it
is not recommended. is not recommended.
num_trainers(int, default 0): If greater than 0, NCCL will be num_trainers(int, default 1): If greater than 1, NCCL will be
initialized with multpile rank of nodes, each node should have initialized with multpile rank of nodes, each node should have
same number of GPUs. Distributed training will be enabled then. same number of GPUs. Distributed training will be enabled then.
trainer_id(int, default 0): Must use together with num_trainers. trainer_id(int, default 0): Must use together with num_trainers.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册