rpc_server_test.cc 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <stdlib.h>
16
#include <unistd.h>
17
#include <memory>
18
#include <string>
Y
Yancey1989 已提交
19
#include <thread>  // NOLINT
20
#include <unordered_map>
21 22

#include "gtest/gtest.h"
Y
Yancey1989 已提交
23
#include "paddle/fluid/framework/block_desc.h"
Y
Yancey1989 已提交
24 25 26
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

W
Wu Yi 已提交
27
#include "paddle/fluid/operators/distributed/distributed.h"
28
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
29 30 31
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
32

33 34
namespace framework = paddle::framework;
namespace platform = paddle::platform;
35
namespace distributed = paddle::operators::distributed;
36

37
USE_NO_KERNEL_OP(lookup_sparse_table);
Y
Yancey1989 已提交
38

39 40
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
41

Y
Yancey1989 已提交
42 43 44
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
  auto root_block = program->MutableBlock(0);
  auto* block = program->AppendBlock(*root_block);
Y
Yancey1989 已提交
45 46 47 48

  framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
  framework::VariableNameMap output({{"Output", {"out"}}});
  auto op = block->AppendOp();
49
  op->SetType("lookup_sparse_table");
Y
Yancey1989 已提交
50 51 52
  op->SetInput("W", {"w"});
  op->SetInput("Ids", {"ids"});
  op->SetOutput("Out", {"out"});
Y
Yancey1989 已提交
53 54

  auto& out = *root_block->Var("out");
55
  out.SetType(framework::proto::VarType::LOD_TENSOR);
Y
Yancey1989 已提交
56 57
  out.SetShape({10, 10});

Y
Yancey1989 已提交
58 59 60
  return block;
}

Y
Yancey1989 已提交
61 62
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
  auto w_var = scope->Var("w");
Y
Yancey1989 已提交
63
  w_var->GetMutable<framework::SelectedRows>();
Y
Yancey1989 已提交
64

Y
Yancey1989 已提交
65
  auto out_var = scope->Var("out");
66
  out_var->GetMutable<framework::LoDTensor>();
Y
Yancey1989 已提交
67

Y
Yancey1989 已提交
68
  auto ids_var = scope->Var("ids");
69
  ids_var->GetMutable<framework::LoDTensor>();
Y
Yancey1989 已提交
70 71
}

Y
Yancey1989 已提交
72 73
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
                         int64_t rows_numel) {
Y
Yancey1989 已提交
74
  CreateVarsOnScope(scope, place);
75 76 77 78
  auto ids_var = scope->Var("ids")->GetMutable<framework::LoDTensor>();
  int64_t* ids_ptr =
      ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
  for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
Y
Yancey1989 已提交
79 80
}

Y
Yancey1989 已提交
81 82
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
                         int64_t rows_numel) {
Y
Yancey1989 已提交
83
  CreateVarsOnScope(scope, place);
Y
Yancey1989 已提交
84 85 86
  auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
  auto w_value = w->mutable_value();
  w_value->Resize({rows_numel, 10});
87
  for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true);
Y
Yancey1989 已提交
88 89 90 91

  auto ptr = w_value->mutable_data<float>(*place);

  for (int64_t i = 0; i < w_value->numel(); ++i) {
Y
Yancey1989 已提交
92 93 94
    ptr[i] = static_cast<float>(i / 10);
  }
}
Y
Yancey1989 已提交
95

Y
Yancey1989 已提交
96
void StartServer(const std::string& rpc_name) {
Y
Yancey1989 已提交
97 98 99 100 101
  framework::ProgramDesc program;
  framework::Scope scope;
  platform::CPUPlace place;
  framework::Executor exe(place);
  platform::CPUDeviceContext ctx(place);
Y
Yancey1989 已提交
102
  auto* block = AppendPrefetchBlcok(&program);
103 104 105
  std::string in_var_name("ids");
  std::vector<int> prefetch_block_ids{block->ID()};
  auto prepared = exe.Prepare(program, prefetch_block_ids);
Y
Yancey1989 已提交
106
  InitTensorsOnServer(&scope, &place, 10);
Y
Yancey1989 已提交
107

108 109 110 111
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>
      prefetch_var_name_to_prepared;
  prefetch_var_name_to_prepared[in_var_name] = prepared[0];
Y
Yancey1989 已提交
112

113
  g_req_handler->SetProgram(&program);
114
  g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
115 116 117 118
  g_req_handler->SetDevCtx(&ctx);
  g_req_handler->SetScope(&scope);
  g_req_handler->SetExecutor(&exe);

Y
Yancey1989 已提交
119
  g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
120 121 122

  distributed::HeartBeatMonitor::Init(2, true, "w@grad");

123 124 125
  g_req_handler->SetRPCServer(g_rpc_service.get());

  std::thread server_thread(
126
      std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
Y
Yancey1989 已提交
127

128
  server_thread.join();
129 130
}

131
TEST(PREFETCH, CPU) {
132 133
  setenv("http_proxy", "", 1);
  setenv("https_proxy", "", 1);
134
  g_req_handler.reset(new distributed::RequestPrefetchHandler(true));
G
gongweibao 已提交
135
  g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
136
  distributed::RPCClient* client =
W
Wu Yi 已提交
137
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
138

Y
Yancey1989 已提交
139
  std::thread server_thread(StartServer, distributed::kRequestPrefetch);
140 141 142 143 144
  g_rpc_service->WaitServerReady();

  int port = g_rpc_service->GetSelectedPort();
  std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);

145 146 147
  framework::Scope scope;
  platform::CPUPlace place;
  platform::CPUDeviceContext ctx(place);
148 149 150 151 152 153 154
  {
    // create var on local scope
    int64_t rows_numel = 5;
    InitTensorsOnClient(&scope, &place, rows_numel);
    std::string in_var_name("ids");
    std::string out_var_name("out");

G
gongweibao 已提交
155
    client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
W
Wu Yi 已提交
156
    client->Wait();
157
    auto var = scope.Var(out_var_name);
158 159
    auto value = var->GetMutable<framework::LoDTensor>();
    auto ptr = value->mutable_data<float>(place);
160 161

    for (int64_t i = 0; i < rows_numel; ++i) {
162
      EXPECT_EQ(ptr[0 + i * value->dims()[1]], static_cast<float>(i * 2));
163
    }
Y
Yancey1989 已提交
164
  }
165

W
Wu Yi 已提交
166
  g_rpc_service->ShutDown();
167 168 169 170
  server_thread.join();
  LOG(INFO) << "begin reset";
  g_rpc_service.reset(nullptr);
  g_req_handler.reset(nullptr);
171
}
Y
Yancey1989 已提交
172 173

TEST(COMPLETE, CPU) {
174 175
  setenv("http_proxy", "", 1);
  setenv("https_proxy", "", 1);
Y
Yancey1989 已提交
176 177 178
  g_req_handler.reset(new distributed::RequestSendHandler(true));
  g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
  distributed::RPCClient* client =
W
Wu Yi 已提交
179
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
Y
Yancey1989 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
  PADDLE_ENFORCE(client != nullptr);
  std::thread server_thread(StartServer, distributed::kRequestSend);
  g_rpc_service->WaitServerReady();
  int port = g_rpc_service->GetSelectedPort();
  std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
  client->AsyncSendComplete(ep);
  client->Wait();

  EXPECT_EQ(g_rpc_service->GetClientNum(), 1);

  g_rpc_service->ShutDown();
  server_thread.join();
  g_rpc_service.reset(nullptr);
  g_req_handler.reset(nullptr);
}