提交 0a881a1e 编写于 作者: Q qiaolongfei

init RunAsyncUpdate

上级 36083018
...@@ -27,6 +27,36 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) { ...@@ -27,6 +27,36 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
} }
static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
if (str.empty()) {
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
static void AsyncExecuteBlock(framework::Executor *executor,
framework::ExecutorPrepareContext *prepared,
framework::Scope *scope) {
framework::Async([&executor, &prepared, &scope]() {
try {
executor->RunPreparedContext(prepared, scope, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
}
static void ParallelExecuteBlocks( static void ParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor, const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>> const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
...@@ -168,12 +198,82 @@ void ListenAndServOp::RunSyncUpdate( ...@@ -168,12 +198,82 @@ void ListenAndServOp::RunSyncUpdate(
} // while(true) } // while(true)
} }
void ListenAndServOp::RunAsyncUpdate(
framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope, framework::BlockDesc *prefetch_block) const {
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_id;
std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_map_str = Attr<std::vector<std::string>>("grad_map");
for (auto &grad_and_id : grad_map_str) {
std::vector<std::string> pieces;
split(grad_and_id, ' ', &pieces);
PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_id[pieces[0]] = block_id;
id_to_grad[block_id] = pieces[0];
}
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
block_list.push_back(blkid);
}
}
PADDLE_ENFORCE_EQ(grad_map_str.size(), block_list.size(),
"grad num should be equal to optimize block num");
auto optimize_prepared = executor->Prepare(*program, block_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared;
for (size_t i = 0; i < block_list.size(); ++i) {
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i];
}
bool exit_flag = false;
while (!exit_flag) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else {
VLOG(3) << "received grad: " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(),
recv_scope);
// TODO(qiao): explain why
if (var->IsType<framework::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
}
if (exit_flag) {
rpc_service_->ShutDown();
break;
}
} // while(true)
}
void ListenAndServOp::RunImpl(const framework::Scope &scope, void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode");
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
...@@ -201,7 +301,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -201,7 +301,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sleep(5); sleep(5);
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
SavePort(rpc_service_); SavePort(rpc_service_);
RunSyncUpdate(&executor, program, &recv_scope, prefetch_block); if (sync_mode) {
RunSyncUpdate(&executor, program, &recv_scope, prefetch_block);
} else {
RunAsyncUpdate(&executor, program, &recv_scope, prefetch_block);
}
} }
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -220,6 +324,12 @@ from send_op and send back variables to recv_op. ...@@ -220,6 +324,12 @@ from send_op and send back variables to recv_op.
"IP address to listen on.") "IP address to listen on.")
.SetDefault("127.0.0.1:6164") .SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>(
"grad_map(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])",
"a map from grad name to it's optimize block id")
.SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not")
.SetDefault(false);
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side."); "BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock, AddAttr<framework::BlockDesc *>(kPrefetchBlock,
......
...@@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Scope* recv_scope, framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const; framework::BlockDesc* prefetch_block) const;
void RunAsyncUpdate(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;
void Stop() override; void Stop() override;
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册