提交 d8278815 编写于 作者: Q Qiao Longfei

fix pserver and prefetch rpc

上级 5856c2f3
...@@ -169,6 +169,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -169,6 +169,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) { int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
...@@ -184,11 +185,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -184,11 +185,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
s, method, h, this] { s, method, h, table_name, this] {
auto* var = p_scope->FindVar(in_var_name_val); auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val); SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
0, table_name);
VLOG(30) << s->GetVarHandlePtr()->String() << " begin"; VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
......
...@@ -194,6 +194,7 @@ class GRPCClient : public RPCClient { ...@@ -194,6 +194,7 @@ class GRPCClient : public RPCClient {
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendBatchBarrier( VarHandlePtr AsyncSendBatchBarrier(
......
...@@ -42,7 +42,8 @@ static void SerializeDestroyCallback(void* payload) { ...@@ -42,7 +42,8 @@ static void SerializeDestroyCallback(void* payload) {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, const std::string& out_name, ::grpc::ByteBuffer* msg, const std::string& out_name,
const int trainer_id) { const int trainer_id,
const std::string& table_name) {
platform::RecordRPCEvent record_event("serial", &ctx); platform::RecordRPCEvent record_event("serial", &ctx);
VarMsg request; VarMsg request;
TensorPayload* payload = nullptr; TensorPayload* payload = nullptr;
...@@ -63,6 +64,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -63,6 +64,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (!out_name.empty()) { if (!out_name.empty()) {
request.set_out_varname(out_name); request.set_out_varname(out_name);
} }
if (!table_name.empty()) {
request.set_table_name(table_name);
}
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
request.set_type(::sendrecv::LOD_TENSOR); request.set_type(::sendrecv::LOD_TENSOR);
payload = new TensorPayload(GetTensorPayload(var, ctx, &request)); payload = new TensorPayload(GetTensorPayload(var, ctx, &request));
......
...@@ -39,7 +39,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -39,7 +39,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, ::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string(), const std::string& out_varname = std::string(),
const int trainer_id = 0); const int trainer_id = 0,
const std::string& table_name = std::string());
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
......
...@@ -84,7 +84,7 @@ inline void SplitIdsIntoMultipleVarsBySection( ...@@ -84,7 +84,7 @@ inline void SplitIdsIntoMultipleVarsBySection(
const std::vector<int64_t>& height_section, const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids, const std::vector<std::vector<int64_t>>& splited_ids,
framework::Scope* scope) { framework::Scope* scope) {
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size() + 1, ""); PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), "");
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
...@@ -184,7 +184,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -184,7 +184,8 @@ void prefetch(const std::string& id_name, const std::string& out_name,
VLOG(30) << "sending " << in_var_names[i] << " to " << epmap[i] VLOG(30) << "sending " << in_var_names[i] << " to " << epmap[i]
<< " to get " << out_var_names[i] << " back"; << " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar( rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i])); epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i],
table_name));
} else { } else {
VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; VLOG(30) << "don't send no-initialied variable: " << out_var_names[i];
} }
......
...@@ -120,12 +120,13 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, ...@@ -120,12 +120,13 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
const std::string& table_name) { const std::string& table_name) {
VLOG(40) << "RequestPrefetchHandler " << varname; VLOG(40) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
if (table_name.empty()) { if (table_name.empty()) {
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext( executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope); (*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
} else { } else {
(*outvar)->GetMutable<framework::LoDTensor>();
auto lookup_table_op = auto lookup_table_op =
BuildLookupTableOp(table_name, varname, out_var_name); BuildLookupTableOp(table_name, varname, out_var_name);
paddle::platform::CPUPlace cpu_place; paddle::platform::CPUPlace cpu_place;
......
...@@ -48,7 +48,7 @@ class RPCClient { ...@@ -48,7 +48,7 @@ class RPCClient {
virtual VarHandlePtr AsyncPrefetchVar( virtual VarHandlePtr AsyncPrefetchVar(
const std::string& ep, const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& in_var_name, const framework::Scope& scope, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name, const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendBatchBarrier( virtual VarHandlePtr AsyncSendBatchBarrier(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册