From 686d15c8e02d3e90437050fdde96004d225f7c29 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 26 Nov 2018 14:53:41 +0800 Subject: [PATCH] update grpc_variable_response --- paddle/fluid/operators/distributed/grpc_client.cc | 5 +++-- .../fluid/operators/distributed/grpc_serde_test.cc | 3 ++- .../distributed/grpc_variable_response.cc | 14 ++++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 39365dd06..bee6020d5 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -175,6 +175,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, const std::string ep_val = ep; const std::string in_var_name_val = in_var_name; const std::string out_var_name_val = out_var_name; + const std::string table_name_val = table_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); GetProcessor* s = new GetProcessor(ch); @@ -185,12 +186,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, s->Prepare(h, time_out); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - s, method, h, table_name, this] { + s, method, h, table_name_val, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val, - 0, table_name); + 0, table_name_val); VLOG(30) << s->GetVarHandlePtr()->String() << " begin"; diff --git a/paddle/fluid/operators/distributed/grpc_serde_test.cc b/paddle/fluid/operators/distributed/grpc_serde_test.cc index 96ea05e74..1936c2c62 100644 --- a/paddle/fluid/operators/distributed/grpc_serde_test.cc +++ b/paddle/fluid/operators/distributed/grpc_serde_test.cc @@ -130,7 +130,8 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { math::set_constant(ctx, tensor, 31.9); ::grpc::ByteBuffer msg; - operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg); + operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg, + "outvar", 0, "table_name"); EXPECT_GT(msg.Length(), static_cast(0)); // deserialize diff --git a/paddle/fluid/operators/distributed/grpc_variable_response.cc b/paddle/fluid/operators/distributed/grpc_variable_response.cc index d6d219d43..76ad02b03 100644 --- a/paddle/fluid/operators/distributed/grpc_variable_response.cc +++ b/paddle/fluid/operators/distributed/grpc_variable_response.cc @@ -301,6 +301,20 @@ int GRPCVariableResponse::Parse(Source* source) { meta_.set_trainer_id(trainer_id); break; } + case sendrecv::VariableMessage::kTableNameFieldNumber: { + uint32_t length; + if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { + return tag; + } + + std::string temp; + if (!input.ReadString(&temp, length)) { + return tag; + } + + meta_.set_table_name(temp); + break; + } default: { // Unknown tag, return unknown error. return -1; -- GitLab