提交 e9d815e3 编写于 作者: T typhoonzero

prepare and create op before run

上级 18461d09
...@@ -99,7 +99,6 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -99,7 +99,6 @@ class ListenAndServOp : public framework::OperatorBase {
blk_ctx_list.push_back(nullptr); // block0 is not used. blk_ctx_list.push_back(nullptr); // block0 is not used.
for (int blkid = 1; blkid < num_blocks; ++blkid) { for (int blkid = 1; blkid < num_blocks; ++blkid) {
auto *exe_ctx = executor.Prepare(*program, blkid); auto *exe_ctx = executor.Prepare(*program, blkid);
VLOG(2) << "prepare ctx: " << exe_ctx;
blk_ctx_list.push_back(exe_ctx); blk_ctx_list.push_back(exe_ctx);
} }
...@@ -149,6 +148,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -149,6 +148,7 @@ class ListenAndServOp : public framework::OperatorBase {
// should be global ops. // should be global ops.
// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
// and this will still work. // and this will still work.
std::vector<std::future<void>> fs; std::vector<std::future<void>> fs;
// block0 contains only listen_and_serv op, start run from block1. // block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) { for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
...@@ -156,13 +156,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -156,13 +156,8 @@ class ListenAndServOp : public framework::OperatorBase {
[&executor, &program, &recv_scope, &blk_ctx_list, blkid]() { [&executor, &program, &recv_scope, &blk_ctx_list, blkid]() {
int run_block = blkid; // thread local int run_block = blkid; // thread local
try { try {
VLOG(2) << "run ctx: " << blk_ctx_list[run_block]
<< " block: " << run_block;
executor.RunPreparedContext(blk_ctx_list[run_block], executor.RunPreparedContext(blk_ctx_list[run_block],
&recv_scope, false, false); &recv_scope, false, false);
// executor.Run(*program, &recv_scope, run_block,
// false /*create_local_scope*/,
// false /*create_vars*/);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
...@@ -174,8 +169,6 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -174,8 +169,6 @@ class ListenAndServOp : public framework::OperatorBase {
try { try {
executor.RunPreparedContext(blk_ctx_list[num_blocks - 1], &recv_scope, executor.RunPreparedContext(blk_ctx_list[num_blocks - 1], &recv_scope,
false, false); false, false);
// executor.Run(*program, &recv_scope, num_blocks - 1,
// false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
......
...@@ -66,6 +66,7 @@ class SendOp : public framework::OperatorBase { ...@@ -66,6 +66,7 @@ class SendOp : public framework::OperatorBase {
auto* client_var = scope.FindVar(client_var_name); auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>(); detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
ctx.Wait(); // wait before sending
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册