提交 9851a534 编写于 作者: Q Qiao Longfei

add prefetch part in pserver

上级 1f87f263
...@@ -181,6 +181,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -181,6 +181,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process... // prefetch process...
std::string in_var_name = request_->Varname(); std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname(); std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId(); int trainer_id = request_->GetTrainerId();
VLOG(40) << "RequestPrefetch, in_var_name: " << in_var_name VLOG(40) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name; << " out_var_name: " << out_var_name;
......
...@@ -191,7 +191,8 @@ class RequestHandler { ...@@ -191,7 +191,8 @@ class RequestHandler {
virtual bool Handle(const std::string& varname, framework::Scope* scope, virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name = "") = 0; const std::string& out_var_name = "",
const std::string& table_name = "") = 0;
protected: protected:
const bool sync_mode_; const bool sync_mode_;
......
...@@ -37,7 +37,8 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -37,7 +37,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name,
const std::string& table_name) {
VLOG(40) << "RequestSendHandler:" << varname; VLOG(40) << "RequestSendHandler:" << varname;
// Sync // Sync
...@@ -77,7 +78,8 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -77,7 +78,8 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name,
const std::string& table_name) {
VLOG(40) << "RequestGetHandler:" << varname; VLOG(40) << "RequestGetHandler:" << varname;
if (sync_mode_) { if (sync_mode_) {
if (varname == FETCH_BARRIER_MESSAGE) { if (varname == FETCH_BARRIER_MESSAGE) {
...@@ -114,14 +116,21 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, ...@@ -114,14 +116,21 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name,
const std::string& table_name) {
VLOG(40) << "RequestPrefetchHandler " << varname; VLOG(40) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(out_var_name); auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType()); InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext( if (table_name.empty()) {
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope); executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
} else {
auto lookup_table_op =
BuildLookupTableOp(table_name, varname, out_var_name);
paddle::platform::CPUPlace cpu_place;
lookup_table_op->Run(*scope, cpu_place);
}
return true; return true;
} }
...@@ -130,7 +139,8 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -130,7 +139,8 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name,
const std::string& table_name) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
checkpoint_notify_id != -1, checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke."); "when checkpoint_notify_id = -1, there should be no RPC invoke.");
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
...@@ -43,8 +44,8 @@ class RequestSendHandler final : public RequestHandler { ...@@ -43,8 +44,8 @@ class RequestSendHandler final : public RequestHandler {
virtual ~RequestSendHandler() {} virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const int trainer_id, const std::string& out_var_name = "",
const std::string& out_var_name = "") override; const std::string& table_name = "") override;
private: private:
bool enable_dc_asgd_; bool enable_dc_asgd_;
...@@ -59,21 +60,44 @@ class RequestGetHandler final : public RequestHandler { ...@@ -59,21 +60,44 @@ class RequestGetHandler final : public RequestHandler {
virtual ~RequestGetHandler() {} virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const int trainer_id, const std::string& out_var_name = "",
const std::string& out_var_name = "") override; const std::string& table_name = "") override;
private: private:
bool enable_dc_asgd_; bool enable_dc_asgd_;
}; };
static inline void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}
class RequestPrefetchHandler final : public RequestHandler { class RequestPrefetchHandler final : public RequestHandler {
public: public:
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {} virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const int trainer_id, const std::string& out_var_name = "",
const std::string& out_var_name = "") override; const std::string& table_name = "") override;
private:
std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("lookup_table");
BuildVar("W", {table_name.data()}, op_desc.add_inputs());
BuildVar("Ids", {id_name.data()}, op_desc.add_inputs());
BuildVar("Out", {out_name.data()}, op_desc.add_outputs());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
return op;
}
}; };
class RequestCheckpointHandler final : public RequestHandler { class RequestCheckpointHandler final : public RequestHandler {
...@@ -85,8 +109,8 @@ class RequestCheckpointHandler final : public RequestHandler { ...@@ -85,8 +109,8 @@ class RequestCheckpointHandler final : public RequestHandler {
virtual ~RequestCheckpointHandler() {} virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const int trainer_id, const std::string& out_var_name = "",
const std::string& out_var_name = "") override; const std::string& table_name = "") override;
private: private:
int checkpoint_notify_id; int checkpoint_notify_id;
......
...@@ -80,6 +80,7 @@ message VariableMessage { ...@@ -80,6 +80,7 @@ message VariableMessage {
// when profile switches from 1 to 2. // when profile switches from 1 to 2.
int64 profile = 11; int64 profile = 11;
int64 trainer_id = 12; int64 trainer_id = 12;
string table_name = 13;
} }
message VoidMessage {} message VoidMessage {}
...@@ -85,6 +85,7 @@ class VariableResponse { ...@@ -85,6 +85,7 @@ class VariableResponse {
inline framework::Scope* GetMutableLocalScope() const { return local_scope_; } inline framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() const { return meta_.varname(); } inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); } inline std::string OutVarname() const { return meta_.out_varname(); }
inline std::string TableName() const { return meta_.table_name(); }
// should call parse first. // should call parse first.
framework::Variable* GetVar() { framework::Variable* GetVar() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册