提交 d8278815 编写于 作者: Q Qiao Longfei

fix pserver and prefetch rpc

上级 5856c2f3
......@@ -169,6 +169,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
......@@ -184,11 +185,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
s, method, h, this] {
s, method, h, table_name, this] {
auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
0, table_name);
VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
......
......@@ -194,6 +194,7 @@ class GRPCClient : public RPCClient {
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendBatchBarrier(
......
......@@ -42,7 +42,8 @@ static void SerializeDestroyCallback(void* payload) {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, const std::string& out_name,
const int trainer_id) {
const int trainer_id,
const std::string& table_name) {
platform::RecordRPCEvent record_event("serial", &ctx);
VarMsg request;
TensorPayload* payload = nullptr;
......@@ -63,6 +64,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (!out_name.empty()) {
request.set_out_varname(out_name);
}
if (!table_name.empty()) {
request.set_table_name(table_name);
}
if (var->IsType<framework::LoDTensor>()) {
request.set_type(::sendrecv::LOD_TENSOR);
payload = new TensorPayload(GetTensorPayload(var, ctx, &request));
......
......@@ -39,7 +39,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string(),
const int trainer_id = 0);
const int trainer_id = 0,
const std::string& table_name = std::string());
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
......
......@@ -84,7 +84,7 @@ inline void SplitIdsIntoMultipleVarsBySection(
const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids,
framework::Scope* scope) {
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size() + 1, "");
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), "");
auto place = platform::CPUPlace();
......@@ -184,7 +184,8 @@ void prefetch(const std::string& id_name, const std::string& out_name,
VLOG(30) << "sending " << in_var_names[i] << " to " << epmap[i]
<< " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i]));
epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i],
table_name));
} else {
VLOG(30) << "don't send no-initialied variable: " << out_var_names[i];
}
......
......@@ -120,12 +120,13 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
const std::string& table_name) {
VLOG(40) << "RequestPrefetchHandler " << varname;
if (table_name.empty()) {
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
if (table_name.empty()) {
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
} else {
(*outvar)->GetMutable<framework::LoDTensor>();
auto lookup_table_op =
BuildLookupTableOp(table_name, varname, out_var_name);
paddle::platform::CPUPlace cpu_place;
......
......@@ -48,7 +48,7 @@ class RPCClient {
virtual VarHandlePtr AsyncPrefetchVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& in_var_name,
const std::string& out_var_name,
const std::string& out_var_name, const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendBatchBarrier(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册