提交 5f4d9130 编写于 作者: T typhoonzero

merge codes

上级 ae19d2ea
......@@ -162,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
}
// This URL explains why shutdown is complicate:
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
void AsyncGRPCServer::ShutDown() {
server_->Shutdown();
ShutdownQueue();
......@@ -188,6 +187,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(4) << "create Requestget status:" << get->Status();
}
// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
......@@ -202,7 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
PADDLE_ENFORCE(tag);
if (cq_name == "cq_get") WaitCond(2);
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (cq_name == "cq_get") WaitCond(1);
if (cq_name == "cq_send") WaitCond(0);
RequestBase* base = (RequestBase*)tag;
......
......@@ -42,10 +42,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void RunSyncUpdate();
// functions to sync server barrier status.
void WaitStart();
void WaitDone();
void Start();
void Done();
void WaitCond(int cond);
void SetCond(int cond);
void WaitClientGet(int count);
void SetScope(framework::Scope *scope) { scope_ = scope; }
......
......@@ -105,15 +105,14 @@ class RecvOp : public framework::OperatorBase {
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);
// rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
int64_t barrier_size = param_count * fan_in;
while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(kCondStart);
VLOG(3) << "================ start get from service ===========";
for (size_t i = 0; i < param_count * fan_in; ++i) {
rpc_service_->SetCond(0);
for (size_t i = 0; i < barrier_size; ++i) {
const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
......@@ -130,8 +129,6 @@ class RecvOp : public framework::OperatorBase {
}
VLOG(3) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
// Assume grad_var_name must appear in global scope.
std::string grad_var_name_trainer;
if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
......@@ -145,16 +142,14 @@ class RecvOp : public framework::OperatorBase {
if (exit_flag) {
break;
}
// rpc_service_->Reset();
try {
executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
VLOG(3) << "================ run sub program end ===========";
rpc_service_->SetCond(kCondDone);
rpc_service_->WaitClientGet(param_count * fan_in);
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(barrier_size);
grads_counter_.clear();
} // while(true)
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册