From 9851a534780471b5eefed15fed8846e25a319149 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 23 Nov 2018 15:18:24 +0800 Subject: [PATCH] add prefetch part in pserver --- .../operators/distributed/grpc_server.cc | 1 + .../operators/distributed/request_handler.h | 3 +- .../distributed/request_handler_impl.cc | 24 +++++++---- .../distributed/request_handler_impl.h | 40 +++++++++++++++---- .../operators/distributed/send_recv.proto.in | 1 + .../operators/distributed/variable_response.h | 1 + 6 files changed, 54 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/distributed/grpc_server.cc b/paddle/fluid/operators/distributed/grpc_server.cc index ffd2b1707..d5295dc63 100644 --- a/paddle/fluid/operators/distributed/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc_server.cc @@ -181,6 +181,7 @@ class RequestPrefetch final : public RequestBase { // prefetch process... std::string in_var_name = request_->Varname(); std::string out_var_name = request_->OutVarname(); + std::string table_name = request_->TableName(); int trainer_id = request_->GetTrainerId(); VLOG(40) << "RequestPrefetch, in_var_name: " << in_var_name << " out_var_name: " << out_var_name; diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 3bcc59a47..f29b2bf7d 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -191,7 +191,8 @@ class RequestHandler { virtual bool Handle(const std::string& varname, framework::Scope* scope, framework::Variable* var, framework::Variable** outvar, const int trainer_id, - const std::string& out_var_name = "") = 0; + const std::string& out_var_name = "", + const std::string& table_name = "") = 0; protected: const bool sync_mode_; diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index dae56cc84..0f1264ee9 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -37,7 +37,8 @@ bool RequestSendHandler::Handle(const std::string& varname, framework::Variable* invar, framework::Variable** outvar, const int trainer_id, - const std::string& out_var_name) { + const std::string& out_var_name, + const std::string& table_name) { VLOG(40) << "RequestSendHandler:" << varname; // Sync @@ -77,7 +78,8 @@ bool RequestGetHandler::Handle(const std::string& varname, framework::Variable* invar, framework::Variable** outvar, const int trainer_id, - const std::string& out_var_name) { + const std::string& out_var_name, + const std::string& table_name) { VLOG(40) << "RequestGetHandler:" << varname; if (sync_mode_) { if (varname == FETCH_BARRIER_MESSAGE) { @@ -114,14 +116,21 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, framework::Variable* invar, framework::Variable** outvar, const int trainer_id, - const std::string& out_var_name) { + const std::string& out_var_name, + const std::string& table_name) { VLOG(40) << "RequestPrefetchHandler " << varname; auto var_desc = program_->Block(0).FindVar(out_var_name); InitializeVariable(*outvar, var_desc->GetType()); - executor_->RunPreparedContext( - (*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope); - + if (table_name.empty()) { + executor_->RunPreparedContext( + (*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope); + } else { + auto lookup_table_op = + BuildLookupTableOp(table_name, varname, out_var_name); + paddle::platform::CPUPlace cpu_place; + lookup_table_op->Run(*scope, cpu_place); + } return true; } @@ -130,7 +139,8 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Variable* invar, framework::Variable** outvar, const int trainer_id, - const std::string& out_var_name) { + const std::string& out_var_name, + const std::string& table_name) { PADDLE_ENFORCE( checkpoint_notify_id != -1, "when checkpoint_notify_id = -1, there should be no RPC invoke."); diff --git a/paddle/fluid/operators/distributed/request_handler_impl.h b/paddle/fluid/operators/distributed/request_handler_impl.h index c1afda9dd..5e0b25c5c 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.h +++ b/paddle/fluid/operators/distributed/request_handler_impl.h @@ -24,6 +24,7 @@ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" @@ -43,8 +44,8 @@ class RequestSendHandler final : public RequestHandler { virtual ~RequestSendHandler() {} bool Handle(const std::string& varname, framework::Scope* scope, framework::Variable* var, framework::Variable** outvar, - const int trainer_id, - const std::string& out_var_name = "") override; + const int trainer_id, const std::string& out_var_name = "", + const std::string& table_name = "") override; private: bool enable_dc_asgd_; @@ -59,21 +60,44 @@ class RequestGetHandler final : public RequestHandler { virtual ~RequestGetHandler() {} bool Handle(const std::string& varname, framework::Scope* scope, framework::Variable* var, framework::Variable** outvar, - const int trainer_id, - const std::string& out_var_name = "") override; + const int trainer_id, const std::string& out_var_name = "", + const std::string& table_name = "") override; private: bool enable_dc_asgd_; }; +static inline void BuildVar(const std::string& param_name, + std::initializer_list arguments, + paddle::framework::proto::OpDesc::Var* var) { + var->set_parameter(param_name); + for (auto& arg_name : arguments) { + *var->mutable_arguments()->Add() = arg_name; + } +} + class RequestPrefetchHandler final : public RequestHandler { public: explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {} virtual ~RequestPrefetchHandler() {} bool Handle(const std::string& varname, framework::Scope* scope, framework::Variable* var, framework::Variable** outvar, - const int trainer_id, - const std::string& out_var_name = "") override; + const int trainer_id, const std::string& out_var_name = "", + const std::string& table_name = "") override; + + private: + std::unique_ptr BuildLookupTableOp( + const std::string& table_name, const std::string& id_name, + const std::string& out_name) { + paddle::framework::proto::OpDesc op_desc; + op_desc.set_type("lookup_table"); + BuildVar("W", {table_name.data()}, op_desc.add_inputs()); + BuildVar("Ids", {id_name.data()}, op_desc.add_inputs()); + BuildVar("Out", {out_name.data()}, op_desc.add_outputs()); + + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + return op; + } }; class RequestCheckpointHandler final : public RequestHandler { @@ -85,8 +109,8 @@ class RequestCheckpointHandler final : public RequestHandler { virtual ~RequestCheckpointHandler() {} bool Handle(const std::string& varname, framework::Scope* scope, framework::Variable* var, framework::Variable** outvar, - const int trainer_id, - const std::string& out_var_name = "") override; + const int trainer_id, const std::string& out_var_name = "", + const std::string& table_name = "") override; private: int checkpoint_notify_id; diff --git a/paddle/fluid/operators/distributed/send_recv.proto.in b/paddle/fluid/operators/distributed/send_recv.proto.in index 55820c980..7b7d069f1 100644 --- a/paddle/fluid/operators/distributed/send_recv.proto.in +++ b/paddle/fluid/operators/distributed/send_recv.proto.in @@ -80,6 +80,7 @@ message VariableMessage { // when profile switches from 1 to 2. int64 profile = 11; int64 trainer_id = 12; + string table_name = 13; } message VoidMessage {} diff --git a/paddle/fluid/operators/distributed/variable_response.h b/paddle/fluid/operators/distributed/variable_response.h index 4c7fcbbdf..a4324f67b 100644 --- a/paddle/fluid/operators/distributed/variable_response.h +++ b/paddle/fluid/operators/distributed/variable_response.h @@ -85,6 +85,7 @@ class VariableResponse { inline framework::Scope* GetMutableLocalScope() const { return local_scope_; } inline std::string Varname() const { return meta_.varname(); } inline std::string OutVarname() const { return meta_.out_varname(); } + inline std::string TableName() const { return meta_.table_name(); } // should call parse first. framework::Variable* GetVar() { -- GitLab