communicator.h 22.2 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <ThreadPool.h>
#include <stdint.h>
19

T
tangwei12 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "gflags/gflags.h"
33
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
34
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
35
#include "paddle/fluid/distributed/ps/service/ps_client.h"
Z
zhaocaibei123 已提交
36
#include "paddle/fluid/framework/channel.h"
T
tangwei12 已提交
37 38 39 40 41 42 43 44
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
45 46
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
T
tangwei12 已提交
47

48 49 50 51 52 53 54
namespace paddle {
namespace distributed {
class PSClient;
struct CommContext;
}  // namespace distributed
}  // namespace paddle

T
tangwei12 已提交
55 56 57 58 59 60 61 62 63 64 65
DECLARE_bool(communicator_is_sgd_optimizer);

namespace paddle {
namespace distributed {

using Scope = framework::Scope;
using Variable = framework::Variable;

template <typename T>
class BlockingQueue {
 public:
S
seemingwang 已提交
66
  explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
67 68
    PADDLE_ENFORCE_GT(capacity_,
                      0,
S
seemingwang 已提交
69 70 71
                      platform::errors::InvalidArgument(
                          "The capacity must be greater than 0."));
  }
T
tangwei12 已提交
72 73

  bool Push(const T &elem) {
S
seemingwang 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    std::unique_lock<std::mutex> lock(mutex_);
    WaitForWrite(lock);

    queue_.push_back(elem);

    Notify();
    return true;
  }
  bool WaitForWrite(std::unique_lock<std::mutex> &lock) {  // NOLINT
    while (FullUnlocked()) {
      if (empty_waiters_ != 0) {
        empty_cond_.notify_one();
      }
      full_waiters_++;
      full_cond_.wait(lock);
      full_waiters_--;
T
tangwei12 已提交
90 91 92
    }
    return true;
  }
S
seemingwang 已提交
93 94 95 96 97 98 99 100
  bool WaitForRead(std::unique_lock<std::mutex> &lock) {  // NOLINT
    while (EmptyUnlocked()) {
      if (full_waiters_ != 0) {
        full_cond_.notify_one();
      }
      empty_waiters_++;
      empty_cond_.wait(lock);
      empty_waiters_--;
T
tangwei12 已提交
101 102 103
    }
    return true;
  }
S
seemingwang 已提交
104
  bool EmptyUnlocked() { return queue_.empty(); }
T
tangwei12 已提交
105

S
seemingwang 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
  bool FullUnlocked() { return queue_.size() >= capacity_; }
  void Notify() {
    if (empty_waiters_ != 0 && (!EmptyUnlocked())) {
      empty_cond_.notify_one();
    }
    if (full_waiters_ != 0 && (!FullUnlocked())) {
      full_cond_.notify_one();
    }
  }

  bool Push(T &&elem) {
    std::unique_lock<std::mutex> lock(mutex_);
    WaitForWrite(lock);
    queue_.emplace_back(std::move(elem));

    Notify();
    return true;
  }
T
tangwei12 已提交
124 125
  T Pop() {
    std::unique_lock<std::mutex> lock(mutex_);
S
seemingwang 已提交
126
    WaitForRead(lock);
T
tangwei12 已提交
127 128
    T rc(std::move(queue_.front()));
    queue_.pop_front();
S
seemingwang 已提交
129
    Notify();
T
tangwei12 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143
    return rc;
  }

  size_t Cap() const {
    std::lock_guard<std::mutex> lock(mutex_);
    return capacity_;
  }

  size_t Size() const {
    std::lock_guard<std::mutex> lock(mutex_);
    return queue_.size();
  }

 private:
S
seemingwang 已提交
144 145 146 147
  int empty_waiters_ = 0;
  int full_waiters_ = 0;
  std::condition_variable empty_cond_;
  std::condition_variable full_cond_;
T
tangwei12 已提交
148 149 150 151 152 153
  const size_t capacity_;
  std::deque<T> queue_;

  mutable std::mutex mutex_;
};

154 155
template <typename T,
          int MajorType = Eigen::RowMajor,
T
tangwei12 已提交
156 157 158 159 160 161
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename T>
inline void MergeVars(const std::string &var_name,
                      const std::vector<std::shared_ptr<Variable>> &vars,
162 163
                      Scope *scope,
                      bool merge_add = true) {
164
  PADDLE_ENFORCE_NE(
165 166
      vars.empty(),
      true,
167
      platform::errors::InvalidArgument("vector vars are empty."));
T
tangwei12 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
  auto cpu_place = platform::CPUPlace();
  auto &var0 = vars[0];
  auto *out_var = scope->Var(var_name);

  if (var0->IsType<framework::LoDTensor>()) {
    auto dims = var0->Get<framework::LoDTensor>().dims();
    VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims
            << "; merge add: " << merge_add;
    // init output tensor
    auto *out_t = out_var->GetMutable<framework::LoDTensor>();
    out_t->mutable_data<T>(dims, cpu_place);
    // check the input dims
    for (auto &var : vars) {
      auto &var_t = var->Get<framework::LoDTensor>();
      PADDLE_ENFORCE_EQ(
183 184
          var_t.dims(),
          dims,
T
tangwei12 已提交
185 186 187 188
          platform::errors::InvalidArgument("vars should have the same dims."));
    }

    // set output tensor to 0.
L
Leo Chen 已提交
189 190
    phi::CPUContext cpu_ctx;
    phi::funcs::SetConstant<phi::CPUContext, T> constant_functor;
T
tangwei12 已提交
191 192 193 194 195 196 197 198 199 200 201 202
    constant_functor(cpu_ctx, out_t, static_cast<T>(0));
    // sum all vars to out
    auto result = EigenVector<T>::Flatten(*out_t);
    for (auto &var : vars) {
      auto &in_t = var->Get<framework::LoDTensor>();
      auto in = EigenVector<T>::Flatten(in_t);
      result.device(*cpu_ctx.eigen_device()) = result + in;
    }
    if (!merge_add) {
      result.device(*cpu_ctx.eigen_device()) =
          result / static_cast<T>(vars.size());
    }
203 204 205
  } else if (var0->IsType<phi::SelectedRows>()) {
    auto &slr0 = var0->Get<phi::SelectedRows>();
    auto *out_slr = out_var->GetMutable<phi::SelectedRows>();
T
tangwei12 已提交
206 207
    out_slr->mutable_rows()->clear();
    out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place);
208
    std::vector<const phi::SelectedRows *> inputs;
T
tangwei12 已提交
209 210
    inputs.reserve(vars.size());
    for (auto &var : vars) {
211
      inputs.push_back(&var->Get<phi::SelectedRows>());
T
tangwei12 已提交
212
    }
L
Leo Chen 已提交
213
    phi::CPUContext dev_ctx;
T
tangwei12 已提交
214
    if (merge_add) {
L
Leo Chen 已提交
215
      paddle::operators::math::scatter::MergeAdd<phi::CPUContext, T> merge_add;
T
tangwei12 已提交
216 217
      merge_add(dev_ctx, inputs, out_slr);
    } else {
L
Leo Chen 已提交
218 219
      paddle::operators::math::scatter::MergeAverage<phi::CPUContext, T>
          merge_average;
T
tangwei12 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
      merge_average(dev_ctx, inputs, out_slr);
    }

    VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
            << " dims: " << slr0.value().dims() << "; merge add: " << merge_add;
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!",
                                                   var0->Type()));
  }
}

using RpcCtxMap = std::unordered_map<std::string, CommContext>;
using RecvCtxMap = std::unordered_map<uint64_t, std::vector<std::string>>;
using SparseValue = std::unordered_map<int64_t, std::vector<float>>;

class Communicator {
 public:
  Communicator();

  explicit Communicator(const std::map<std::string, std::string> &envs_) {
240
    VLOG(3) << "Communicator Init Envs";
T
tangwei12 已提交
241 242
    for (auto &iter : envs_) {
      envs[iter.first] = iter.second;
243
      VLOG(3) << iter.first << ": " << iter.second;
T
tangwei12 已提交
244
    }
245 246 247 248 249
    if (!envs.empty()) {
      barrier_table_id_ = std::stoi(envs.at("barrier_table_id"));
      trainer_id_ = std::stoi(envs.at("trainer_id"));
      trainers_ = std::stoi(envs.at("trainers"));
    }
T
tangwei12 已提交
250 251 252 253
  }

  virtual void InitBrpcClient(const std::string &dist_desc,
                              const std::vector<std::string> &host_sign_list);
Z
zhaocaibei123 已提交
254 255 256 257 258

  virtual std::vector<uint64_t> GetClientInfo();

  virtual int SetClients(std::vector<uint64_t> &host_sign_list);  // NOLINT

T
tangwei12 已提交
259 260
  // 1. recv dense param
  virtual void RpcRecvDense(const std::vector<std::string> &varnames,
261 262
                            int table_id,
                            Scope *scope);
T
tangwei12 已提交
263 264
  // 2. send dense param
  virtual void RpcSendDenseParam(const std::vector<std::string> &varnames,
265 266
                                 int table_id,
                                 const Scope &scope);
T
tangwei12 已提交
267 268 269
  // 3. send dense grad
  virtual void RpcSendDense(const CommContext &ctx, const Scope &scope);
  // 4. send sparse grad
270 271
  virtual void RpcSendSparse(const std::string &var_name,
                             int table_id,
T
tangwei12 已提交
272 273
                             const Scope &scope);
  // 5. send sparse param
274 275
  virtual void RpcSendSparseParam(const std::string &varname,
                                  int table_id,
T
tangwei12 已提交
276 277
                                  const Scope &scope);
  // 6. recv sparse param
278 279
  virtual void RpcRecvSparse(const std::string &varname,
                             int table_id,
T
tangwei12 已提交
280
                             Scope *scope);
281
  // 7. send gloabl step
282 283
  virtual void SendGlobalStep(const CommContext &ctx,
                              int batches,
284
                              Scope *send_scope);
T
tangwei12 已提交
285

286 287 288 289 290 291 292 293 294
  virtual std::unordered_map<uint32_t, std::string> QueryFLClientsInfo() {
    return {};
  }
  virtual void SaveFLStrategy(
      const std::unordered_map<uint32_t, std::string> &fl_strategy) {}
  virtual void StartCoordinator(
      const std::string &self_endpoint,
      const std::vector<std::string> &trainer_endpoints) {}

T
tangwei12 已提交
295 296 297 298 299
  virtual ~Communicator() {}
  virtual void RpcProfilerControl();

  virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx);

Z
zhaocaibei123 已提交
300
  // note: only for pull dense param first before training
301 302
  virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx);

T
tangwei12 已提交
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
  virtual void Start() = 0;

  virtual void Stop() = 0;

  virtual bool IsRunning() { return running_; }

  virtual void Clean() {}

  virtual bool Check(const int table_id) = 0;
  virtual bool Check(const std::vector<std::string> &var_tables) = 0;

  virtual void Send(const std::vector<std::string> &var_names,
                    const framework::Scope &scope) = 0;

  virtual void RecvNoBarrier() {}

  virtual void Barrier() {}

  virtual void BarrierWithTable(uint32_t barrier_type) {
Z
zhaocaibei123 已提交
322
    auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type);
T
tangwei12 已提交
323
    rets.wait();
324
    int status = rets.get();
325 326
    PADDLE_ENFORCE_EQ(status,
                      0,
327 328
                      platform::errors::InvalidArgument(
                          "The ret status must be 0 when barrier with table"));
T
tangwei12 已提交
329 330
  }

Z
zhaocaibei123 已提交
331 332 333
  virtual void CreateC2CConnection(int pserver_timeout_ms,
                                   int pserver_connect_timeout_ms,
                                   int max_retry) {
Z
zhaocaibei123 已提交
334
    _worker_ptr->CreateClient2ClientConnection(
Z
zhaocaibei123 已提交
335 336 337
        pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
  }

T
tangwei12 已提交
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
  virtual void BarrierTriggerDecrement() {}

  virtual void BarrierTriggerReset(int init_counter) {}

  virtual void InitEnvs() = 0;

  virtual void InitImpl(const RpcCtxMap &send_varname_to_ctx,
                        const RecvCtxMap &recv_varname_to_ctx,
                        Scope *recv_scope) {}

  static Communicator *GetInstance() { return communicator_.get(); }

  static std::shared_ptr<Communicator> GetInstantcePtr() {
    return communicator_;
  }

  template <typename T>
  static Communicator *InitInstance(
356 357
      const RpcCtxMap &send_ctx,
      const RecvCtxMap &recv_ctx,
T
tangwei12 已提交
358
      const std::string &dist_desc,
359 360
      const std::vector<std::string> &host_sign_list,
      Scope *recv_scope,
T
tangwei12 已提交
361
      const std::map<std::string, std::string> &envs) {
362 363 364 365 366 367 368
    std::call_once(init_flag_,
                   &Communicator::InitWithRpcCtx<T>,
                   send_ctx,
                   recv_ctx,
                   dist_desc,
                   host_sign_list,
                   recv_scope,
T
tangwei12 已提交
369 370 371 372
                   std::ref(envs));
    return communicator_.get();
  }

373
  // called by InitInstance.
T
tangwei12 已提交
374 375 376 377 378 379 380
  template <typename T>
  static void InitWithRpcCtx(const RpcCtxMap &send_ctx,
                             const RecvCtxMap &recv_ctx,
                             const std::string &dist_desc,
                             const std::vector<std::string> &host_sign_list,
                             Scope *recv_scope,
                             const std::map<std::string, std::string> &envs) {
381
    VLOG(0) << "Communicator type is: " << typeid(T).name();
T
tangwei12 已提交
382 383 384 385 386 387 388 389 390 391
    if (communicator_.get() == nullptr) {
      communicator_.reset(new T(std::ref(envs)));
      communicator_->InitEnvs();
      communicator_->InitBrpcClient(dist_desc, host_sign_list);
      communicator_->InitImpl(send_ctx, recv_ctx, recv_scope);
    }
  }

  PSClient *GetPsClient() { return _worker_ptr.get(); }

T
Thunderbrook 已提交
392 393
  RecvCtxMap &GetRecvCtxMap() { return recv_varname_to_ctx_; }

394
  std::shared_ptr<PSClient> _worker_ptr;  // pointer to worker
T
tangwei12 已提交
395 396 397 398 399 400 401 402 403 404 405 406

 protected:
  bool running_ = false;
  bool waiting_ = true;
  bool flushing_ = false;
  bool do_server_profiler_ = false;
  static std::shared_ptr<Communicator> communicator_;
  static std::once_flag init_flag_;

  std::unordered_map<std::string, std::string> envs;

  // 计算每个shard 对 dense的存储量
Z
zhaocaibei123 已提交
407 408
  inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
                                   uint32_t shard_num) {
T
tangwei12 已提交
409 410 411
    return dense_dim_total / shard_num + 1;
  }

Z
zhaocaibei123 已提交
412
  void InitGFlag(const std::string &gflags);
T
tangwei12 已提交
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
  paddle::distributed::PSParameter _ps_param;
  paddle::distributed::PaddlePSEnvironment _ps_env;
  int servers_ = 0;
  int trainers_;
  int trainer_id_ = 0;
  int barrier_table_id_ = 0;
  RpcCtxMap send_varname_to_ctx_;
  RecvCtxMap recv_varname_to_ctx_;

  Scope *recv_scope_;  // should be global scope
  std::unique_ptr<Scope> xpu_temp_scope_;
  std::atomic<uint32_t> _async_call_num{0};
};

class AsyncCommunicator : public Communicator {
 public:
  AsyncCommunicator() : Communicator() {}

  explicit AsyncCommunicator(const std::map<std::string, std::string> &envs)
      : Communicator(envs) {}

  ~AsyncCommunicator();

  void InitEnvs() {
    independent_recv_ = static_cast<bool>(
        std::stoi(envs.at("communicator_independent_recv_thread")));
    min_send_grad_num_before_recv_ =
        std::stoi(envs.at("communicator_min_send_grad_num_before_recv"));
    thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
    max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
    send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
    send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
    need_global_step_ =
        static_cast<bool>(std::stoi(envs.at("need_global_step")));
  }

  void Start() override;

  void Stop() override;

  void InitImpl(const RpcCtxMap &send_varname_to_ctx,
                const RecvCtxMap &recv_varname_to_ctx,
                Scope *recv_scope) override;

  virtual void MainThread();
  virtual void RecvThread();

  virtual bool Check(const int table_id);
  virtual bool Check(const std::vector<std::string> &var_tables);

  void Send(const std::vector<std::string> &var_names,
            const framework::Scope &scope) override;

  virtual void SendByCommunicator();

  virtual void RecvByCommunicator();

  virtual void RecvNoBarrier();

  virtual int BatchesCounter() { return 1; }

  virtual void BarrierSend() {}

  virtual void BarrierRecv() {}

  virtual void BarrierWeakUp() {}

Z
zhaocaibei123 已提交
480 481
  void PushDensePostProcessing();

Y
yaoxuefeng 已提交
482
  void PullSparseToTensorSync(
483 484 485 486 487
      const uint64_t table_id,
      int fea_dim,
      uint64_t padding_id,
      platform::Place place,
      bool is_training,
Y
yaoxuefeng 已提交
488 489 490 491
      std::vector<const framework::LoDTensor *> *inputs,  // NOLINT
      std::vector<framework::LoDTensor *> *outputs);      // NOLINT

  void PushSparseFromTensorAsync(
492 493 494 495 496 497 498
      const uint64_t table_id,
      int fea_dim,
      uint64_t padding_id,
      platform::Place place,
      std::vector<const framework::LoDTensor *> *inputs,
      const framework::LoDTensor *shows,
      const framework::LoDTensor *clicks,
Y
yaoxuefeng 已提交
499 500
      std::vector<framework::LoDTensor *> *outputs);

T
tangwei12 已提交
501 502 503 504 505 506 507 508 509 510 511 512 513 514
 protected:
  std::unordered_map<std::string,
                     std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
      send_varname_to_queue_;
  std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};

  int min_send_grad_num_before_recv_;
  int thread_pool_size_;
  int max_merge_var_num_;
  int send_wait_times_;
  int send_queue_size_;
  bool need_global_step_ = false;
  bool independent_recv_ = true;
  int parallel_task_nums_ = 0;
Y
yaoxuefeng 已提交
515
  int32_t sleep_seconds_before_fail_exit_;
T
tangwei12 已提交
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541

  std::unique_ptr<std::thread> main_thread_{nullptr};
  std::unique_ptr<std::thread> recv_thread_{nullptr};

  std::unique_ptr<Scope> send_scope_;  // an independent scope
  std::atomic_uint grad_num_{0};  // the num of gradient sent since last recv
};

class HalfAsyncCommunicator : public AsyncCommunicator {
 public:
  HalfAsyncCommunicator() {}

  explicit HalfAsyncCommunicator(const std::map<std::string, std::string> &envs)
      : AsyncCommunicator(envs) {}

  void InitEnvs() {
    // enfore to recv after send
    independent_recv_ = false;
    min_send_grad_num_before_recv_ = 0;
    thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
    max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
    send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
    send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
    need_global_step_ =
        static_cast<bool>(std::stoi(envs.at("need_global_step")));

542
    VLOG(1) << "HalfAsyncCommunicator Initialized";
T
tangwei12 已提交
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
  }

  void MainThread() override;

  void SendByCommunicator() override;

  void Clean() override;

  void Barrier() override;

  void BarrierTriggerDecrement() override;

  void BarrierTriggerReset(int initial_val) override;

  int BatchesCounter();

  void BarrierWeakUp();

 protected:
  // mutex for Wait for barrier
  std::mutex barrier_mutex_;
  std::condition_variable barrier_cond_;
  std::atomic<int64_t> barrier_trigger_{0};
  std::atomic<int64_t> barrier_counter_{0};
};

class SyncCommunicator : public HalfAsyncCommunicator {
 public:
  SyncCommunicator() : HalfAsyncCommunicator() {}

  explicit SyncCommunicator(const std::map<std::string, std::string> &envs)
      : HalfAsyncCommunicator(envs) {}

  void InitEnvs() {
    // enfore to recv after send
    independent_recv_ = false;
    min_send_grad_num_before_recv_ = 0;
    max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
    send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
    thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
    send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
    need_global_step_ =
        static_cast<bool>(std::stoi(envs.at("need_global_step")));

587
    VLOG(1) << "SyncCommunicator Initialized";
T
tangwei12 已提交
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605
  }

  void BarrierSend();

  void BarrierRecv();

 private:
  std::vector<std::string> pserver_endpoints_{};
};

class GeoCommunicator : public AsyncCommunicator {
 public:
  GeoCommunicator() : AsyncCommunicator() {}

  explicit GeoCommunicator(const std::map<std::string, std::string> &envs)
      : AsyncCommunicator(envs) {}

  void InitParams(const RecvCtxMap &recv_varname_to_ctx) override;
Z
zhaocaibei123 已提交
606
  void InitDense(std::vector<std::string> &varnames, int table_id);  // NOLINT
T
tangwei12 已提交
607 608 609 610 611 612
  void InitSparse(const std::string &var_name, int table_id);

  void SendDense(const CommContext &send_ctx);
  void RecvDense(const CommContext &send_ctx);

  std::vector<int64_t> MergeSparseIds(const std::string &varname);
Z
zhaocaibei123 已提交
613 614
  void SendSparse(const std::string &varname,
                  std::vector<int64_t> &sparse_ids,  // NOLINT
615 616
                  int table_id,
                  int ep_idx);
T
tangwei12 已提交
617 618 619 620
  void RecvSparse(const std::string &varname, int table_id, int ep_idx);

  void MainThread() override;

621
  virtual void InitEnvs() {
T
tangwei12 已提交
622 623 624 625 626 627 628
    independent_recv_ = false;
    min_send_grad_num_before_recv_ = 0;
    send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
    thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
    // id_queue's size
    max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
    send_queue_size_ = max_merge_var_num_;
629
    VLOG(1) << "GeoCommunicator Initialized";
T
tangwei12 已提交
630 631
  }

632 633 634 635
  void InitImpl(const RpcCtxMap &send_varname_to_ctx,
                const RecvCtxMap &recv_varname_to_ctx,
                Scope *recv_scope) override;

T
tangwei12 已提交
636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654
  void Send(const std::vector<std::string> &var_names,
            const framework::Scope &scope) override;

  void SendByCommunicator() { return; }

  void RecvByCommunicator() override { return; }

  inline std::string GradToParam(const std::string var_name) {
    std::string param_name = var_name.substr(0, var_name.size() - 5);
    return param_name;
  }

  inline std::string SplitedGradToParam(const std::string delta_name) {
    // delta_name: emb.delta0
    auto pos = delta_name.find(".block");
    std::string param_name = delta_name.substr(0, pos);
    return param_name;
  }

655
 public:
T
tangwei12 已提交
656 657 658 659 660 661 662
  // parameter for delta calc and send
  std::shared_ptr<Scope> delta_scope_;
  // parameter for storage the pserver param after last recv
  std::shared_ptr<Scope> old_scope_;
  // parameter on pserver
  std::shared_ptr<Scope> pserver_scope_;

663 664 665
  std::unordered_map<
      std::string,
      paddle::framework::Channel<std::shared_ptr<std::vector<int64_t>>>>
T
tangwei12 已提交
666 667 668
      sparse_id_queues_;
};

669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687
class FLCommunicator : public GeoCommunicator {
 public:
  FLCommunicator() : GeoCommunicator() {}

  ~FLCommunicator() {
    is_running_ = false;
    async_send_thread_->join();
  }

  explicit FLCommunicator(const std::map<std::string, std::string> &envs)
      : GeoCommunicator(envs) {}

  void InitEnvs() override {}

  virtual void InitBrpcClient(const std::string &dist_desc,
                              const std::vector<std::string> &host_sign_list);

  void InitImpl(const RpcCtxMap &send_varname_to_ctx,
                const RecvCtxMap &recv_varname_to_ctx,
688
                Scope *recv_scope) {}
689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713

  void StartCoordinatorClient(
      const std::vector<std::string> &trainer_endpoints);

  void StartCoordinatorServer();

  void StartCoordinator(
      const std::string &self_endpoint,
      const std::vector<std::string> &trainer_endpoints) override;

  std::unordered_map<uint32_t, std::string> QueryFLClientsInfo();
  void SaveFLStrategy(
      const std::unordered_map<uint32_t, std::string> &fl_strategy);

  void SendThreadAsync();
  void RpcSendFLStrategy();

 private:
  int thread_pool_size_ = 1;
  bool is_running_ = true;
  PaddlePSEnvironment ps_env_;
  std::shared_ptr<CoordinatorClient> coordinator_client_ptr_{nullptr};
  std::unique_ptr<std::thread> async_send_thread_{nullptr};
};

T
tangwei12 已提交
714 715
}  // namespace distributed
}  // namespace paddle