From d827881502592b91a486727f496c40249eee03a4 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 26 Nov 2018 11:34:46 +0800 Subject: [PATCH] fix pserver and prefetch rpc --- paddle/fluid/operators/distributed/grpc_client.cc | 6 ++++-- paddle/fluid/operators/distributed/grpc_client.h | 1 + paddle/fluid/operators/distributed/grpc_serde.cc | 6 +++++- paddle/fluid/operators/distributed/grpc_serde.h | 3 ++- paddle/fluid/operators/distributed/parameter_prefetch.cc | 5 +++-- paddle/fluid/operators/distributed/request_handler_impl.cc | 5 +++-- paddle/fluid/operators/distributed/rpc_client.h | 2 +- 7 files changed, 19 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index c28f86146..39365dd06 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -169,6 +169,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, const framework::Scope& scope, const std::string& in_var_name, const std::string& out_var_name, + const std::string& table_name, int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; @@ -184,11 +185,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, s->Prepare(h, time_out); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - s, method, h, this] { + s, method, h, table_name, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; - SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val); + SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val, + 0, table_name); VLOG(30) << s->GetVarHandlePtr()->String() << " begin"; diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index d8e9cee85..a31a46564 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -194,6 +194,7 @@ class GRPCClient : public RPCClient { const framework::Scope& scope, const std::string& in_var_name, const std::string& out_var_name, + const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncSendBatchBarrier( diff --git a/paddle/fluid/operators/distributed/grpc_serde.cc b/paddle/fluid/operators/distributed/grpc_serde.cc index f27b70a5a..8b3009d39 100644 --- a/paddle/fluid/operators/distributed/grpc_serde.cc +++ b/paddle/fluid/operators/distributed/grpc_serde.cc @@ -42,7 +42,8 @@ static void SerializeDestroyCallback(void* payload) { void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, ::grpc::ByteBuffer* msg, const std::string& out_name, - const int trainer_id) { + const int trainer_id, + const std::string& table_name) { platform::RecordRPCEvent record_event("serial", &ctx); VarMsg request; TensorPayload* payload = nullptr; @@ -63,6 +64,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, if (!out_name.empty()) { request.set_out_varname(out_name); } + if (!table_name.empty()) { + request.set_table_name(table_name); + } if (var->IsType()) { request.set_type(::sendrecv::LOD_TENSOR); payload = new TensorPayload(GetTensorPayload(var, ctx, &request)); diff --git a/paddle/fluid/operators/distributed/grpc_serde.h b/paddle/fluid/operators/distributed/grpc_serde.h index 7ec489e96..fe566d9b4 100644 --- a/paddle/fluid/operators/distributed/grpc_serde.h +++ b/paddle/fluid/operators/distributed/grpc_serde.h @@ -39,7 +39,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, ::grpc::ByteBuffer* msg, const std::string& out_varname = std::string(), - const int trainer_id = 0); + const int trainer_id = 0, + const std::string& table_name = std::string()); void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index 327c8cb4d..4d677e30b 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -84,7 +84,7 @@ inline void SplitIdsIntoMultipleVarsBySection( const std::vector& height_section, const std::vector>& splited_ids, framework::Scope* scope) { - PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size() + 1, ""); + PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), ""); auto place = platform::CPUPlace(); @@ -184,7 +184,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, VLOG(30) << "sending " << in_var_names[i] << " to " << epmap[i] << " to get " << out_var_names[i] << " back"; rets.push_back(rpc_client->AsyncPrefetchVar( - epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i])); + epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i], + table_name)); } else { VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; } diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 0f1264ee9..e041337fd 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -120,12 +120,13 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, const std::string& table_name) { VLOG(40) << "RequestPrefetchHandler " << varname; - auto var_desc = program_->Block(0).FindVar(out_var_name); - InitializeVariable(*outvar, var_desc->GetType()); if (table_name.empty()) { + auto var_desc = program_->Block(0).FindVar(out_var_name); + InitializeVariable(*outvar, var_desc->GetType()); executor_->RunPreparedContext( (*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope); } else { + (*outvar)->GetMutable(); auto lookup_table_op = BuildLookupTableOp(table_name, varname, out_var_name); paddle::platform::CPUPlace cpu_place; diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 1983802e4..4cd3abb5a 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -48,7 +48,7 @@ class RPCClient { virtual VarHandlePtr AsyncPrefetchVar( const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& in_var_name, - const std::string& out_var_name, + const std::string& out_var_name, const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) = 0; virtual VarHandlePtr AsyncSendBatchBarrier( -- GitLab