提交 7b0c0273 编写于 作者: T typhoonzero

update by comments

上级 928418a9
......@@ -42,7 +42,7 @@ class ParallelExecutor {
const std::vector<Scope*>& local_scopes,
bool allow_op_delay, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards,
size_t num_trainers = 0, size_t trainer_id = 0);
size_t num_trainers = 1, size_t trainer_id = 0);
~ParallelExecutor();
......
......@@ -75,29 +75,29 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
rpc_service_ = new detail::AsyncGRPCServer(endpoint, true);
detail::AsyncGRPCServer rpc_service(endpoint, true);
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_service_->SetScope(scope);
rpc_service_->SetDevCtx(&dev_ctx);
rpc_service_->SetProgram(&empty_program);
rpc_service_->SetExecutor(&executor);
rpc_service.SetScope(scope);
rpc_service.SetDevCtx(&dev_ctx);
rpc_service.SetProgram(&empty_program);
rpc_service.SetExecutor(&executor);
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_));
rpc_service_->SetCond(0);
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service));
rpc_service.SetCond(0);
VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service_->Get();
auto recv = rpc_service.Get();
VLOG(3) << "got nccl id and stop server...";
rpc_service_->ShutDown();
rpc_service.ShutDown();
VLOG(3) << "rpc server stopped";
// TODO(wuyi): reinit nccl communicators
server_thread.join();
delete rpc_service_;
}
protected:
mutable detail::AsyncGRPCServer* rpc_service_ = nullptr;
// protected:
// mutable detail::AsyncGRPCServer* rpc_service_ = nullptr;
};
class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -78,7 +78,7 @@ struct NCCLContextMap {
explicit NCCLContextMap(const std::vector<platform::Place> &places,
ncclUniqueId *nccl_id = nullptr,
size_t num_trainers = 0, size_t trainer_id = 0) {
size_t num_trainers = 1, size_t trainer_id = 0) {
PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size());
for (auto &p : places) {
......@@ -100,7 +100,7 @@ struct NCCLContextMap {
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
comms.get(), static_cast<int>(order_.size()), order_.data()));
} else {
PADDLE_ENFORCE_GT(num_trainers, 0);
PADDLE_ENFORCE_GT(num_trainers, 1);
// TODO(wuyi): need to ensure each node have same number of GPUs
{
int nranks = num_trainers * order_.size();
......
......@@ -32,7 +32,7 @@ class ParallelExecutor(object):
share_vars_from=None,
use_default_grad_scale=True,
balance_parameter_opt_between_cards=False,
num_trainers=0,
num_trainers=1,
trainer_id=0):
"""
ParallelExecutor can run program in parallel.
......@@ -57,7 +57,7 @@ class ParallelExecutor(object):
balance_parameter_opt_between_cards(bool, default True): Whether
updating different gradients on different cards. Currently, it
is not recommended.
num_trainers(int, default 0): If greater than 0, NCCL will be
num_trainers(int, default 1): If greater than 1, NCCL will be
initialized with multpile rank of nodes, each node should have
same number of GPUs. Distributed training will be enabled then.
trainer_id(int, default 0): Must use together with num_trainers.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册