communicator.cc 26.8 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/distributed/communicator.h"
Q
Qiao Longfei 已提交
16
#include <gflags/gflags.h>
17
#include <paddle/fluid/framework/program_desc.h>
18
#include <algorithm>
Q
Qiao Longfei 已提交
19
#include <chrono>  // NOLINT
20
#include <map>
Q
Qiao Longfei 已提交
21
#include <thread>  // NOLINT
22
#include <unordered_set>
Q
Qiao Longfei 已提交
23
#include "paddle/fluid/framework/eigen.h"
Q
Qiao Longfei 已提交
24 25
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
26
#include "paddle/fluid/framework/threadpool.h"
Q
Qiao Longfei 已提交
27
#include "paddle/fluid/framework/variable_helper.h"
C
Chengmo 已提交
28
#include "paddle/fluid/operators/distributed/distributed.h"
Q
Qiao Longfei 已提交
29 30
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
T
tangwei12 已提交
31
#include "paddle/fluid/string/printf.h"
32
#include "paddle/fluid/string/split.h"
Q
Qiao Longfei 已提交
33

Q
Qiao Longfei 已提交
34 35 36 37
namespace paddle {
namespace operators {
namespace distributed {

38 39 40 41
using Tree =
    std::map<std::string, std::map<std::string, std::vector<std::string>>>;
using RpcCtxMap = operators::distributed::RpcCtxMap;

Q
Qiao Longfei 已提交
42 43 44 45 46 47
inline double GetCurrentUS() {
  struct timeval time;
  gettimeofday(&time, NULL);
  return 1e+6 * time.tv_sec + time.tv_usec;
}

48
Communicator::Communicator() {}
1
123malin 已提交
49

T
tangwei12 已提交
50
std::once_flag Communicator::init_flag_;
51
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
Q
can run  
Qiao Longfei 已提交
52

T
tangwei12 已提交
53 54 55 56 57 58 59
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
                                 const RpcCtxMap &recv_varname_to_ctx,
                                 Scope *recv_scope) {
  send_varname_to_ctx_ = std::move(send_varname_to_ctx);
  recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
  recv_scope_ = std::move(recv_scope);

60 61 62 63 64 65 66
  if (send_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be send, will not start send_thread";
  } else {
    send_scope_.reset(new Scope());
    for (auto &iter : send_varname_to_ctx_) {
      send_varname_to_queue_[iter.first] =
          std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
67
              send_queue_size_);
68
    }
69
    send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
70 71 72 73 74
  }

  if (recv_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be received, will not start recv_thread";
  } else {
75
    recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
Q
Qiao Longfei 已提交
76 77 78
  }
}

79 80 81 82 83 84 85 86
AsyncCommunicator::~AsyncCommunicator() {
  running_ = false;
  if (main_thread_) main_thread_->join();
}

void AsyncCommunicator::SendGlobalStep(int batches) {
  if (!need_global_step_) {
    return;
T
tangwei12 已提交
87 88
  }

89 90
  if (batches == 0) {
    return;
T
tangwei12 已提交
91 92
  }

93 94 95 96 97
  auto &var_name = STEP_COUNTER;
  auto *out_var = send_scope_->Var(var_name);
  auto *out_t = out_var->GetMutable<framework::LoDTensor>();
  auto *data = out_t->mutable_data<int64_t>({1}, platform::CPUPlace());
  data[0] = static_cast<int64_t>(batches);
T
tangwei12 已提交
98

99 100 101
  auto &ctx = send_varname_to_ctx_.at(var_name);
  auto send_functor = distributed::ParameterSend<float>();
  send_functor(ctx, *send_scope_, true, 1);
Q
Qiao Longfei 已提交
102 103
}

104 105 106 107 108 109 110 111 112 113 114 115
void AsyncCommunicator::SendByCommunicator(int batches) {
  std::vector<std::future<void>> task_futures;
  task_futures.reserve(send_varname_to_ctx_.size());
  VLOG(3) << "run send graph";
  auto before_run_send_graph = GetCurrentUS();
  for (auto &iter : send_varname_to_queue_) {
    auto &var_name = iter.first;
    auto &var_queue = iter.second;

    auto send_task = [this, batches, &var_name, &var_queue] {
      if (var_name == STEP_COUNTER) {
        return;
Q
Qiao Longfei 已提交
116
      }
117

118 119 120
      VLOG(3) << var_name << " merge and send";
      std::vector<std::shared_ptr<Variable>> vars;
      vars.reserve(batches);
Q
Qiao Longfei 已提交
121

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
      for (int i = 0; i < batches; ++i) {
        vars.push_back(var_queue->Pop());
      }

      auto &ctx = send_varname_to_ctx_.at(var_name);

      auto before_merge = GetCurrentUS();
      MergeVars<float>(var_name, vars, send_scope_.get(), ctx.merge_add);
      auto after_merge = GetCurrentUS();
      VLOG(3) << "merge " << batches << " " << var_name << " use time "
              << after_merge - before_merge;

      auto send_functor = distributed::ParameterSend<float>();
      send_functor(ctx, *send_scope_, true, 1);
      auto after_send = GetCurrentUS();
      VLOG(3) << "send " << var_name << " use time "
              << after_send - after_merge;
    };
    task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
Q
Qiao Longfei 已提交
141
  }
142 143 144 145 146 147 148
  for (auto &task_f : task_futures) {
    task_f.wait();
  }
  auto after_run_send_graph = GetCurrentUS();

  VLOG(3) << "run send graph use time "
          << after_run_send_graph - before_run_send_graph;
Q
Qiao Longfei 已提交
149 150
}

151 152 153 154 155 156
void AsyncCommunicator::MainThread() {
  VLOG(3) << "MainThread start and wait";

  while (waiting_ && running_) {
    std::this_thread::sleep_for(std::chrono::milliseconds(100));
    VLOG(3) << "wait for running";
157 158
  }

159
  while (running_) {
160 161 162 163 164 165 166 167 168 169 170 171
    int batches = BatchesCounter();

    if (batches > 0) {
      SendGlobalStep(batches);
      SendByCommunicator(batches);
      BarrierSend();
      RecvByCommunicator();
      BarrierRecv();
      BarrierWeakUp();
    } else {
      VLOG(1) << "get nothing from sending queue, will skip send/recv";
    }
172
  }
173
  VLOG(1) << "communicator stopped, send thread exit";
174 175
}

176
void AsyncCommunicator::RecvByCommunicator() {
T
tangwei12 已提交
177 178
  VLOG(3) << "parallel run recv graph";
  if (!running_) return;
179 180 181 182 183
  RecvNoBarrier();
  VLOG(3) << "run recv graph use time";
}

void AsyncCommunicator::RecvNoBarrier() {
T
tangwei12 已提交
184 185
  std::vector<std::future<void>> task_futures;
  task_futures.reserve(recv_varname_to_ctx_.size());
186

T
tangwei12 已提交
187 188 189 190 191
  for (auto &iter : recv_varname_to_ctx_) {
    auto recv_task = [this, &iter] {
      auto &var_name = iter.first;
      VLOG(4) << "recv var " << var_name;
      auto recv_functor = distributed::ParameterRecv<float>();
192
      recv_functor(iter.second, *recv_scope_, false);
T
tangwei12 已提交
193 194 195
    };
    task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
  }
196

T
tangwei12 已提交
197 198 199
  for (auto &task : task_futures) {
    task.wait();
  }
200 201
}

202
int AsyncCommunicator::BatchesCounter() {
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
  auto &step_queue = send_varname_to_queue_.at(STEP_COUNTER);

  size_t merged_var_num = 0;
  size_t wait_times = 0;

  while (merged_var_num < static_cast<size_t>(max_merge_var_num_)) {
    if (step_queue->Size() == 0) {
      VLOG(3) << "wait_times -> " << wait_times;
      if (wait_times >= static_cast<size_t>(send_wait_times_)) {
        break;
      }
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
      wait_times++;
      continue;
    } else {
      step_queue->Pop();
      wait_times = 0;
      merged_var_num++;
    }
  }

  return merged_var_num;
225 226
}

T
tangwei12 已提交
227
void AsyncCommunicator::Start() {
228
  VLOG(1) << "Communicator start";
229 230 231 232
  if (!communicator_) {
    VLOG(0) << "Communicator is not inited, do nothing";
  } else {
    VLOG(1) << "start send thread and recv thread";
233
    waiting_ = true;
234
    running_ = true;
235
    BarrierTriggerReset(max_merge_var_num_);
236
    // start send and recv thread
237 238
    main_thread_.reset(
        new std::thread(std::bind(&AsyncCommunicator::MainThread, this)));
239 240 241
  }
}

T
tangwei12 已提交
242
void AsyncCommunicator::Stop() {
243
  VLOG(1) << "Communicator stop";
244 245 246 247
  running_ = false;
  if (!communicator_) {
    VLOG(0) << "Communicator is not inited, do nothing";
  } else {
248
    if (main_thread_) {
249
      VLOG(1) << "stop send thread";
250 251
      main_thread_->join();
      main_thread_.reset(nullptr);
252
    }
Q
Qiao Longfei 已提交
253
  }
254
  VLOG(1) << "Communicator stop done";
Q
Qiao Longfei 已提交
255 256
}

257 258 259
void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
                             const std::vector<std::string> &var_tables,
                             const framework::Scope &scope) {
260 261
  waiting_ = false;

262
  PADDLE_ENFORCE_EQ(
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
      var_tables.size(), 1,
      platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));

  auto table_name = var_tables[0];
  auto &queue = send_varname_to_queue_.at(table_name);

  if (table_name == STEP_COUNTER) {
    auto tmp_var = std::make_shared<Variable>();
    auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
    tensor->Resize(framework::make_ddim({1}));
    auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
    out_d[0] = 1;
    VLOG(3) << "send to " << table_name << " with queue size " << queue->Size();
    queue->Push(tmp_var);
  } else {
    PADDLE_ENFORCE_GE(var_names.size(), 1,
                      platform::errors::InvalidArgument(
                          "var_names.size() >= 1 is permitted"));

    auto *var = scope.FindVar(var_names[0]);

    PADDLE_ENFORCE_EQ(
        var->IsInitialized(), true,
        platform::errors::InvalidArgument("grad var should be inited"));

    auto tmp_var = std::make_shared<Variable>();
    if (var->IsType<framework::SelectedRows>()) {
      framework::CopyVariable(*var, tmp_var.get());
      VLOG(3) << "send to " << table_name << " with queue size "
              << queue->Size();
      queue->Push(tmp_var);
    } else if (var->IsType<framework::LoDTensor>()) {
      // push var into send queue by var_name
      auto var_name = var_names[0];
      framework::CopyVariable(*var, tmp_var.get());
      VLOG(3) << "send to " << table_name << " with queue size "
              << queue->Size();
      queue->Push(tmp_var);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "unknown var type to copy, only support LoDTensor/SelectedRows"));
    }
  }
306
}
307

308 309 310 311
void HalfAsyncCommunicator::Clean() {
  for (auto &iter : send_varname_to_queue_) {
    auto &var_name = iter.first;
    auto &var_queue = iter.second;
312

313 314
    while (var_queue->Size() > 0) {
      var_queue->Pop();
315 316
    }

317 318 319 320
    VLOG(3) << "clean var: " << var_name << " done";
  }
}

321
int HalfAsyncCommunicator::BatchesCounter() {
322 323 324 325 326 327
  while (running_) {
    if (barrier_counter_.load() >= barrier_trigger_.load() &&
        barrier_trigger_.load() != 0) {
      break;
    } else {
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
328
    }
329
  }
330

331 332
  return barrier_counter_.load();
}
C
Chengmo 已提交
333

334 335
void HalfAsyncCommunicator::Barrier() {
  barrier_counter_++;
336

337 338 339
  if (!running_) {
    VLOG(3) << "Communicator is not running, release barrier";
    return;
340 341
  }

342 343 344
  {
    std::unique_lock<std::mutex> lk(barrier_mutex_);
    barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); });
345
  }
346
}
347

348 349 350 351
void HalfAsyncCommunicator::BarrierTriggerDecrement() {
  barrier_trigger_--;
  VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to "
          << barrier_trigger_.load();
352 353
}

354 355 356 357 358
void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) {
  barrier_trigger_.store(initial_val);

  VLOG(3) << "BarrierTriggerReset reset barrier trigger to "
          << barrier_trigger_.load();
359 360
}

361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
void HalfAsyncCommunicator::BarrierWeakUp() {
  barrier_counter_.store(0);
  barrier_cond_.notify_all();
}

void SyncCommunicator::BarrierSend() {
  if (!running_) return;

  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);

  std::vector<distributed::VarHandlePtr> rets;

  for (auto &ep : pserver_endpoints_) {
    rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
376
  }
377 378 379 380 381 382 383

  for (size_t i = 0; i < rets.size(); i++) {
    PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
                                               "internal error in RPCClient"));
  }

  VLOG(4) << "BarrierSend with SyncCommunicator";
384 385
}

386 387 388 389 390 391 392 393 394
void SyncCommunicator::BarrierRecv() {
  if (!running_) return;

  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);

  std::vector<distributed::VarHandlePtr> rets;
  for (auto &ep : pserver_endpoints_) {
    rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
395 396
  }

397 398 399
  for (size_t i = 0; i < rets.size(); i++) {
    PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
                                               "internal error in RPCClient"));
400
  }
401 402

  VLOG(4) << "BarrierRecv with SyncCommunicator";
403 404
}

405 406 407 408 409 410
void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
                               const RpcCtxMap &recv_varname_to_ctx,
                               Scope *recv_scope) {
  send_varname_to_ctx_ = std::move(send_varname_to_ctx);
  recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
  recv_scope_ = std::move(recv_scope);
411

412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
  PADDLE_ENFORCE_GT(
      send_varname_to_ctx.size(), 0,
      platform::errors::InvalidArgument("send var contexts can not be zero"));

  send_scope_.reset(new Scope());
  for (auto &iter : send_varname_to_ctx_) {
    auto &varname = iter.first;

    if (varname == STEP_COUNTER) {
      send_varname_to_queue_[varname] =
          std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
              send_queue_size_);
    } else {
      auto &send_ctx = iter.second;

      if (!send_ctx.is_sparse) {
C
Chengmo 已提交
428
        continue;
429 430
      }

431 432 433
      send_ids_to_queue_[varname] =
          std::make_shared<BlockingQueue<std::vector<int64_t>>>(
              send_queue_size_);
434 435
    }
  }
436
  send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
437

438 439 440 441
  if (recv_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be received, will not start recv_thread";
  } else {
    recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
442 443
  }

444 445 446
  delta_scope_.reset(new Scope());
  old_scope_.reset(new Scope());
  pserver_scope_.reset(new Scope());
447

448
  Init();
449 450
}

451 452 453 454
void GeoCommunicator::Send(const std::vector<std::string> &var_names,
                           const std::vector<std::string> &var_tables,
                           const framework::Scope &scope) {
  waiting_ = false;
455

456 457 458
  PADDLE_ENFORCE_EQ(
      var_tables.size(), 1,
      platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
C
Chengmo 已提交
459

460
  auto table_name = var_tables[0];
461

462 463
  if (table_name == STEP_COUNTER) {
    auto &queue = send_varname_to_queue_.at(table_name);
464

465 466 467 468 469 470 471 472 473 474 475 476
    auto tmp_var = std::make_shared<Variable>();
    auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
    tensor->Resize(framework::make_ddim({1}));
    auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
    out_d[0] = 1;
    VLOG(3) << "send to " << table_name << " with queue size " << queue->Size();
    queue->Push(tmp_var);
  } else {
    auto &queue = send_ids_to_queue_.at(table_name);
    PADDLE_ENFORCE_EQ(var_names.size(), 1,
                      platform::errors::InvalidArgument(
                          "var_names.size() == 1 is permitted"));
477

478
    auto *var = scope.FindVar(var_names[0]);
479

480 481 482
    PADDLE_ENFORCE_EQ(
        var->IsInitialized(), true,
        platform::errors::InvalidArgument("grad var should be inited"));
483

484 485 486 487
    if (!var->IsType<framework::SelectedRows>()) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Only LodTensor can be send in GeoCommunicator::Send"));
    }
C
Chengmo 已提交
488

489 490 491 492
    std::vector<int64_t> ids;
    auto &rows = var->Get<framework::SelectedRows>().rows();
    ids.assign(rows.begin(), rows.end());
    queue->Push(ids);
493
  }
494 495 496 497 498 499 500 501 502
}

void GeoCommunicator::SendByCommunicator(int batches) {
  std::vector<std::future<void>> tasks;
  tasks.reserve(send_varname_to_ctx_.size());

  for (auto &iter : send_varname_to_ctx_) {
    auto &var_name = iter.first;
    auto &send_ctx = iter.second;
C
Chengmo 已提交
503

504 505 506 507
    auto send_task = [this, batches, &var_name, &send_ctx] {
      if (var_name == STEP_COUNTER) {
        return;
      }
C
Chengmo 已提交
508

509 510 511 512 513 514 515 516 517
      if (send_ctx.is_sparse) {
        SendSparse(var_name, batches);
      } else {
        VLOG(1) << "send dense " << var_name << " begin";
        SendDense(var_name);
        VLOG(1) << "send dense " << var_name << " done";
      }
    };
    tasks.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
C
Chengmo 已提交
518
  }
519

520 521 522 523
  for (auto &task : tasks) {
    task.wait();
  }
}
C
Chengmo 已提交
524

525 526 527
void GeoCommunicator::SendSparse(const std::string &varname, int batches) {
  std::vector<int64_t> ids;
  auto &ids_queue = send_ids_to_queue_.at(varname);
C
Chengmo 已提交
528

529 530 531
  for (int i = 0; i < batches; ++i) {
    auto pop_ids = ids_queue->Pop();
    std::copy(pop_ids.begin(), pop_ids.end(), back_inserter(ids));
532 533
  }

534 535 536 537 538 539
  auto size = ids.size();

  std::set<int64_t> st(ids.begin(), ids.end());
  ids.assign(st.begin(), st.end());
  VLOG(1) << "SendSparse receive var: " << varname << " unset: " << size
          << " set: " << ids.size();
C
Chengmo 已提交
540

541 542 543
  if (ids.empty()) {
    LOG(WARNING) << "WARNING: GEO has nothing to send, return directly ";
    return;
544 545
  }

546 547 548 549 550 551 552 553
  auto *var_latest = recv_scope_->FindVar(varname);

  PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
  auto &t_latest = var_latest->Get<framework::LoDTensor>();

  auto dims1 = t_latest.dims()[1];
C
Chengmo 已提交
554 555

  auto cpu_ctx = paddle::platform::CPUDeviceContext();
556 557 558 559 560 561 562 563
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
  t_delta->set_height(ids.size());
  t_delta->mutable_rows()->assign(ids.begin(), ids.end());
  auto *t_value = t_delta->mutable_value();
  t_value->mutable_data<float>(
      framework::make_ddim({static_cast<int64_t>(ids.size()), dims1}),
      cpu_ctx.GetPlace());
C
Chengmo 已提交
564

565 566 567
  std::vector<std::vector<std::vector<float> *>> values;
  auto *ins = distributed::LargeScaleKV::GetInstance();
  ins->Get(varname)->Get(ids, {"Param"}, &values);
C
Chengmo 已提交
568

569 570
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  float coefficient = 1.0 / static_cast<float>(trainers_);
C
Chengmo 已提交
571

572 573 574 575 576 577
  for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
    blas.VSUB(dims1, t_latest.data<float>() + ids[j] * dims1,
              values[j][0]->data(), t_value->data<float>() + j * dims1);
    blas.SCAL(dims1, coefficient, t_value->data<float>() + j * dims1);
    blas.VADD(dims1, values[j][0]->data(), t_value->data<float>() + j * dims1,
              values[j][0]->data());
C
Chengmo 已提交
578 579
  }

580 581 582
  auto &ctx = send_varname_to_ctx_.at(varname);
  auto send = distributed::ParameterSend<float>();
  send(ctx, *delta_scope_, true, 1);
583 584
}

585 586 587
void GeoCommunicator::SendDense(const std::string &varname) {
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_timestamp = old_scope_->FindVar(varname);
588

589 590 591 592 593 594
  PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
  PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
595

596 597
  auto &t_latest = var_latest->Get<framework::LoDTensor>();
  auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
C
Chengmo 已提交
598

599 600 601 602
  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
  t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace());
C
Chengmo 已提交
603

604 605 606
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  blas.VSUB(t_latest.numel(), t_latest.data<float>(),
            t_timestamp->data<float>(), t_delta->data<float>());
C
Chengmo 已提交
607

608 609
  float coefficient = 1.0 / static_cast<float>(trainers_);
  blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>());
C
Chengmo 已提交
610

611 612
  blas.VADD(t_latest.numel(), t_timestamp->data<float>(),
            t_delta->data<float>(), t_timestamp->data<float>());
613

614 615 616 617
  auto &ctx = send_varname_to_ctx_.at(varname);
  auto send = distributed::ParameterSend<float>();
  send(ctx, *delta_scope_, true, 1);
}
618

619 620 621
void GeoCommunicator::RecvByCommunicator() {
  std::vector<std::future<void>> tasks;
  tasks.reserve(recv_varname_to_ctx_.size());
622

623 624 625
  for (auto &iter : recv_varname_to_ctx_) {
    auto &var_name = iter.first;
    auto &recv_ctx = iter.second;
626

627 628 629 630 631 632 633 634
    auto recv_task = [this, &var_name, &recv_ctx] {
      if (recv_ctx.is_sparse) {
        RecvSparse(var_name);
      } else {
        RecvDense(var_name);
      }
    };
    tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
635
  }
636 637
  for (auto &task : tasks) {
    task.wait();
638 639 640
  }
}

641 642
void GeoCommunicator::RecvSparse(const std::string &varname) {
  VLOG(1) << "RecvSparse receive var: " << varname;
643

644 645
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_psrever = pserver_scope_->Var(varname);
646

647 648 649
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
  recv(ctx, *pserver_scope_, true);
650

651 652 653 654
  PADDLE_ENFORCE_EQ(
      var_psrever->IsInitialized(), true,
      platform::errors::Unavailable(
          "%s in pserver scope is not initialized, please check", varname));
655

656 657 658
  std::vector<int64_t> ids;
  ids.assign(var_psrever->Get<framework::SelectedRows>().rows().begin(),
             var_psrever->Get<framework::SelectedRows>().rows().end());
659

660 661
  VLOG(1) << "RecvSparse receive var: " << varname
          << " ids Size: " << ids.size();
662

663
  auto t_psrever = var_psrever->Get<framework::SelectedRows>().value();
664

665
  std::vector<std::vector<std::vector<float> *>> old_values;
666

667 668
  auto *ins = distributed::LargeScaleKV::GetInstance();
  ins->Get(varname)->Get(ids, {"Param"}, &old_values);
669

670
  auto *t_latest = var_latest->GetMutable<framework::LoDTensor>();
671

672 673
  auto dims1 = t_latest->dims()[1];
  auto numel = ids.size() * dims1;
674

675 676 677 678 679 680 681 682 683 684 685 686 687 688
  std::vector<float> v_delta;
  v_delta.resize(numel);

  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);

  for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
    blas.VSUB(dims1, t_psrever.data<float>() + j * dims1,
              old_values[j][0]->data(), v_delta.data() + j * dims1);
    blas.VADD(dims1, t_latest->data<float>() + ids[j] * dims1,
              v_delta.data() + j * dims1,
              t_latest->data<float>() + ids[j] * dims1);
    blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1,
               old_values[j][0]->data());
689 690 691
  }
}

692 693 694 695
void GeoCommunicator::RecvDense(const std::string &varname) {
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_timestamp = old_scope_->FindVar(varname);
  auto *var_psrever = pserver_scope_->Var(varname);
696

697 698 699
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
  recv(ctx, *pserver_scope_, true);
700

701 702 703 704
  PADDLE_ENFORCE_EQ(
      var_psrever->IsInitialized(), true,
      platform::errors::Unavailable(
          "%s in pserver scope is not initialized, please check", varname));
705

706 707 708
  auto t_psrever = var_psrever->Get<framework::LoDTensor>();
  auto t_latest = var_latest->GetMutable<framework::LoDTensor>();
  auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
709

710 711 712 713
  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
  t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace());
714

715 716 717 718 719 720 721
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  blas.VSUB(t_latest->numel(), t_psrever.data<float>(),
            t_timestamp->data<float>(), t_delta->data<float>());
  blas.VADD(t_latest->numel(), t_latest->data<float>(), t_delta->data<float>(),
            t_latest->data<float>());
  blas.VCOPY(t_latest->numel(), t_psrever.data<float>(),
             t_timestamp->data<float>());
722 723
}

724 725 726
void GeoCommunicator::Init() {
  std::vector<std::future<void>> tasks;
  tasks.reserve(recv_varname_to_ctx_.size());
727

728 729 730
  for (auto &iter : recv_varname_to_ctx_) {
    auto &var_name = iter.first;
    auto &recv_ctx = iter.second;
731

732 733 734 735 736 737
    auto recv_task = [this, &var_name, &recv_ctx] {
      if (!recv_ctx.is_sparse) {
        InitDense(var_name);
      }
    };
    tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
738 739
  }

740 741
  for (auto &task : tasks) {
    task.wait();
742
  }
743
  InitSparse();
744
}
745

746 747 748
void GeoCommunicator::InitDense(const std::string varname) {
  auto *var = old_scope_->Var(varname);
  var->GetMutable<framework::LoDTensor>();
T
tangwei12 已提交
749

750 751 752 753 754
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
  recv(ctx, *old_scope_);
  VLOG(1) << "init dense variable " << varname << " done";
}
T
tangwei12 已提交
755

756 757
void GeoCommunicator::InitSparse() {
  auto sparse_metas = string::split_string<std::string>(sparse_attrs_, "#");
T
tangwei12 已提交
758

759 760
  std::vector<distributed::SparseMeta> metas;
  std::vector<int64_t> dicts;
T
tangwei12 已提交
761

762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779
  for (auto &sparse_meta : sparse_metas) {
    auto attrs = string::split_string<std::string>(sparse_meta, ":");

    auto meta = distributed::SparseMeta();
    meta.name = attrs[0];
    meta.value_names = {"Param"};

    auto dic = string::split_string<std::string>(attrs[1], ",");
    dicts.push_back(std::stoi(dic[0]));
    meta.value_dims = {std::stoi(dic[1])};
    meta.mode = distributed::Mode::training;
    meta.grad_name = "none";
    meta.cached_varnames = {};
    meta.initializer_attrs = string::split_string<std::string>(attrs[2]);
    meta.entry = "none";

    VLOG(3) << "add sparse meta: " << meta.ToString();
    metas.push_back(meta);
T
tangwei12 已提交
780 781
  }

782
  LargeScaleKV::Init(metas);
T
tangwei12 已提交
783

S
seiriosPlus 已提交
784 785 786 787 788 789 790 791 792
  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);

  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto cpu_place = platform::CPUPlace();
  auto &cpu_ctx = *pool.Get(cpu_place);

  framework::Scope &local_scope = send_scope_->NewScope();

793
  for (size_t i = 0; i < metas.size(); i++) {
S
seiriosPlus 已提交
794 795 796 797 798
    auto &meta = metas[i];
    auto &ctx = recv_varname_to_ctx_.at(meta.name);
    auto pserver_num = ctx.splited_varnames.size();
    for (size_t j = 0; j < ctx.splited_varnames.size(); j++) {
      auto &recv_var_name = ctx.splited_varnames[i];
T
tangwei12 已提交
799

S
seiriosPlus 已提交
800 801 802 803
      distributed::VarHandlePtr ret;
      ret = rpc_client->AsyncGetVarNoBarrier(endpoints[i], cpu_ctx, local_scope,
                                             recv_var_name, recv_var_name);
      width = recv_t.value().dims()[1];
T
tangwei12 已提交
804

S
seiriosPlus 已提交
805 806 807 808 809 810
      PADDLE_ENFORCE_EQ(
          width, meta.value_dims[0],
          platform::errors::InvalidArgument("sparse params do not match"));

      auto *recv_var = local_scope->FindVar(recv_var_name);
      auto &recv_t = recv_var->Get<framework::SelectedRows>();
T
tangwei12 已提交
811

S
seiriosPlus 已提交
812 813 814 815
      std::vector<int64_t> ids;
      std::transform(recv_t.rows().begin(), recv_t.rows().end(),
                     std::back_inserter(ids),
                     [&](int64_t id) { return id * pserver_num + i; });
816

S
seiriosPlus 已提交
817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
      std::vector<std::vector<std::vector<float> *>> values;
      auto *ins = distributed::LargeScaleKV::GetInstance();

      ins->Get(meta.name)->Init(ids);
      ins->Get(meta.name)->Get(ids, {"Param"}, &values);

      PADDLE_ENFORCE_NE(ret->Wait(), 0U, platform::errors::ExecutionTimeout(
                                             "internal error in RPCClient"));

      auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);

      for (size_t k = 0; k < ids.size(); ++k) {
        blas.VCOPY(width, recv_t.value().data<float>() + k * width,
                   values[k][0]->data());
      }
    }
T
tangwei12 已提交
833 834
  }

S
seiriosPlus 已提交
835 836 837
  send_scope_->DeleteScope(&local_scope);

  VLOG(3) << "GeoCommunicator init sparse " << varname << " done ";
838
  VLOG(3) << "init sparse variable done";
T
tangwei12 已提交
839 840
}

Q
Qiao Longfei 已提交
841 842 843
}  // namespace distributed
}  // namespace operators
}  // namespace paddle