communicator.cc 29.7 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/distributed/communicator.h"
16

Q
Qiao Longfei 已提交
17
#include <gflags/gflags.h>
18
#include <paddle/fluid/framework/program_desc.h>
19

20
#include <algorithm>
Q
Qiao Longfei 已提交
21
#include <chrono>  // NOLINT
22
#include <map>
Q
Qiao Longfei 已提交
23
#include <thread>  // NOLINT
24
#include <unordered_set>
25

Q
Qiao Longfei 已提交
26
#include "paddle/fluid/framework/eigen.h"
Q
Qiao Longfei 已提交
27 28
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
29
#include "paddle/fluid/framework/threadpool.h"
Q
Qiao Longfei 已提交
30
#include "paddle/fluid/framework/variable_helper.h"
C
Chengmo 已提交
31
#include "paddle/fluid/operators/distributed/distributed.h"
Q
Qiao Longfei 已提交
32 33
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
T
tangwei12 已提交
34
#include "paddle/fluid/string/printf.h"
35
#include "paddle/fluid/string/split.h"
Q
Qiao Longfei 已提交
36

Q
Qiao Longfei 已提交
37 38 39 40
namespace paddle {
namespace operators {
namespace distributed {

41 42 43 44
using Tree =
    std::map<std::string, std::map<std::string, std::vector<std::string>>>;
using RpcCtxMap = operators::distributed::RpcCtxMap;

Q
Qiao Longfei 已提交
45 46 47 48 49 50
inline double GetCurrentUS() {
  struct timeval time;
  gettimeofday(&time, NULL);
  return 1e+6 * time.tv_sec + time.tv_usec;
}

51
Communicator::Communicator() {}
1
123malin 已提交
52

T
tangwei12 已提交
53
std::once_flag Communicator::init_flag_;
54
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
Q
can run  
Qiao Longfei 已提交
55

T
tangwei12 已提交
56 57 58 59 60 61 62
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
                                 const RpcCtxMap &recv_varname_to_ctx,
                                 Scope *recv_scope) {
  send_varname_to_ctx_ = std::move(send_varname_to_ctx);
  recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
  recv_scope_ = std::move(recv_scope);

63 64 65 66 67 68 69
  if (send_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be send, will not start send_thread";
  } else {
    send_scope_.reset(new Scope());
    for (auto &iter : send_varname_to_ctx_) {
      send_varname_to_queue_[iter.first] =
          std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
70
              send_queue_size_);
71
    }
72
    send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
73 74 75 76 77
  }

  if (recv_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be received, will not start recv_thread";
  } else {
78
    recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
Q
Qiao Longfei 已提交
79
  }
T
tangwei12 已提交
80 81

  InitParams();
Q
Qiao Longfei 已提交
82 83
}

T
tangwei12 已提交
84 85
void AsyncCommunicator::InitParams() { RecvNoBarrier(); }

86 87 88 89 90 91 92 93
AsyncCommunicator::~AsyncCommunicator() {
  running_ = false;
  if (main_thread_) main_thread_->join();
}

void AsyncCommunicator::SendGlobalStep(int batches) {
  if (!need_global_step_) {
    return;
T
tangwei12 已提交
94 95
  }

96 97
  if (batches == 0) {
    return;
T
tangwei12 已提交
98 99
  }

100 101 102 103 104
  auto &var_name = STEP_COUNTER;
  auto *out_var = send_scope_->Var(var_name);
  auto *out_t = out_var->GetMutable<framework::LoDTensor>();
  auto *data = out_t->mutable_data<int64_t>({1}, platform::CPUPlace());
  data[0] = static_cast<int64_t>(batches);
T
tangwei12 已提交
105

106 107 108
  auto &ctx = send_varname_to_ctx_.at(var_name);
  auto send_functor = distributed::ParameterSend<float>();
  send_functor(ctx, *send_scope_, true, 1);
Q
Qiao Longfei 已提交
109 110
}

111 112 113 114 115 116 117 118 119 120 121 122
void AsyncCommunicator::SendByCommunicator(int batches) {
  std::vector<std::future<void>> task_futures;
  task_futures.reserve(send_varname_to_ctx_.size());
  VLOG(3) << "run send graph";
  auto before_run_send_graph = GetCurrentUS();
  for (auto &iter : send_varname_to_queue_) {
    auto &var_name = iter.first;
    auto &var_queue = iter.second;

    auto send_task = [this, batches, &var_name, &var_queue] {
      if (var_name == STEP_COUNTER) {
        return;
Q
Qiao Longfei 已提交
123
      }
124

125 126 127
      VLOG(3) << var_name << " merge and send";
      std::vector<std::shared_ptr<Variable>> vars;
      vars.reserve(batches);
Q
Qiao Longfei 已提交
128

129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
      for (int i = 0; i < batches; ++i) {
        vars.push_back(var_queue->Pop());
      }

      auto &ctx = send_varname_to_ctx_.at(var_name);

      auto before_merge = GetCurrentUS();
      MergeVars<float>(var_name, vars, send_scope_.get(), ctx.merge_add);
      auto after_merge = GetCurrentUS();
      VLOG(3) << "merge " << batches << " " << var_name << " use time "
              << after_merge - before_merge;

      auto send_functor = distributed::ParameterSend<float>();
      send_functor(ctx, *send_scope_, true, 1);
      auto after_send = GetCurrentUS();
      VLOG(3) << "send " << var_name << " use time "
              << after_send - after_merge;
    };
    task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
Q
Qiao Longfei 已提交
148
  }
149 150 151 152 153 154 155
  for (auto &task_f : task_futures) {
    task_f.wait();
  }
  auto after_run_send_graph = GetCurrentUS();

  VLOG(3) << "run send graph use time "
          << after_run_send_graph - before_run_send_graph;
Q
Qiao Longfei 已提交
156 157
}

158 159 160 161 162 163
void AsyncCommunicator::MainThread() {
  VLOG(3) << "MainThread start and wait";

  while (waiting_ && running_) {
    std::this_thread::sleep_for(std::chrono::milliseconds(100));
    VLOG(3) << "wait for running";
164 165
  }

166
  while (running_) {
T
tangwei12 已提交
167 168 169 170 171 172 173 174 175 176 177 178
    int batches = BatchesCounter();

    if (batches > 0) {
      SendGlobalStep(batches);
      SendByCommunicator(batches);
      BarrierSend();
      RecvByCommunicator();
      BarrierRecv();
      BarrierWeakUp();
    } else {
      VLOG(1) << "get nothing from sending queue, will skip send/recv";
    }
179
  }
180
  VLOG(1) << "communicator stopped, send thread exit";
181 182
}

183
void AsyncCommunicator::RecvByCommunicator() {
T
tangwei12 已提交
184 185
  VLOG(3) << "parallel run recv graph";
  if (!running_) return;
186 187 188 189 190
  RecvNoBarrier();
  VLOG(3) << "run recv graph use time";
}

void AsyncCommunicator::RecvNoBarrier() {
T
tangwei12 已提交
191 192
  std::vector<std::future<void>> task_futures;
  task_futures.reserve(recv_varname_to_ctx_.size());
193

T
tangwei12 已提交
194 195 196 197 198
  for (auto &iter : recv_varname_to_ctx_) {
    auto recv_task = [this, &iter] {
      auto &var_name = iter.first;
      VLOG(4) << "recv var " << var_name;
      auto recv_functor = distributed::ParameterRecv<float>();
T
tangwei12 已提交
199
      recv_functor(iter.second, *recv_scope_);
T
tangwei12 已提交
200 201 202
    };
    task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
  }
203

T
tangwei12 已提交
204 205 206
  for (auto &task : task_futures) {
    task.wait();
  }
207 208
}

T
tangwei12 已提交
209
int AsyncCommunicator::BatchesCounter() {
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
  auto &step_queue = send_varname_to_queue_.at(STEP_COUNTER);

  size_t merged_var_num = 0;
  size_t wait_times = 0;

  while (merged_var_num < static_cast<size_t>(max_merge_var_num_)) {
    if (step_queue->Size() == 0) {
      VLOG(3) << "wait_times -> " << wait_times;
      if (wait_times >= static_cast<size_t>(send_wait_times_)) {
        break;
      }
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
      wait_times++;
      continue;
    } else {
      step_queue->Pop();
      wait_times = 0;
      merged_var_num++;
    }
  }

  return merged_var_num;
232 233
}

T
tangwei12 已提交
234
void AsyncCommunicator::Start() {
235
  VLOG(1) << "Communicator start";
236 237 238 239
  if (!communicator_) {
    VLOG(0) << "Communicator is not inited, do nothing";
  } else {
    VLOG(1) << "start send thread and recv thread";
240
    waiting_ = true;
241
    running_ = true;
242
    BarrierTriggerReset(max_merge_var_num_);
243
    // start send and recv thread
244 245
    main_thread_.reset(
        new std::thread(std::bind(&AsyncCommunicator::MainThread, this)));
246 247 248
  }
}

T
tangwei12 已提交
249
void AsyncCommunicator::Stop() {
250
  VLOG(1) << "Communicator stop";
251 252 253 254
  running_ = false;
  if (!communicator_) {
    VLOG(0) << "Communicator is not inited, do nothing";
  } else {
255
    if (main_thread_) {
256
      VLOG(1) << "stop send thread";
257 258
      main_thread_->join();
      main_thread_.reset(nullptr);
259
    }
Q
Qiao Longfei 已提交
260
  }
261
  VLOG(1) << "Communicator stop done";
Q
Qiao Longfei 已提交
262 263
}

264 265 266
void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
                             const std::vector<std::string> &var_tables,
                             const framework::Scope &scope) {
267 268
  waiting_ = false;

269
  PADDLE_ENFORCE_EQ(
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
      var_tables.size(), 1,
      platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));

  auto table_name = var_tables[0];
  auto &queue = send_varname_to_queue_.at(table_name);

  if (table_name == STEP_COUNTER) {
    auto tmp_var = std::make_shared<Variable>();
    auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
    tensor->Resize(framework::make_ddim({1}));
    auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
    out_d[0] = 1;
    VLOG(3) << "send to " << table_name << " with queue size " << queue->Size();
    queue->Push(tmp_var);
  } else {
    PADDLE_ENFORCE_GE(var_names.size(), 1,
                      platform::errors::InvalidArgument(
                          "var_names.size() >= 1 is permitted"));

    auto *var = scope.FindVar(var_names[0]);

    PADDLE_ENFORCE_EQ(
        var->IsInitialized(), true,
        platform::errors::InvalidArgument("grad var should be inited"));

    auto tmp_var = std::make_shared<Variable>();
    if (var->IsType<framework::SelectedRows>()) {
      framework::CopyVariable(*var, tmp_var.get());
      VLOG(3) << "send to " << table_name << " with queue size "
              << queue->Size();
      queue->Push(tmp_var);
    } else if (var->IsType<framework::LoDTensor>()) {
      // push var into send queue by var_name
      auto var_name = var_names[0];
      framework::CopyVariable(*var, tmp_var.get());
      VLOG(3) << "send to " << table_name << " with queue size "
              << queue->Size();
      queue->Push(tmp_var);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "unknown var type to copy, only support LoDTensor/SelectedRows"));
    }
  }
313
}
314

315 316 317 318
void HalfAsyncCommunicator::Clean() {
  for (auto &iter : send_varname_to_queue_) {
    auto &var_name = iter.first;
    auto &var_queue = iter.second;
319

320 321
    while (var_queue->Size() > 0) {
      var_queue->Pop();
322 323
    }

324 325 326 327
    VLOG(3) << "clean var: " << var_name << " done";
  }
}

T
tangwei12 已提交
328
int HalfAsyncCommunicator::BatchesCounter() {
329 330 331 332 333 334
  while (running_) {
    if (barrier_counter_.load() >= barrier_trigger_.load() &&
        barrier_trigger_.load() != 0) {
      break;
    } else {
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
335
    }
336
  }
337

338 339
  return barrier_counter_.load();
}
C
Chengmo 已提交
340

341 342
void HalfAsyncCommunicator::Barrier() {
  barrier_counter_++;
343

344 345 346
  if (!running_) {
    VLOG(3) << "Communicator is not running, release barrier";
    return;
347 348
  }

349 350 351
  {
    std::unique_lock<std::mutex> lk(barrier_mutex_);
    barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); });
352
  }
353
}
354

355 356 357 358
void HalfAsyncCommunicator::BarrierTriggerDecrement() {
  barrier_trigger_--;
  VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to "
          << barrier_trigger_.load();
359 360
}

361 362 363 364 365
void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) {
  barrier_trigger_.store(initial_val);

  VLOG(3) << "BarrierTriggerReset reset barrier trigger to "
          << barrier_trigger_.load();
366 367
}

368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
void HalfAsyncCommunicator::BarrierWeakUp() {
  barrier_counter_.store(0);
  barrier_cond_.notify_all();
}

void SyncCommunicator::BarrierSend() {
  if (!running_) return;

  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);

  std::vector<distributed::VarHandlePtr> rets;

  for (auto &ep : pserver_endpoints_) {
    rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
383
  }
384 385 386 387 388 389 390

  for (size_t i = 0; i < rets.size(); i++) {
    PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
                                               "internal error in RPCClient"));
  }

  VLOG(4) << "BarrierSend with SyncCommunicator";
391 392
}

393 394 395 396 397 398 399 400 401
void SyncCommunicator::BarrierRecv() {
  if (!running_) return;

  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);

  std::vector<distributed::VarHandlePtr> rets;
  for (auto &ep : pserver_endpoints_) {
    rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
402 403
  }

404 405 406
  for (size_t i = 0; i < rets.size(); i++) {
    PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
                                               "internal error in RPCClient"));
407
  }
408 409

  VLOG(4) << "BarrierRecv with SyncCommunicator";
410 411
}

412 413 414 415 416 417
void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
                               const RpcCtxMap &recv_varname_to_ctx,
                               Scope *recv_scope) {
  send_varname_to_ctx_ = std::move(send_varname_to_ctx);
  recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
  recv_scope_ = std::move(recv_scope);
418

419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
  PADDLE_ENFORCE_GT(
      send_varname_to_ctx.size(), 0,
      platform::errors::InvalidArgument("send var contexts can not be zero"));

  send_scope_.reset(new Scope());
  for (auto &iter : send_varname_to_ctx_) {
    auto &varname = iter.first;

    if (varname == STEP_COUNTER) {
      send_varname_to_queue_[varname] =
          std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
              send_queue_size_);
    } else {
      auto &send_ctx = iter.second;

434
      send_var_nums_ += send_ctx.splited_varnames.size();
435
      if (!send_ctx.is_sparse) {
C
Chengmo 已提交
436
        continue;
437
      }
438 439 440 441 442 443 444 445 446 447
      int pserver_num = static_cast<int>(send_ctx.epmap.size());
      for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
        sparse_id_queues_.insert(
            std::pair<std::string, std::shared_ptr<BlockingQueue<
                                       std::shared_ptr<std::vector<int64_t>>>>>(
                send_ctx.splited_varnames[ep_idx],
                std::make_shared<
                    BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>(
                    send_queue_size_)));
      }
448 449
    }
  }
450
  send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
451

452 453 454 455
  if (recv_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be received, will not start recv_thread";
  } else {
    recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
456 457
  }

458 459 460
  delta_scope_.reset(new Scope());
  old_scope_.reset(new Scope());
  pserver_scope_.reset(new Scope());
461

T
tangwei12 已提交
462
  InitParams();
463 464
}

465 466 467 468
void GeoCommunicator::Send(const std::vector<std::string> &var_names,
                           const std::vector<std::string> &var_tables,
                           const framework::Scope &scope) {
  waiting_ = false;
469 470 471
  PADDLE_ENFORCE_EQ(
      var_tables.size(), 1,
      platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
472

473 474
  auto table_name = var_tables[0];
  if (table_name == STEP_COUNTER) return;
475

476 477 478
  auto before_send = GetCurrentUS();
  size_t splited_var_nums =
      send_varname_to_ctx_[table_name].splited_varnames.size();
479

480
  std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table;
481

482 483 484 485 486 487 488 489 490 491 492 493 494
  for (size_t j = 0; j < splited_var_nums; j++) {
    ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
        send_varname_to_ctx_[table_name].splited_varnames[j],
        std::unordered_set<int64_t>()));
  }
  auto *var = scope.FindVar(var_names[0]);
  auto &rows = var->Get<framework::SelectedRows>().rows();

  // insert ids which has not been record
  for (size_t j = 0; j < rows.size(); j++) {
    auto ep_idx = rows[j] % splited_var_nums;
    ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
        .insert(rows[j]);
495
  }
496

497 498 499 500 501 502 503 504 505 506 507
  auto before_push = GetCurrentUS();
  for (auto &iter : ids_table) {
    auto &key = iter.first;
    auto &sparse_ids_set = iter.second;
    auto sparse_ids_vec = std::make_shared<std::vector<int64_t>>();
    sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end());
    sparse_id_queues_.at(key)->Push(sparse_ids_vec);
    VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key
            << "'s queue";
  }
  auto after_send = GetCurrentUS();
508 509
  VLOG(3) << "run send " << table_name << " op finish. using "
          << (before_push - before_send) << "; " << (after_send - before_push);
510 511
}

512 513
void GeoCommunicator::MainThread() {
  VLOG(3) << "MainThread start and wait";
514

515 516 517 518
  while (waiting_ && running_) {
    std::this_thread::sleep_for(std::chrono::milliseconds(100));
    VLOG(3) << "wait for running";
  }
C
Chengmo 已提交
519

520 521 522
  while (running_) {
    std::vector<std::future<void>> tasks;
    tasks.reserve(send_var_nums_);
C
Chengmo 已提交
523

524 525 526 527
    for (auto &iter : send_varname_to_ctx_) {
      auto &var_name = iter.first;
      auto &send_ctx = iter.second;
      int pserver_num = static_cast<int>(send_ctx.epmap.size());
528
      if (send_ctx.is_sparse) {
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
        for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
          auto send_recv_task = [this, ep_idx, &var_name] {
            auto before_send_sparse = GetCurrentUS();
            if (var_name == STEP_COUNTER) {
              return;
            }
            auto send_varname =
                send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx];
            auto sparse_ids = MergeSparseIds(send_varname);
            if (sparse_ids.size() == 0) {
              return;
            }
            SendSparse(var_name, ep_idx, sparse_ids);
            auto after_send_sparse = GetCurrentUS();
            RecvSparse(var_name, ep_idx);
            auto after_recv_sparse = GetCurrentUS();
            VLOG(3)
                << "send recv "
                << send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx]
                << " finish, using " << (after_send_sparse - before_send_sparse)
                << " and " << (after_recv_sparse - after_send_sparse)
                << "; total = " << (after_recv_sparse - before_send_sparse);
          };
          tasks.emplace_back(
              send_threadpool_->enqueue(std::move(send_recv_task)));
        }
555
      } else {
556 557 558 559 560 561 562 563 564
        auto send_recv_task = [this, &var_name, &send_ctx] {
          if (var_name == STEP_COUNTER) {
            return;
          }
          SendDense(var_name);
          RecvDense(var_name);
        };
        tasks.emplace_back(
            send_threadpool_->enqueue(std::move(send_recv_task)));
565
      }
566 567 568 569
    }
    for (auto &task : tasks) {
      task.wait();
    }
570 571
  }
}
C
Chengmo 已提交
572

573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
std::vector<int64_t> GeoCommunicator::MergeSparseIds(
    const std::string &send_varname) {
  size_t merge_num = 0, wait_times = 0;
  std::unordered_set<int64_t> sparse_ids;
  while (merge_num < static_cast<size_t>(max_merge_var_num_)) {
    VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num;
    if (sparse_id_queues_.at(send_varname)->Size() > 0) {
      wait_times = 0;
      std::shared_ptr<std::vector<int64_t>> pop_ids =
          sparse_id_queues_.at(send_varname)->Pop();
      for (size_t j = 0; j < pop_ids->size(); j++) {
        sparse_ids.insert(pop_ids->at(j));
      }
      merge_num += 1;
      VLOG(3) << "sparse_id_queues_(" << send_varname << ") pushed";
    } else if (sparse_id_queues_.at(send_varname)->Size() == 0) {
      VLOG(3) << "wait_times -> " << wait_times;
      if (wait_times >= static_cast<size_t>(send_wait_times_)) {
        break;
      }
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
      wait_times++;
      continue;
    }
597
  }
598 599 600 601 602 603 604 605 606 607 608
  std::vector<int64_t> res;
  res.assign(sparse_ids.begin(), sparse_ids.end());
  return res;
}
void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx,
                                 const std::vector<int64_t> &sparse_ids) {
  auto &rpc_ctx = send_varname_to_ctx_.at(varname);
  auto send_varname = rpc_ctx.splited_varnames[ep_idx];
  auto trainer_id = rpc_ctx.trainer_id;
  auto endpoint = rpc_ctx.epmap[ep_idx];
  auto pserver_num = rpc_ctx.epmap.size();
609

610 611 612 613 614 615 616 617
  auto *var_latest = recv_scope_->FindVar(varname);

  PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
  auto &t_latest = var_latest->Get<framework::LoDTensor>();

  auto dims1 = t_latest.dims()[1];
C
Chengmo 已提交
618 619

  auto cpu_ctx = paddle::platform::CPUDeviceContext();
620
  auto *var_delta = delta_scope_->Var(send_varname);
621
  auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
622

623 624
  auto *t_value = t_delta->mutable_value();
  t_value->mutable_data<float>(
625
      framework::make_ddim({static_cast<int64_t>(sparse_ids.size()), dims1}),
626
      cpu_ctx.GetPlace());
C
Chengmo 已提交
627

628 629
  std::vector<std::vector<std::vector<float> *>> values;
  auto *ins = distributed::LargeScaleKV::GetInstance();
630
  ins->Get(varname)->Get(sparse_ids, {"Param"}, &values);
C
Chengmo 已提交
631

632 633
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  float coefficient = 1.0 / static_cast<float>(trainers_);
C
Chengmo 已提交
634

635 636
  for (auto j = 0; j < static_cast<int>(sparse_ids.size()); ++j) {
    blas.VSUB(dims1, t_latest.data<float>() + sparse_ids[j] * dims1,
637 638 639 640
              values[j][0]->data(), t_value->data<float>() + j * dims1);
    blas.SCAL(dims1, coefficient, t_value->data<float>() + j * dims1);
    blas.VADD(dims1, values[j][0]->data(), t_value->data<float>() + j * dims1,
              values[j][0]->data());
C
Chengmo 已提交
641 642
  }

643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
  std::vector<int64_t> send_rows;
  send_rows.reserve(sparse_ids.size());
  for (auto idx : sparse_ids) {
    send_rows.push_back(idx / pserver_num);
  }
  t_delta->set_height(rpc_ctx.height_sections[ep_idx]);
  t_delta->set_rows(send_rows);

  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &cpu_ctx_send = *pool.Get(platform::CPUPlace());
  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);

  auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send,
                                      *delta_scope_.get(), send_varname);
  ret->Wait();
659 660
}

661 662 663
void GeoCommunicator::SendDense(const std::string &varname) {
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_timestamp = old_scope_->FindVar(varname);
664

665 666 667 668 669 670
  PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
  PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
671

672 673
  auto &t_latest = var_latest->Get<framework::LoDTensor>();
  auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
C
Chengmo 已提交
674

675 676 677 678
  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
  t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace());
C
Chengmo 已提交
679

680 681 682
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  blas.VSUB(t_latest.numel(), t_latest.data<float>(),
            t_timestamp->data<float>(), t_delta->data<float>());
C
Chengmo 已提交
683

684 685
  float coefficient = 1.0 / static_cast<float>(trainers_);
  blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>());
C
Chengmo 已提交
686

687 688
  blas.VADD(t_latest.numel(), t_timestamp->data<float>(),
            t_delta->data<float>(), t_timestamp->data<float>());
689

690 691 692 693
  auto &ctx = send_varname_to_ctx_.at(varname);
  auto send = distributed::ParameterSend<float>();
  send(ctx, *delta_scope_, true, 1);
}
694

695
void GeoCommunicator::RecvByCommunicator() { return; }
696

697 698 699 700 701 702
void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
  auto train_id = recv_varname_to_ctx_.at(varname).trainer_id;
  auto endpoint = recv_varname_to_ctx_.at(varname).epmap[ep_idx];
  auto splited_var_name =
      recv_varname_to_ctx_.at(varname).splited_varnames[ep_idx];
  auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size();
703

704 705 706 707
  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace());
  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(train_id);
708

709 710 711 712 713
  auto *var_psrever = pserver_scope_->Var(splited_var_name);
  auto handle = rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv,
                                        *pserver_scope_.get(), splited_var_name,
                                        splited_var_name, splited_var_name);
  handle->Wait();
714

715
  auto *var_latest = recv_scope_->FindVar(varname);
716

717 718 719 720
  PADDLE_ENFORCE_EQ(
      var_psrever->IsInitialized(), true,
      platform::errors::Unavailable(
          "%s in pserver scope is not initialized, please check", varname));
721

722 723 724
  std::vector<int64_t> ids;
  ids.assign(var_psrever->Get<framework::SelectedRows>().rows().begin(),
             var_psrever->Get<framework::SelectedRows>().rows().end());
725

726 727 728 729 730
  for (size_t j = 0; j < ids.size(); j++) {
    ids[j] = ids[j] * pserver_num + ep_idx;
  }

  VLOG(3) << "RecvSparse receive var: " << splited_var_name
731
          << " ids Size: " << ids.size();
732

733
  auto t_psrever = var_psrever->Get<framework::SelectedRows>().value();
734

735
  std::vector<std::vector<std::vector<float> *>> old_values;
736

737 738
  auto *ins = distributed::LargeScaleKV::GetInstance();
  ins->Get(varname)->Get(ids, {"Param"}, &old_values);
739

740
  auto *t_latest = var_latest->GetMutable<framework::LoDTensor>();
741

742 743
  auto dims1 = t_latest->dims()[1];
  auto numel = ids.size() * dims1;
744

745 746 747 748 749 750 751 752 753 754 755 756 757 758
  std::vector<float> v_delta;
  v_delta.resize(numel);

  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);

  for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
    blas.VSUB(dims1, t_psrever.data<float>() + j * dims1,
              old_values[j][0]->data(), v_delta.data() + j * dims1);
    blas.VADD(dims1, t_latest->data<float>() + ids[j] * dims1,
              v_delta.data() + j * dims1,
              t_latest->data<float>() + ids[j] * dims1);
    blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1,
               old_values[j][0]->data());
759 760 761
  }
}

762 763 764 765
void GeoCommunicator::RecvDense(const std::string &varname) {
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_timestamp = old_scope_->FindVar(varname);
  auto *var_psrever = pserver_scope_->Var(varname);
766

767 768
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
T
tangwei12 已提交
769
  recv(ctx, *pserver_scope_);
770

771 772 773 774
  PADDLE_ENFORCE_EQ(
      var_psrever->IsInitialized(), true,
      platform::errors::Unavailable(
          "%s in pserver scope is not initialized, please check", varname));
775

776 777 778
  auto t_psrever = var_psrever->Get<framework::LoDTensor>();
  auto t_latest = var_latest->GetMutable<framework::LoDTensor>();
  auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
779

780 781 782 783
  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
  t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace());
784

785 786 787 788 789 790 791
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  blas.VSUB(t_latest->numel(), t_psrever.data<float>(),
            t_timestamp->data<float>(), t_delta->data<float>());
  blas.VADD(t_latest->numel(), t_latest->data<float>(), t_delta->data<float>(),
            t_latest->data<float>());
  blas.VCOPY(t_latest->numel(), t_psrever.data<float>(),
             t_timestamp->data<float>());
792 793
}

T
tangwei12 已提交
794
void GeoCommunicator::InitParams() {
795 796
  std::vector<std::future<void>> tasks;
  tasks.reserve(recv_varname_to_ctx_.size());
797

798 799 800
  for (auto &iter : recv_varname_to_ctx_) {
    auto &var_name = iter.first;
    auto &recv_ctx = iter.second;
801

802 803 804 805 806 807
    auto recv_task = [this, &var_name, &recv_ctx] {
      if (!recv_ctx.is_sparse) {
        InitDense(var_name);
      }
    };
    tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
808 809
  }

810 811
  for (auto &task : tasks) {
    task.wait();
812
  }
813
  InitSparse();
814
}
815

816 817 818
void GeoCommunicator::InitDense(const std::string varname) {
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
T
tangwei12 已提交
819 820 821 822 823 824 825 826 827
  recv(ctx, *recv_scope_);

  auto *global_var = recv_scope_->FindVar(varname);
  global_var->GetMutable<framework::LoDTensor>();

  auto *old_var = old_scope_->Var(varname);
  old_var->GetMutable<framework::LoDTensor>();

  framework::CopyVariable(*global_var, old_var);
828 829
  VLOG(1) << "init dense variable " << varname << " done";
}
T
tangwei12 已提交
830

831 832
void GeoCommunicator::InitSparse() {
  auto sparse_metas = string::split_string<std::string>(sparse_attrs_, "#");
T
tangwei12 已提交
833

834 835
  std::vector<distributed::SparseMeta> metas;
  std::vector<int64_t> dicts;
T
tangwei12 已提交
836

837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854
  for (auto &sparse_meta : sparse_metas) {
    auto attrs = string::split_string<std::string>(sparse_meta, ":");

    auto meta = distributed::SparseMeta();
    meta.name = attrs[0];
    meta.value_names = {"Param"};

    auto dic = string::split_string<std::string>(attrs[1], ",");
    dicts.push_back(std::stoi(dic[0]));
    meta.value_dims = {std::stoi(dic[1])};
    meta.mode = distributed::Mode::training;
    meta.grad_name = "none";
    meta.cached_varnames = {};
    meta.initializer_attrs = string::split_string<std::string>(attrs[2]);
    meta.entry = "none";

    VLOG(3) << "add sparse meta: " << meta.ToString();
    metas.push_back(meta);
T
tangwei12 已提交
855 856
  }

857
  LargeScaleKV::Init(metas);
T
tangwei12 已提交
858

T
tangwei12 已提交
859 860 861
  for (auto &meta : metas) {
    auto &ctx = recv_varname_to_ctx_.at(meta.name);
    auto recv = distributed::ParameterRecv<float>();
T
tangwei12 已提交
862

T
tangwei12 已提交
863 864 865 866
    auto *global_var = recv_scope_->FindVar(meta.name);
    auto global_value = global_var->Get<framework::LoDTensor>();
    auto rows = global_value.dims()[0];
    auto dim1 = global_value.dims()[1];
T
tangwei12 已提交
867

T
tangwei12 已提交
868 869 870 871 872 873 874 875 876 877 878 879
    recv(ctx, *recv_scope_);
    VLOG(1) << "recv " << meta.name << " with global scope for init";

    auto n_rows = global_var->Get<framework::LoDTensor>().dims()[0];

    PADDLE_ENFORCE_EQ(
        rows, n_rows,
        platform::errors::InvalidArgument(
            "global var: %s origin dim must equal recved rows", meta.name));

    std::vector<int64_t> ids(rows);
    std::iota(ids.begin(), ids.end(), 0);
T
tangwei12 已提交
880

881
    auto *ins = distributed::LargeScaleKV::GetInstance();
T
tangwei12 已提交
882 883 884 885
    std::vector<std::vector<std::vector<float> *>> values;

    ins->Get(meta.name)->Init(ids);
    ins->Get(meta.name)->Get(ids, {"Param"}, &values);
886

T
tangwei12 已提交
887 888 889 890 891 892 893
    auto blas = math::GetBlas<platform::CPUDeviceContext, float>(
        paddle::platform::CPUDeviceContext());

    for (auto &id : ids) {
      blas.VCOPY(dim1, global_value.data<float>() + id * dim1,
                 values[id][0]->data());
    }
T
tangwei12 已提交
894 895
  }

896
  VLOG(3) << "init sparse variable done";
T
tangwei12 已提交
897 898
}

Q
Qiao Longfei 已提交
899 900 901
}  // namespace distributed
}  // namespace operators
}  // namespace paddle