提交 4e36c0ec 编写于 作者: Q qiaolongfei

update prefetch logic in grpc_server

上级 0d3d4ae7
......@@ -155,16 +155,18 @@ class RequestPrefetch final : public RequestBase {
void Process() override {
// prefetch process...
std::string varname = request_->OutVarname();
VLOG(3) << "RequestPrefetch " << varname;
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
VLOG(3) << "in_var_name: " << in_var_name
<< " RequestPrefetch: " << out_var_name;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(varname);
auto invar = scope->FindVar(in_var_name);
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar);
request_handler_->Handle(in_var_name, scope, invar, &outvar);
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_);
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
......
......@@ -99,11 +99,17 @@ void StartServer() {
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program);
auto prepared = exe.Prepare(program, block->ID());
std::string in_var_name("ids");
std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
InitTensorsOnServer(&scope, &place, 10);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(std::move(prepared));
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
......
......@@ -57,9 +57,12 @@ class RequestHandler {
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
// Used for dist lookup table prefetch
void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
prefetch_ctx_.reset(prepared.release());
std::unordered_map<
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
prefetch_var_name_to_prepared_ctx_ = g;
}
// Used for async.
......@@ -75,9 +78,6 @@ class RequestHandler {
bool sync_mode() { return sync_mode_; }
framework::Scope* scope() { return scope_; }
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
framework::ExecutorPrepareContext* prefetch_ctx() {
return prefetch_ctx_.get();
}
framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; }
......@@ -106,12 +106,17 @@ class RequestHandler {
framework::Executor* executor_;
framework::Scope* scope_;
framework::ProgramDesc* program_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
// used for distribute lookup table prefetch
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
prefetch_var_name_to_prepared_ctx_;
// Used for async.
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_;
RPCServer* rpc_server_;
};
......
......@@ -111,7 +111,8 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
auto var_desc = program_->Block(0).FindVar(varname);
*outvar = scope->FindVar(varname);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_.get(), scope);
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
return true;
}
......
......@@ -89,16 +89,19 @@ void ListenAndServOp::SavePort() const {
rpc_service_->SavePort();
}
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const {
void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope,
const std::vector<int> &prefetch_block_id_list) const {
// FIXME(qiao) run should not run the block to do prefetch, currently prefetch
// block
// can only be at the last blocks of the program
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
for (size_t blkid = 1; blkid < prefetch_block_id_list[0]; ++blkid) {
block_list.push_back(blkid);
}
auto optimize_prepared = executor->Prepare(*program, block_list);
......@@ -128,16 +131,14 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(1);
double ts = detail::GetTimestamp();
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
for (size_t blkid = 2; blkid < prefetch_block_id_list[0]; ++blkid) {
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
}
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
recv_scope);
......@@ -203,18 +204,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
} // while(true)
}
static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx,
framework::Executor *executor,
framework::ProgramDesc *program,
framework::ExecutorPrepareContext *prefetch_ctx,
detail::RPCServer *rpc_server) {
static void FillRequestCtx(
detail::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx, framework::Executor *executor,
framework::ProgramDesc *program,
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx,
detail::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
h->SetExecutor(executor);
h->SetProgram(program);
h->SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx));
h->SetPrefetchPreparedCtx(prefetch_ctx);
h->SetRPCServer(rpc_server);
}
......@@ -248,18 +250,41 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto grad_to_block_id_str = Attr<std::vector<std::string>>(kPrefetchBlock);
framework::BlockDesc *prefetch_block = nullptr;
auto *program = optimize_block->Program();
framework::Executor executor(dev_place);
// prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
std::vector<int> prefetch_block_id_list;
std::unordered_map<int32_t, std::string> block_id_to_prefetch_var_name;
auto prefetch_var_name_to_block_id_str =
Attr<std::vector<std::string>>(kPrefetchVarNameToBlockId);
for (const auto &prefetch_var_name_and_id :
prefetch_var_name_to_block_id_str) {
std::vector<std::string> pieces;
split(prefetch_var_name_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
int block_id = std::stoi(pieces[1]);
prefetch_block_id_list.push_back(block_id);
block_id_to_prefetch_var_name[block_id] = pieces[0];
}
auto prefetch_prepared = executor.Prepare(*program, prefetch_block_id_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared_ctx;
for (int i = 0; i < prefetch_block_id_list.size(); ++i) {
auto block_id = prefetch_block_id_list[i];
auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
}
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
&dev_ctx, &executor, program, prefetch_prepared.release(),
rpc_service_.get());
&dev_ctx, &executor, program,
&prefetch_var_name_to_prepared_ctx, rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
......@@ -277,7 +302,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// Write to a file of server selected port for python use.
SavePort();
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list);
} else {
RunAsyncLoop(&executor, program);
}
......@@ -303,7 +328,7 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<std::vector<std::string>>(kPrefetchBlock,
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch block to run on server side.");
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <atomic>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -30,7 +31,7 @@ namespace paddle {
namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "prefetch_var_name_to_block_id";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service);
......@@ -46,7 +47,7 @@ class ListenAndServOp : public framework::OperatorBase {
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;
const std::vector<int>& prefetch_block_id_list) const;
void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册