rpc_server_test.cc 7.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
T
tangwei12 已提交
2

3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
T
tangwei12 已提交
6

7
    http://www.apache.org/licenses/LICENSE-2.0
T
tangwei12 已提交
8

9 10 11 12 13 14
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <stdlib.h>
16
#include <unistd.h>
17
#include <memory>
18
#include <string>
Y
Yancey1989 已提交
19
#include <thread>  // NOLINT
20
#include <unordered_map>
21 22

#include "gtest/gtest.h"
Y
Yancey1989 已提交
23
#include "paddle/fluid/framework/block_desc.h"
Y
Yancey1989 已提交
24 25 26
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

W
Wu Yi 已提交
27
#include "paddle/fluid/operators/distributed/distributed.h"
28
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
29 30 31
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
32

33 34
namespace framework = paddle::framework;
namespace platform = paddle::platform;
35
namespace distributed = paddle::operators::distributed;
36

37
USE_NO_KERNEL_OP(lookup_sparse_table_read);
38
USE_OP(scale);
Y
Yancey1989 已提交
39

40 41
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
42

43
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
Y
Yancey1989 已提交
44 45
  auto root_block = program->MutableBlock(0);
  auto* block = program->AppendBlock(*root_block);
Y
Yancey1989 已提交
46

47 48 49 50 51 52 53
  framework::OpDesc* op = block->AppendOp();
  op->SetType("scale");
  op->SetInput("X", {"x"});
  op->SetOutput("Out", {"res"});
  op->SetAttr("scale", 0.5f);

  auto& out = *root_block->Var("res");
54
  out.SetType(framework::proto::VarType::LOD_TENSOR);
55
  out.SetShape({1, 10});
Y
Yancey1989 已提交
56

Y
Yancey1989 已提交
57 58 59
  return block;
}

Y
Yancey1989 已提交
60 61
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
  auto w_var = scope->Var("w");
Y
Yancey1989 已提交
62
  w_var->GetMutable<framework::SelectedRows>();
Y
Yancey1989 已提交
63

Y
Yancey1989 已提交
64
  auto out_var = scope->Var("out");
65
  out_var->GetMutable<framework::LoDTensor>();
Y
Yancey1989 已提交
66

Y
Yancey1989 已提交
67
  auto ids_var = scope->Var("ids");
68
  ids_var->GetMutable<framework::LoDTensor>();
69 70 71 72 73 74

  auto x_var = scope->Var("x");
  x_var->GetMutable<framework::LoDTensor>();

  auto res_var = scope->Var("res");
  res_var->GetMutable<framework::LoDTensor>();
Y
Yancey1989 已提交
75 76
}

Y
Yancey1989 已提交
77 78
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
                         int64_t rows_numel) {
Y
Yancey1989 已提交
79
  CreateVarsOnScope(scope, place);
80 81 82 83
  auto ids_var = scope->Var("ids")->GetMutable<framework::LoDTensor>();
  int64_t* ids_ptr =
      ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
  for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
84 85 86 87 88

  auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
  float* x_ptr =
      x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
  for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0;
Y
Yancey1989 已提交
89 90
}

Y
Yancey1989 已提交
91 92
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
                         int64_t rows_numel) {
Y
Yancey1989 已提交
93
  CreateVarsOnScope(scope, place);
Y
Yancey1989 已提交
94 95 96
  auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
  auto w_value = w->mutable_value();
  w_value->Resize({rows_numel, 10});
97
  for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true);
Y
Yancey1989 已提交
98 99 100 101

  auto ptr = w_value->mutable_data<float>(*place);

  for (int64_t i = 0; i < w_value->numel(); ++i) {
Y
Yancey1989 已提交
102 103 104
    ptr[i] = static_cast<float>(i / 10);
  }
}
Y
Yancey1989 已提交
105

Y
Yancey1989 已提交
106
void StartServer(const std::string& rpc_name) {
Y
Yancey1989 已提交
107 108 109 110 111 112
  framework::ProgramDesc program;
  framework::Scope scope;
  platform::CPUPlace place;
  framework::Executor exe(place);
  platform::CPUDeviceContext ctx(place);

113 114 115
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>
      prefetch_var_name_to_prepared;
Y
Yancey1989 已提交
116

117
  g_req_handler->SetProgram(&program);
118
  g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
119 120 121 122
  g_req_handler->SetDevCtx(&ctx);
  g_req_handler->SetScope(&scope);
  g_req_handler->SetExecutor(&exe);

Y
Yancey1989 已提交
123
  g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
124 125 126

  distributed::HeartBeatMonitor::Init(2, true, "w@grad");

127 128 129
  g_req_handler->SetRPCServer(g_rpc_service.get());

  std::thread server_thread(
130
      std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
Y
Yancey1989 已提交
131

132
  server_thread.join();
133 134
}

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
void StartSendAndRecvServer(const std::string& rpc_name) {
  framework::ProgramDesc program;
  framework::Scope scope;
  platform::CPUPlace place;
  framework::Executor exe(place);
  platform::CPUDeviceContext ctx(place);
  auto block = AppendSendAndRecvBlock(&program);
  std::string in_var_name("x");
  std::vector<int> prefetch_block_ids{block->ID()};
  auto prepared = exe.Prepare(program, prefetch_block_ids);
  InitTensorsOnServer(&scope, &place, 10);

  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>
      grad_to_prepared_ctx;
  grad_to_prepared_ctx[in_var_name] = prepared[0];

  g_req_handler->SetProgram(&program);
  g_req_handler->SetGradToPreparedCtx(&grad_to_prepared_ctx);
  g_req_handler->SetDevCtx(&ctx);
  g_req_handler->SetScope(&scope);
  g_req_handler->SetExecutor(&exe);

  g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
  g_req_handler->SetRPCServer(g_rpc_service.get());

  std::thread server_thread(
      std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));

  server_thread.join();
}

Y
Yancey1989 已提交
167
TEST(COMPLETE, CPU) {
168 169
  setenv("http_proxy", "", 1);
  setenv("https_proxy", "", 1);
T
tangwei12 已提交
170 171
  g_req_handler.reset(
      new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
Y
Yancey1989 已提交
172 173
  g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
  distributed::RPCClient* client =
W
Wu Yi 已提交
174
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
M
MRXLT 已提交
175 176 177
  PADDLE_ENFORCE_NE(client, nullptr,
                    platform::errors::InvalidArgument(
                        "Client Start Fail, Check Your Code & Env"));
T
tangwei12 已提交
178
  std::thread server_thread(StartServer, distributed::kRequestSend);
Y
Yancey1989 已提交
179 180 181 182 183 184
  g_rpc_service->WaitServerReady();
  int port = g_rpc_service->GetSelectedPort();
  std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
  client->AsyncSendComplete(ep);
  client->Wait();

T
tangwei12 已提交
185
  EXPECT_EQ(g_rpc_service->GetClientNum(), 1);
Y
Yancey1989 已提交
186 187 188 189 190 191

  g_rpc_service->ShutDown();
  server_thread.join();
  g_rpc_service.reset(nullptr);
  g_req_handler.reset(nullptr);
}
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234

TEST(SENDANDRECV, CPU) {
  setenv("http_proxy", "", 1);
  setenv("https_proxy", "", 1);
  g_req_handler.reset(new distributed::RequestSendAndRecvHandler(
      distributed::DistributedMode::kAsync));
  g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
  distributed::RPCClient* client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
  PADDLE_ENFORCE_NE(client, nullptr,
                    platform::errors::InvalidArgument(
                        "Client Start Fail, Check Your Code & Env"));
  std::thread server_thread(StartSendAndRecvServer,
                            distributed::kRequestSendAndRecv);
  g_rpc_service->WaitServerReady();
  int port = g_rpc_service->GetSelectedPort();
  std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);

  framework::Scope scope;
  platform::CPUPlace place;
  platform::CPUDeviceContext ctx(place);

  // create var on local scope
  int64_t rows_numel = 10;
  InitTensorsOnClient(&scope, &place, rows_numel);
  std::string in_var_name("x");
  std::string out_var_name("res");

  client->AsyncSendAndRecv(ep, ctx, scope, in_var_name, out_var_name);
  client->Wait();
  auto var = scope.Var(out_var_name);
  auto value = var->GetMutable<framework::LoDTensor>();
  auto ptr = value->mutable_data<float>(place);

  for (int64_t i = 0; i < rows_numel; ++i) {
    EXPECT_EQ(ptr[i], 0.5);
  }
  g_rpc_service->ShutDown();
  server_thread.join();
  LOG(INFO) << "begin reset";
  g_rpc_service.reset(nullptr);
  g_req_handler.reset(nullptr);
}