提交 b8484161 编写于 作者: T typhoonzero

follow comments

上级 175a4f52
...@@ -90,7 +90,7 @@ OpDesc *BlockDesc::PrependOp() { ...@@ -90,7 +90,7 @@ OpDesc *BlockDesc::PrependOp() {
return ops_.front().get(); return ops_.front().get();
} }
void BlockDescBind::RemoveOp(size_t s, size_t e) { void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return; return;
} }
......
...@@ -58,7 +58,7 @@ Status SendRecvServerImpl::Wait(ServerContext *context, ...@@ -58,7 +58,7 @@ Status SendRecvServerImpl::Wait(ServerContext *context,
return Status::OK; return Status::OK;
} }
void SendRecvServerImpl::Start() { void SendRecvServerImpl::Reset() {
std::lock_guard<std::mutex> lock(this->mutex_); std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false; done_ = false;
} }
......
...@@ -56,7 +56,7 @@ class SendRecvServerImpl final : public SendRecvService::Service { ...@@ -56,7 +56,7 @@ class SendRecvServerImpl final : public SendRecvService::Service {
VariableMessage *out_var) override; VariableMessage *out_var) override;
Status Wait(ServerContext *context, const VoidMessage *in_var, Status Wait(ServerContext *context, const VoidMessage *in_var,
VoidMessage *out_var) override; VoidMessage *out_var) override;
void Start(); void Reset();
void Done(); void Done();
void SetScope(framework::Scope *scope) { scope_ = scope; }; void SetScope(framework::Scope *scope) { scope_ = scope; };
......
...@@ -80,7 +80,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -80,7 +80,7 @@ class RecvOp : public framework::OperatorBase {
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers"); auto trainer_count = Attr<int>("Trainers");
size_t param_count = param_list.size(); size_t param_count = param_list.size();
rpc_service_->Start(); rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
while (true) { while (true) {
// Get from multiple trainers, we don't care about order in which // Get from multiple trainers, we don't care about order in which
...@@ -93,6 +93,8 @@ class RecvOp : public framework::OperatorBase { ...@@ -93,6 +93,8 @@ class RecvOp : public framework::OperatorBase {
std::string param_var_name; std::string param_var_name;
if (it != grad_list.end()) { if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()]; param_var_name = param_list[it - grad_list.begin()];
} else {
LOG(ERROR) << "grad have no paired param found!";
} }
VLOG(3) << "recved grad: " << grad_var_name VLOG(3) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name; << " updating param: " << param_var_name;
...@@ -112,7 +114,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -112,7 +114,7 @@ class RecvOp : public framework::OperatorBase {
// FIXME(typhoonzero): do not copy // FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor); framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
} }
rpc_service_->Start(); rpc_service_->Reset();
std::string program_str = Attr<std::string>("OptimizeProgram"); std::string program_str = Attr<std::string>("OptimizeProgram");
framework::ProgramDesc program_desc; framework::ProgramDesc program_desc;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册