heter_server.h 23.2 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
T
tangwei12 已提交
23
#include <unordered_set>
T
tangwei12 已提交
24
#include <vector>
25

T
tangwei12 已提交
26 27 28
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
29
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
30
#include "paddle/fluid/distributed/ps/service/heter_client.h"
31
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
32
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
33
#include "paddle/fluid/framework/blocking_queue.h"
T
tangwei12 已提交
34 35 36 37 38 39 40
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/profiler.h"
41
#include "paddle/phi/core/flags.h"
T
tangwei12 已提交
42

43 44 45 46 47 48 49 50 51 52
namespace google {
namespace protobuf {
class Closure;
class RpcController;
}  // namespace protobuf
}  // namespace google
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
53
class Scope;
54 55
}  // namespace framework
}  // namespace paddle
56
PHI_DECLARE_double(eager_delete_tensor_gb);
57 58 59
namespace paddle {
namespace distributed {

60 61
DECLARE_int32(pserver_timeout_ms);
DECLARE_int32(heter_world_size);
62 63
DECLARE_int32(switch_send_recv_timeout_s);

64 65
using MultiVarMsg = MultiVariableMessage;
using VarMsg = VariableMessage;
66

67 68 69 70
using serviceHandler =
    std::function<int32_t(const PsRequestMessage& request,
                          PsResponseMessage& response,  // NOLINT
                          brpc::Controller* cntl)>;
71 72
using HeterServiceHandler =
    std::function<int32_t(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>;
T
tangwei12 已提交
73

74
using HeterRpcCallbackFunc = std::function<void(void*)>;
T
tangwei12 已提交
75

76
class ServiceHandlerBase {
T
tangwei12 已提交
77
 public:
78
  ServiceHandlerBase() : dev_ctx_(nullptr), scope_(nullptr) {}
T
tangwei12 已提交
79

80
  virtual ~ServiceHandlerBase() {}
T
tangwei12 已提交
81

82 83
  void SetScope(const framework::Scope* scope) { scope_ = scope; }
  void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
T
tangwei12 已提交
84

85 86
  virtual int Handle(const MultiVarMsg* request,
                     MultiVarMsg* response,
87
                     brpc::Controller* cntl) = 0;
T
tangwei12 已提交
88

89 90 91
 protected:
  const platform::DeviceContext* dev_ctx_;
  const framework::Scope* scope_;
T
tangwei12 已提交
92 93
};

94 95
using SharedMiniScope =
    std::shared_ptr<std::unordered_map<int, ::paddle::framework::Scope*>>;
96

97
using SharedMicroScope = std::shared_ptr<std::unordered_map<
98 99
    int,
    std::shared_ptr<std::vector<::paddle::framework::Scope*>>>>;
100

101
using SharedTaskQueue = std::shared_ptr<
102 103 104
    std::unordered_map<int,
                       std::shared_ptr<::paddle::framework::BlockingQueue<
                           std::pair<std::string, int>>>>>;
105

106 107 108 109 110 111 112 113 114 115 116 117 118
class ValueInSwitch {
 public:
  ValueInSwitch() {}
  ~ValueInSwitch() {}
  char* data() { return _data.data(); }
  size_t size() { return _data.size(); }
  void resize(size_t size) { _data.resize(size); }
  void shrink_to_fit() { _data.shrink_to_fit(); }

 private:
  std::vector<char> _data;
};

119
class SendAndRecvVariableHandler final : public ServiceHandlerBase {
120
 public:
121
  SendAndRecvVariableHandler() {
122 123
    this->num_microbatch_ = 0;
    this->num_minibatch_ = 0;
124
    _local_shards.reset(new shard_type[FLAGS_heter_world_size]);
125 126
  }

127
  virtual ~SendAndRecvVariableHandler() {}
128

129 130 131 132 133
  void SetMiniScopes(SharedMiniScope mini_scopes) {
    mini_scopes_ = mini_scopes;
    num_minibatch_ = mini_scopes_->size();
  }

134 135
  void SetMicroScopes(SharedMicroScope micro_scopes) {
    micro_scopes_ = micro_scopes;
136 137 138 139 140 141 142 143 144 145 146
    for (auto& scope_pair : (*micro_scopes_)) {
      // auto mini_idx = scope_pair.first;
      auto& micro_scopes = scope_pair.second;
      num_microbatch_ = micro_scopes->size();
      break;
    }
  }

  int GetThreadNum() {
    std::unique_lock<std::mutex> lk(scope_mutex_);
    return (*task_queue_).size();
147 148
  }

149 150 151 152 153
  int SaveInSwitchWithScope(const MultiVarMsg* request,
                            PsResponseMessage* response,
                            brpc::Controller* cntl);

  void WaitForVarsConsumed(int32_t group_id, const std::string& var_name) {
Z
ziyoujiyi 已提交
154
    // timeline_.Start();
155
    while (true) {
Z
ziyoujiyi 已提交
156 157 158 159 160
      {
        std::lock_guard<std::mutex> lock(scope_mutex_);
        if (vars_ready_flag[group_id][var_name] == 0) {
          break;
        }
161
      }
Z
ziyoujiyi 已提交
162
      /*
163 164 165
      timeline_.Pause();
      if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
        VLOG(0) << "vars not consumed exceed 10 miniutes";
166 167
        break;
      }
Z
ziyoujiyi 已提交
168
      */
169 170 171 172 173
    }
    return;
  }

  void WaitForVarsProduced(int32_t group_id, const std::string& var_name) {
Z
ziyoujiyi 已提交
174
    // timeline_.Start();
175
    while (true) {
Z
ziyoujiyi 已提交
176 177 178 179 180
      {
        std::lock_guard<std::mutex> lock(scope_mutex_);
        if (vars_ready_flag[group_id][var_name] == 1) {
          break;
        }
181
      }
Z
ziyoujiyi 已提交
182
      /*
183 184 185 186 187
      timeline_.Pause();
      if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
        VLOG(0) << "vars not produced exceed 10 miniutes";
        break;
      }
Z
ziyoujiyi 已提交
188
      */
189 190 191 192 193 194 195 196
    }
    return;
  }

  int SaveInSwitchWithShard(const MultiVarMsg* request,
                            PsResponseMessage* response,
                            brpc::Controller* cntl);

197 198
  int QueryInSwitchWithShard(const MultiVarMsg* request,
                             MultiVarMsg* response,
199 200
                             brpc::Controller* cntl);

201 202
  int QueryInSwitchWithScope(const MultiVarMsg* request,
                             MultiVarMsg* response,
203 204
                             brpc::Controller* cntl);

205 206
  void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; }

207 208
  int Handle(const MultiVarMsg* request,
             MultiVarMsg* response,
209
             brpc::Controller* cntl) override {
210 211
    LOG(INFO) << "entered Handle";
    platform::RecordEvent record_event("SendAndRecvVariableHandler->Handle",
212 213
                                       platform::TracerEventType::Communication,
                                       1);
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
    FLAGS_eager_delete_tensor_gb = -1;

    // get microID from request
    // deserialize variable to micro scope
    // Push to heter worker's task_queue
    std::unique_ptr<paddle::framework::Scope> local_scope_ptr(
        new paddle::framework::Scope());
    auto& local_scope = *(local_scope_ptr.get());
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::CPUPlace cpu_place;
    auto& cpu_dev_ctx = *pool.Get(cpu_place);

    auto message_name = request->message_name();
    auto& request_io_buffer = cntl->request_attachment();

    distributed::DeserializeFromMultiVarMsgAndIOBuf(
        *request, &request_io_buffer, cpu_dev_ctx, &local_scope);

    auto* var = local_scope.FindVar("microbatch_id");
233 234
    PADDLE_ENFORCE_NE(var,
                      nullptr,
235 236
                      platform::errors::InvalidArgument(
                          "Not find variable microbatch_id in scope."));
237
    auto* tensor = var->GetMutable<phi::DenseTensor>();
238
    auto data = reinterpret_cast<const float*>(tensor->data());
239
    auto micro_id = static_cast<int>(data[0]);
240
    VLOG(4) << "micro_id in heter server: " << micro_id;
241 242 243
    int minibatch_index = micro_id / 10;
    int microbatch_index = micro_id % 10;

244 245 246 247
    // check minibatch_index is in mini_scopes_
    std::unique_lock<std::mutex> lk(scope_mutex_);
    if ((*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end()) {
      lk.unlock();
248

249
      PADDLE_ENFORCE_EQ(
250 251
          (*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(),
          1,
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
          platform::errors::InvalidArgument(
              "minibatch index should in current trainer"));

    } else {
      // create mini scope & micro scopes
      auto* minibatch_scope = &(scope_->NewScope());
      (*mini_scopes_)[minibatch_index] = minibatch_scope;
      (*micro_scopes_)[minibatch_index].reset(
          new std::vector<paddle::framework::Scope*>{});
      for (int i = 0; i < num_microbatch_; i++) {
        auto* micro_scope = &(minibatch_scope->NewScope());
        (*((*micro_scopes_)[minibatch_index])).push_back(micro_scope);
      }
      (*task_queue_)[minibatch_index].reset(
          new ::paddle::framework::BlockingQueue<
              std::pair<std::string, int>>());
      lk.unlock();
    }
270 271 272 273 274 275 276

    auto* micro_scope =
        (*((*micro_scopes_)[minibatch_index]))[microbatch_index];

    distributed::DeserializeFromMultiVarMsgAndIOBuf(
        *request, &request_io_buffer, *dev_ctx_, micro_scope);
    // blocking queue handles multi thread
277 278 279
    VLOG(4) << "Handle in HeterServer: " << message_name << ", "
            << microbatch_index;
    VLOG(4) << "task_queue_ size: " << task_queue_->size();
280 281
    (*task_queue_)[minibatch_index]->Push(
        std::make_pair(message_name, microbatch_index));
282

283 284 285 286 287 288 289
    auto response_var_nums = request->recv_var_names_size();
    std::vector<std::string> response_var_names(response_var_nums),
        empty_var_names{};
    for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) {
      response_var_names[var_idx] = request->recv_var_names(var_idx);
    }
    auto& response_io_buffer = cntl->response_attachment();
290 291 292 293 294 295 296
    distributed::SerializeToMultiVarMsgAndIOBuf(message_name,
                                                response_var_names,
                                                empty_var_names,
                                                *dev_ctx_,
                                                &local_scope,
                                                response,
                                                &response_io_buffer);
297
    VLOG(4) << "Handle over";
298 299 300
    return 0;
  }

301
 public:
302
  using shard_type = SparseTableShard<std::string, ValueInSwitch>;
303
  std::shared_ptr<paddle::framework::Scope> local_scope_ptr;  // for switch
304 305
  std::unordered_map<uint32_t, std::unordered_map<std::string, uint32_t>>
      vars_ready_flag;
306
  std::unique_ptr<shard_type[]> _local_shards;
307
  platform::Timer timeline_;
308

309 310
 private:
  // share with HeterPipelineTrainer
311
  SharedMiniScope mini_scopes_{nullptr};
312 313 314 315
  SharedMicroScope micro_scopes_{nullptr};

  int num_microbatch_;
  int num_minibatch_;
316
  std::mutex scope_mutex_;
317 318 319 320 321 322 323

  bool is_first_stage_ = false;
  bool is_last_stage_ = false;

  SharedTaskQueue task_queue_;
};

324 325 326 327
class HeterService : public PsService {
 public:
  HeterService() {
    _service_handler_map[PS_STOP_SERVER] =
328 329 330 331 332
        std::bind(&HeterService::stop_heter_worker,
                  this,
                  std::placeholders::_1,
                  std::placeholders::_2,
                  std::placeholders::_3);
333
    _service_handler_map[PS_START_PROFILER] =
334 335 336 337 338
        std::bind(&HeterService::start_profiler,
                  this,
                  std::placeholders::_1,
                  std::placeholders::_2,
                  std::placeholders::_3);
339
    _service_handler_map[PS_STOP_PROFILER] =
340 341 342 343 344
        std::bind(&HeterService::stop_profiler,
                  this,
                  std::placeholders::_1,
                  std::placeholders::_2,
                  std::placeholders::_3);
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377

    service_handler_.local_scope_ptr =
        std::make_shared<paddle::framework::Scope>();
  }

  virtual ~HeterService() {}

  virtual void service(::google::protobuf::RpcController* controller,
                       const PsRequestMessage* request,
                       PsResponseMessage* response,
                       ::google::protobuf::Closure* done) {
    brpc::ClosureGuard done_guard(done);

    response->set_err_code(0);
    response->set_err_msg("");
    brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
    auto itr = _service_handler_map.find(request->cmd_id());
    if (itr == _service_handler_map.end()) {
      std::string err_msg(
          "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
      err_msg.append(std::to_string(request->cmd_id()));
      return;
    }
    serviceHandler handler = itr->second;
    int service_ret = handler(*request, *response, cntl);
    VLOG(4) << "handler in service ret: " << service_ret;
    if (service_ret != 0) {
      response->set_err_code(service_ret);
      response->set_err_msg("server internal error");
    }
  }

  virtual void SendAndRecvVariable(
378 379 380 381
      ::google::protobuf::RpcController* controller,
      const MultiVarMsg* request,
      MultiVarMsg* response,
      ::google::protobuf::Closure* done) {
382 383 384 385 386 387 388 389 390
    // This object helps you to call done->Run() in RAII style. If you need
    // to process the request asynchronously, pass done_guard.release().
    brpc::ClosureGuard done_guard(done);
    std::string message_name = request->message_name();
    VLOG(0) << "SendAndRecvVariable message_name: " << message_name;
    auto itr = handler_map_.find(message_name);
    brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
    LOG(INFO) << "SendAndRecvVariable(client addr) =" << cntl->remote_side();
    PADDLE_ENFORCE_NE(
391 392
        itr,
        handler_map_.end(),
393 394 395 396 397 398 399 400 401 402
        platform::errors::InvalidArgument(
            "HeterService::SendAndRecvVariable Get illegal message_name: %s "
            "which is not in HeterService::handler_map_",
            message_name));
    itr->second(request, response, cntl);
    // We don't want to call done->Run() here, release the guard.
    // done_guard.release();
  }

  virtual void RecvFromSwitch(::google::protobuf::RpcController* controller,
403 404
                              const MultiVarMsg* request,
                              MultiVarMsg* response,
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
                              ::google::protobuf::Closure* done) {
    brpc::ClosureGuard done_guard(done);
    brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
    // int ret = service_handler_.QueryInSwitchWithScope(request, response,
    // cntl);
    int ret = service_handler_.QueryInSwitchWithShard(request, response, cntl);
    // std::string message_name = request->message_name();
    // auto itr = handler_map_.find(message_name);
    // int ret = itr->second(request, response, cntl);
    if (ret != 0) {
      LOG(ERROR) << "QueryInSwitchWithScope failed!";
    }
    // response->set_message_name(message_name);
  }

  virtual void SendToSwitch(::google::protobuf::RpcController* controller,
                            const MultiVarMsg* request,
                            PsResponseMessage* response,
                            ::google::protobuf::Closure* done) {
    VLOG(4) << "entering SendToSwitch";
    brpc::ClosureGuard done_guard(done);
Z
ziyoujiyi 已提交
426
    std::shared_ptr<HeterClient> switch_client_ptr_ =
427
        HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_SWITCH);
Z
ziyoujiyi 已提交
428 429
    if (switch_client_ptr_->peer_switch_channels_.empty()) {
      LOG(ERROR) << "switch_client_ptr_->peer_switch_channels_ null";
430
    }
Z
ziyoujiyi 已提交
431
    brpc::Channel* channel = switch_client_ptr_->peer_switch_channels_[0].get();
432 433 434 435 436 437 438 439
    brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
    // proxy: 定义新的 OnHeterRpcDone 对象(或者在类 OnHeterRpcDone 中 reset)
    OnHeterRpcDone* closure2 = new OnHeterRpcDone([](void* done) {
      auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
      int ret = closure->CheckResponse();
      closure->set_promise_value(ret);
      if (closure->cntl.Failed()) {
        PADDLE_ENFORCE_NE(
440 441
            closure->cntl.Failed(),
            true,
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
            platform::errors::Unimplemented(
                "HeterClient::SendS2S meets brpc error, error message is %s",
                closure->cntl.ErrorText()));
      }
    });
    auto& std_cntl = closure2->cntl;
    std_cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
    std_cntl.request_attachment().append(cntl->request_attachment().movable());

    auto promise = std::make_shared<std::promise<int32_t>>();
    closure2->add_promise(promise);
    std::future<int> fut = promise->get_future();
    // brpc::Controller std_cntl;
    // std_cntl.request_attachment().append(cntl->request_attachment().movable());
    PsService_Stub stub(channel);
    stub.SendS2S(&std_cntl, request, response, closure2);
    cntl->response_attachment().append(
        std_cntl.response_attachment().movable());
    fut.wait();
    VLOG(4) << "SendToSwitch done";
Z
ziyoujiyi 已提交
462
    delete closure2;
463 464 465
  }

  void SendS2S(::google::protobuf::RpcController* controller,
466 467
               const MultiVarMsg* request,
               PsResponseMessage* response,
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490
               ::google::protobuf::Closure* done) {
    VLOG(4) << "entering SendS2S";
    brpc::ClosureGuard done_guard(done);
    brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
    // int ret = service_handler_.SaveInSwitchWithScope(request, response,
    // cntl);
    int ret = service_handler_.SaveInSwitchWithShard(request, response, cntl);
    // std::string message_name = request->message_name();
    // auto itr = handler_map_.find(message_name);
    // if (itr == handler_map_.end()) {
    //    LOG(ERROR) << "can not find func handler";
    //}
    // int ret = itr->second(request, response, cntl);
    if (ret != 0) {
      LOG(ERROR) << "SaveInSwitchWithScope failed";
    }
    std::string err_msg = "ok";
    response->set_err_msg(err_msg.c_str());
    response->set_err_code(ret);
    VLOG(4) << "heter server SendS2S done";
  }

  void SendToWorker(::google::protobuf::RpcController* controller,
491 492
                    const MultiVarMsg* request,
                    PsResponseMessage* response,
493 494 495 496
                    ::google::protobuf::Closure* done) {
    brpc::ClosureGuard done_guard(done);
    brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
    VLOG(4) << "SendToWorker(client addr) =" << cntl->remote_side();
Z
ziyoujiyi 已提交
497
    std::shared_ptr<distributed::HeterClient> switch_client_ptr_ =
498 499
        HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_WORKER);
    VLOG(4) << "in switch client, peer worker 0: "
Z
ziyoujiyi 已提交
500 501
            << switch_client_ptr_->peer_worker_list_[0];
    brpc::Channel* channel = switch_client_ptr_->peer_worker_channels_[0].get();
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537

    auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
    PsService_Stub stub(channel);
    stub.SendAndRecvVariable(controller, request, &closure->response, done);
    // fill response content
    std::string err_msg("pass to worker");
    response->set_err_msg(err_msg.c_str());
    response->set_err_code(0);
  }

  void RegisterServiceHandler(std::string message_name,
                              HeterServiceHandler func) {
    handler_map_[message_name] = func;
  }

  void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }

  void SetInterEndpoint(const std::string& end_point) {
    endpoint_inter_ = end_point;
  }

  void SetPeerEndPoints(const std::vector<std::string>& peer_endpoints) {
    peer_endpoints_ = peer_endpoints;
  }

  void SetFanin(const int& fan_in) { fan_in_ = fan_in; }

  void ForceExit() {
    VLOG(3) << "heter service force exit";
    is_exit_ = true;
    return;
  }

  bool IsExit() { return is_exit_; }

 private:
538 539 540
  int32_t stop_profiler(const PsRequestMessage& request UNUSED,
                        PsResponseMessage& response UNUSED,  // NOLINT
                        brpc::Controller* cntl UNUSED) {
541 542 543 544 545 546
    platform::DisableProfiler(
        platform::EventSortingKey::kDefault,
        string::Sprintf("heter_worker_%s_profile", endpoint_));
    return 0;
  }

547 548 549
  int32_t start_profiler(const PsRequestMessage& request UNUSED,
                         PsResponseMessage& response UNUSED,  // NOLINT
                         brpc::Controller* cntl UNUSED) {
550 551 552 553 554
    platform::EnableProfiler(platform::ProfilerState::kAll);
    return 0;
  }

  int32_t stop_heter_worker(const PsRequestMessage& request,
555 556
                            PsResponseMessage& response UNUSED,  // NOLINT
                            brpc::Controller* cntl UNUSED) {
557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578
    auto client_id = request.client_id();
    stop_cpu_worker_set_.insert(client_id);
    if (stop_cpu_worker_set_.size() == fan_in_) {
      is_exit_ = true;
    }
    return 0;
  }

 private:
  SendAndRecvVariableHandler service_handler_;
  std::string endpoint_;
  std::string endpoint_inter_;
  // for switch
  std::vector<std::string> peer_endpoints_;

  std::unordered_map<int32_t, serviceHandler> _service_handler_map;
  std::unordered_map<std::string, HeterServiceHandler> handler_map_;
  std::unordered_set<int> stop_cpu_worker_set_;
  uint32_t fan_in_;
  bool is_exit_ = false;
};

T
tangwei12 已提交
579 580
class HeterServer {
 public:
581
  HeterServer() : ready_(0) {}
T
tangwei12 已提交
582 583
  virtual ~HeterServer() {}
  void Stop() {
T
tangwei12 已提交
584
    std::unique_lock<std::mutex> lock(mutex_);
585
    if (stoped_ == true) return;
586 587 588
    if (!IsExit()) {
      service_.ForceExit();
    }
T
tangwei12 已提交
589 590
    stoped_ = true;
    cv_.notify_all();
T
tangwei12 已提交
591 592 593 594
    server_.Stop(1000);
    server_.Join();
  }

595 596
  bool IsStop() {
    std::unique_lock<std::mutex> lock(mutex_);
597
    return stoped_;
598 599
  }

T
tangwei12 已提交
600 601 602 603 604
  bool IsExit() { return service_.IsExit(); }

  void RegisterServiceHandler(std::string message_name,
                              HeterServiceHandler func);

605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
  void StartHeterService(bool need_encrypt = false);

  void StartHeterInterService(bool need_encrypt = false);

  void SetEndPoint(const std::string& endpoint) {
    this->endpoint_ = endpoint;
    service_.SetEndpoint(endpoint);
  }

  void SetLocalScope() {
    request_handler_->local_scope_ptr =
        std::make_shared<paddle::framework::Scope>();
  }

  void SetInterEndpoint(const std::string& endpoint) {
    this->endpoint_inter_ = endpoint;
    service_.SetInterEndpoint(endpoint);
  }

  void SetPeerEndPoints(const std::vector<std::string>& peer_endpoints) {
    this->peer_endpoints_ = peer_endpoints;
    service_.SetPeerEndPoints(peer_endpoints);
  }
T
tangwei12 已提交
628

629 630
  void SetFanin(const int& fan_in);

631 632
  void SetServiceHandler(
      std::shared_ptr<SendAndRecvVariableHandler> request_handler) {
633 634 635
    request_handler_ = request_handler;
  }

636 637 638
  void SetMiniBatchScopes(SharedMiniScope mini_scopes) {
    request_handler_->SetMiniScopes(mini_scopes);
  }
639 640 641 642 643

  void SetMicroBatchScopes(SharedMicroScope micro_scopes) {
    request_handler_->SetMicroScopes(micro_scopes);
  }

644 645
  int GetThreadNum() { return request_handler_->GetThreadNum(); }

646 647 648
  void SetTaskQueue(SharedTaskQueue task_queue) {
    request_handler_->SetTaskQueue(task_queue);
  }
T
tangwei12 已提交
649 650 651

  // HeterWrapper singleton
  static std::shared_ptr<HeterServer> GetInstance() {
652
    std::unique_lock<std::mutex> lock(mtx_);
653
    if (s_instance_ == nullptr) {
654
      s_instance_.reset(new HeterServer());
T
tangwei12 已提交
655 656 657 658 659 660 661 662
    }
    return s_instance_;
  }

  void WaitServerReady();

 private:
  static std::shared_ptr<HeterServer> s_instance_;
T
tangwei12 已提交
663
  mutable std::mutex mutex_;
664
  static std::mutex mtx_;
T
tangwei12 已提交
665 666
  std::condition_variable cv_;
  std::condition_variable condition_ready_;
667
  bool stoped_ = true;
T
tangwei12 已提交
668
  std::string endpoint_;
669 670 671
  std::string endpoint_inter_;
  // for switch
  std::vector<std::string> peer_endpoints_;
T
tangwei12 已提交
672 673 674

 protected:
  brpc::Server server_;
675
  brpc::Server server_inter_;
T
tangwei12 已提交
676
  HeterService service_;
677
  std::shared_ptr<SendAndRecvVariableHandler> request_handler_;
678

T
tangwei12 已提交
679 680
  DISABLE_COPY_AND_ASSIGN(HeterServer);
  std::mutex mutex_ready_;
T
tangwei12 已提交
681

Z
zmxdream 已提交
682
  int ready_;
T
tangwei12 已提交
683 684 685 686
};

}  // end namespace distributed
}  // end namespace paddle