rpc_server_test.cc 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <stdlib.h>
16
#include <unistd.h>
17
#include <memory>
18
#include <string>
Y
Yancey1989 已提交
19
#include <thread>  // NOLINT
20
#include <unordered_map>
21 22

#include "gtest/gtest.h"
Y
Yancey1989 已提交
23
#include "paddle/fluid/framework/block_desc.h"
Y
Yancey1989 已提交
24 25 26
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

W
Wu Yi 已提交
27
#include "paddle/fluid/operators/distributed/distributed.h"
28 29 30
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
31

32 33
namespace framework = paddle::framework;
namespace platform = paddle::platform;
34
namespace distributed = paddle::operators::distributed;
35

36
USE_NO_KERNEL_OP(lookup_sparse_table);
Y
Yancey1989 已提交
37

38 39
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
40

Y
Yancey1989 已提交
41 42 43
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
  auto root_block = program->MutableBlock(0);
  auto* block = program->AppendBlock(*root_block);
Y
Yancey1989 已提交
44 45 46 47

  framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
  framework::VariableNameMap output({{"Output", {"out"}}});
  auto op = block->AppendOp();
48
  op->SetType("lookup_sparse_table");
Y
Yancey1989 已提交
49 50 51
  op->SetInput("W", {"w"});
  op->SetInput("Ids", {"ids"});
  op->SetOutput("Out", {"out"});
Y
Yancey1989 已提交
52 53

  auto& out = *root_block->Var("out");
54
  out.SetType(framework::proto::VarType::LOD_TENSOR);
Y
Yancey1989 已提交
55 56
  out.SetShape({10, 10});

Y
Yancey1989 已提交
57 58 59
  return block;
}

Y
Yancey1989 已提交
60 61
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
  auto w_var = scope->Var("w");
Y
Yancey1989 已提交
62
  w_var->GetMutable<framework::SelectedRows>();
Y
Yancey1989 已提交
63

Y
Yancey1989 已提交
64
  auto out_var = scope->Var("out");
65
  out_var->GetMutable<framework::LoDTensor>();
Y
Yancey1989 已提交
66

Y
Yancey1989 已提交
67
  auto ids_var = scope->Var("ids");
68
  ids_var->GetMutable<framework::LoDTensor>();
Y
Yancey1989 已提交
69 70
}

Y
Yancey1989 已提交
71 72
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
                         int64_t rows_numel) {
Y
Yancey1989 已提交
73
  CreateVarsOnScope(scope, place);
74 75 76 77
  auto ids_var = scope->Var("ids")->GetMutable<framework::LoDTensor>();
  int64_t* ids_ptr =
      ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
  for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
Y
Yancey1989 已提交
78 79
}

Y
Yancey1989 已提交
80 81
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
                         int64_t rows_numel) {
Y
Yancey1989 已提交
82
  CreateVarsOnScope(scope, place);
Y
Yancey1989 已提交
83 84 85
  auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
  auto w_value = w->mutable_value();
  w_value->Resize({rows_numel, 10});
86
  for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true);
Y
Yancey1989 已提交
87 88 89 90

  auto ptr = w_value->mutable_data<float>(*place);

  for (int64_t i = 0; i < w_value->numel(); ++i) {
Y
Yancey1989 已提交
91 92 93
    ptr[i] = static_cast<float>(i / 10);
  }
}
Y
Yancey1989 已提交
94

Y
Yancey1989 已提交
95
void StartServer(const std::string& rpc_name) {
Y
Yancey1989 已提交
96 97 98 99 100
  framework::ProgramDesc program;
  framework::Scope scope;
  platform::CPUPlace place;
  framework::Executor exe(place);
  platform::CPUDeviceContext ctx(place);
Y
Yancey1989 已提交
101
  auto* block = AppendPrefetchBlcok(&program);
102 103 104
  std::string in_var_name("ids");
  std::vector<int> prefetch_block_ids{block->ID()};
  auto prepared = exe.Prepare(program, prefetch_block_ids);
Y
Yancey1989 已提交
105
  InitTensorsOnServer(&scope, &place, 10);
Y
Yancey1989 已提交
106

107 108 109 110
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>
      prefetch_var_name_to_prepared;
  prefetch_var_name_to_prepared[in_var_name] = prepared[0];
Y
Yancey1989 已提交
111

112
  g_req_handler->SetProgram(&program);
113
  g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
114 115 116 117
  g_req_handler->SetDevCtx(&ctx);
  g_req_handler->SetScope(&scope);
  g_req_handler->SetExecutor(&exe);

Y
Yancey1989 已提交
118
  g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
119 120 121
  g_req_handler->SetRPCServer(g_rpc_service.get());

  std::thread server_thread(
122
      std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
Y
Yancey1989 已提交
123

124
  server_thread.join();
125 126
}

127
TEST(PREFETCH, CPU) {
128 129
  setenv("http_proxy", "", 1);
  setenv("https_proxy", "", 1);
130
  g_req_handler.reset(new distributed::RequestPrefetchHandler(true));
G
gongweibao 已提交
131
  g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
132
  distributed::RPCClient* client =
W
Wu Yi 已提交
133
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
134

Y
Yancey1989 已提交
135
  std::thread server_thread(StartServer, distributed::kRequestPrefetch);
136 137 138 139 140
  g_rpc_service->WaitServerReady();

  int port = g_rpc_service->GetSelectedPort();
  std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);

141 142 143
  framework::Scope scope;
  platform::CPUPlace place;
  platform::CPUDeviceContext ctx(place);
144 145 146 147 148 149 150
  {
    // create var on local scope
    int64_t rows_numel = 5;
    InitTensorsOnClient(&scope, &place, rows_numel);
    std::string in_var_name("ids");
    std::string out_var_name("out");

G
gongweibao 已提交
151
    client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
W
Wu Yi 已提交
152
    client->Wait();
153
    auto var = scope.Var(out_var_name);
154 155
    auto value = var->GetMutable<framework::LoDTensor>();
    auto ptr = value->mutable_data<float>(place);
156 157

    for (int64_t i = 0; i < rows_numel; ++i) {
158
      EXPECT_EQ(ptr[0 + i * value->dims()[1]], static_cast<float>(i * 2));
159
    }
Y
Yancey1989 已提交
160
  }
161

W
Wu Yi 已提交
162
  g_rpc_service->ShutDown();
163 164 165 166
  server_thread.join();
  LOG(INFO) << "begin reset";
  g_rpc_service.reset(nullptr);
  g_req_handler.reset(nullptr);
167
}
Y
Yancey1989 已提交
168 169

TEST(COMPLETE, CPU) {
170 171
  setenv("http_proxy", "", 1);
  setenv("https_proxy", "", 1);
Y
Yancey1989 已提交
172 173 174
  g_req_handler.reset(new distributed::RequestSendHandler(true));
  g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
  distributed::RPCClient* client =
W
Wu Yi 已提交
175
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
Y
Yancey1989 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
  PADDLE_ENFORCE(client != nullptr);
  std::thread server_thread(StartServer, distributed::kRequestSend);
  g_rpc_service->WaitServerReady();
  int port = g_rpc_service->GetSelectedPort();
  std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
  client->AsyncSendComplete(ep);
  client->Wait();

  EXPECT_EQ(g_rpc_service->GetClientNum(), 1);

  g_rpc_service->ShutDown();
  server_thread.join();
  g_rpc_service.reset(nullptr);
  g_req_handler.reset(nullptr);
}