提交 5f4d9130 编写于 作者: T typhoonzero

merge codes

上级 ae19d2ea
...@@ -162,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() { ...@@ -162,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
} }
// This URL explains why shutdown is complicate: // This URL explains why shutdown is complicate:
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
void AsyncGRPCServer::ShutDown() { void AsyncGRPCServer::ShutDown() {
server_->Shutdown(); server_->Shutdown();
ShutdownQueue(); ShutdownQueue();
...@@ -188,6 +187,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -188,6 +187,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(4) << "create Requestget status:" << get->Status(); VLOG(4) << "create Requestget status:" << get->Status();
} }
// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name, std::string cq_name,
std::function<void()> TryToRegisterNewOne) { std::function<void()> TryToRegisterNewOne) {
...@@ -202,7 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, ...@@ -202,7 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
} }
PADDLE_ENFORCE(tag); PADDLE_ENFORCE(tag);
if (cq_name == "cq_get") WaitCond(2); // FIXME(typhoonzero): de-couple the barriers with recv_op
if (cq_name == "cq_get") WaitCond(1);
if (cq_name == "cq_send") WaitCond(0); if (cq_name == "cq_send") WaitCond(0);
RequestBase* base = (RequestBase*)tag; RequestBase* base = (RequestBase*)tag;
......
...@@ -42,10 +42,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -42,10 +42,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void RunSyncUpdate(); void RunSyncUpdate();
// functions to sync server barrier status. // functions to sync server barrier status.
void WaitStart(); void WaitCond(int cond);
void WaitDone(); void SetCond(int cond);
void Start();
void Done();
void WaitClientGet(int count); void WaitClientGet(int count);
void SetScope(framework::Scope *scope) { scope_ = scope; } void SetScope(framework::Scope *scope) { scope_ = scope; }
......
...@@ -105,15 +105,14 @@ class RecvOp : public framework::OperatorBase { ...@@ -105,15 +105,14 @@ class RecvOp : public framework::OperatorBase {
framework::ProgramDesc program(program_desc); framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false; bool exit_flag = false;
int64_t barrier_size = param_count * fan_in;
while (!exit_flag) { while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(kCondStart); rpc_service_->SetCond(0);
VLOG(3) << "================ start get from service ==========="; for (size_t i = 0; i < barrier_size; ++i) {
for (size_t i = 0; i < param_count * fan_in; ++i) {
const detail::MessageWithName &v = rpc_service_->Get(); const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first; auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
...@@ -130,8 +129,6 @@ class RecvOp : public framework::OperatorBase { ...@@ -130,8 +129,6 @@ class RecvOp : public framework::OperatorBase {
} }
VLOG(3) << "recved grad: " << grad_var_name VLOG(3) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name; << " updating param: " << param_var_name;
// Assume grad_var_name must appear in global scope.
std::string grad_var_name_trainer;
if (fan_in > 1) { if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
} }
...@@ -145,16 +142,14 @@ class RecvOp : public framework::OperatorBase { ...@@ -145,16 +142,14 @@ class RecvOp : public framework::OperatorBase {
if (exit_flag) { if (exit_flag) {
break; break;
} }
// rpc_service_->Reset();
try { try {
executor.Run(program, &recv_scope, 0, /*global_block*/ executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/); false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
VLOG(3) << "================ run sub program end ==========="; rpc_service_->SetCond(1);
rpc_service_->SetCond(kCondDone); rpc_service_->WaitClientGet(barrier_size);
rpc_service_->WaitClientGet(param_count * fan_in);
grads_counter_.clear(); grads_counter_.clear();
} // while(true) } // while(true)
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册