brpc_client.cc 16.2 KB
Newer Older
G
gongweibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/distributed/brpc/brpc_client.h"
G
gongweibao 已提交
16
#include "paddle/fluid/framework/threadpool.h"
W
Wu Yi 已提交
17
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
18
#include "paddle/fluid/platform/profiler.h"
G
gongweibao 已提交
19 20 21

namespace paddle {
namespace operators {
22
namespace distributed {
G
gongweibao 已提交
23 24 25 26 27 28

DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds");
DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)");

BRPCClient::~BRPCClient() { Wait(); }

29 30 31
void HandleSendResponse(brpc::Controller* cntl, sendrecv::VoidMessage* response,
                        VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
                        ChannelContextPtr ch_ctx, BRPCClient* cls) {
G
gongweibao 已提交
32 33 34 35
  // std::unique_ptr makes sure cntl/response will be deleted before returning.
  std::unique_ptr<brpc::Controller> cntl_guard(cntl);
  std::unique_ptr<sendrecv::VoidMessage> response_guard(response);

36 37 38
  // this channel can be used by other now.
  ch_ptr->Push(ch_ctx);

G
gongweibao 已提交
39
  if (cntl->Failed()) {
40 41 42
    PADDLE_THROW(platform::errors::Unavailable(
        "Failed to send variable %s, error text is %s.", var_h->name(),
        cntl->ErrorText()));
43 44
    var_h->Finish(false);
    cls->DecreaseReqCount();
G
gongweibao 已提交
45 46
    return;
  }
47 48 49 50 51 52 53
  var_h->Finish(true);
  cls->DecreaseReqCount();

  VLOG(4) << "HandleSendResponse from: " << cntl->remote_side()
          << ", varname: " << var_h->name()
          << ", latency: " << cntl->latency_us() << "us";
  VLOG(4) << "Finish HandleSendResponse";
G
gongweibao 已提交
54 55
}

56 57 58 59 60
VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
G
gongweibao 已提交
61 62 63 64 65
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
  const auto ch_ptr = GetChannel(ep_val);
66
  const std::string method = kSendRPC;
67 68 69 70 71 72 73
  VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));

  framework::AsyncIO([=] {
    auto ch_ctx = ch_ptr->Pop();
    brpc::Controller* cntl = new brpc::Controller();
    sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
    cntl->set_timeout_ms(time_out);
G
gongweibao 已提交
74

75 76 77 78 79
    auto* var = p_scope->FindVar(var_name_val);
    sendrecv::VariableMessage request;
    distributed::SerializeToIOBuf(var_name_val, var, *p_ctx, &request,
                                  &cntl->request_attachment(), "", false,
                                  trainer_id_);
G
gongweibao 已提交
80

81 82
    google::protobuf::Closure* done = brpc::NewCallback(
        &HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
G
gongweibao 已提交
83

84
    platform::RecordRPCEvent record_event(method);
85 86 87 88 89 90 91

    ch_ctx->stub->SendVariable(cntl, &request, response, done);

    if (UNLIKELY(platform::IsProfileEnabled())) {
      var_h->Wait();
    }
  });
G
gongweibao 已提交
92 93
  req_count_++;

94
  return var_h;
G
gongweibao 已提交
95
}
96 97 98 99 100 101 102 103 104 105
void HandleFetchBarrierResponse(brpc::Controller* cntl,
                                sendrecv::VariableMessage* response,
                                VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
                                ChannelContextPtr ch_ctx, BRPCClient* cls) {
  // std::unique_ptr makes sure cntl/response will be deleted before returning.
  std::unique_ptr<brpc::Controller> cntl_guard(cntl);
  std::unique_ptr<sendrecv::VariableMessage> response_guard(response);

  // this channel can be used other now.
  ch_ptr->Push(ch_ctx);
G
gongweibao 已提交
106

107
  if (cntl->Failed()) {
108 109 110
    PADDLE_THROW(platform::errors::Unavailable(
        "Failed to get HandleFetchBarrierResponse %s, error text is %s.",
        var_h->name(), cntl->ErrorText()));
111 112 113 114 115 116 117 118 119 120 121 122 123
    var_h->Finish(false);
    cls->DecreaseReqCount();
    return;
  }

  var_h->Finish(true);
  cls->DecreaseReqCount();

  VLOG(4) << "HandleFetchBarrierResponse from: " << cntl->remote_side()
          << ", varname: " << var_h->name()
          << ", latency: " << cntl->latency_us() << "us";
  VLOG(4) << "Finish HandleFetchBarrierResponse";
}
G
gongweibao 已提交
124
void HandleGetResponse(brpc::Controller* cntl,
125 126 127
                       sendrecv::VariableMessage* response, VarHandlePtr var_h,
                       ChannelQueuePtr ch_ptr, ChannelContextPtr ch_ctx,
                       BRPCClient* cls) {
G
gongweibao 已提交
128 129 130 131
  // std::unique_ptr makes sure cntl/response will be deleted before returning.
  std::unique_ptr<brpc::Controller> cntl_guard(cntl);
  std::unique_ptr<sendrecv::VariableMessage> response_guard(response);

132 133 134
  // this channel can be used other now.
  ch_ptr->Push(ch_ctx);

G
gongweibao 已提交
135
  if (cntl->Failed()) {
136 137 138
    PADDLE_THROW(platform::errors::Unavailable(
        "Failed to get variable %s, error text is %s.", var_h->name(),
        cntl->ErrorText()));
139 140
    cls->DecreaseReqCount();
    var_h->Finish(false);
G
gongweibao 已提交
141 142 143
    return;
  }

144 145 146 147 148 149 150 151 152 153 154 155
  VLOG(4) << "HandleGetResponse from: " << cntl->remote_side()
          << ", varname: " << var_h->name()
          << ", latency: " << cntl->latency_us() << "us";

  framework::Variable* outvar = nullptr;
  int trainer_id;
  distributed::DeserializeFromIOBuf(*response, cntl->response_attachment(),
                                    *var_h->ctx(), var_h->scope(), &outvar,
                                    &trainer_id);
  VLOG(4) << "Finish HandleGetResponse";
  cls->DecreaseReqCount();
  var_h->Finish(true);
G
gongweibao 已提交
156 157
}

158 159 160 161
VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
162
                                      const std::string& out_var_name,
163 164
                                      const std::string& method_name,
                                      int64_t time_out) {
G
gongweibao 已提交
165 166 167
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
168
  const std::string out_varname_val = out_var_name;
G
gongweibao 已提交
169
  const framework::Scope* p_scope = &scope;
170
  const auto ch_ptr = GetChannel(ep_val);
171 172 173
  const std::string method = kGetRPC;
  VarHandlePtr var_h(
      new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
174 175 176 177 178 179 180

  framework::AsyncIO([=] {
    auto ch_ctx = ch_ptr->Pop();

    brpc::Controller* cntl = new brpc::Controller();
    sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
    cntl->set_timeout_ms(time_out);
G
gongweibao 已提交
181

182 183
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
184
    req.set_out_varname(out_varname_val);
185 186 187 188 189
    req.set_trainer_id(trainer_id_);

    google::protobuf::Closure* done = brpc::NewCallback(
        &HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);

190
    platform::RecordRPCEvent record_event(method);
191

192
    if (method_name == kGetMonomerRPC) {
193
      ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done);
194 195
    } else if (method_name == kGetNoBarrierRPC) {
      ch_ctx->stub->GetVariableNoBarrier(cntl, &req, response, done);
196 197 198 199 200 201 202 203
    } else {
      ch_ctx->stub->GetVariable(cntl, &req, response, done);
    }

    if (UNLIKELY(platform::IsProfileEnabled())) {
      var_h->Wait();
    }
  });
G
gongweibao 已提交
204 205 206

  req_count_++;

207 208 209
  return var_h;
}

210 211 212 213 214 215 216 217 218 219 220
VarHandlePtr BRPCClient::AsyncGetVarNoBarrier(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    const std::string& out_var_name, int64_t time_out) {
  std::string var_name_no_barrier =
      string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);

  return _AsyncGetVar(ep, ctx, scope, var_name_no_barrier, out_var_name,
                      kGetNoBarrierRPC, time_out);
}

221 222 223 224
VarHandlePtr BRPCClient::AsyncGetMonomerVariable(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
225 226
  return _AsyncGetVar(ep, ctx, scope, var_name, var_name, kGetMonomerRPC,
                      time_out);
227 228 229 230 231
}

VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
                                                const std::string& var_name,
                                                int64_t time_out) {
232
  return AsyncSendMessage(ep, kSendMonomerFetchBarrierRPC, var_name, time_out);
G
gongweibao 已提交
233 234
}

235 236 237 238
VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
239
                                     const std::string& out_var_name,
Q
Qiao Longfei 已提交
240
                                     const std::string& table_name,
241
                                     int64_t time_out) {
242
  return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC,
Q
Qiao Longfei 已提交
243
                      time_out);
244 245 246 247 248 249 250 251 252
}

VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
                                          const std::string& table_name,
                                          int64_t time_out) {
G
gongweibao 已提交
253 254 255 256
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
257
  const std::string table_name_val = table_name;
G
gongweibao 已提交
258
  const framework::Scope* p_scope = &scope;
259 260
  const auto ch_ptr = GetChannel(ep_val);

261
  const std::string method = kPrefetchRPC;
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278

  VarHandlePtr var_h(
      new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));

  framework::AsyncIO([=] {
    auto ch_ctx = ch_ptr->Pop();

    brpc::Controller* cntl = new brpc::Controller();
    sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
    cntl->set_timeout_ms(time_out);

    auto* var = p_scope->FindVar(in_var_name_val);
    sendrecv::VariableMessage req;
    distributed::SerializeToIOBuf(in_var_name_val, var, *p_ctx, &req,
                                  &cntl->request_attachment(), out_var_name_val,
                                  false, 0, table_name_val);

279
    platform::RecordRPCEvent record_event(method);
280 281 282

    google::protobuf::Closure* done = brpc::NewCallback(
        &HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
G
gongweibao 已提交
283

284 285 286 287 288 289
    ch_ctx->stub->PrefetchVariable(cntl, &req, response, done);

    if (UNLIKELY(platform::IsProfileEnabled())) {
      var_h->Wait();
    }
  });
G
gongweibao 已提交
290 291

  req_count_++;
292
  return var_h;
G
gongweibao 已提交
293 294
}

295 296
VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
297
  return AsyncSendMessage(ep, kBatchBarrierRPC, BATCH_BARRIER_MESSAGE,
298
                          time_out);
G
gongweibao 已提交
299 300
}

301 302 303 304 305 306 307 308 309 310 311 312
VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
  auto ch_ptr = GetChannel(ep);
  auto ch_ctx = ch_ptr->Pop();

  brpc::Controller* cntl = new brpc::Controller();
  sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
  cntl->set_timeout_ms(time_out);

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);

313
  const std::string method = kFetchBarrierRPC;
314 315 316 317
  // var handle
  VarHandlePtr var_h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));

318
  platform::RecordRPCEvent record_event(method);
319 320 321 322 323 324

  google::protobuf::Closure* done = brpc::NewCallback(
      &HandleFetchBarrierResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);

  ch_ctx->stub->GetVariable(cntl, &req, response, done);

G
gongweibao 已提交
325
  req_count_++;
326 327 328 329 330 331

  if (UNLIKELY(platform::IsProfileEnabled())) {
    var_h->Wait();
  }

  return var_h;
G
gongweibao 已提交
332 333
}

334 335 336 337 338 339 340 341
bool BRPCClient::Wait() {
  VLOG(9) << "begin to brpcclient wait";
  {
    std::unique_lock<std::mutex> lk(sync_mutex_);
    sync_cond_.wait(lk, [this] { return req_count_ == 0; });
  }
  VLOG(9) << "end to brpcclient wait";
  return true;
G
gongweibao 已提交
342 343 344
}

ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
345
  VLOG(4) << "begin to GetChannel:" << ep;
G
gongweibao 已提交
346 347 348 349
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    auto it = channels_.find(ep);
    if (it != channels_.end()) {
350
      VLOG(4) << "end to GetChannel:" << ep;
G
gongweibao 已提交
351 352 353 354 355 356 357
      return it->second;
    }
  }

  ChannelQueuePtr q(new framework::BlockingQueue<ChannelContextPtr>());

  brpc::ChannelOptions options;
358 359 360
#ifdef PADDLE_WITH_BRPC_RDMA
  options.use_rdma = true;
#endif
G
gongweibao 已提交
361
  options.protocol = "baidu_std";
362 363 364
  // don't use pooled type. the server can't afford that.
  options.connection_type = "single";
  options.connect_timeout_ms = 1000;
G
gongweibao 已提交
365 366
  options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/;
  options.max_retry = FLAGS_max_retry;
367 368 369 370 371

  VLOG(1) << "create " << brpc_channel_num_per_server_
          << " brpc channels to pserver:" << ep;

  for (int i = 0; i < brpc_channel_num_per_server_; ++i) {
G
gongweibao 已提交
372 373
    std::shared_ptr<ChannelContext> c(new ChannelContext());
    if (c->channel.Init(ep.c_str(), &options) != 0) {
374 375
      PADDLE_THROW(
          platform::errors::Unavailable("Failed to initialize channel."));
G
gongweibao 已提交
376 377 378 379 380 381 382 383 384 385 386 387 388
      return nullptr;
    }

    c->stub.reset(new sendrecv::SendRecvService_Stub(
        static_cast<google::protobuf::RpcChannel*>(&c->channel)));
    q->Push(c);
  }

  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    channels_[ep] = q;
  }

389
  VLOG(4) << "end to GetChannel:" << ep;
G
gongweibao 已提交
390 391 392
  return q;
}

393 394
VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
395
  return AsyncSendMessage(ep, kSendCompleteRPC, COMPLETE_MESSAGE, time_out);
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
}

void BRPCClient::SendComplete() {
  for (auto& kv : channels_) {
    AsyncSendComplete(kv.first);
  }
}

VarHandlePtr BRPCClient::AsyncSendVarMessage(
    const std::string& ep, const std::string& method_name,
    const sendrecv::VariableMessage& req, int64_t time_out) {
  auto ch_ptr = GetChannel(ep);
  auto ch_ctx = ch_ptr->Pop();

  brpc::Controller* cntl = new brpc::Controller();
  sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
  cntl->set_timeout_ms(time_out);

414
  platform::RecordRPCEvent record_event(method_name);
415 416 417 418 419 420 421

  VarHandlePtr var_h(
      new VarHandle(ep, method_name, req.varname(), nullptr, nullptr));

  google::protobuf::Closure* done = brpc::NewCallback(
      &HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);

422
  if (method_name == kCheckPointNotifyRPC) {
423
    ch_ctx->stub->CheckpointNotify(cntl, &req, response, done);
424
  } else if (method_name == kSendMonomerFetchBarrierRPC) {
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
    ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done);
  } else {
    ch_ctx->stub->SendVariable(cntl, &req, response, done);
  }
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    var_h->Wait();
  }

  return var_h;
}

VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
                                          const std::string& method_name,
                                          const std::string& message,
                                          int64_t time_out) {
  sendrecv::VariableMessage req;
  req.set_varname(message);

  return AsyncSendVarMessage(ep, method_name, req, time_out);
}

VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
449 450
                                               const std::string& dirname,
                                               const std::string& varname,
451
                                               const int mode,
452 453
                                               int64_t time_out) {
  sendrecv::VariableMessage req;
454 455
  req.set_varname(varname);
  req.set_out_varname(dirname);
456 457 458 459

  return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out);
}

460
}  // namespace distributed
G
gongweibao 已提交
461 462
}  // namespace operators
}  // namespace paddle