communicator.h 19.5 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <ThreadPool.h>
#include <stdint.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "gflags/gflags.h"
#include "paddle/fluid/distributed/communicator_common.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"

#include "paddle/fluid/distributed/service/ps_client.h"

46 47 48 49 50 51 52
namespace paddle {
namespace distributed {
class PSClient;
struct CommContext;
}  // namespace distributed
}  // namespace paddle

T
tangwei12 已提交
53 54 55 56 57 58 59 60 61 62 63
DECLARE_bool(communicator_is_sgd_optimizer);

namespace paddle {
namespace distributed {

using Scope = framework::Scope;
using Variable = framework::Variable;

template <typename T>
class BlockingQueue {
 public:
S
seemingwang 已提交
64 65 66 67 68
  explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
    PADDLE_ENFORCE_GT(capacity_, 0,
                      platform::errors::InvalidArgument(
                          "The capacity must be greater than 0."));
  }
T
tangwei12 已提交
69 70

  bool Push(const T &elem) {
S
seemingwang 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    std::unique_lock<std::mutex> lock(mutex_);
    WaitForWrite(lock);

    queue_.push_back(elem);

    Notify();
    return true;
  }
  bool WaitForWrite(std::unique_lock<std::mutex> &lock) {  // NOLINT
    while (FullUnlocked()) {
      if (empty_waiters_ != 0) {
        empty_cond_.notify_one();
      }
      full_waiters_++;
      full_cond_.wait(lock);
      full_waiters_--;
T
tangwei12 已提交
87 88 89
    }
    return true;
  }
S
seemingwang 已提交
90 91 92 93 94 95 96 97
  bool WaitForRead(std::unique_lock<std::mutex> &lock) {  // NOLINT
    while (EmptyUnlocked()) {
      if (full_waiters_ != 0) {
        full_cond_.notify_one();
      }
      empty_waiters_++;
      empty_cond_.wait(lock);
      empty_waiters_--;
T
tangwei12 已提交
98 99 100
    }
    return true;
  }
S
seemingwang 已提交
101
  bool EmptyUnlocked() { return queue_.empty(); }
T
tangwei12 已提交
102

S
seemingwang 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
  bool FullUnlocked() { return queue_.size() >= capacity_; }
  void Notify() {
    if (empty_waiters_ != 0 && (!EmptyUnlocked())) {
      empty_cond_.notify_one();
    }
    if (full_waiters_ != 0 && (!FullUnlocked())) {
      full_cond_.notify_one();
    }
  }

  bool Push(T &&elem) {
    std::unique_lock<std::mutex> lock(mutex_);
    WaitForWrite(lock);
    queue_.emplace_back(std::move(elem));

    Notify();
    return true;
  }
T
tangwei12 已提交
121 122
  T Pop() {
    std::unique_lock<std::mutex> lock(mutex_);
S
seemingwang 已提交
123
    WaitForRead(lock);
T
tangwei12 已提交
124 125
    T rc(std::move(queue_.front()));
    queue_.pop_front();
S
seemingwang 已提交
126
    Notify();
T
tangwei12 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140
    return rc;
  }

  size_t Cap() const {
    std::lock_guard<std::mutex> lock(mutex_);
    return capacity_;
  }

  size_t Size() const {
    std::lock_guard<std::mutex> lock(mutex_);
    return queue_.size();
  }

 private:
S
seemingwang 已提交
141 142 143 144
  int empty_waiters_ = 0;
  int full_waiters_ = 0;
  std::condition_variable empty_cond_;
  std::condition_variable full_cond_;
T
tangwei12 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
  const size_t capacity_;
  std::deque<T> queue_;

  mutable std::mutex mutex_;
};

template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename T>
inline void MergeVars(const std::string &var_name,
                      const std::vector<std::shared_ptr<Variable>> &vars,
                      Scope *scope, bool merge_add = true) {
  PADDLE_ENFORCE_NE(vars.empty(), true, platform::errors::InvalidArgument(
                                            "vector vars are empty."));
  auto cpu_place = platform::CPUPlace();
  auto &var0 = vars[0];
  auto *out_var = scope->Var(var_name);

  if (var0->IsType<framework::LoDTensor>()) {
    auto dims = var0->Get<framework::LoDTensor>().dims();
    VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims
            << "; merge add: " << merge_add;
    // init output tensor
    auto *out_t = out_var->GetMutable<framework::LoDTensor>();
    out_t->mutable_data<T>(dims, cpu_place);
    // check the input dims
    for (auto &var : vars) {
      auto &var_t = var->Get<framework::LoDTensor>();
      PADDLE_ENFORCE_EQ(
          var_t.dims(), dims,
          platform::errors::InvalidArgument("vars should have the same dims."));
    }

    // set output tensor to 0.
    auto cpu_ctx = paddle::platform::CPUDeviceContext();
    paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext, T>
        constant_functor;
    constant_functor(cpu_ctx, out_t, static_cast<T>(0));
    // sum all vars to out
    auto result = EigenVector<T>::Flatten(*out_t);
    for (auto &var : vars) {
      auto &in_t = var->Get<framework::LoDTensor>();
      auto in = EigenVector<T>::Flatten(in_t);
      result.device(*cpu_ctx.eigen_device()) = result + in;
    }
    if (!merge_add) {
      result.device(*cpu_ctx.eigen_device()) =
          result / static_cast<T>(vars.size());
    }
  } else if (var0->IsType<framework::SelectedRows>()) {
    auto &slr0 = var0->Get<framework::SelectedRows>();
    auto *out_slr = out_var->GetMutable<framework::SelectedRows>();
    out_slr->mutable_rows()->clear();
    out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place);
    std::vector<const paddle::framework::SelectedRows *> inputs;
    inputs.reserve(vars.size());
    for (auto &var : vars) {
      inputs.push_back(&var->Get<framework::SelectedRows>());
    }
    auto dev_ctx = paddle::platform::CPUDeviceContext();
    if (merge_add) {
      paddle::operators::math::scatter::MergeAdd<
          paddle::platform::CPUDeviceContext, T>
          merge_add;
      merge_add(dev_ctx, inputs, out_slr);
    } else {
      paddle::operators::math::scatter::MergeAverage<
          paddle::platform::CPUDeviceContext, T>
          merge_average;
      merge_average(dev_ctx, inputs, out_slr);
    }

    VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
            << " dims: " << slr0.value().dims() << "; merge add: " << merge_add;
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!",
                                                   var0->Type()));
  }
}

using RpcCtxMap = std::unordered_map<std::string, CommContext>;
using RecvCtxMap = std::unordered_map<uint64_t, std::vector<std::string>>;
using SparseValue = std::unordered_map<int64_t, std::vector<float>>;

class Communicator {
 public:
  Communicator();

  explicit Communicator(const std::map<std::string, std::string> &envs_) {
236
    VLOG(3) << "Communicator Init Envs";
T
tangwei12 已提交
237 238
    for (auto &iter : envs_) {
      envs[iter.first] = iter.second;
239
      VLOG(3) << iter.first << ": " << iter.second;
T
tangwei12 已提交
240 241 242 243 244 245 246 247
    }
    barrier_table_id_ = std::stoi(envs.at("barrier_table_id"));
    trainer_id_ = std::stoi(envs.at("trainer_id"));
    trainers_ = std::stoi(envs.at("trainers"));
  }

  virtual void InitBrpcClient(const std::string &dist_desc,
                              const std::vector<std::string> &host_sign_list);
Z
zhaocaibei123 已提交
248 249 250 251 252

  virtual std::vector<uint64_t> GetClientInfo();

  virtual int SetClients(std::vector<uint64_t> &host_sign_list);  // NOLINT

T
tangwei12 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
  // 1. recv dense param
  virtual void RpcRecvDense(const std::vector<std::string> &varnames,
                            int table_id, Scope *scope);
  // 2. send dense param
  virtual void RpcSendDenseParam(const std::vector<std::string> &varnames,
                                 int table_id, const Scope &scope);
  // 3. send dense grad
  virtual void RpcSendDense(const CommContext &ctx, const Scope &scope);
  // 4. send sparse grad
  virtual void RpcSendSparse(const std::string &var_name, int table_id,
                             const Scope &scope);
  // 5. send sparse param
  virtual void RpcSendSparseParam(const std::string &varname, int table_id,
                                  const Scope &scope);
  // 6. recv sparse param
  virtual void RpcRecvSparse(const std::string &varname, int table_id,
                             Scope *scope);
270 271 272
  // 7. send gloabl step
  virtual void SendGlobalStep(const CommContext &ctx, int batches,
                              Scope *send_scope);
T
tangwei12 已提交
273 274 275 276 277 278

  virtual ~Communicator() {}
  virtual void RpcProfilerControl();

  virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx);

Z
zhaocaibei123 已提交
279
  // note: only for pull dense param first before training
280 281
  virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx);

T
tangwei12 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
  virtual void Start() = 0;

  virtual void Stop() = 0;

  virtual bool IsRunning() { return running_; }

  virtual void Clean() {}

  virtual bool Check(const int table_id) = 0;
  virtual bool Check(const std::vector<std::string> &var_tables) = 0;

  virtual void Send(const std::vector<std::string> &var_names,
                    const framework::Scope &scope) = 0;

  virtual void RecvNoBarrier() {}

  virtual void Barrier() {}

  virtual void BarrierWithTable(uint32_t barrier_type) {
    auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type);
    rets.wait();
  }

Z
zhaocaibei123 已提交
305 306 307 308 309 310 311
  virtual void CreateC2CConnection(int pserver_timeout_ms,
                                   int pserver_connect_timeout_ms,
                                   int max_retry) {
    _worker_ptr->create_client2client_connection(
        pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
  }

T
tangwei12 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
  virtual void BarrierTriggerDecrement() {}

  virtual void BarrierTriggerReset(int init_counter) {}

  virtual void InitEnvs() = 0;

  virtual void InitImpl(const RpcCtxMap &send_varname_to_ctx,
                        const RecvCtxMap &recv_varname_to_ctx,
                        Scope *recv_scope) {}

  static Communicator *GetInstance() { return communicator_.get(); }

  static std::shared_ptr<Communicator> GetInstantcePtr() {
    return communicator_;
  }

  template <typename T>
  static Communicator *InitInstance(
      const RpcCtxMap &send_ctx, const RecvCtxMap &recv_ctx,
      const std::string &dist_desc,
      const std::vector<std::string> &host_sign_list, Scope *recv_scope,
      const std::map<std::string, std::string> &envs) {
    std::call_once(init_flag_, &Communicator::InitWithRpcCtx<T>, send_ctx,
                   recv_ctx, dist_desc, host_sign_list, recv_scope,
                   std::ref(envs));
    return communicator_.get();
  }

  // Init is called by InitInstance.
  template <typename T>
  static void InitWithRpcCtx(const RpcCtxMap &send_ctx,
                             const RecvCtxMap &recv_ctx,
                             const std::string &dist_desc,
                             const std::vector<std::string> &host_sign_list,
                             Scope *recv_scope,
                             const std::map<std::string, std::string> &envs) {
    if (communicator_.get() == nullptr) {
      communicator_.reset(new T(std::ref(envs)));
      communicator_->InitEnvs();
      communicator_->InitBrpcClient(dist_desc, host_sign_list);
      communicator_->InitImpl(send_ctx, recv_ctx, recv_scope);
    }
  }

  PSClient *GetPsClient() { return _worker_ptr.get(); }

Z
zhaocaibei123 已提交
358 359
  std::unique_ptr<paddle::distributed::PSClient> GetPsClientPtr() {
    return std::move(_worker_ptr);
T
tangwei12 已提交
360 361
  }

T
Thunderbrook 已提交
362 363
  RecvCtxMap &GetRecvCtxMap() { return recv_varname_to_ctx_; }

Z
zhaocaibei123 已提交
364
  std::unique_ptr<PSClient> _worker_ptr;  // pointer to worker
T
tangwei12 已提交
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449

 protected:
  bool running_ = false;
  bool waiting_ = true;
  bool flushing_ = false;
  bool do_server_profiler_ = false;
  static std::shared_ptr<Communicator> communicator_;
  static std::once_flag init_flag_;

  std::unordered_map<std::string, std::string> envs;

  // 计算每个shard 对 dense的存储量
  inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
                                      uint32_t shard_num) {
    return dense_dim_total / shard_num + 1;
  }

  void init_gflag(const std::string &gflags);
  paddle::distributed::PSParameter _ps_param;
  paddle::distributed::PaddlePSEnvironment _ps_env;
  int servers_ = 0;
  int trainers_;
  int trainer_id_ = 0;
  int barrier_table_id_ = 0;
  RpcCtxMap send_varname_to_ctx_;
  RecvCtxMap recv_varname_to_ctx_;

  Scope *recv_scope_;  // should be global scope
  std::unique_ptr<Scope> xpu_temp_scope_;
  std::atomic<uint32_t> _async_call_num{0};
};

class AsyncCommunicator : public Communicator {
 public:
  AsyncCommunicator() : Communicator() {}

  explicit AsyncCommunicator(const std::map<std::string, std::string> &envs)
      : Communicator(envs) {}

  ~AsyncCommunicator();

  void InitEnvs() {
    independent_recv_ = static_cast<bool>(
        std::stoi(envs.at("communicator_independent_recv_thread")));
    min_send_grad_num_before_recv_ =
        std::stoi(envs.at("communicator_min_send_grad_num_before_recv"));
    thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
    max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
    send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
    send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
    need_global_step_ =
        static_cast<bool>(std::stoi(envs.at("need_global_step")));
  }

  void Start() override;

  void Stop() override;

  void InitImpl(const RpcCtxMap &send_varname_to_ctx,
                const RecvCtxMap &recv_varname_to_ctx,
                Scope *recv_scope) override;

  virtual void MainThread();
  virtual void RecvThread();

  virtual bool Check(const int table_id);
  virtual bool Check(const std::vector<std::string> &var_tables);

  void Send(const std::vector<std::string> &var_names,
            const framework::Scope &scope) override;

  virtual void SendByCommunicator();

  virtual void RecvByCommunicator();

  virtual void RecvNoBarrier();

  virtual int BatchesCounter() { return 1; }

  virtual void BarrierSend() {}

  virtual void BarrierRecv() {}

  virtual void BarrierWeakUp() {}

Z
zhaocaibei123 已提交
450 451
  void PushDensePostProcessing();

T
tangwei12 已提交
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
 protected:
  std::unordered_map<std::string,
                     std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
      send_varname_to_queue_;
  std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};

  int min_send_grad_num_before_recv_;
  int thread_pool_size_;
  int max_merge_var_num_;
  int send_wait_times_;
  int send_queue_size_;
  bool need_global_step_ = false;
  bool independent_recv_ = true;
  int parallel_task_nums_ = 0;

  std::unique_ptr<std::thread> main_thread_{nullptr};
  std::unique_ptr<std::thread> recv_thread_{nullptr};

  std::unique_ptr<Scope> send_scope_;  // an independent scope
  std::atomic_uint grad_num_{0};  // the num of gradient sent since last recv
};

class HalfAsyncCommunicator : public AsyncCommunicator {
 public:
  HalfAsyncCommunicator() {}

  explicit HalfAsyncCommunicator(const std::map<std::string, std::string> &envs)
      : AsyncCommunicator(envs) {}

  void InitEnvs() {
    // enfore to recv after send
    independent_recv_ = false;
    min_send_grad_num_before_recv_ = 0;
    thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
    max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
    send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
    send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
    need_global_step_ =
        static_cast<bool>(std::stoi(envs.at("need_global_step")));

492
    VLOG(1) << "HalfAsyncCommunicator Initialized";
T
tangwei12 已提交
493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
  }

  void MainThread() override;

  void SendByCommunicator() override;

  void Clean() override;

  void Barrier() override;

  void BarrierTriggerDecrement() override;

  void BarrierTriggerReset(int initial_val) override;

  int BatchesCounter();

  void BarrierWeakUp();

 protected:
  // mutex for Wait for barrier
  std::mutex barrier_mutex_;
  std::condition_variable barrier_cond_;
  std::atomic<int64_t> barrier_trigger_{0};
  std::atomic<int64_t> barrier_counter_{0};
};

class SyncCommunicator : public HalfAsyncCommunicator {
 public:
  SyncCommunicator() : HalfAsyncCommunicator() {}

  explicit SyncCommunicator(const std::map<std::string, std::string> &envs)
      : HalfAsyncCommunicator(envs) {}

  void InitEnvs() {
    // enfore to recv after send
    independent_recv_ = false;
    min_send_grad_num_before_recv_ = 0;
    max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
    send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
    thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
    send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
    need_global_step_ =
        static_cast<bool>(std::stoi(envs.at("need_global_step")));

537
    VLOG(1) << "SyncCommunicator Initialized";
T
tangwei12 已提交
538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
  }

  void BarrierSend();

  void BarrierRecv();

 private:
  std::vector<std::string> pserver_endpoints_{};
};

class GeoCommunicator : public AsyncCommunicator {
 public:
  GeoCommunicator() : AsyncCommunicator() {}

  explicit GeoCommunicator(const std::map<std::string, std::string> &envs)
      : AsyncCommunicator(envs) {}

  void InitImpl(const RpcCtxMap &send_varname_to_ctx,
                const RecvCtxMap &recv_varname_to_ctx,
                Scope *recv_scope) override;

  void InitParams(const RecvCtxMap &recv_varname_to_ctx) override;
Z
zhaocaibei123 已提交
560
  void InitDense(std::vector<std::string> &varnames, int table_id);  // NOLINT
T
tangwei12 已提交
561 562 563 564 565 566
  void InitSparse(const std::string &var_name, int table_id);

  void SendDense(const CommContext &send_ctx);
  void RecvDense(const CommContext &send_ctx);

  std::vector<int64_t> MergeSparseIds(const std::string &varname);
Z
zhaocaibei123 已提交
567 568
  void SendSparse(const std::string &varname,
                  std::vector<int64_t> &sparse_ids,  // NOLINT
T
tangwei12 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581
                  int table_id, int ep_idx);
  void RecvSparse(const std::string &varname, int table_id, int ep_idx);

  void MainThread() override;

  void InitEnvs() {
    independent_recv_ = false;
    min_send_grad_num_before_recv_ = 0;
    send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
    thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
    // id_queue's size
    max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
    send_queue_size_ = max_merge_var_num_;
582
    VLOG(1) << "GeoCommunicator Initialized";
T
tangwei12 已提交
583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619
  }

  void Send(const std::vector<std::string> &var_names,
            const framework::Scope &scope) override;

  void SendByCommunicator() { return; }

  void RecvByCommunicator() override { return; }

  inline std::string GradToParam(const std::string var_name) {
    std::string param_name = var_name.substr(0, var_name.size() - 5);
    return param_name;
  }

  inline std::string SplitedGradToParam(const std::string delta_name) {
    // delta_name: emb.delta0
    auto pos = delta_name.find(".block");
    std::string param_name = delta_name.substr(0, pos);
    return param_name;
  }

 private:
  // parameter for delta calc and send
  std::shared_ptr<Scope> delta_scope_;
  // parameter for storage the pserver param after last recv
  std::shared_ptr<Scope> old_scope_;
  // parameter on pserver
  std::shared_ptr<Scope> pserver_scope_;

  std::unordered_map<
      std::string,
      std::shared_ptr<BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>>
      sparse_id_queues_;
};

}  // namespace distributed
}  // namespace paddle