diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 39365dd0688fcad2d6132a46a716b9f7362af618..bee6020d5dd4d0105965df4fc50c81239b225a7f 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -175,6 +175,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, const std::string ep_val = ep; const std::string in_var_name_val = in_var_name; const std::string out_var_name_val = out_var_name; + const std::string table_name_val = table_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); GetProcessor* s = new GetProcessor(ch); @@ -185,12 +186,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, s->Prepare(h, time_out); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - s, method, h, table_name, this] { + s, method, h, table_name_val, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val, - 0, table_name); + 0, table_name_val); VLOG(30) << s->GetVarHandlePtr()->String() << " begin"; diff --git a/paddle/fluid/operators/distributed/grpc_serde_test.cc b/paddle/fluid/operators/distributed/grpc_serde_test.cc index 96ea05e74ed76768248a27ab435dc801b7d1b995..1936c2c623a779c2599aa560247fa5e24f28cd62 100644 --- a/paddle/fluid/operators/distributed/grpc_serde_test.cc +++ b/paddle/fluid/operators/distributed/grpc_serde_test.cc @@ -130,7 +130,8 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { math::set_constant(ctx, tensor, 31.9); ::grpc::ByteBuffer msg; - operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg); + operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg, + "outvar", 0, "table_name"); EXPECT_GT(msg.Length(), static_cast(0)); // deserialize diff --git a/paddle/fluid/operators/distributed/grpc_variable_response.cc b/paddle/fluid/operators/distributed/grpc_variable_response.cc index d6d219d4369ba785e5c369538d4a18dc682952c1..76ad02b0300a58cd19ff2541ad53d067197f4177 100644 --- a/paddle/fluid/operators/distributed/grpc_variable_response.cc +++ b/paddle/fluid/operators/distributed/grpc_variable_response.cc @@ -301,6 +301,20 @@ int GRPCVariableResponse::Parse(Source* source) { meta_.set_trainer_id(trainer_id); break; } + case sendrecv::VariableMessage::kTableNameFieldNumber: { + uint32_t length; + if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { + return tag; + } + + std::string temp; + if (!input.ReadString(&temp, length)) { + return tag; + } + + meta_.set_table_name(temp); + break; + } default: { // Unknown tag, return unknown error. return -1;