diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index bde2ba3907c8a16d72151103a26828a0a8d75249..0668b08ff7ab3c8ca4f1e989fc7af45a8ec5f63c 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -90,7 +90,7 @@ OpDesc *BlockDesc::PrependOp() { return ops_.front().get(); } -void BlockDescBind::RemoveOp(size_t s, size_t e) { +void BlockDesc::RemoveOp(size_t s, size_t e) { if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { return; } diff --git a/paddle/operators/detail/recv_impl.cc b/paddle/operators/detail/recv_impl.cc index e984f4238698c2834efc6b01d75937850cf529c5..517a1946a0c25081b20c320f4104c81503a4249e 100644 --- a/paddle/operators/detail/recv_impl.cc +++ b/paddle/operators/detail/recv_impl.cc @@ -58,7 +58,7 @@ Status SendRecvServerImpl::Wait(ServerContext *context, return Status::OK; } -void SendRecvServerImpl::Start() { +void SendRecvServerImpl::Reset() { std::lock_guard lock(this->mutex_); done_ = false; } diff --git a/paddle/operators/detail/send_recv_impl.h b/paddle/operators/detail/send_recv_impl.h index 82ab3ab689260dc1ce285b5def2e690d4c8b1144..eec9dd38d188247cba4da2a377038a28c847e40e 100644 --- a/paddle/operators/detail/send_recv_impl.h +++ b/paddle/operators/detail/send_recv_impl.h @@ -56,7 +56,7 @@ class SendRecvServerImpl final : public SendRecvService::Service { VariableMessage *out_var) override; Status Wait(ServerContext *context, const VoidMessage *in_var, VoidMessage *out_var) override; - void Start(); + void Reset(); void Done(); void SetScope(framework::Scope *scope) { scope_ = scope; }; diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index dfb6e7852914d0f8d404fc1909ac89ab5aba107e..efc9fdc46e8b38fe33e46e8a2035fcb861294e83 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -80,7 +80,7 @@ class RecvOp : public framework::OperatorBase { auto grad_list = Attr>("GradList"); auto trainer_count = Attr("Trainers"); size_t param_count = param_list.size(); - rpc_service_->Start(); + rpc_service_->Reset(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. while (true) { // Get from multiple trainers, we don't care about order in which @@ -93,6 +93,8 @@ class RecvOp : public framework::OperatorBase { std::string param_var_name; if (it != grad_list.end()) { param_var_name = param_list[it - grad_list.begin()]; + } else { + LOG(ERROR) << "grad have no paired param found!"; } VLOG(3) << "recved grad: " << grad_var_name << " updating param: " << param_var_name; @@ -112,7 +114,7 @@ class RecvOp : public framework::OperatorBase { // FIXME(typhoonzero): do not copy framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor); } - rpc_service_->Start(); + rpc_service_->Reset(); std::string program_str = Attr("OptimizeProgram"); framework::ProgramDesc program_desc;