communicator.cc 29.9 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/distributed/communicator.h"
16

Q
Qiao Longfei 已提交
17
#include <gflags/gflags.h>
18
#include <paddle/fluid/framework/program_desc.h>
19

20
#include <algorithm>
Q
Qiao Longfei 已提交
21
#include <chrono>  // NOLINT
22
#include <map>
Q
Qiao Longfei 已提交
23
#include <thread>  // NOLINT
24
#include <unordered_set>
25

Q
Qiao Longfei 已提交
26
#include "paddle/fluid/framework/eigen.h"
Q
Qiao Longfei 已提交
27 28
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
29
#include "paddle/fluid/framework/threadpool.h"
Q
Qiao Longfei 已提交
30
#include "paddle/fluid/framework/variable_helper.h"
C
Chengmo 已提交
31
#include "paddle/fluid/operators/distributed/distributed.h"
Q
Qiao Longfei 已提交
32 33
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
T
tangwei12 已提交
34
#include "paddle/fluid/string/printf.h"
35
#include "paddle/fluid/string/split.h"
Q
Qiao Longfei 已提交
36

Q
Qiao Longfei 已提交
37 38 39 40
namespace paddle {
namespace operators {
namespace distributed {

41 42 43 44
using Tree =
    std::map<std::string, std::map<std::string, std::vector<std::string>>>;
using RpcCtxMap = operators::distributed::RpcCtxMap;

Q
Qiao Longfei 已提交
45 46 47 48 49 50
inline double GetCurrentUS() {
  struct timeval time;
  gettimeofday(&time, NULL);
  return 1e+6 * time.tv_sec + time.tv_usec;
}

51
Communicator::Communicator() {}
1
123malin 已提交
52

T
tangwei12 已提交
53
std::once_flag Communicator::init_flag_;
54
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
Q
can run  
Qiao Longfei 已提交
55

T
tangwei12 已提交
56 57 58 59 60 61 62
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
                                 const RpcCtxMap &recv_varname_to_ctx,
                                 Scope *recv_scope) {
  send_varname_to_ctx_ = std::move(send_varname_to_ctx);
  recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
  recv_scope_ = std::move(recv_scope);

63 64 65 66 67 68 69
  if (send_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be send, will not start send_thread";
  } else {
    send_scope_.reset(new Scope());
    for (auto &iter : send_varname_to_ctx_) {
      send_varname_to_queue_[iter.first] =
          std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
70
              send_queue_size_);
71
    }
72
    send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
73 74 75 76 77
  }

  if (recv_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be received, will not start recv_thread";
  } else {
78
    recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
Q
Qiao Longfei 已提交
79
  }
T
tangwei12 已提交
80 81

  InitParams();
Q
Qiao Longfei 已提交
82 83
}

T
tangwei12 已提交
84 85
void AsyncCommunicator::InitParams() { RecvNoBarrier(); }

86 87 88 89 90 91 92 93
AsyncCommunicator::~AsyncCommunicator() {
  running_ = false;
  if (main_thread_) main_thread_->join();
}

void AsyncCommunicator::SendGlobalStep(int batches) {
  if (!need_global_step_) {
    return;
T
tangwei12 已提交
94 95
  }

96 97
  if (batches == 0) {
    return;
T
tangwei12 已提交
98 99
  }

100 101 102 103 104
  auto &var_name = STEP_COUNTER;
  auto *out_var = send_scope_->Var(var_name);
  auto *out_t = out_var->GetMutable<framework::LoDTensor>();
  auto *data = out_t->mutable_data<int64_t>({1}, platform::CPUPlace());
  data[0] = static_cast<int64_t>(batches);
T
tangwei12 已提交
105

106 107 108
  auto &ctx = send_varname_to_ctx_.at(var_name);
  auto send_functor = distributed::ParameterSend<float>();
  send_functor(ctx, *send_scope_, true, 1);
Q
Qiao Longfei 已提交
109 110
}

111 112 113 114 115 116 117 118 119 120 121 122
void AsyncCommunicator::SendByCommunicator(int batches) {
  std::vector<std::future<void>> task_futures;
  task_futures.reserve(send_varname_to_ctx_.size());
  VLOG(3) << "run send graph";
  auto before_run_send_graph = GetCurrentUS();
  for (auto &iter : send_varname_to_queue_) {
    auto &var_name = iter.first;
    auto &var_queue = iter.second;

    auto send_task = [this, batches, &var_name, &var_queue] {
      if (var_name == STEP_COUNTER) {
        return;
Q
Qiao Longfei 已提交
123
      }
124

125 126 127
      VLOG(3) << var_name << " merge and send";
      std::vector<std::shared_ptr<Variable>> vars;
      vars.reserve(batches);
Q
Qiao Longfei 已提交
128

129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
      for (int i = 0; i < batches; ++i) {
        vars.push_back(var_queue->Pop());
      }

      auto &ctx = send_varname_to_ctx_.at(var_name);

      auto before_merge = GetCurrentUS();
      MergeVars<float>(var_name, vars, send_scope_.get(), ctx.merge_add);
      auto after_merge = GetCurrentUS();
      VLOG(3) << "merge " << batches << " " << var_name << " use time "
              << after_merge - before_merge;

      auto send_functor = distributed::ParameterSend<float>();
      send_functor(ctx, *send_scope_, true, 1);
      auto after_send = GetCurrentUS();
      VLOG(3) << "send " << var_name << " use time "
              << after_send - after_merge;
    };
    task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
Q
Qiao Longfei 已提交
148
  }
149 150 151 152 153 154 155
  for (auto &task_f : task_futures) {
    task_f.wait();
  }
  auto after_run_send_graph = GetCurrentUS();

  VLOG(3) << "run send graph use time "
          << after_run_send_graph - before_run_send_graph;
Q
Qiao Longfei 已提交
156 157
}

158 159 160 161 162 163
void AsyncCommunicator::MainThread() {
  VLOG(3) << "MainThread start and wait";

  while (waiting_ && running_) {
    std::this_thread::sleep_for(std::chrono::milliseconds(100));
    VLOG(3) << "wait for running";
164 165
  }

166
  while (running_) {
T
tangwei12 已提交
167 168 169 170 171 172 173 174 175 176 177 178
    int batches = BatchesCounter();

    if (batches > 0) {
      SendGlobalStep(batches);
      SendByCommunicator(batches);
      BarrierSend();
      RecvByCommunicator();
      BarrierRecv();
      BarrierWeakUp();
    } else {
      VLOG(1) << "get nothing from sending queue, will skip send/recv";
    }
179
  }
180
  VLOG(1) << "communicator stopped, send thread exit";
181 182
}

183
void AsyncCommunicator::RecvByCommunicator() {
T
tangwei12 已提交
184 185
  VLOG(3) << "parallel run recv graph";
  if (!running_) return;
186 187 188 189 190
  RecvNoBarrier();
  VLOG(3) << "run recv graph use time";
}

void AsyncCommunicator::RecvNoBarrier() {
T
tangwei12 已提交
191 192
  std::vector<std::future<void>> task_futures;
  task_futures.reserve(recv_varname_to_ctx_.size());
193

T
tangwei12 已提交
194 195 196 197 198
  for (auto &iter : recv_varname_to_ctx_) {
    auto recv_task = [this, &iter] {
      auto &var_name = iter.first;
      VLOG(4) << "recv var " << var_name;
      auto recv_functor = distributed::ParameterRecv<float>();
T
tangwei12 已提交
199
      recv_functor(iter.second, *recv_scope_);
T
tangwei12 已提交
200 201 202
    };
    task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
  }
203

T
tangwei12 已提交
204 205 206
  for (auto &task : task_futures) {
    task.wait();
  }
207 208
}

T
tangwei12 已提交
209
int AsyncCommunicator::BatchesCounter() {
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
  auto &step_queue = send_varname_to_queue_.at(STEP_COUNTER);

  size_t merged_var_num = 0;
  size_t wait_times = 0;

  while (merged_var_num < static_cast<size_t>(max_merge_var_num_)) {
    if (step_queue->Size() == 0) {
      VLOG(3) << "wait_times -> " << wait_times;
      if (wait_times >= static_cast<size_t>(send_wait_times_)) {
        break;
      }
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
      wait_times++;
      continue;
    } else {
      step_queue->Pop();
      wait_times = 0;
      merged_var_num++;
    }
  }

  return merged_var_num;
232 233
}

T
tangwei12 已提交
234
void AsyncCommunicator::Start() {
235
  VLOG(1) << "Communicator start";
236 237 238 239
  if (!communicator_) {
    VLOG(0) << "Communicator is not inited, do nothing";
  } else {
    VLOG(1) << "start send thread and recv thread";
240
    waiting_ = true;
241
    running_ = true;
242
    BarrierTriggerReset(max_merge_var_num_);
243
    // start send and recv thread
244 245
    main_thread_.reset(
        new std::thread(std::bind(&AsyncCommunicator::MainThread, this)));
246 247 248
  }
}

T
tangwei12 已提交
249
void AsyncCommunicator::Stop() {
250
  VLOG(1) << "Communicator stop";
251 252 253 254
  running_ = false;
  if (!communicator_) {
    VLOG(0) << "Communicator is not inited, do nothing";
  } else {
255
    if (main_thread_) {
256
      VLOG(1) << "stop send thread";
257 258
      main_thread_->join();
      main_thread_.reset(nullptr);
259
    }
Q
Qiao Longfei 已提交
260
  }
261
  VLOG(1) << "Communicator stop done";
Q
Qiao Longfei 已提交
262 263
}

264 265 266
void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
                             const std::vector<std::string> &var_tables,
                             const framework::Scope &scope) {
267 268
  waiting_ = false;

269
  PADDLE_ENFORCE_EQ(
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
      var_tables.size(), 1,
      platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));

  auto table_name = var_tables[0];
  auto &queue = send_varname_to_queue_.at(table_name);

  if (table_name == STEP_COUNTER) {
    auto tmp_var = std::make_shared<Variable>();
    auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
    tensor->Resize(framework::make_ddim({1}));
    auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
    out_d[0] = 1;
    VLOG(3) << "send to " << table_name << " with queue size " << queue->Size();
    queue->Push(tmp_var);
  } else {
    PADDLE_ENFORCE_GE(var_names.size(), 1,
                      platform::errors::InvalidArgument(
                          "var_names.size() >= 1 is permitted"));

    auto *var = scope.FindVar(var_names[0]);

    PADDLE_ENFORCE_EQ(
        var->IsInitialized(), true,
        platform::errors::InvalidArgument("grad var should be inited"));

    auto tmp_var = std::make_shared<Variable>();
    if (var->IsType<framework::SelectedRows>()) {
      framework::CopyVariable(*var, tmp_var.get());
      VLOG(3) << "send to " << table_name << " with queue size "
              << queue->Size();
      queue->Push(tmp_var);
    } else if (var->IsType<framework::LoDTensor>()) {
      // push var into send queue by var_name
      auto var_name = var_names[0];
      framework::CopyVariable(*var, tmp_var.get());
      VLOG(3) << "send to " << table_name << " with queue size "
              << queue->Size();
      queue->Push(tmp_var);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "unknown var type to copy, only support LoDTensor/SelectedRows"));
    }
  }
313
}
314

315 316 317 318
void HalfAsyncCommunicator::Clean() {
  for (auto &iter : send_varname_to_queue_) {
    auto &var_name = iter.first;
    auto &var_queue = iter.second;
319

320 321
    while (var_queue->Size() > 0) {
      var_queue->Pop();
322 323
    }

324 325 326 327
    VLOG(3) << "clean var: " << var_name << " done";
  }
}

T
tangwei12 已提交
328
int HalfAsyncCommunicator::BatchesCounter() {
329 330 331 332 333 334
  while (running_) {
    if (barrier_counter_.load() >= barrier_trigger_.load() &&
        barrier_trigger_.load() != 0) {
      break;
    } else {
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
335
    }
336
  }
337

338 339
  return barrier_counter_.load();
}
C
Chengmo 已提交
340

341 342
void HalfAsyncCommunicator::Barrier() {
  barrier_counter_++;
343

344 345 346
  if (!running_) {
    VLOG(3) << "Communicator is not running, release barrier";
    return;
347 348
  }

349 350 351
  {
    std::unique_lock<std::mutex> lk(barrier_mutex_);
    barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); });
352
  }
353
}
354

355 356 357 358
void HalfAsyncCommunicator::BarrierTriggerDecrement() {
  barrier_trigger_--;
  VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to "
          << barrier_trigger_.load();
359 360
}

361 362 363 364 365
void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) {
  barrier_trigger_.store(initial_val);

  VLOG(3) << "BarrierTriggerReset reset barrier trigger to "
          << barrier_trigger_.load();
366 367
}

368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
void HalfAsyncCommunicator::BarrierWeakUp() {
  barrier_counter_.store(0);
  barrier_cond_.notify_all();
}

void SyncCommunicator::BarrierSend() {
  if (!running_) return;

  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);

  std::vector<distributed::VarHandlePtr> rets;

  for (auto &ep : pserver_endpoints_) {
    rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
383
  }
384 385 386 387 388 389 390

  for (size_t i = 0; i < rets.size(); i++) {
    PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
                                               "internal error in RPCClient"));
  }

  VLOG(4) << "BarrierSend with SyncCommunicator";
391 392
}

393 394 395 396 397 398 399 400 401
void SyncCommunicator::BarrierRecv() {
  if (!running_) return;

  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);

  std::vector<distributed::VarHandlePtr> rets;
  for (auto &ep : pserver_endpoints_) {
    rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
402 403
  }

404 405 406
  for (size_t i = 0; i < rets.size(); i++) {
    PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
                                               "internal error in RPCClient"));
407
  }
408 409

  VLOG(4) << "BarrierRecv with SyncCommunicator";
410 411
}

412 413 414 415 416 417
void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
                               const RpcCtxMap &recv_varname_to_ctx,
                               Scope *recv_scope) {
  send_varname_to_ctx_ = std::move(send_varname_to_ctx);
  recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
  recv_scope_ = std::move(recv_scope);
418

419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
  PADDLE_ENFORCE_GT(
      send_varname_to_ctx.size(), 0,
      platform::errors::InvalidArgument("send var contexts can not be zero"));

  send_scope_.reset(new Scope());
  for (auto &iter : send_varname_to_ctx_) {
    auto &varname = iter.first;

    if (varname == STEP_COUNTER) {
      send_varname_to_queue_[varname] =
          std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
              send_queue_size_);
    } else {
      auto &send_ctx = iter.second;

434
      send_var_nums_ += send_ctx.splited_varnames.size();
435
      if (!send_ctx.is_sparse) {
C
Chengmo 已提交
436
        continue;
437
      }
438 439 440 441 442 443 444 445 446 447
      int pserver_num = static_cast<int>(send_ctx.epmap.size());
      for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
        sparse_id_queues_.insert(
            std::pair<std::string, std::shared_ptr<BlockingQueue<
                                       std::shared_ptr<std::vector<int64_t>>>>>(
                send_ctx.splited_varnames[ep_idx],
                std::make_shared<
                    BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>(
                    send_queue_size_)));
      }
448 449
    }
  }
450
  send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
451

452 453 454 455
  if (recv_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be received, will not start recv_thread";
  } else {
    recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
456 457
  }

458 459 460
  delta_scope_.reset(new Scope());
  old_scope_.reset(new Scope());
  pserver_scope_.reset(new Scope());
461

T
tangwei12 已提交
462
  InitParams();
463 464
}

465 466 467 468
void GeoCommunicator::Send(const std::vector<std::string> &var_names,
                           const std::vector<std::string> &var_tables,
                           const framework::Scope &scope) {
  waiting_ = false;
469

470 471
  auto before_send = GetCurrentUS();
  std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table;
472

473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
  for (size_t i = 0; i < var_tables.size(); i++) {
    auto table_name = var_tables[i];
    if (table_name == STEP_COUNTER) {
      continue;
    } else {
      size_t splited_var_nums =
          send_varname_to_ctx_[table_name].splited_varnames.size();

      for (size_t j = 0; j < splited_var_nums; j++) {
        if (ids_table.find(
                send_varname_to_ctx_[table_name].splited_varnames[j]) ==
            ids_table.end()) {
          ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
              send_varname_to_ctx_[table_name].splited_varnames[j],
              std::unordered_set<int64_t>()));
        }
      }
490

491 492 493 494
      auto *var = scope.FindVar(var_names[i]);
      auto var_tensor = var->Get<framework::LoDTensor>();
      int element_number = var_tensor.numel();
      const int64_t *var_mutable_data = var_tensor.data<int64_t>();
495

496 497 498 499 500 501
      // insert ids which has not been record
      for (int j = 0; j < element_number; j++) {
        auto ep_idx = var_mutable_data[j] % splited_var_nums;
        ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
            .insert(var_mutable_data[j]);
      }
502
    }
503
  }
504 505 506 507 508 509 510 511 512 513 514 515 516
  auto before_push = GetCurrentUS();
  for (auto &iter : ids_table) {
    auto &key = iter.first;
    auto &sparse_ids_set = iter.second;
    auto sparse_ids_vec = std::make_shared<std::vector<int64_t>>();
    sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end());
    sparse_id_queues_.at(key)->Push(sparse_ids_vec);
    VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key
            << "'s queue";
  }
  auto after_send = GetCurrentUS();
  VLOG(3) << "run send_op finish. using " << (before_push - before_send) << "; "
          << (after_send - before_push);
517 518
}

519 520
void GeoCommunicator::MainThread() {
  VLOG(3) << "MainThread start and wait";
521

522 523 524 525
  while (waiting_ && running_) {
    std::this_thread::sleep_for(std::chrono::milliseconds(100));
    VLOG(3) << "wait for running";
  }
C
Chengmo 已提交
526

527 528 529
  while (running_) {
    std::vector<std::future<void>> tasks;
    tasks.reserve(send_var_nums_);
C
Chengmo 已提交
530

531 532 533 534
    for (auto &iter : send_varname_to_ctx_) {
      auto &var_name = iter.first;
      auto &send_ctx = iter.second;
      int pserver_num = static_cast<int>(send_ctx.epmap.size());
535
      if (send_ctx.is_sparse) {
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561
        for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
          auto send_recv_task = [this, ep_idx, &var_name] {
            auto before_send_sparse = GetCurrentUS();
            if (var_name == STEP_COUNTER) {
              return;
            }
            auto send_varname =
                send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx];
            auto sparse_ids = MergeSparseIds(send_varname);
            if (sparse_ids.size() == 0) {
              return;
            }
            SendSparse(var_name, ep_idx, sparse_ids);
            auto after_send_sparse = GetCurrentUS();
            RecvSparse(var_name, ep_idx);
            auto after_recv_sparse = GetCurrentUS();
            VLOG(3)
                << "send recv "
                << send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx]
                << " finish, using " << (after_send_sparse - before_send_sparse)
                << " and " << (after_recv_sparse - after_send_sparse)
                << "; total = " << (after_recv_sparse - before_send_sparse);
          };
          tasks.emplace_back(
              send_threadpool_->enqueue(std::move(send_recv_task)));
        }
562
      } else {
563 564 565 566 567 568 569 570 571
        auto send_recv_task = [this, &var_name, &send_ctx] {
          if (var_name == STEP_COUNTER) {
            return;
          }
          SendDense(var_name);
          RecvDense(var_name);
        };
        tasks.emplace_back(
            send_threadpool_->enqueue(std::move(send_recv_task)));
572
      }
573 574 575 576
    }
    for (auto &task : tasks) {
      task.wait();
    }
577 578
  }
}
C
Chengmo 已提交
579

580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
std::vector<int64_t> GeoCommunicator::MergeSparseIds(
    const std::string &send_varname) {
  size_t merge_num = 0, wait_times = 0;
  std::unordered_set<int64_t> sparse_ids;
  while (merge_num < static_cast<size_t>(max_merge_var_num_)) {
    VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num;
    if (sparse_id_queues_.at(send_varname)->Size() > 0) {
      wait_times = 0;
      std::shared_ptr<std::vector<int64_t>> pop_ids =
          sparse_id_queues_.at(send_varname)->Pop();
      for (size_t j = 0; j < pop_ids->size(); j++) {
        sparse_ids.insert(pop_ids->at(j));
      }
      merge_num += 1;
      VLOG(3) << "sparse_id_queues_(" << send_varname << ") pushed";
    } else if (sparse_id_queues_.at(send_varname)->Size() == 0) {
      VLOG(3) << "wait_times -> " << wait_times;
      if (wait_times >= static_cast<size_t>(send_wait_times_)) {
        break;
      }
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
      wait_times++;
      continue;
    }
604
  }
605 606 607 608 609 610 611 612 613 614 615
  std::vector<int64_t> res;
  res.assign(sparse_ids.begin(), sparse_ids.end());
  return res;
}
void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx,
                                 const std::vector<int64_t> &sparse_ids) {
  auto &rpc_ctx = send_varname_to_ctx_.at(varname);
  auto send_varname = rpc_ctx.splited_varnames[ep_idx];
  auto trainer_id = rpc_ctx.trainer_id;
  auto endpoint = rpc_ctx.epmap[ep_idx];
  auto pserver_num = rpc_ctx.epmap.size();
616

617 618 619 620 621 622 623 624
  auto *var_latest = recv_scope_->FindVar(varname);

  PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
  auto &t_latest = var_latest->Get<framework::LoDTensor>();

  auto dims1 = t_latest.dims()[1];
C
Chengmo 已提交
625 626

  auto cpu_ctx = paddle::platform::CPUDeviceContext();
627
  auto *var_delta = delta_scope_->Var(send_varname);
628
  auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
629

630 631
  auto *t_value = t_delta->mutable_value();
  t_value->mutable_data<float>(
632
      framework::make_ddim({static_cast<int64_t>(sparse_ids.size()), dims1}),
633
      cpu_ctx.GetPlace());
C
Chengmo 已提交
634

635 636
  std::vector<std::vector<std::vector<float> *>> values;
  auto *ins = distributed::LargeScaleKV::GetInstance();
637
  ins->Get(varname)->Get(sparse_ids, {"Param"}, &values);
C
Chengmo 已提交
638

639 640
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  float coefficient = 1.0 / static_cast<float>(trainers_);
C
Chengmo 已提交
641

642 643
  for (auto j = 0; j < static_cast<int>(sparse_ids.size()); ++j) {
    blas.VSUB(dims1, t_latest.data<float>() + sparse_ids[j] * dims1,
644 645 646 647
              values[j][0]->data(), t_value->data<float>() + j * dims1);
    blas.SCAL(dims1, coefficient, t_value->data<float>() + j * dims1);
    blas.VADD(dims1, values[j][0]->data(), t_value->data<float>() + j * dims1,
              values[j][0]->data());
C
Chengmo 已提交
648 649
  }

650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665
  std::vector<int64_t> send_rows;
  send_rows.reserve(sparse_ids.size());
  for (auto idx : sparse_ids) {
    send_rows.push_back(idx / pserver_num);
  }
  t_delta->set_height(rpc_ctx.height_sections[ep_idx]);
  t_delta->set_rows(send_rows);

  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &cpu_ctx_send = *pool.Get(platform::CPUPlace());
  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);

  auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send,
                                      *delta_scope_.get(), send_varname);
  ret->Wait();
666 667
}

668 669 670
void GeoCommunicator::SendDense(const std::string &varname) {
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_timestamp = old_scope_->FindVar(varname);
671

672 673 674 675 676 677
  PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
  PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
678

679 680
  auto &t_latest = var_latest->Get<framework::LoDTensor>();
  auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
C
Chengmo 已提交
681

682 683 684 685
  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
  t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace());
C
Chengmo 已提交
686

687 688 689
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  blas.VSUB(t_latest.numel(), t_latest.data<float>(),
            t_timestamp->data<float>(), t_delta->data<float>());
C
Chengmo 已提交
690

691 692
  float coefficient = 1.0 / static_cast<float>(trainers_);
  blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>());
C
Chengmo 已提交
693

694 695
  blas.VADD(t_latest.numel(), t_timestamp->data<float>(),
            t_delta->data<float>(), t_timestamp->data<float>());
696

697 698 699 700
  auto &ctx = send_varname_to_ctx_.at(varname);
  auto send = distributed::ParameterSend<float>();
  send(ctx, *delta_scope_, true, 1);
}
701

702
void GeoCommunicator::RecvByCommunicator() { return; }
703

704 705 706 707 708 709
void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
  auto train_id = recv_varname_to_ctx_.at(varname).trainer_id;
  auto endpoint = recv_varname_to_ctx_.at(varname).epmap[ep_idx];
  auto splited_var_name =
      recv_varname_to_ctx_.at(varname).splited_varnames[ep_idx];
  auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size();
710

711 712 713 714
  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace());
  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(train_id);
715

716 717 718 719 720
  auto *var_psrever = pserver_scope_->Var(splited_var_name);
  auto handle = rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv,
                                        *pserver_scope_.get(), splited_var_name,
                                        splited_var_name, splited_var_name);
  handle->Wait();
721

722
  auto *var_latest = recv_scope_->FindVar(varname);
723

724 725 726 727
  PADDLE_ENFORCE_EQ(
      var_psrever->IsInitialized(), true,
      platform::errors::Unavailable(
          "%s in pserver scope is not initialized, please check", varname));
728

729 730 731
  std::vector<int64_t> ids;
  ids.assign(var_psrever->Get<framework::SelectedRows>().rows().begin(),
             var_psrever->Get<framework::SelectedRows>().rows().end());
732

733 734 735 736 737
  for (size_t j = 0; j < ids.size(); j++) {
    ids[j] = ids[j] * pserver_num + ep_idx;
  }

  VLOG(3) << "RecvSparse receive var: " << splited_var_name
738
          << " ids Size: " << ids.size();
739

740
  auto t_psrever = var_psrever->Get<framework::SelectedRows>().value();
741

742
  std::vector<std::vector<std::vector<float> *>> old_values;
743

744 745
  auto *ins = distributed::LargeScaleKV::GetInstance();
  ins->Get(varname)->Get(ids, {"Param"}, &old_values);
746

747
  auto *t_latest = var_latest->GetMutable<framework::LoDTensor>();
748

749 750
  auto dims1 = t_latest->dims()[1];
  auto numel = ids.size() * dims1;
751

752 753 754 755 756 757 758 759 760 761 762 763 764 765
  std::vector<float> v_delta;
  v_delta.resize(numel);

  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);

  for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
    blas.VSUB(dims1, t_psrever.data<float>() + j * dims1,
              old_values[j][0]->data(), v_delta.data() + j * dims1);
    blas.VADD(dims1, t_latest->data<float>() + ids[j] * dims1,
              v_delta.data() + j * dims1,
              t_latest->data<float>() + ids[j] * dims1);
    blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1,
               old_values[j][0]->data());
766 767 768
  }
}

769 770 771 772
void GeoCommunicator::RecvDense(const std::string &varname) {
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_timestamp = old_scope_->FindVar(varname);
  auto *var_psrever = pserver_scope_->Var(varname);
773

774 775
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
T
tangwei12 已提交
776
  recv(ctx, *pserver_scope_);
777

778 779 780 781
  PADDLE_ENFORCE_EQ(
      var_psrever->IsInitialized(), true,
      platform::errors::Unavailable(
          "%s in pserver scope is not initialized, please check", varname));
782

783 784 785
  auto t_psrever = var_psrever->Get<framework::LoDTensor>();
  auto t_latest = var_latest->GetMutable<framework::LoDTensor>();
  auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
786

787 788 789 790
  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
  t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace());
791

792 793 794 795 796 797 798
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  blas.VSUB(t_latest->numel(), t_psrever.data<float>(),
            t_timestamp->data<float>(), t_delta->data<float>());
  blas.VADD(t_latest->numel(), t_latest->data<float>(), t_delta->data<float>(),
            t_latest->data<float>());
  blas.VCOPY(t_latest->numel(), t_psrever.data<float>(),
             t_timestamp->data<float>());
799 800
}

T
tangwei12 已提交
801
void GeoCommunicator::InitParams() {
802 803
  std::vector<std::future<void>> tasks;
  tasks.reserve(recv_varname_to_ctx_.size());
804

805 806 807
  for (auto &iter : recv_varname_to_ctx_) {
    auto &var_name = iter.first;
    auto &recv_ctx = iter.second;
808

809 810 811 812 813 814
    auto recv_task = [this, &var_name, &recv_ctx] {
      if (!recv_ctx.is_sparse) {
        InitDense(var_name);
      }
    };
    tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
815 816
  }

817 818
  for (auto &task : tasks) {
    task.wait();
819
  }
820
  InitSparse();
821
}
822

823 824 825
void GeoCommunicator::InitDense(const std::string varname) {
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
T
tangwei12 已提交
826 827 828 829 830 831 832 833 834
  recv(ctx, *recv_scope_);

  auto *global_var = recv_scope_->FindVar(varname);
  global_var->GetMutable<framework::LoDTensor>();

  auto *old_var = old_scope_->Var(varname);
  old_var->GetMutable<framework::LoDTensor>();

  framework::CopyVariable(*global_var, old_var);
835 836
  VLOG(1) << "init dense variable " << varname << " done";
}
T
tangwei12 已提交
837

838 839
void GeoCommunicator::InitSparse() {
  auto sparse_metas = string::split_string<std::string>(sparse_attrs_, "#");
T
tangwei12 已提交
840

841 842
  std::vector<distributed::SparseMeta> metas;
  std::vector<int64_t> dicts;
T
tangwei12 已提交
843

844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861
  for (auto &sparse_meta : sparse_metas) {
    auto attrs = string::split_string<std::string>(sparse_meta, ":");

    auto meta = distributed::SparseMeta();
    meta.name = attrs[0];
    meta.value_names = {"Param"};

    auto dic = string::split_string<std::string>(attrs[1], ",");
    dicts.push_back(std::stoi(dic[0]));
    meta.value_dims = {std::stoi(dic[1])};
    meta.mode = distributed::Mode::training;
    meta.grad_name = "none";
    meta.cached_varnames = {};
    meta.initializer_attrs = string::split_string<std::string>(attrs[2]);
    meta.entry = "none";

    VLOG(3) << "add sparse meta: " << meta.ToString();
    metas.push_back(meta);
T
tangwei12 已提交
862 863
  }

864
  LargeScaleKV::Init(metas);
T
tangwei12 已提交
865

T
tangwei12 已提交
866 867 868
  for (auto &meta : metas) {
    auto &ctx = recv_varname_to_ctx_.at(meta.name);
    auto recv = distributed::ParameterRecv<float>();
T
tangwei12 已提交
869

T
tangwei12 已提交
870 871 872 873
    auto *global_var = recv_scope_->FindVar(meta.name);
    auto global_value = global_var->Get<framework::LoDTensor>();
    auto rows = global_value.dims()[0];
    auto dim1 = global_value.dims()[1];
T
tangwei12 已提交
874

T
tangwei12 已提交
875 876 877 878 879 880 881 882 883 884 885 886
    recv(ctx, *recv_scope_);
    VLOG(1) << "recv " << meta.name << " with global scope for init";

    auto n_rows = global_var->Get<framework::LoDTensor>().dims()[0];

    PADDLE_ENFORCE_EQ(
        rows, n_rows,
        platform::errors::InvalidArgument(
            "global var: %s origin dim must equal recved rows", meta.name));

    std::vector<int64_t> ids(rows);
    std::iota(ids.begin(), ids.end(), 0);
T
tangwei12 已提交
887

888
    auto *ins = distributed::LargeScaleKV::GetInstance();
T
tangwei12 已提交
889 890 891 892
    std::vector<std::vector<std::vector<float> *>> values;

    ins->Get(meta.name)->Init(ids);
    ins->Get(meta.name)->Get(ids, {"Param"}, &values);
893

T
tangwei12 已提交
894 895 896 897 898 899 900
    auto blas = math::GetBlas<platform::CPUDeviceContext, float>(
        paddle::platform::CPUDeviceContext());

    for (auto &id : ids) {
      blas.VCOPY(dim1, global_value.data<float>() + id * dim1,
                 values[id][0]->data());
    }
T
tangwei12 已提交
901 902
  }

903
  VLOG(3) << "init sparse variable done";
T
tangwei12 已提交
904 905
}

Q
Qiao Longfei 已提交
906 907 908
}  // namespace distributed
}  // namespace operators
}  // namespace paddle