test_send_nccl_id.cc 3.4 KB
Newer Older
T
typhoonzero 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <unistd.h>
#include <string>
#include <thread>  // NOLINT

#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
G
gongweibao 已提交
23
#include "paddle/fluid/operators/detail/macros.h"
24
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
T
typhoonzero 已提交
25 26 27
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
Y
yi.wu 已提交
28
#include "paddle/fluid/platform/nccl_helper.h"
T
typhoonzero 已提交
29 30
#include "paddle/fluid/string/printf.h"

G
gongweibao 已提交
31 32 33 34
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/send_recv_util.h"
#endif

T
typhoonzero 已提交
35 36 37 38 39
USE_NO_KERNEL_OP(listen_and_serv);

namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
40
namespace distributed = paddle::operators::distributed;
T
typhoonzero 已提交
41 42
namespace string = paddle::string;

43 44
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
T
typhoonzero 已提交
45

46
void StartServer() {
T
typhoonzero 已提交
47 48
  f::Scope scope;
  p::CPUPlace place;
T
typhoonzero 已提交
49
  scope.Var(NCCL_ID_VARNAME);
T
typhoonzero 已提交
50 51 52 53 54
  p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
  auto& dev_ctx = *pool.Get(p::CPUPlace());

  f::ProgramDesc empty_program;
  f::Executor executor(dev_ctx.GetPlace());
55 56 57 58 59
  g_req_handler->SetScope(&scope);
  g_req_handler->SetDevCtx(&dev_ctx);
  g_req_handler->SetProgram(&empty_program);
  g_req_handler->SetExecutor(&executor);

60
  g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get());
61
  g_req_handler->SetRPCServer(g_rpc_service.get());
T
typhoonzero 已提交
62 63

  std::thread server_thread(
64
      std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
65

66 67
  g_rpc_service->SetCond(distributed::kRequestSend);
  g_rpc_service->WaitBarrier(distributed::kRequestSend);
68

T
typhoonzero 已提交
69
  LOG(INFO) << "got nccl id and stop server...";
70
  g_rpc_service->ShutDown();
T
typhoonzero 已提交
71 72 73
  server_thread.join();
}

G
gongweibao 已提交
74
TEST(SendNcclId, RPCServer) {
75
  g_req_handler.reset(new distributed::RequestSendHandler(true));
G
gongweibao 已提交
76
  g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
77 78 79

  std::thread server_thread(StartServer);
  g_rpc_service->WaitServerReady();
T
typhoonzero 已提交
80 81 82 83 84 85

  f::Scope scope;
  p::CPUPlace place;
  p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
  auto& dev_ctx = *pool.Get(p::CPUPlace());

T
typhoonzero 已提交
86
  auto var = scope.Var(NCCL_ID_VARNAME);
T
typhoonzero 已提交
87 88 89
  auto id = var->GetMutable<ncclUniqueId>();
  p::dynload::ncclGetUniqueId(id);

90 91
  int port = g_rpc_service->GetSelectedPort();

T
typhoonzero 已提交
92
  std::string ep = string::Sprintf("127.0.0.1:%d", port);
G
gongweibao 已提交
93

94 95
  distributed::RPCClient* client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>();
G
gongweibao 已提交
96

G
gongweibao 已提交
97 98
  LOG(INFO) << "connect to server" << ep;
  client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME);
W
Wu Yi 已提交
99 100 101
  client->Wait();
  client->AsyncSendBatchBarrier(ep);
  client->Wait();
102

T
typhoonzero 已提交
103
  server_thread.join();
104 105
  g_rpc_service.reset(nullptr);
  g_req_handler.reset(nullptr);
T
typhoonzero 已提交
106
}