提交 a0ced3df 编写于 作者: Q qiaolongfei

async update can run

上级 34f28185
...@@ -315,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, ...@@ -315,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
VLOG(3) << "HandleRequest for " << cq_name << " while after Next"; VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
PADDLE_ENFORCE(tag); PADDLE_ENFORCE(tag);
// FIXME(typhoonzero): de-couple the barriers with recv_op if (sync_mode_) {
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); // FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
}
RequestBase* base = reinterpret_cast<RequestBase*>(tag); RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference: // reference:
...@@ -334,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, ...@@ -334,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
switch (base->Status()) { switch (base->Status()) {
case PROCESS: { case PROCESS: {
VLOG(4) << cq_name << " status:" << base->Status(); VLOG(4) << cq_name << " PROCESS status:" << base->Status();
TryToRegisterNewOne(); TryToRegisterNewOne();
base->Process(); base->Process();
break; break;
} }
case FINISH: { case FINISH: {
VLOG(4) << cq_name << " status:" << base->Status(); VLOG(4) << cq_name << " FINISH status:" << base->Status();
delete base; delete base;
break; break;
} }
......
...@@ -61,7 +61,7 @@ class VariableResponse { ...@@ -61,7 +61,7 @@ class VariableResponse {
// other: number of error field. // other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer); int Parse(const ::grpc::ByteBuffer& byte_buffer);
const framework::Scope& GetLocalScope() const { return *local_scope_; } framework::Scope& GetLocalScope() const { return *local_scope_; }
inline std::string Varname() { return meta_.varname(); } inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); } inline std::string OutVarname() { return meta_.out_varname(); }
......
...@@ -48,13 +48,15 @@ static void split(const std::string &str, char sep, ...@@ -48,13 +48,15 @@ static void split(const std::string &str, char sep,
static void AsyncExecuteBlock(framework::Executor *executor, static void AsyncExecuteBlock(framework::Executor *executor,
framework::ExecutorPrepareContext *prepared, framework::ExecutorPrepareContext *prepared,
framework::Scope *scope) { framework::Scope *scope) {
framework::Async([&executor, &prepared, &scope]() { std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
try { try {
executor->RunPreparedContext(prepared, scope, false, false); executor->RunPreparedContext(prepared, scope, false, false);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
}); });
// TODO(qiao) maybe we can remove this
future.wait();
} }
static void ParallelExecuteBlocks( static void ParallelExecuteBlocks(
...@@ -203,6 +205,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -203,6 +205,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
framework::Scope *recv_scope, framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const { framework::BlockDesc *prefetch_block) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id // grad name to block id
std::unordered_map<std::string, int32_t> grad_to_id; std::unordered_map<std::string, int32_t> grad_to_id;
std::unordered_map<int32_t, std::string> id_to_grad; std::unordered_map<int32_t, std::string> id_to_grad;
...@@ -210,7 +213,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -210,7 +213,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
auto grad_to_id_str = Attr<std::vector<std::string>>("grad_to_id"); auto grad_to_id_str = Attr<std::vector<std::string>>("grad_to_id");
for (auto &grad_and_id : grad_to_id_str) { for (auto &grad_and_id : grad_to_id_str) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(grad_and_id, ' ', &pieces); split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0); PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]); int block_id = std::stoi(pieces[1]);
...@@ -223,14 +227,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -223,14 +227,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
std::vector<int> block_list; std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) { for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) { block_list.push_back(blkid);
block_list.push_back(blkid);
}
} }
PADDLE_ENFORCE_EQ(grad_to_id_str.size(), block_list.size(),
"grad num should be equal to optimize block num");
auto optimize_prepared = executor->Prepare(*program, block_list); auto optimize_prepared = executor->Prepare(*program, block_list);
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared; grad_to_prepared;
...@@ -238,6 +237,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -238,6 +237,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i]; grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i];
} }
VLOG(3) << "RunAsyncLoop into while";
bool exit_flag = false; bool exit_flag = false;
while (!exit_flag) { while (!exit_flag) {
const detail::ReceivedMessage v = rpc_service_->Get(); const detail::ReceivedMessage v = rpc_service_->Get();
...@@ -254,7 +254,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -254,7 +254,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
PADDLE_THROW("Can not find server side var"); PADDLE_THROW("Can not find server side var");
} }
AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(), AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(),
recv_scope); &(v.second->GetLocalScope()));
// TODO(qiao): explain why // TODO(qiao): explain why
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
......
...@@ -41,6 +41,8 @@ class SendOp : public framework::OperatorBase { ...@@ -41,6 +41,8 @@ class SendOp : public framework::OperatorBase {
std::vector<std::string> endpoints = std::vector<std::string> endpoints =
Attr<std::vector<std::string>>("endpoints"); Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
...@@ -64,11 +66,13 @@ class SendOp : public framework::OperatorBase { ...@@ -64,11 +66,13 @@ class SendOp : public framework::OperatorBase {
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) { if (sync_mode) {
VLOG(3) << "batch barrier, ep: " << ep; for (auto& ep : endpoints) {
rpc_client->AsyncSendBatchBarrier(ep); VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
} }
PADDLE_ENFORCE(rpc_client->Wait());
if (outs.size() > 0) { if (outs.size() > 0) {
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
...@@ -112,6 +116,7 @@ This operator will send tensor to recv_op at the parameter server. ...@@ -112,6 +116,7 @@ This operator will send tensor to recv_op at the parameter server.
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping") "variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
} }
}; };
......
...@@ -297,8 +297,11 @@ class DistributeTranspiler: ...@@ -297,8 +297,11 @@ class DistributeTranspiler:
inputs={"X": send_inputs}, inputs={"X": send_inputs},
outputs={"Out": send_outputs, outputs={"Out": send_outputs,
"RPCClient": rpc_client_var}, "RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints, attrs={
"epmap": eplist}) "endpoints": pserver_endpoints,
"epmap": eplist,
"sync_mode": self.sync_mode
})
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1: if len(splited_var) <= 1:
...@@ -404,8 +407,8 @@ class DistributeTranspiler: ...@@ -404,8 +407,8 @@ class DistributeTranspiler:
for op in self.optimize_ops: for op in self.optimize_ops:
if op.type == "scale": if op.type == "scale":
for in_name in op.input_arg_names: for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or\ if in_name.startswith("beta1_pow_acc") or \
in_name.startswith("beta2_pow_acc"): in_name.startswith("beta2_pow_acc"):
global_ops.append(op) global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id): def __append_optimize_op__(op, block, grad_to_block_id):
...@@ -434,7 +437,6 @@ class DistributeTranspiler: ...@@ -434,7 +437,6 @@ class DistributeTranspiler:
__append_optimize_op__(op, per_opt_block, grad_to_block_id) __append_optimize_op__(op, per_opt_block, grad_to_block_id)
# append global ops # append global ops
opt_state_block = None
if global_ops: if global_ops:
opt_state_block = pserver_program.create_block( opt_state_block = pserver_program.create_block(
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册