未验证 提交 7a993ee4 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #10080 from jacquesqiao/refine-listen-and-serve-op

Refine listen and serve op
...@@ -59,15 +59,13 @@ class AsyncGRPCServer final { ...@@ -59,15 +59,13 @@ class AsyncGRPCServer final {
void SetProgram(framework::ProgramDesc *program) { program_ = program; } void SetProgram(framework::ProgramDesc *program) { program_ = program; }
void SetPrefetchBlkdId(int blkid) { prefetch_blk_id_ = blkid; }
void SetExecutor(framework::Executor *executor) { executor_ = executor; } void SetExecutor(framework::Executor *executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) { void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
prefetch_ctx_ = prepared; prefetch_ctx_ = prepared;
} }
int GetSelectedPort() { return selected_port_; } int GetSelectedPort() const { return selected_port_; }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
...@@ -114,7 +112,6 @@ class AsyncGRPCServer final { ...@@ -114,7 +112,6 @@ class AsyncGRPCServer final {
std::unique_ptr<std::thread> t_get_; std::unique_ptr<std::thread> t_get_;
std::unique_ptr<std::thread> t_prefetch_; std::unique_ptr<std::thread> t_prefetch_;
int prefetch_blk_id_;
framework::ExecutorPrepareContext *prefetch_ctx_; framework::ExecutorPrepareContext *prefetch_ctx_;
framework::ProgramDesc *program_; framework::ProgramDesc *program_;
framework::Executor *executor_; framework::Executor *executor_;
......
...@@ -27,20 +27,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) { ...@@ -27,20 +27,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
} }
static void CreateTensorFromMessageType(framework::Variable *var,
sendrecv::VarType var_type) {
if (var_type == sendrecv::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>();
} else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
var->GetMutable<framework::SelectedRows>();
} else {
PADDLE_THROW(
"VariableMessage type %d is not in "
"[LoDTensor, SelectedRows]",
var_type);
}
}
static void ParallelExecuteBlocks( static void ParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor, const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>> const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
...@@ -62,6 +48,13 @@ static void ParallelExecuteBlocks( ...@@ -62,6 +48,13 @@ static void ParallelExecuteBlocks(
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} }
static void SavePort(std::shared_ptr<detail::AsyncGRPCServer> rpc_service) {
std::ofstream port_file;
port_file.open("/tmp/paddle.selected_port");
port_file << rpc_service->GetSelectedPort();
port_file.close();
}
ListenAndServOp::ListenAndServOp(const std::string &type, ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
...@@ -77,59 +70,26 @@ void ListenAndServOp::Stop() { ...@@ -77,59 +70,26 @@ void ListenAndServOp::Stop() {
server_thread_->join(); server_thread_->join();
} }
void ListenAndServOp::RunImpl(const framework::Scope &scope, void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
const platform::Place &dev_place) const { framework::ProgramDesc *program,
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); framework::Scope *recv_scope,
auto &dev_ctx = *pool.Get(dev_place); framework::BlockDesc *prefetch_block) const {
framework::Scope &recv_scope = scope.NewScope();
if (!rpc_service_) {
std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
}
auto ins = Inputs("X");
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
framework::Executor executor(dev_place);
std::vector<int> block_list; std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) { for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) { block_list.push_back(blkid);
block_list.push_back(blkid);
}
} }
auto optimize_prepared = executor.Prepare(*program, block_list); auto optimize_prepared = executor->Prepare(*program, block_list);
// Insert placeholder for block0 which holds current op itself. // Insert placeholder for block0 which holds current op itself.
optimize_prepared.insert( optimize_prepared.insert(
optimize_prepared.begin(), optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr)); std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
// TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchBlkdId(prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
prefetch_prepared.release();
rpc_service_->SetProgram(program);
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
sleep(5);
// Write to a file of server selected port for python use.
std::ofstream port_file;
port_file.open("/tmp/paddle.selected_port");
port_file << rpc_service_->GetSelectedPort();
port_file.close();
bool exit_flag = false; bool exit_flag = false;
// Record received sparse variables, so that // Record received sparse variables, so that
// we could reset those after execute optimize program // we could reset those after execute optimize program
...@@ -170,7 +130,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -170,7 +130,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
break; break;
} }
// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads // NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work. // and this will still work.
// The optimize blocks which have the same parent ID would run parallel // The optimize blocks which have the same parent ID would run parallel
...@@ -182,16 +142,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -182,16 +142,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
for (size_t blkid = 2; blkid < num_blocks; ++blkid) { for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) { if (blkid != static_cast<size_t>(prefetch_block->ID())) {
if (program->Block(blkid).Parent() != last_parent_blkid) { if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared, ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, &recv_scope); program, recv_scope);
parallel_blkids.clear(); parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent(); last_parent_blkid = program->Block(blkid).Parent();
} }
parallel_blkids.push_back(blkid); parallel_blkids.push_back(blkid);
} }
} }
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared, ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
program, &recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
...@@ -209,6 +169,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -209,6 +169,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
} // while(true) } // while(true)
} }
void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
framework::Executor executor(dev_place);
// prepare rpc_service
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
rpc_service_->SetProgram(program);
rpc_service_->SetExecutor(&executor);
// prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
prefetch_prepared.release();
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
sleep(5);
// Write to a file of server selected port for python use.
SavePort(rpc_service_);
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
}
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker) ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
......
...@@ -34,17 +34,22 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service); ...@@ -34,17 +34,22 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);
class ListenAndServOp : public framework::OperatorBase { class ListenAndServOp : public framework::OperatorBase {
public: public:
ListenAndServOp(const std::string &type, ListenAndServOp(const std::string& type,
const framework::VariableNameMap &inputs, const framework::VariableNameMap& inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap &attrs); const framework::AttributeMap& attrs);
int GetSelectedPort() const; int GetSelectedPort() const;
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;
void Stop() override; void Stop() override;
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope& scope,
const platform::Place &dev_place) const override; const platform::Place& dev_place) const override;
protected: protected:
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_; mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
......
...@@ -127,7 +127,7 @@ void StartServerNet(bool is_sparse) { ...@@ -127,7 +127,7 @@ void StartServerNet(bool is_sparse) {
const auto &root_block = program.Block(0); const auto &root_block = program.Block(0);
auto *optimize_block = program.AppendBlock(root_block); auto *optimize_block = program.AppendBlock(root_block);
auto *prefetch_block = program.AppendBlock(root_block); auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensers, must be of same shape. // X for server side tensors, RX for received tensors, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
f::AttributeMap attrs; f::AttributeMap attrs;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册