From b848416166a6a6d0750b2b1ac112cb5e7a0b2cfa Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 21 Dec 2017 20:44:16 +0800 Subject: [PATCH] follow comments --- paddle/framework/block_desc.cc | 2 +- paddle/operators/detail/recv_impl.cc | 2 +- paddle/operators/detail/send_recv_impl.h | 2 +- paddle/operators/recv_op.cc | 6 ++++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index bde2ba3907c..0668b08ff7a 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -90,7 +90,7 @@ OpDesc *BlockDesc::PrependOp() { return ops_.front().get(); } -void BlockDescBind::RemoveOp(size_t s, size_t e) { +void BlockDesc::RemoveOp(size_t s, size_t e) { if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { return; } diff --git a/paddle/operators/detail/recv_impl.cc b/paddle/operators/detail/recv_impl.cc index e984f423869..517a1946a0c 100644 --- a/paddle/operators/detail/recv_impl.cc +++ b/paddle/operators/detail/recv_impl.cc @@ -58,7 +58,7 @@ Status SendRecvServerImpl::Wait(ServerContext *context, return Status::OK; } -void SendRecvServerImpl::Start() { +void SendRecvServerImpl::Reset() { std::lock_guard lock(this->mutex_); done_ = false; } diff --git a/paddle/operators/detail/send_recv_impl.h b/paddle/operators/detail/send_recv_impl.h index 82ab3ab6892..eec9dd38d18 100644 --- a/paddle/operators/detail/send_recv_impl.h +++ b/paddle/operators/detail/send_recv_impl.h @@ -56,7 +56,7 @@ class SendRecvServerImpl final : public SendRecvService::Service { VariableMessage *out_var) override; Status Wait(ServerContext *context, const VoidMessage *in_var, VoidMessage *out_var) override; - void Start(); + void Reset(); void Done(); void SetScope(framework::Scope *scope) { scope_ = scope; }; diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index dfb6e785291..efc9fdc46e8 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -80,7 +80,7 @@ class RecvOp : public framework::OperatorBase { auto grad_list = Attr>("GradList"); auto trainer_count = Attr("Trainers"); size_t param_count = param_list.size(); - rpc_service_->Start(); + rpc_service_->Reset(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. while (true) { // Get from multiple trainers, we don't care about order in which @@ -93,6 +93,8 @@ class RecvOp : public framework::OperatorBase { std::string param_var_name; if (it != grad_list.end()) { param_var_name = param_list[it - grad_list.begin()]; + } else { + LOG(ERROR) << "grad have no paired param found!"; } VLOG(3) << "recved grad: " << grad_var_name << " updating param: " << param_var_name; @@ -112,7 +114,7 @@ class RecvOp : public framework::OperatorBase { // FIXME(typhoonzero): do not copy framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor); } - rpc_service_->Start(); + rpc_service_->Reset(); std::string program_str = Attr("OptimizeProgram"); framework::ProgramDesc program_desc; -- GitLab