提交 9851a534 编写于 作者: Q Qiao Longfei

add prefetch part in pserver

上级 1f87f263
......@@ -181,6 +181,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process...
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId();
VLOG(40) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;
......
......@@ -191,7 +191,8 @@ class RequestHandler {
virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") = 0;
const std::string& out_var_name = "",
const std::string& table_name = "") = 0;
protected:
const bool sync_mode_;
......
......@@ -37,7 +37,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(40) << "RequestSendHandler:" << varname;
// Sync
......@@ -77,7 +78,8 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(40) << "RequestGetHandler:" << varname;
if (sync_mode_) {
if (varname == FETCH_BARRIER_MESSAGE) {
......@@ -114,14 +116,21 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(40) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
if (table_name.empty()) {
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
} else {
auto lookup_table_op =
BuildLookupTableOp(table_name, varname, out_var_name);
paddle::platform::CPUPlace cpu_place;
lookup_table_op->Run(*scope, cpu_place);
}
return true;
}
......@@ -130,7 +139,8 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
......
......@@ -24,6 +24,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
......@@ -43,8 +44,8 @@ class RequestSendHandler final : public RequestHandler {
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
bool enable_dc_asgd_;
......@@ -59,21 +60,44 @@ class RequestGetHandler final : public RequestHandler {
virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
bool enable_dc_asgd_;
};
static inline void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}
class RequestPrefetchHandler final : public RequestHandler {
public:
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("lookup_table");
BuildVar("W", {table_name.data()}, op_desc.add_inputs());
BuildVar("Ids", {id_name.data()}, op_desc.add_inputs());
BuildVar("Out", {out_name.data()}, op_desc.add_outputs());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
return op;
}
};
class RequestCheckpointHandler final : public RequestHandler {
......@@ -85,8 +109,8 @@ class RequestCheckpointHandler final : public RequestHandler {
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
int checkpoint_notify_id;
......
......@@ -80,6 +80,7 @@ message VariableMessage {
// when profile switches from 1 to 2.
int64 profile = 11;
int64 trainer_id = 12;
string table_name = 13;
}
message VoidMessage {}
......@@ -85,6 +85,7 @@ class VariableResponse {
inline framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); }
inline std::string TableName() const { return meta_.table_name(); }
// should call parse first.
framework::Variable* GetVar() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册