brpc_server.cc 16.5 KB
Newer Older
G
gongweibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/distributed/brpc/brpc_server.h"
1
123malin 已提交
16 17
#include <memory>
#include <unordered_map>
18
#include "paddle/fluid/framework/threadpool.h"
W
Wu Yi 已提交
19 20
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
21
#include "paddle/fluid/operators/distributed/request_handler.h"
G
gongweibao 已提交
22 23 24

namespace sendrecv {

25 26 27
namespace distributed = paddle::operators::distributed;

typedef std::unordered_map<std::string, distributed::RequestHandler*>
G
gongweibao 已提交
28 29 30 31
    HandlerMap;

class BRPCServiceImpl : public SendRecvService {
 public:
32 33 34 35 36
  explicit BRPCServiceImpl(const HandlerMap& rpc_call_map,
                           distributed::RPCServer* rpc_server)
      : rpc_server_(rpc_server) {
    VLOG(3) << "BRPCServiceImpl size: " << rpc_call_map.size();
    auto it = rpc_call_map.find(distributed::kRequestSend);
G
gongweibao 已提交
37 38
    if (it != rpc_call_map.end()) {
      request_send_h_ = it->second;
39 40
      send_threads_.reset(new paddle::framework::ThreadPool(
          rpc_server_->GetThreadNum(distributed::kRequestSend)));
G
gongweibao 已提交
41 42
    }

43
    it = rpc_call_map.find(distributed::kRequestGet);
G
gongweibao 已提交
44 45
    if (it != rpc_call_map.end()) {
      request_get_h_ = it->second;
46 47
      get_threads_.reset(new paddle::framework::ThreadPool(
          rpc_server_->GetThreadNum(distributed::kRequestGet)));
G
gongweibao 已提交
48 49
    }

50 51 52 53 54 55 56
    it = rpc_call_map.find(distributed::kRequestGetNoBarrier);
    if (it != rpc_call_map.end()) {
      request_getnobarrier_h_ = it->second;
      getnobarrier_threads_.reset(new paddle::framework::ThreadPool(
          rpc_server_->GetThreadNum(distributed::kRequestGetNoBarrier)));
    }

57
    it = rpc_call_map.find(distributed::kRequestPrefetch);
G
gongweibao 已提交
58 59
    if (it != rpc_call_map.end()) {
      request_prefetch_h_ = it->second;
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
      prefetch_threads_.reset(new paddle::framework::ThreadPool(
          rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
    }

    it = rpc_call_map.find(distributed::kRequestCheckpoint);
    if (it != rpc_call_map.end()) {
      request_checkpoint_h_ = it->second;
      checkpoint_notify_threads_.reset(new paddle::framework::ThreadPool(
          rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
    }

    it = rpc_call_map.find(distributed::kRequestGetMonomerVariable);
    if (it != rpc_call_map.end()) {
      request_get_monomer_handler_h_ = it->second;
    }

    it = rpc_call_map.find(distributed::kRequestGetMonomerBarrier);
    if (it != rpc_call_map.end()) {
      request_get_monomer_barrier_handler_h_ = it->second;
G
gongweibao 已提交
79 80 81 82 83 84 85
    }
  }

  virtual ~BRPCServiceImpl() {}
  void SendVariable(google::protobuf::RpcController* cntl_butil,
                    const VariableMessage* request, VoidMessage* response,
                    google::protobuf::Closure* done) override {
86 87 88 89 90 91 92
    send_threads_->Run(
        [=] { _SendVariable(cntl_butil, request, response, done); });
  }

  void _SendVariable(google::protobuf::RpcController* cntl_butil,
                     const VariableMessage* request, VoidMessage* response,
                     google::protobuf::Closure* done) {
M
MRXLT 已提交
93 94 95
    PADDLE_ENFORCE_NOT_NULL(
        request_send_h_, platform::errors::PreconditionNotMet(
                             "RequestSend handler should be registed first!"));
G
gongweibao 已提交
96
    brpc::ClosureGuard done_guard(done);
97
    brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
G
gongweibao 已提交
98 99

    std::string varname = request->varname();
100 101 102
    VLOG(3) << "RequestSend var_name:" << varname
            << ", trainer_id:" << request->trainer_id()
            << ", from:" << cntl->remote_side();
G
gongweibao 已提交
103

104 105
    distributed::BRPCVariableResponse resp(request_send_h_->scope(),
                                           request_send_h_->dev_ctx(),
1
123malin 已提交
106
                                           request_send_h_->distributed_mode());
M
MRXLT 已提交
107 108 109
    PADDLE_ENFORCE_EQ(
        resp.Parse(cntl->request_attachment(), *request), 0,
        platform::errors::InvalidArgument("parse iobuf to tensor error!"));
G
gongweibao 已提交
110

111 112 113 114
    auto scope = resp.GetMutableLocalScope();
    auto invar = resp.GetVar();
    int trainer_id = request->trainer_id();
    paddle::framework::Variable* outvar = nullptr;
G
gongweibao 已提交
115

116
    request_send_h_->Handle(varname, scope, invar, &outvar, trainer_id);
G
gongweibao 已提交
117 118 119 120 121
  }

  void GetVariable(google::protobuf::RpcController* cntl_butil,
                   const VariableMessage* request, VariableMessage* response,
                   google::protobuf::Closure* done) override {
122 123 124 125
    get_threads_->Run(
        [=] { _GetVariable(cntl_butil, request, response, done); });
  }

126 127 128 129 130 131 132 133
  void GetVariableNoBarrier(google::protobuf::RpcController* cntl_butil,
                            const VariableMessage* request,
                            VariableMessage* response,
                            google::protobuf::Closure* done) override {
    getnobarrier_threads_->Run(
        [=] { _GetVariableNoBarrier(cntl_butil, request, response, done); });
  }

134 135 136
  void _GetVariable(google::protobuf::RpcController* cntl_butil,
                    const VariableMessage* request, VariableMessage* response,
                    google::protobuf::Closure* done) {
M
MRXLT 已提交
137 138 139
    PADDLE_ENFORCE_NOT_NULL(
        request_get_h_, platform::errors::PreconditionNotMet(
                            "RequestGet handler should be registed first!"));
G
gongweibao 已提交
140

141 142 143 144
    brpc::ClosureGuard done_guard(done);
    brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);

    std::string varname = request->varname();
145
    std::string out_varname = request->out_varname();
146
    VLOG(3) << "RequestGet varname:" << varname
147
            << ", out_varname:" << out_varname
148 149 150 151
            << ", trainer_id:" << request->trainer_id()
            << ", from:" << cntl->remote_side();

    auto scope = request_get_h_->scope();
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
    paddle::framework::Variable* invar = nullptr;
    int trainer_id = request->trainer_id();
    paddle::framework::Variable* outvar = nullptr;

    request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id,
                           out_varname);

    if (outvar) {
      distributed::SerializeToIOBuf(out_varname, outvar,
                                    *request_get_h_->dev_ctx(), response,
                                    &cntl->response_attachment(), "", false);
    }
  }

  void _GetVariableNoBarrier(google::protobuf::RpcController* cntl_butil,
                             const VariableMessage* request,
                             VariableMessage* response,
                             google::protobuf::Closure* done) {
M
MRXLT 已提交
170 171 172 173
    PADDLE_ENFORCE_NOT_NULL(
        request_getnobarrier_h_,
        platform::errors::PreconditionNotMet(
            "RequestGetNoBarrier handler should be registed first!"));
174 175 176 177 178 179

    brpc::ClosureGuard done_guard(done);
    brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);

    std::string varname = request->varname();
    std::string out_varname = request->out_varname();
180
    int trainer_id = request->trainer_id();
181 182 183 184 185 186 187

    VLOG(3) << "RequestGetNoBarrier varname:" << varname
            << ", out_varname:" << out_varname << ", trainer_id:" << trainer_id
            << ", from:" << cntl->remote_side();

    auto scope = request_getnobarrier_h_->scope();
    paddle::framework::Variable* invar = nullptr;
188 189
    paddle::framework::Variable* outvar = nullptr;

190 191
    request_getnobarrier_h_->Handle(varname, scope, invar, &outvar, trainer_id,
                                    out_varname);
192 193

    if (outvar) {
194 195 196
      distributed::SerializeToIOBuf(
          out_varname, outvar, *request_getnobarrier_h_->dev_ctx(), response,
          &cntl->response_attachment(), "", false);
197 198
    }
  }
199

G
gongweibao 已提交
200 201 202 203
  void PrefetchVariable(google::protobuf::RpcController* cntl_butil,
                        const VariableMessage* request,
                        VariableMessage* response,
                        google::protobuf::Closure* done) override {
204 205 206 207 208 209 210 211
    prefetch_threads_->Run(
        [=] { _PrefetchVariable(cntl_butil, request, response, done); });
  }

  void _PrefetchVariable(google::protobuf::RpcController* cntl_butil,
                         const VariableMessage* request,
                         VariableMessage* response,
                         google::protobuf::Closure* done) {
M
MRXLT 已提交
212 213 214
    PADDLE_ENFORCE_NOT_NULL(request_prefetch_h_,
                   platform::errors::PreconditionNotMet(
                       "kRequestPrefetch handler should be registed first!");
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229

    brpc::ClosureGuard done_guard(done);
    brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);

    // prefetch process...
    std::string in_var_name = request->varname();
    std::string out_var_name = request->out_varname();
    VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
            << ", out_var_name: " << out_var_name
            << ", trainer_id:" << request->trainer_id()
            << ", from:" << cntl->remote_side();

    distributed::BRPCVariableResponse resp(
        request_prefetch_h_->scope(), request_prefetch_h_->dev_ctx(), true);

M
MRXLT 已提交
230 231 232
    PADDLE_ENFORCE_EQ(resp.Parse(cntl->request_attachment(), *request), 0,
                   platform::errors::InvalidArgument(
                       "parse iobuf to tensor error!"));
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257

    auto scope = resp.GetMutableLocalScope();
    auto invar = scope->FindVar(in_var_name);
    std::string table_name = request->table_name();
    int trainer_id = request->trainer_id();
    paddle::framework::Variable* outvar = scope->Var(out_var_name);

    request_prefetch_h_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
                                out_var_name, table_name);

    distributed::SerializeToIOBuf(out_var_name, outvar,
                                  *request_prefetch_h_->dev_ctx(), response,
                                  &cntl->response_attachment(), "", true);
  }

  void CheckpointNotify(google::protobuf::RpcController* cntl_butil,
                        const VariableMessage* request, VoidMessage* response,
                        google::protobuf::Closure* done) override {
    checkpoint_notify_threads_->Run(
        [=] { _CheckpointNotify(cntl_butil, request, response, done); });
  }

  void _CheckpointNotify(google::protobuf::RpcController* cntl_butil,
                         const VariableMessage* request, VoidMessage* response,
                         google::protobuf::Closure* done) {
M
MRXLT 已提交
258 259 260 261
    PADDLE_ENFORCE_NOT_NULL(
        request_checkpoint_h_,
        platform::errors::PreconditionNotMet(
            "kRequestCheckpointNotify handler should be registed first!"));
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287

    brpc::ClosureGuard done_guard(done);
    brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);

    distributed::BRPCVariableResponse resp(request_checkpoint_h_->scope(),
                                           request_checkpoint_h_->dev_ctx());

    auto scope = resp.GetMutableLocalScope();

    std::string checkpoint_notify = request->varname();
    std::string checkpoint_dir = request->out_varname();
    int trainer_id = request->trainer_id();

    VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
            << ", dir: " << checkpoint_dir
            << ", trainer_id:" << request->trainer_id()
            << ", from:" << cntl->remote_side();

    request_checkpoint_h_->Handle(checkpoint_notify, scope, nullptr, nullptr,
                                  trainer_id, checkpoint_dir);
  }

  void GetMonomerVariable(google::protobuf::RpcController* cntl_butil,
                          const VariableMessage* request,
                          VariableMessage* response,
                          google::protobuf::Closure* done) override {
M
MRXLT 已提交
288 289 290 291
    PADDLE_ENFORCE_NOT_NULL(
        request_get_monomer_handler_h_,
        platform::errors::PreconditionNotMet(
            "kRequestGetMonomerVariable handler should be registed first!"));
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320

    brpc::ClosureGuard done_guard(done);
    brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);

    // proc request.
    std::string varname = request->varname();
    VLOG(3) << "GetMonomerVariable " << varname
            << ", trainer_id:" << request->trainer_id()
            << ", from:" << cntl->remote_side();

    rpc_server_->WaitVarCond(varname);
    distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);

    auto scope = h.scope_;
    auto invar = scope->FindVar(varname);
    paddle::framework::Variable* outvar = nullptr;

    request_get_monomer_handler_h_->Handle(varname, scope, invar, &outvar,
                                           request->trainer_id());

    if (outvar) {
      distributed::SerializeToIOBuf(varname, outvar, *h.dev_ctx_, response,
                                    &cntl->response_attachment(), "", false);
    }
  }

  void GetMonomerBarrier(google::protobuf::RpcController* cntl_butil,
                         const VariableMessage* request, VoidMessage* response,
                         google::protobuf::Closure* done) override {
M
MRXLT 已提交
321 322 323 324
    PADDLE_ENFORCE_NOT_NULL(
        request_get_monomer_barrier_handler_h_,
        platform::errors::PreconditionNotMet(
            "RequestGetMonomerBarrier handler should be registed first!"));
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342

    brpc::ClosureGuard done_guard(done);
    brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);

    std::string varname = request->varname();
    VLOG(3) << "RequestGetMonomerBarrier var_name:" << varname
            << ", trainer_id:" << request->trainer_id()
            << ", from:" << cntl->remote_side();

    rpc_server_->WaitVarCond(varname);
    distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);

    paddle::framework::Scope* scope = nullptr;
    paddle::framework::Variable* invar = nullptr;
    paddle::framework::Variable* outvar = nullptr;

    request_get_monomer_barrier_handler_h_->Handle(
        varname, scope, invar, &outvar, request->trainer_id());
G
gongweibao 已提交
343 344 345
  }

 private:
346 347
  distributed::RequestHandler* request_send_h_{nullptr};
  distributed::RequestHandler* request_get_h_{nullptr};
348
  distributed::RequestHandler* request_getnobarrier_h_{nullptr};
349 350 351 352 353 354 355
  distributed::RequestHandler* request_prefetch_h_{nullptr};
  distributed::RequestHandler* request_checkpoint_h_{nullptr};
  distributed::RequestHandler* request_get_monomer_handler_h_{nullptr};
  distributed::RequestHandler* request_get_monomer_barrier_handler_h_{nullptr};

  distributed::RPCServer* rpc_server_{nullptr};

356
  // FIXME(gongwb): brpc should support process one rpc use one threadpool.
357 358
  std::unique_ptr<paddle::framework::ThreadPool> send_threads_;
  std::unique_ptr<paddle::framework::ThreadPool> get_threads_;
359
  std::unique_ptr<paddle::framework::ThreadPool> getnobarrier_threads_;
360 361
  std::unique_ptr<paddle::framework::ThreadPool> prefetch_threads_;
  std::unique_ptr<paddle::framework::ThreadPool> checkpoint_notify_threads_;
G
gongweibao 已提交
362 363 364 365 366
};
}  // namespace sendrecv

namespace paddle {
namespace operators {
367
namespace distributed {
G
gongweibao 已提交
368 369 370

void AsyncBRPCServer::StartServer() {
  // Instance of your service.
371
  sendrecv::BRPCServiceImpl service_impl(rpc_call_map_, this);
G
gongweibao 已提交
372 373 374 375 376

  // Add the service into server. Notice the second parameter, because the
  // service is put on stack, we don't want server to delete it, otherwise
  // use brpc::SERVER_OWNS_SERVICE.
  if (server_.AddService(&service_impl, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
377 378
    PADDDLE_THROW(platform::errors::Unavailable(
        "Failed to add service into BRPC server."));
G
gongweibao 已提交
379 380 381 382
    return;
  }

  brpc::ServerOptions options;
383 384 385
#ifdef PADDLE_WITH_BRPC_RDMA
  options.use_rdma = true;
#endif
G
gongweibao 已提交
386 387 388
  options.idle_timeout_sec = idle_timeout_s_;
  options.max_concurrency = max_concurrency_;
  if (server_.Start(bind_address_.c_str(), &options) != 0) {
389 390
    PADDDLE_THROW(platform::errors::Unavailable(
        "Failed to start EchoServer %s.", bind_address_));
G
gongweibao 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
    return;
  }

  butil::EndPoint ep = server_.listen_address();
  selected_port_ = ep.port;

  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();

  server_.Join();
}

void AsyncBRPCServer::ShutDownImpl() { server_.Stop(1000); }

void AsyncBRPCServer::WaitServerReady() {
M
minqiyang 已提交
409
  VLOG(3) << "AsyncGRPCServer is wait server ready";
G
gongweibao 已提交
410 411
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
M
minqiyang 已提交
412
  VLOG(3) << "AsyncGRPCServer WaitSeverReady";
G
gongweibao 已提交
413 414
}

415
};  // namespace distributed
G
gongweibao 已提交
416 417
};  // namespace operators
};  // namespace paddle