diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index a594de67e05acd28ffedc5407beecfaea1281444..bd6e25449f05f27cef04cf8f38a1b0b3a55d8da2 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -95,6 +95,13 @@ class ListenAndServOp : public framework::OperatorBase { "server program should have at least 2 blocks"); framework::Executor executor(dev_place); + std::vector blk_ctx_list; + blk_ctx_list.push_back(nullptr); // block0 is not used. + for (int blkid = 1; blkid < num_blocks; ++blkid) { + auto *exe_ctx = executor.Prepare(*program, blkid); + VLOG(2) << "prepare ctx: " << exe_ctx; + blk_ctx_list.push_back(exe_ctx); + } // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; @@ -145,23 +152,30 @@ class ListenAndServOp : public framework::OperatorBase { std::vector> fs; // block0 contains only listen_and_serv op, start run from block1. for (int blkid = 1; blkid < num_blocks - 1; ++blkid) { - fs.push_back(framework::Async([&executor, &program, &recv_scope, - blkid]() { - int run_block = blkid; // thread local - try { - executor.Run(*program, &recv_scope, run_block, - false /*create_local_scope*/, false /*create_vars*/); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - })); + fs.push_back(framework::Async( + [&executor, &program, &recv_scope, &blk_ctx_list, blkid]() { + int run_block = blkid; // thread local + try { + VLOG(2) << "run ctx: " << blk_ctx_list[run_block] + << " block: " << run_block; + executor.RunPreparedContext(blk_ctx_list[run_block], + &recv_scope, false, false); + // executor.Run(*program, &recv_scope, run_block, + // false /*create_local_scope*/, + // false /*create_vars*/); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + })); } for (int i = 0; i < num_blocks - 2; ++i) fs[i].wait(); // Run global block at final step, or block1 if there are only 2 blocks if (num_blocks >= 2) { try { - executor.Run(*program, &recv_scope, num_blocks - 1, - false /*create_local_scope*/, false /*create_vars*/); + executor.RunPreparedContext(blk_ctx_list[num_blocks - 1], &recv_scope, + false, false); + // executor.Run(*program, &recv_scope, num_blocks - 1, + // false /*create_local_scope*/, false /*create_vars*/); } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } @@ -180,6 +194,10 @@ class ListenAndServOp : public framework::OperatorBase { rpc_service_->WaitClientGet(fan_in); sparse_vars.clear(); } // while(true) + + for (int i = 0; i < num_blocks; ++i) { + delete blk_ctx_list[i]; + } } protected: