未验证 提交 faebadd9 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #10228 from jacquesqiao/use-multi-thread-todo-update

Use multi thread to do update
......@@ -82,7 +82,9 @@ class RequestSend final : public RequestBase {
virtual std::string GetReqName() { return request_->Varname(); }
virtual void Process() {
queue_->Push(std::make_pair(request_->Varname(), request_));
std::string var_name = GetReqName();
VLOG(3) << "RequestSend " << var_name;
queue_->Push(std::make_pair(var_name, request_));
sendrecv::VoidMessage reply;
responder_.Finish(reply, ::grpc::Status::OK, this);
......@@ -106,7 +108,7 @@ class RequestGet final : public RequestBase {
responder_(&ctx_),
scope_(scope),
queue_(queue) {
int method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this);
}
......@@ -118,6 +120,7 @@ class RequestGet final : public RequestBase {
virtual void Process() {
// proc request.
std::string var_name = request_.varname();
VLOG(3) << "RequestGet " << var_name;
auto* var = scope_->FindVar(var_name);
::grpc::ByteBuffer reply;
......@@ -176,7 +179,7 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer reply;
std::string var_name = request_->OutVarname();
VLOG(3) << "prefetch var " << var_name;
VLOG(3) << "RequestPrefetch " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
......@@ -307,18 +310,20 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
bool ok = false;
while (true) {
VLOG(3) << "HandleRequest for " << cq_name << " while in";
VLOG(3) << "HandleRequest for " << cq_name << " wait Next";
if (!cq->Next(&tag, &ok)) {
LOG(INFO) << cq_name << " CompletionQueue shutdown!";
break;
}
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
VLOG(3) << "HandleRequest for " << cq_name << " get Next";
PADDLE_ENFORCE(tag);
if (sync_mode_) {
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
}
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
......@@ -336,9 +341,9 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
switch (base->Status()) {
case PROCESS: {
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
TryToRegisterNewOne();
base->Process();
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
break;
}
case FINISH: {
......
......@@ -45,20 +45,6 @@ static void split(const std::string &str, char sep,
}
}
static void AsyncExecuteBlock(framework::Executor *executor,
framework::ExecutorPrepareContext *prepared,
framework::Scope *scope) {
std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
try {
executor->RunPreparedContext(prepared, scope, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
// TODO(qiao) maybe we can remove this
future.wait();
}
static void ParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
......@@ -201,14 +187,40 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
} // while(true)
}
static void AsyncUpdateThread(
const std::string &var_name, const bool &exit_flag,
const std::shared_ptr<detail::ReceivedQueue> &queue,
framework::Executor *executor,
framework::ExecutorPrepareContext *prepared) {
VLOG(3) << "update thread for " << var_name << " started";
while (!exit_flag) {
const detail::ReceivedMessage v = queue->Pop();
auto recv_var_name = v.first;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
auto fs = framework::Async([var_name, &executor, &v, prepared] {
try {
executor->RunPreparedContext(prepared, v.second->GetMutableLocalScope(),
false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
fs.wait();
}
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const {
framework::ProgramDesc *program) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
std::unordered_map<std::string, std::shared_ptr<detail::ReceivedQueue>>
grad_to_queue;
auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
......@@ -220,6 +232,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id;
grad_to_queue[pieces[0]] = std::make_shared<detail::ReceivedQueue>();
id_to_grad[block_id] = pieces[0];
}
size_t num_blocks = program->Size();
......@@ -238,8 +251,21 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
}
VLOG(3) << "RunAsyncLoop into while";
bool exit_flag = false;
VLOG(3) << "start async optimize threads";
std::vector<std::future<void>> fs;
for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) {
std::string grad_name = iter->first;
VLOG(3) << "create async update thread for " << grad_name;
fs.push_back(framework::AsyncIO([grad_name, &exit_flag, &executor,
&grad_to_queue, &grad_to_prepared_ctx]() {
AsyncUpdateThread(grad_name, exit_flag, grad_to_queue[grad_name],
executor, grad_to_prepared_ctx[grad_name].get());
}));
}
VLOG(3) << "RunAsyncLoop into while";
while (!exit_flag) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
......@@ -249,13 +275,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
break;
} else {
VLOG(3) << "received grad: " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
AsyncExecuteBlock(executor, grad_to_prepared_ctx[recv_var_name].get(),
v.second->GetMutableLocalScope());
grad_to_queue[recv_var_name]->Push(v);
}
if (exit_flag) {
......@@ -304,7 +324,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
} else {
RunAsyncLoop(&executor, program, &recv_scope, prefetch_block);
RunAsyncLoop(&executor, program);
}
}
......
......@@ -47,9 +47,7 @@ class ListenAndServOp : public framework::OperatorBase {
framework::BlockDesc* prefetch_block) const;
void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;
framework::ProgramDesc* program) const;
void Stop() override;
......
......@@ -168,7 +168,9 @@ class ListenAndServ(object):
'endpoint': self.endpoint,
'Fanin': self.fan_in,
'OptimizeBlock': current_block,
'PrefetchBlock': empty_block
'PrefetchBlock': empty_block,
'sync_mode': True, # did not support async now in layers
'grad_to_block_id': [""]
})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册