提交 18461d09 编写于 作者: T typhoonzero

wip

上级 5008020d
...@@ -95,6 +95,13 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -95,6 +95,13 @@ class ListenAndServOp : public framework::OperatorBase {
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
std::vector<framework::ExecutorPrepareContext *> blk_ctx_list;
blk_ctx_list.push_back(nullptr); // block0 is not used.
for (int blkid = 1; blkid < num_blocks; ++blkid) {
auto *exe_ctx = executor.Prepare(*program, blkid);
VLOG(2) << "prepare ctx: " << exe_ctx;
blk_ctx_list.push_back(exe_ctx);
}
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false; bool exit_flag = false;
...@@ -145,12 +152,17 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -145,12 +152,17 @@ class ListenAndServOp : public framework::OperatorBase {
std::vector<std::future<void>> fs; std::vector<std::future<void>> fs;
// block0 contains only listen_and_serv op, start run from block1. // block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) { for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
fs.push_back(framework::Async([&executor, &program, &recv_scope, fs.push_back(framework::Async(
blkid]() { [&executor, &program, &recv_scope, &blk_ctx_list, blkid]() {
int run_block = blkid; // thread local int run_block = blkid; // thread local
try { try {
executor.Run(*program, &recv_scope, run_block, VLOG(2) << "run ctx: " << blk_ctx_list[run_block]
false /*create_local_scope*/, false /*create_vars*/); << " block: " << run_block;
executor.RunPreparedContext(blk_ctx_list[run_block],
&recv_scope, false, false);
// executor.Run(*program, &recv_scope, run_block,
// false /*create_local_scope*/,
// false /*create_vars*/);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
...@@ -160,8 +172,10 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -160,8 +172,10 @@ class ListenAndServOp : public framework::OperatorBase {
// Run global block at final step, or block1 if there are only 2 blocks // Run global block at final step, or block1 if there are only 2 blocks
if (num_blocks >= 2) { if (num_blocks >= 2) {
try { try {
executor.Run(*program, &recv_scope, num_blocks - 1, executor.RunPreparedContext(blk_ctx_list[num_blocks - 1], &recv_scope,
false /*create_local_scope*/, false /*create_vars*/); false, false);
// executor.Run(*program, &recv_scope, num_blocks - 1,
// false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
...@@ -180,6 +194,10 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -180,6 +194,10 @@ class ListenAndServOp : public framework::OperatorBase {
rpc_service_->WaitClientGet(fan_in); rpc_service_->WaitClientGet(fan_in);
sparse_vars.clear(); sparse_vars.clear();
} // while(true) } // while(true)
for (int i = 0; i < num_blocks; ++i) {
delete blk_ctx_list[i];
}
} }
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册