From b1b7af400f5be0e7bcfde80e04a9ef8da0adc326 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 19 Dec 2017 14:04:24 +0800 Subject: [PATCH] support multi node --- paddle/operators/detail/recv_impl.cc | 14 +++++++++----- paddle/operators/recv_op.cc | 17 +++++------------ paddle/operators/send_op.cc | 3 ++- python/paddle/v2/fluid/distribute_transpiler.py | 3 +-- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/paddle/operators/detail/recv_impl.cc b/paddle/operators/detail/recv_impl.cc index 47decb6d7..e984f4238 100644 --- a/paddle/operators/detail/recv_impl.cc +++ b/paddle/operators/detail/recv_impl.cc @@ -51,19 +51,23 @@ Status SendRecvServerImpl::GetVariable(ServerContext *context, Status SendRecvServerImpl::Wait(ServerContext *context, const VoidMessage *in_var, VoidMessage *out_var) { - std::unique_lock lock(this->mutex_); - condition_.wait(lock, [=] { return this->done_ == true; }); + { + std::unique_lock lock(this->mutex_); + condition_.wait(lock, [=] { return this->done_ == true; }); + } return Status::OK; } void SendRecvServerImpl::Start() { - std::unique_lock lock(this->mutex_); + std::lock_guard lock(this->mutex_); done_ = false; } void SendRecvServerImpl::Done() { - std::unique_lock lock(this->mutex_); - done_ = true; + { + std::lock_guard lock(this->mutex_); + done_ = true; + } condition_.notify_all(); } diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 6fcb544b5..094084458 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -14,7 +14,6 @@ #include #include -#include #include #include @@ -81,9 +80,9 @@ class RecvOp : public framework::OperatorBase { auto grad_list = Attr>("GradList"); auto trainer_count = Attr("Trainers"); size_t param_count = param_list.size(); + rpc_service_->Start(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. while (true) { - rpc_service_->Start(); // Get from multiple trainers, we don't care about order in which // the gradient arrives, just add suffix 0~n then average the gradient. for (size_t i = 0; i < param_count * trainer_count; ++i) { @@ -95,8 +94,8 @@ class RecvOp : public framework::OperatorBase { if (it != grad_list.end()) { param_var_name = param_list[it - grad_list.begin()]; } - VLOG(10) << "recved grad: " << grad_var_name - << " updating param: " << param_var_name; + VLOG(3) << "recved grad: " << grad_var_name + << " updating param: " << param_var_name; auto *merged_grad = recv_scope.FindVar(grad_var_name); if (merged_grad == nullptr) { // create output of merged var. @@ -113,6 +112,7 @@ class RecvOp : public framework::OperatorBase { // FIXME(typhoonzero): do not copy framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor); } + rpc_service_->Start(); std::string program_str = Attr("OptimizeProgram"); framework::ProgramDesc program_desc; @@ -127,14 +127,7 @@ class RecvOp : public framework::OperatorBase { LOG(ERROR) << "run sub program error " << e.what(); } rpc_service_->Done(); - - // for (size_t i = 0; i < param_count; ++i) { - // auto *out_var = recv_scope.FindVar(param_list[i]); - // detail::TensorWithName out; - // out.first = param_list[i]; - // out.second = out_var->Get(); - // rpc_service_->Push(out); - // } + grads_counter_.clear(); } // while(true) } diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index e94209ec4..9eafa1655 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -52,7 +52,8 @@ class SendOp : public framework::OperatorBase { LOG(ERROR) << "send variable error: " << ins[i]; } } - client_map_[0]->Wait(); // TODO(typhoonzero): support async optimization + // TODO(typhoonzero): support async optimization + client_map_[epmap[0]]->Wait(); for (size_t i = 0; i < ins.size(); ++i) { bool ret = client_map_[epmap[i]]->GetVariable(scope, ins[i]); if (!ret) { diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index e40cdc92b..7dfbab467 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -149,9 +149,8 @@ class DistributeTranspiler: epmap = [] for ep, v in self.param_grad_map.iteritems(): send_op_ordered_inputs.extend(v["grads"]) - for i in v: + for i in v["grads"]: epmap.append(ep) - send_op = program.global_block().append_op( type="send", inputs={"X": send_op_ordered_inputs -- GitLab