提交 a0ced3df 编写于 作者: Q qiaolongfei

async update can run

上级 34f28185
......@@ -315,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
PADDLE_ENFORCE(tag);
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
if (sync_mode_) {
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
}
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference:
......@@ -334,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
switch (base->Status()) {
case PROCESS: {
VLOG(4) << cq_name << " status:" << base->Status();
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
TryToRegisterNewOne();
base->Process();
break;
}
case FINISH: {
VLOG(4) << cq_name << " status:" << base->Status();
VLOG(4) << cq_name << " FINISH status:" << base->Status();
delete base;
break;
}
......
......@@ -61,7 +61,7 @@ class VariableResponse {
// other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer);
const framework::Scope& GetLocalScope() const { return *local_scope_; }
framework::Scope& GetLocalScope() const { return *local_scope_; }
inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }
......
......@@ -48,13 +48,15 @@ static void split(const std::string &str, char sep,
static void AsyncExecuteBlock(framework::Executor *executor,
framework::ExecutorPrepareContext *prepared,
framework::Scope *scope) {
framework::Async([&executor, &prepared, &scope]() {
std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
try {
executor->RunPreparedContext(prepared, scope, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
// TODO(qiao) maybe we can remove this
future.wait();
}
static void ParallelExecuteBlocks(
......@@ -203,6 +205,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_id;
std::unordered_map<int32_t, std::string> id_to_grad;
......@@ -210,7 +213,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
auto grad_to_id_str = Attr<std::vector<std::string>>("grad_to_id");
for (auto &grad_and_id : grad_to_id_str) {
std::vector<std::string> pieces;
split(grad_and_id, ' ', &pieces);
split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
......@@ -223,14 +227,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
block_list.push_back(blkid);
}
block_list.push_back(blkid);
}
PADDLE_ENFORCE_EQ(grad_to_id_str.size(), block_list.size(),
"grad num should be equal to optimize block num");
auto optimize_prepared = executor->Prepare(*program, block_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared;
......@@ -238,6 +237,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i];
}
VLOG(3) << "RunAsyncLoop into while";
bool exit_flag = false;
while (!exit_flag) {
const detail::ReceivedMessage v = rpc_service_->Get();
......@@ -254,7 +254,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
PADDLE_THROW("Can not find server side var");
}
AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(),
recv_scope);
&(v.second->GetLocalScope()));
// TODO(qiao): explain why
if (var->IsType<framework::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
......
......@@ -41,6 +41,8 @@ class SendOp : public framework::OperatorBase {
std::vector<std::string> endpoints =
Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
......@@ -64,11 +66,13 @@ class SendOp : public framework::OperatorBase {
}
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
if (sync_mode) {
for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
}
PADDLE_ENFORCE(rpc_client->Wait());
if (outs.size() > 0) {
for (size_t i = 0; i < outs.size(); i++) {
......@@ -112,6 +116,7 @@ This operator will send tensor to recv_op at the parameter server.
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({});
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
}
};
......
......@@ -297,8 +297,11 @@ class DistributeTranspiler:
inputs={"X": send_inputs},
outputs={"Out": send_outputs,
"RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints,
"epmap": eplist})
attrs={
"endpoints": pserver_endpoints,
"epmap": eplist,
"sync_mode": self.sync_mode
})
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
......@@ -404,8 +407,8 @@ class DistributeTranspiler:
for op in self.optimize_ops:
if op.type == "scale":
for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or\
in_name.startswith("beta2_pow_acc"):
if in_name.startswith("beta1_pow_acc") or \
in_name.startswith("beta2_pow_acc"):
global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id):
......@@ -434,7 +437,6 @@ class DistributeTranspiler:
__append_optimize_op__(op, per_opt_block, grad_to_block_id)
# append global ops
opt_state_block = None
if global_ops:
opt_state_block = pserver_program.create_block(
pserver_program.num_blocks - 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册