rpc_server_test.cc 6.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
T
tangwei12 已提交
2

3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
T
tangwei12 已提交
6

7
    http://www.apache.org/licenses/LICENSE-2.0
T
tangwei12 已提交
8

9 10 11 12 13 14
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <stdlib.h>
16
#include <unistd.h>
17
#include <memory>
18
#include <string>
Y
Yancey1989 已提交
19
#include <thread>  // NOLINT
20
#include <unordered_map>
21 22

#include "gtest/gtest.h"
Y
Yancey1989 已提交
23
#include "paddle/fluid/framework/block_desc.h"
Y
Yancey1989 已提交
24 25 26
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

W
Wu Yi 已提交
27
#include "paddle/fluid/operators/distributed/distributed.h"
28
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
29 30 31
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
32

33 34
namespace framework = paddle::framework;
namespace platform = paddle::platform;
35
namespace distributed = paddle::operators::distributed;
36

37
USE_NO_KERNEL_OP(lookup_sparse_table);
Y
Yancey1989 已提交
38

39 40
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
41

Y
Yancey1989 已提交
42 43 44
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
  auto root_block = program->MutableBlock(0);
  auto* block = program->AppendBlock(*root_block);
Y
Yancey1989 已提交
45 46 47 48

  framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
  framework::VariableNameMap output({{"Output", {"out"}}});
  auto op = block->AppendOp();
49
  op->SetType("lookup_sparse_table");
Y
Yancey1989 已提交
50 51 52
  op->SetInput("W", {"w"});
  op->SetInput("Ids", {"ids"});
  op->SetOutput("Out", {"out"});
Y
Yancey1989 已提交
53 54

  auto& out = *root_block->Var("out");
55
  out.SetType(framework::proto::VarType::LOD_TENSOR);
Y
Yancey1989 已提交
56 57
  out.SetShape({10, 10});

Y
Yancey1989 已提交
58 59 60
  return block;
}

Y
Yancey1989 已提交
61 62
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
  auto w_var = scope->Var("w");
Y
Yancey1989 已提交
63
  w_var->GetMutable<framework::SelectedRows>();
Y
Yancey1989 已提交
64

Y
Yancey1989 已提交
65
  auto out_var = scope->Var("out");
66
  out_var->GetMutable<framework::LoDTensor>();
Y
Yancey1989 已提交
67

Y
Yancey1989 已提交
68
  auto ids_var = scope->Var("ids");
69
  ids_var->GetMutable<framework::LoDTensor>();
Y
Yancey1989 已提交
70 71
}

Y
Yancey1989 已提交
72 73
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
                         int64_t rows_numel) {
Y
Yancey1989 已提交
74
  CreateVarsOnScope(scope, place);
75 76 77 78
  auto ids_var = scope->Var("ids")->GetMutable<framework::LoDTensor>();
  int64_t* ids_ptr =
      ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
  for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
Y
Yancey1989 已提交
79 80
}

Y
Yancey1989 已提交
81 82
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
                         int64_t rows_numel) {
Y
Yancey1989 已提交
83
  CreateVarsOnScope(scope, place);
Y
Yancey1989 已提交
84 85 86
  auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
  auto w_value = w->mutable_value();
  w_value->Resize({rows_numel, 10});
87
  for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true);
Y
Yancey1989 已提交
88 89 90 91

  auto ptr = w_value->mutable_data<float>(*place);

  for (int64_t i = 0; i < w_value->numel(); ++i) {
Y
Yancey1989 已提交
92 93 94
    ptr[i] = static_cast<float>(i / 10);
  }
}
Y
Yancey1989 已提交
95

Y
Yancey1989 已提交
96
void StartServer(const std::string& rpc_name) {
Y
Yancey1989 已提交
97 98 99 100 101
  framework::ProgramDesc program;
  framework::Scope scope;
  platform::CPUPlace place;
  framework::Executor exe(place);
  platform::CPUDeviceContext ctx(place);
Y
Yancey1989 已提交
102
  auto* block = AppendPrefetchBlcok(&program);
103 104 105
  std::string in_var_name("ids");
  std::vector<int> prefetch_block_ids{block->ID()};
  auto prepared = exe.Prepare(program, prefetch_block_ids);
Y
Yancey1989 已提交
106
  InitTensorsOnServer(&scope, &place, 10);
Y
Yancey1989 已提交
107

108 109 110 111
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>
      prefetch_var_name_to_prepared;
  prefetch_var_name_to_prepared[in_var_name] = prepared[0];
Y
Yancey1989 已提交
112

113
  g_req_handler->SetProgram(&program);
114
  g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
115 116 117 118
  g_req_handler->SetDevCtx(&ctx);
  g_req_handler->SetScope(&scope);
  g_req_handler->SetExecutor(&exe);

Y
Yancey1989 已提交
119
  g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
120 121 122

  distributed::HeartBeatMonitor::Init(2, true, "w@grad");

123 124 125
  g_req_handler->SetRPCServer(g_rpc_service.get());

  std::thread server_thread(
126
      std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
Y
Yancey1989 已提交
127

128
  server_thread.join();
129 130
}

131
TEST(PREFETCH, CPU) {
132 133
  setenv("http_proxy", "", 1);
  setenv("https_proxy", "", 1);
1
123malin 已提交
134 135
  g_req_handler.reset(new distributed::RequestPrefetchHandler(
      distributed::DistributedMode::kSync));
G
gongweibao 已提交
136
  g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
137
  distributed::RPCClient* client =
W
Wu Yi 已提交
138
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
139

Y
Yancey1989 已提交
140
  std::thread server_thread(StartServer, distributed::kRequestPrefetch);
141 142 143 144 145
  g_rpc_service->WaitServerReady();

  int port = g_rpc_service->GetSelectedPort();
  std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);

146 147 148
  framework::Scope scope;
  platform::CPUPlace place;
  platform::CPUDeviceContext ctx(place);
149 150 151 152 153 154 155
  {
    // create var on local scope
    int64_t rows_numel = 5;
    InitTensorsOnClient(&scope, &place, rows_numel);
    std::string in_var_name("ids");
    std::string out_var_name("out");

G
gongweibao 已提交
156
    client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
W
Wu Yi 已提交
157
    client->Wait();
158
    auto var = scope.Var(out_var_name);
159 160
    auto value = var->GetMutable<framework::LoDTensor>();
    auto ptr = value->mutable_data<float>(place);
161 162

    for (int64_t i = 0; i < rows_numel; ++i) {
163
      EXPECT_EQ(ptr[0 + i * value->dims()[1]], static_cast<float>(i * 2));
164
    }
Y
Yancey1989 已提交
165
  }
166

W
Wu Yi 已提交
167
  g_rpc_service->ShutDown();
168 169 170 171
  server_thread.join();
  LOG(INFO) << "begin reset";
  g_rpc_service.reset(nullptr);
  g_req_handler.reset(nullptr);
172
}
Y
Yancey1989 已提交
173 174

TEST(COMPLETE, CPU) {
175 176
  setenv("http_proxy", "", 1);
  setenv("https_proxy", "", 1);
T
tangwei12 已提交
177 178
  g_req_handler.reset(
      new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
Y
Yancey1989 已提交
179 180
  g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
  distributed::RPCClient* client =
W
Wu Yi 已提交
181
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
Y
Yancey1989 已提交
182
  PADDLE_ENFORCE(client != nullptr);
T
tangwei12 已提交
183
  std::thread server_thread(StartServer, distributed::kRequestSend);
Y
Yancey1989 已提交
184 185 186 187 188 189
  g_rpc_service->WaitServerReady();
  int port = g_rpc_service->GetSelectedPort();
  std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
  client->AsyncSendComplete(ep);
  client->Wait();

T
tangwei12 已提交
190
  EXPECT_EQ(g_rpc_service->GetClientNum(), 1);
Y
Yancey1989 已提交
191 192 193 194 195 196

  g_rpc_service->ShutDown();
  server_thread.join();
  g_rpc_service.reset(nullptr);
  g_req_handler.reset(nullptr);
}