diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index c28f86146d3040c6a26cabfb795eff67375d4b76..39365dd0688fcad2d6132a46a716b9f7362af618 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -169,6 +169,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, const framework::Scope& scope, const std::string& in_var_name, const std::string& out_var_name, + const std::string& table_name, int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; @@ -184,11 +185,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, s->Prepare(h, time_out); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - s, method, h, this] { + s, method, h, table_name, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; - SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val); + SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val, + 0, table_name); VLOG(30) << s->GetVarHandlePtr()->String() << " begin"; diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index d8e9cee85bd734c2ed4b1cae03ecee04e304b651..a31a465645ee4256a76573576ea7fa5af7a5a101 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -194,6 +194,7 @@ class GRPCClient : public RPCClient { const framework::Scope& scope, const std::string& in_var_name, const std::string& out_var_name, + const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncSendBatchBarrier( diff --git a/paddle/fluid/operators/distributed/grpc_serde.cc b/paddle/fluid/operators/distributed/grpc_serde.cc index f27b70a5a3dd2927b51a95af7bd1b84a6e232f86..8b3009d39f9be36383d2a73bd8b015d62d04cdd5 100644 --- a/paddle/fluid/operators/distributed/grpc_serde.cc +++ b/paddle/fluid/operators/distributed/grpc_serde.cc @@ -42,7 +42,8 @@ static void SerializeDestroyCallback(void* payload) { void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, ::grpc::ByteBuffer* msg, const std::string& out_name, - const int trainer_id) { + const int trainer_id, + const std::string& table_name) { platform::RecordRPCEvent record_event("serial", &ctx); VarMsg request; TensorPayload* payload = nullptr; @@ -63,6 +64,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, if (!out_name.empty()) { request.set_out_varname(out_name); } + if (!table_name.empty()) { + request.set_table_name(table_name); + } if (var->IsType()) { request.set_type(::sendrecv::LOD_TENSOR); payload = new TensorPayload(GetTensorPayload(var, ctx, &request)); diff --git a/paddle/fluid/operators/distributed/grpc_serde.h b/paddle/fluid/operators/distributed/grpc_serde.h index 7ec489e961630747ba00e68ad3603cacbb1aa485..fe566d9b4c67515381d9436c59e9d1c6f28ec6fa 100644 --- a/paddle/fluid/operators/distributed/grpc_serde.h +++ b/paddle/fluid/operators/distributed/grpc_serde.h @@ -39,7 +39,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, ::grpc::ByteBuffer* msg, const std::string& out_varname = std::string(), - const int trainer_id = 0); + const int trainer_id = 0, + const std::string& table_name = std::string()); void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index 327c8cb4dedb61ff50b8239a409c4f25ce699bfc..4d677e30b8e01943abf12c861a8d2a77a69024a5 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -84,7 +84,7 @@ inline void SplitIdsIntoMultipleVarsBySection( const std::vector& height_section, const std::vector>& splited_ids, framework::Scope* scope) { - PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size() + 1, ""); + PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), ""); auto place = platform::CPUPlace(); @@ -184,7 +184,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, VLOG(30) << "sending " << in_var_names[i] << " to " << epmap[i] << " to get " << out_var_names[i] << " back"; rets.push_back(rpc_client->AsyncPrefetchVar( - epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i])); + epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i], + table_name)); } else { VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; } diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 0f1264ee96eeb2cd8fb23af097e6274acf4eeec8..e041337fd9a9b3a89f23bdb7bec6a3d5ba2a3adf 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -120,12 +120,13 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, const std::string& table_name) { VLOG(40) << "RequestPrefetchHandler " << varname; - auto var_desc = program_->Block(0).FindVar(out_var_name); - InitializeVariable(*outvar, var_desc->GetType()); if (table_name.empty()) { + auto var_desc = program_->Block(0).FindVar(out_var_name); + InitializeVariable(*outvar, var_desc->GetType()); executor_->RunPreparedContext( (*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope); } else { + (*outvar)->GetMutable(); auto lookup_table_op = BuildLookupTableOp(table_name, varname, out_var_name); paddle::platform::CPUPlace cpu_place; diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 1983802e49506c79041112ac87d429e4c084ddfd..4cd3abb5a61068bc4f9f5b38cafc2daa8406d448 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -48,7 +48,7 @@ class RPCClient { virtual VarHandlePtr AsyncPrefetchVar( const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& in_var_name, - const std::string& out_var_name, + const std::string& out_var_name, const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) = 0; virtual VarHandlePtr AsyncSendBatchBarrier(