communicator.cc 26.2 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/distributed/communicator.h"
Q
Qiao Longfei 已提交
16
#include <gflags/gflags.h>
17
#include <paddle/fluid/framework/program_desc.h>
18
#include <algorithm>
Q
Qiao Longfei 已提交
19
#include <chrono>  // NOLINT
20
#include <map>
Q
Qiao Longfei 已提交
21
#include <thread>  // NOLINT
22
#include <unordered_set>
Q
Qiao Longfei 已提交
23
#include "paddle/fluid/framework/eigen.h"
Q
Qiao Longfei 已提交
24 25
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
26
#include "paddle/fluid/framework/threadpool.h"
Q
Qiao Longfei 已提交
27
#include "paddle/fluid/framework/variable_helper.h"
C
Chengmo 已提交
28
#include "paddle/fluid/operators/distributed/distributed.h"
Q
Qiao Longfei 已提交
29 30
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
T
tangwei12 已提交
31
#include "paddle/fluid/string/printf.h"
32
#include "paddle/fluid/string/split.h"
Q
Qiao Longfei 已提交
33

Q
Qiao Longfei 已提交
34 35 36 37
namespace paddle {
namespace operators {
namespace distributed {

38 39 40 41
using Tree =
    std::map<std::string, std::map<std::string, std::vector<std::string>>>;
using RpcCtxMap = operators::distributed::RpcCtxMap;

Q
Qiao Longfei 已提交
42 43 44 45 46 47
inline double GetCurrentUS() {
  struct timeval time;
  gettimeofday(&time, NULL);
  return 1e+6 * time.tv_sec + time.tv_usec;
}

48
Communicator::Communicator() {}
1
123malin 已提交
49

T
tangwei12 已提交
50
std::once_flag Communicator::init_flag_;
51
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
Q
can run  
Qiao Longfei 已提交
52

T
tangwei12 已提交
53 54 55 56 57 58 59
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
                                 const RpcCtxMap &recv_varname_to_ctx,
                                 Scope *recv_scope) {
  send_varname_to_ctx_ = std::move(send_varname_to_ctx);
  recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
  recv_scope_ = std::move(recv_scope);

60 61 62 63 64 65 66
  if (send_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be send, will not start send_thread";
  } else {
    send_scope_.reset(new Scope());
    for (auto &iter : send_varname_to_ctx_) {
      send_varname_to_queue_[iter.first] =
          std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
67
              send_queue_size_);
68
    }
69
    send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
70 71 72 73 74
  }

  if (recv_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be received, will not start recv_thread";
  } else {
75
    recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
Q
Qiao Longfei 已提交
76
  }
T
tangwei12 已提交
77 78

  InitParams();
Q
Qiao Longfei 已提交
79 80
}

T
tangwei12 已提交
81 82
void AsyncCommunicator::InitParams() { RecvNoBarrier(); }

83 84 85 86 87 88 89 90
AsyncCommunicator::~AsyncCommunicator() {
  running_ = false;
  if (main_thread_) main_thread_->join();
}

void AsyncCommunicator::SendGlobalStep(int batches) {
  if (!need_global_step_) {
    return;
T
tangwei12 已提交
91 92
  }

93 94
  if (batches == 0) {
    return;
T
tangwei12 已提交
95 96
  }

97 98 99 100 101
  auto &var_name = STEP_COUNTER;
  auto *out_var = send_scope_->Var(var_name);
  auto *out_t = out_var->GetMutable<framework::LoDTensor>();
  auto *data = out_t->mutable_data<int64_t>({1}, platform::CPUPlace());
  data[0] = static_cast<int64_t>(batches);
T
tangwei12 已提交
102

103 104 105
  auto &ctx = send_varname_to_ctx_.at(var_name);
  auto send_functor = distributed::ParameterSend<float>();
  send_functor(ctx, *send_scope_, true, 1);
Q
Qiao Longfei 已提交
106 107
}

108 109 110 111 112 113 114 115 116 117 118 119
void AsyncCommunicator::SendByCommunicator(int batches) {
  std::vector<std::future<void>> task_futures;
  task_futures.reserve(send_varname_to_ctx_.size());
  VLOG(3) << "run send graph";
  auto before_run_send_graph = GetCurrentUS();
  for (auto &iter : send_varname_to_queue_) {
    auto &var_name = iter.first;
    auto &var_queue = iter.second;

    auto send_task = [this, batches, &var_name, &var_queue] {
      if (var_name == STEP_COUNTER) {
        return;
Q
Qiao Longfei 已提交
120
      }
121

122 123 124
      VLOG(3) << var_name << " merge and send";
      std::vector<std::shared_ptr<Variable>> vars;
      vars.reserve(batches);
Q
Qiao Longfei 已提交
125

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
      for (int i = 0; i < batches; ++i) {
        vars.push_back(var_queue->Pop());
      }

      auto &ctx = send_varname_to_ctx_.at(var_name);

      auto before_merge = GetCurrentUS();
      MergeVars<float>(var_name, vars, send_scope_.get(), ctx.merge_add);
      auto after_merge = GetCurrentUS();
      VLOG(3) << "merge " << batches << " " << var_name << " use time "
              << after_merge - before_merge;

      auto send_functor = distributed::ParameterSend<float>();
      send_functor(ctx, *send_scope_, true, 1);
      auto after_send = GetCurrentUS();
      VLOG(3) << "send " << var_name << " use time "
              << after_send - after_merge;
    };
    task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
Q
Qiao Longfei 已提交
145
  }
146 147 148 149 150 151 152
  for (auto &task_f : task_futures) {
    task_f.wait();
  }
  auto after_run_send_graph = GetCurrentUS();

  VLOG(3) << "run send graph use time "
          << after_run_send_graph - before_run_send_graph;
Q
Qiao Longfei 已提交
153 154
}

155 156 157 158 159 160
void AsyncCommunicator::MainThread() {
  VLOG(3) << "MainThread start and wait";

  while (waiting_ && running_) {
    std::this_thread::sleep_for(std::chrono::milliseconds(100));
    VLOG(3) << "wait for running";
161 162
  }

163
  while (running_) {
T
tangwei12 已提交
164 165 166 167 168 169 170 171 172 173 174 175
    int batches = BatchesCounter();

    if (batches > 0) {
      SendGlobalStep(batches);
      SendByCommunicator(batches);
      BarrierSend();
      RecvByCommunicator();
      BarrierRecv();
      BarrierWeakUp();
    } else {
      VLOG(1) << "get nothing from sending queue, will skip send/recv";
    }
176
  }
177
  VLOG(1) << "communicator stopped, send thread exit";
178 179
}

180
void AsyncCommunicator::RecvByCommunicator() {
T
tangwei12 已提交
181 182
  VLOG(3) << "parallel run recv graph";
  if (!running_) return;
183 184 185 186 187
  RecvNoBarrier();
  VLOG(3) << "run recv graph use time";
}

void AsyncCommunicator::RecvNoBarrier() {
T
tangwei12 已提交
188 189
  std::vector<std::future<void>> task_futures;
  task_futures.reserve(recv_varname_to_ctx_.size());
190

T
tangwei12 已提交
191 192 193 194 195
  for (auto &iter : recv_varname_to_ctx_) {
    auto recv_task = [this, &iter] {
      auto &var_name = iter.first;
      VLOG(4) << "recv var " << var_name;
      auto recv_functor = distributed::ParameterRecv<float>();
T
tangwei12 已提交
196
      recv_functor(iter.second, *recv_scope_);
T
tangwei12 已提交
197 198 199
    };
    task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
  }
200

T
tangwei12 已提交
201 202 203
  for (auto &task : task_futures) {
    task.wait();
  }
204 205
}

T
tangwei12 已提交
206
int AsyncCommunicator::BatchesCounter() {
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
  auto &step_queue = send_varname_to_queue_.at(STEP_COUNTER);

  size_t merged_var_num = 0;
  size_t wait_times = 0;

  while (merged_var_num < static_cast<size_t>(max_merge_var_num_)) {
    if (step_queue->Size() == 0) {
      VLOG(3) << "wait_times -> " << wait_times;
      if (wait_times >= static_cast<size_t>(send_wait_times_)) {
        break;
      }
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
      wait_times++;
      continue;
    } else {
      step_queue->Pop();
      wait_times = 0;
      merged_var_num++;
    }
  }

  return merged_var_num;
229 230
}

T
tangwei12 已提交
231
void AsyncCommunicator::Start() {
232
  VLOG(1) << "Communicator start";
233 234 235 236
  if (!communicator_) {
    VLOG(0) << "Communicator is not inited, do nothing";
  } else {
    VLOG(1) << "start send thread and recv thread";
237
    waiting_ = true;
238
    running_ = true;
239
    BarrierTriggerReset(max_merge_var_num_);
240
    // start send and recv thread
241 242
    main_thread_.reset(
        new std::thread(std::bind(&AsyncCommunicator::MainThread, this)));
243 244 245
  }
}

T
tangwei12 已提交
246
void AsyncCommunicator::Stop() {
247
  VLOG(1) << "Communicator stop";
248 249 250 251
  running_ = false;
  if (!communicator_) {
    VLOG(0) << "Communicator is not inited, do nothing";
  } else {
252
    if (main_thread_) {
253
      VLOG(1) << "stop send thread";
254 255
      main_thread_->join();
      main_thread_.reset(nullptr);
256
    }
Q
Qiao Longfei 已提交
257
  }
258
  VLOG(1) << "Communicator stop done";
Q
Qiao Longfei 已提交
259 260
}

261 262 263
void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
                             const std::vector<std::string> &var_tables,
                             const framework::Scope &scope) {
264 265
  waiting_ = false;

266
  PADDLE_ENFORCE_EQ(
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
      var_tables.size(), 1,
      platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));

  auto table_name = var_tables[0];
  auto &queue = send_varname_to_queue_.at(table_name);

  if (table_name == STEP_COUNTER) {
    auto tmp_var = std::make_shared<Variable>();
    auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
    tensor->Resize(framework::make_ddim({1}));
    auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
    out_d[0] = 1;
    VLOG(3) << "send to " << table_name << " with queue size " << queue->Size();
    queue->Push(tmp_var);
  } else {
    PADDLE_ENFORCE_GE(var_names.size(), 1,
                      platform::errors::InvalidArgument(
                          "var_names.size() >= 1 is permitted"));

    auto *var = scope.FindVar(var_names[0]);

    PADDLE_ENFORCE_EQ(
        var->IsInitialized(), true,
        platform::errors::InvalidArgument("grad var should be inited"));

    auto tmp_var = std::make_shared<Variable>();
    if (var->IsType<framework::SelectedRows>()) {
      framework::CopyVariable(*var, tmp_var.get());
      VLOG(3) << "send to " << table_name << " with queue size "
              << queue->Size();
      queue->Push(tmp_var);
    } else if (var->IsType<framework::LoDTensor>()) {
      // push var into send queue by var_name
      auto var_name = var_names[0];
      framework::CopyVariable(*var, tmp_var.get());
      VLOG(3) << "send to " << table_name << " with queue size "
              << queue->Size();
      queue->Push(tmp_var);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "unknown var type to copy, only support LoDTensor/SelectedRows"));
    }
  }
310
}
311

312 313 314 315
void HalfAsyncCommunicator::Clean() {
  for (auto &iter : send_varname_to_queue_) {
    auto &var_name = iter.first;
    auto &var_queue = iter.second;
316

317 318
    while (var_queue->Size() > 0) {
      var_queue->Pop();
319 320
    }

321 322 323 324
    VLOG(3) << "clean var: " << var_name << " done";
  }
}

T
tangwei12 已提交
325
int HalfAsyncCommunicator::BatchesCounter() {
326 327 328 329 330 331
  while (running_) {
    if (barrier_counter_.load() >= barrier_trigger_.load() &&
        barrier_trigger_.load() != 0) {
      break;
    } else {
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
332
    }
333
  }
334

335 336
  return barrier_counter_.load();
}
C
Chengmo 已提交
337

338 339
void HalfAsyncCommunicator::Barrier() {
  barrier_counter_++;
340

341 342 343
  if (!running_) {
    VLOG(3) << "Communicator is not running, release barrier";
    return;
344 345
  }

346 347 348
  {
    std::unique_lock<std::mutex> lk(barrier_mutex_);
    barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); });
349
  }
350
}
351

352 353 354 355
void HalfAsyncCommunicator::BarrierTriggerDecrement() {
  barrier_trigger_--;
  VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to "
          << barrier_trigger_.load();
356 357
}

358 359 360 361 362
void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) {
  barrier_trigger_.store(initial_val);

  VLOG(3) << "BarrierTriggerReset reset barrier trigger to "
          << barrier_trigger_.load();
363 364
}

365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
void HalfAsyncCommunicator::BarrierWeakUp() {
  barrier_counter_.store(0);
  barrier_cond_.notify_all();
}

void SyncCommunicator::BarrierSend() {
  if (!running_) return;

  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);

  std::vector<distributed::VarHandlePtr> rets;

  for (auto &ep : pserver_endpoints_) {
    rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
380
  }
381 382 383 384 385 386 387

  for (size_t i = 0; i < rets.size(); i++) {
    PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
                                               "internal error in RPCClient"));
  }

  VLOG(4) << "BarrierSend with SyncCommunicator";
388 389
}

390 391 392 393 394 395 396 397 398
void SyncCommunicator::BarrierRecv() {
  if (!running_) return;

  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);

  std::vector<distributed::VarHandlePtr> rets;
  for (auto &ep : pserver_endpoints_) {
    rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
399 400
  }

401 402 403
  for (size_t i = 0; i < rets.size(); i++) {
    PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
                                               "internal error in RPCClient"));
404
  }
405 406

  VLOG(4) << "BarrierRecv with SyncCommunicator";
407 408
}

409 410 411 412 413 414
void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
                               const RpcCtxMap &recv_varname_to_ctx,
                               Scope *recv_scope) {
  send_varname_to_ctx_ = std::move(send_varname_to_ctx);
  recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
  recv_scope_ = std::move(recv_scope);
415

416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
  PADDLE_ENFORCE_GT(
      send_varname_to_ctx.size(), 0,
      platform::errors::InvalidArgument("send var contexts can not be zero"));

  send_scope_.reset(new Scope());
  for (auto &iter : send_varname_to_ctx_) {
    auto &varname = iter.first;

    if (varname == STEP_COUNTER) {
      send_varname_to_queue_[varname] =
          std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
              send_queue_size_);
    } else {
      auto &send_ctx = iter.second;

      if (!send_ctx.is_sparse) {
C
Chengmo 已提交
432
        continue;
433 434
      }

435 436 437
      send_ids_to_queue_[varname] =
          std::make_shared<BlockingQueue<std::vector<int64_t>>>(
              send_queue_size_);
438 439
    }
  }
440
  send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
441

442 443 444 445
  if (recv_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be received, will not start recv_thread";
  } else {
    recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
446 447
  }

448 449 450
  delta_scope_.reset(new Scope());
  old_scope_.reset(new Scope());
  pserver_scope_.reset(new Scope());
451

T
tangwei12 已提交
452
  InitParams();
453 454
}

455 456 457 458
void GeoCommunicator::Send(const std::vector<std::string> &var_names,
                           const std::vector<std::string> &var_tables,
                           const framework::Scope &scope) {
  waiting_ = false;
459

460 461 462
  PADDLE_ENFORCE_EQ(
      var_tables.size(), 1,
      platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
C
Chengmo 已提交
463

464
  auto table_name = var_tables[0];
465

466 467
  if (table_name == STEP_COUNTER) {
    auto &queue = send_varname_to_queue_.at(table_name);
468

469 470 471 472 473 474 475 476 477 478 479 480
    auto tmp_var = std::make_shared<Variable>();
    auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
    tensor->Resize(framework::make_ddim({1}));
    auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
    out_d[0] = 1;
    VLOG(3) << "send to " << table_name << " with queue size " << queue->Size();
    queue->Push(tmp_var);
  } else {
    auto &queue = send_ids_to_queue_.at(table_name);
    PADDLE_ENFORCE_EQ(var_names.size(), 1,
                      platform::errors::InvalidArgument(
                          "var_names.size() == 1 is permitted"));
481

482
    auto *var = scope.FindVar(var_names[0]);
483

484 485 486
    PADDLE_ENFORCE_EQ(
        var->IsInitialized(), true,
        platform::errors::InvalidArgument("grad var should be inited"));
487

488 489 490 491
    if (!var->IsType<framework::SelectedRows>()) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Only LodTensor can be send in GeoCommunicator::Send"));
    }
C
Chengmo 已提交
492

493 494 495 496
    std::vector<int64_t> ids;
    auto &rows = var->Get<framework::SelectedRows>().rows();
    ids.assign(rows.begin(), rows.end());
    queue->Push(ids);
497
  }
498 499 500 501 502 503 504 505 506
}

void GeoCommunicator::SendByCommunicator(int batches) {
  std::vector<std::future<void>> tasks;
  tasks.reserve(send_varname_to_ctx_.size());

  for (auto &iter : send_varname_to_ctx_) {
    auto &var_name = iter.first;
    auto &send_ctx = iter.second;
C
Chengmo 已提交
507

508 509 510 511
    auto send_task = [this, batches, &var_name, &send_ctx] {
      if (var_name == STEP_COUNTER) {
        return;
      }
C
Chengmo 已提交
512

513 514 515 516 517 518 519 520 521
      if (send_ctx.is_sparse) {
        SendSparse(var_name, batches);
      } else {
        VLOG(1) << "send dense " << var_name << " begin";
        SendDense(var_name);
        VLOG(1) << "send dense " << var_name << " done";
      }
    };
    tasks.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
C
Chengmo 已提交
522
  }
523

524 525 526 527
  for (auto &task : tasks) {
    task.wait();
  }
}
C
Chengmo 已提交
528

529 530 531
void GeoCommunicator::SendSparse(const std::string &varname, int batches) {
  std::vector<int64_t> ids;
  auto &ids_queue = send_ids_to_queue_.at(varname);
C
Chengmo 已提交
532

533 534 535
  for (int i = 0; i < batches; ++i) {
    auto pop_ids = ids_queue->Pop();
    std::copy(pop_ids.begin(), pop_ids.end(), back_inserter(ids));
536 537
  }

538 539 540 541 542 543
  auto size = ids.size();

  std::set<int64_t> st(ids.begin(), ids.end());
  ids.assign(st.begin(), st.end());
  VLOG(1) << "SendSparse receive var: " << varname << " unset: " << size
          << " set: " << ids.size();
C
Chengmo 已提交
544

545 546 547
  if (ids.empty()) {
    LOG(WARNING) << "WARNING: GEO has nothing to send, return directly ";
    return;
548 549
  }

550 551 552 553 554 555 556 557
  auto *var_latest = recv_scope_->FindVar(varname);

  PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
  auto &t_latest = var_latest->Get<framework::LoDTensor>();

  auto dims1 = t_latest.dims()[1];
C
Chengmo 已提交
558 559

  auto cpu_ctx = paddle::platform::CPUDeviceContext();
560 561 562 563 564 565 566 567
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
  t_delta->set_height(ids.size());
  t_delta->mutable_rows()->assign(ids.begin(), ids.end());
  auto *t_value = t_delta->mutable_value();
  t_value->mutable_data<float>(
      framework::make_ddim({static_cast<int64_t>(ids.size()), dims1}),
      cpu_ctx.GetPlace());
C
Chengmo 已提交
568

569 570 571
  std::vector<std::vector<std::vector<float> *>> values;
  auto *ins = distributed::LargeScaleKV::GetInstance();
  ins->Get(varname)->Get(ids, {"Param"}, &values);
C
Chengmo 已提交
572

573 574
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  float coefficient = 1.0 / static_cast<float>(trainers_);
C
Chengmo 已提交
575

576 577 578 579 580 581
  for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
    blas.VSUB(dims1, t_latest.data<float>() + ids[j] * dims1,
              values[j][0]->data(), t_value->data<float>() + j * dims1);
    blas.SCAL(dims1, coefficient, t_value->data<float>() + j * dims1);
    blas.VADD(dims1, values[j][0]->data(), t_value->data<float>() + j * dims1,
              values[j][0]->data());
C
Chengmo 已提交
582 583
  }

584 585 586
  auto &ctx = send_varname_to_ctx_.at(varname);
  auto send = distributed::ParameterSend<float>();
  send(ctx, *delta_scope_, true, 1);
587 588
}

589 590 591
void GeoCommunicator::SendDense(const std::string &varname) {
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_timestamp = old_scope_->FindVar(varname);
592

593 594 595 596 597 598
  PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
  PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
599

600 601
  auto &t_latest = var_latest->Get<framework::LoDTensor>();
  auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
C
Chengmo 已提交
602

603 604 605 606
  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
  t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace());
C
Chengmo 已提交
607

608 609 610
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  blas.VSUB(t_latest.numel(), t_latest.data<float>(),
            t_timestamp->data<float>(), t_delta->data<float>());
C
Chengmo 已提交
611

612 613
  float coefficient = 1.0 / static_cast<float>(trainers_);
  blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>());
C
Chengmo 已提交
614

615 616
  blas.VADD(t_latest.numel(), t_timestamp->data<float>(),
            t_delta->data<float>(), t_timestamp->data<float>());
617

618 619 620 621
  auto &ctx = send_varname_to_ctx_.at(varname);
  auto send = distributed::ParameterSend<float>();
  send(ctx, *delta_scope_, true, 1);
}
622

623 624 625
void GeoCommunicator::RecvByCommunicator() {
  std::vector<std::future<void>> tasks;
  tasks.reserve(recv_varname_to_ctx_.size());
626

627 628 629
  for (auto &iter : recv_varname_to_ctx_) {
    auto &var_name = iter.first;
    auto &recv_ctx = iter.second;
630

631 632 633 634 635 636 637 638
    auto recv_task = [this, &var_name, &recv_ctx] {
      if (recv_ctx.is_sparse) {
        RecvSparse(var_name);
      } else {
        RecvDense(var_name);
      }
    };
    tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
639
  }
640 641
  for (auto &task : tasks) {
    task.wait();
642 643 644
  }
}

645 646
void GeoCommunicator::RecvSparse(const std::string &varname) {
  VLOG(1) << "RecvSparse receive var: " << varname;
647

648 649
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_psrever = pserver_scope_->Var(varname);
650

651 652 653
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
  recv(ctx, *pserver_scope_, true);
654

655 656 657 658
  PADDLE_ENFORCE_EQ(
      var_psrever->IsInitialized(), true,
      platform::errors::Unavailable(
          "%s in pserver scope is not initialized, please check", varname));
659

660 661 662
  std::vector<int64_t> ids;
  ids.assign(var_psrever->Get<framework::SelectedRows>().rows().begin(),
             var_psrever->Get<framework::SelectedRows>().rows().end());
663

664 665
  VLOG(1) << "RecvSparse receive var: " << varname
          << " ids Size: " << ids.size();
666

667
  auto t_psrever = var_psrever->Get<framework::SelectedRows>().value();
668

669
  std::vector<std::vector<std::vector<float> *>> old_values;
670

671 672
  auto *ins = distributed::LargeScaleKV::GetInstance();
  ins->Get(varname)->Get(ids, {"Param"}, &old_values);
673

674
  auto *t_latest = var_latest->GetMutable<framework::LoDTensor>();
675

676 677
  auto dims1 = t_latest->dims()[1];
  auto numel = ids.size() * dims1;
678

679 680 681 682 683 684 685 686 687 688 689 690 691 692
  std::vector<float> v_delta;
  v_delta.resize(numel);

  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);

  for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
    blas.VSUB(dims1, t_psrever.data<float>() + j * dims1,
              old_values[j][0]->data(), v_delta.data() + j * dims1);
    blas.VADD(dims1, t_latest->data<float>() + ids[j] * dims1,
              v_delta.data() + j * dims1,
              t_latest->data<float>() + ids[j] * dims1);
    blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1,
               old_values[j][0]->data());
693 694 695
  }
}

696 697 698 699
void GeoCommunicator::RecvDense(const std::string &varname) {
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_timestamp = old_scope_->FindVar(varname);
  auto *var_psrever = pserver_scope_->Var(varname);
700

701 702
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
T
tangwei12 已提交
703
  recv(ctx, *pserver_scope_);
704

705 706 707 708
  PADDLE_ENFORCE_EQ(
      var_psrever->IsInitialized(), true,
      platform::errors::Unavailable(
          "%s in pserver scope is not initialized, please check", varname));
709

710 711 712
  auto t_psrever = var_psrever->Get<framework::LoDTensor>();
  auto t_latest = var_latest->GetMutable<framework::LoDTensor>();
  auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
713

714 715 716 717
  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
  t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace());
718

719 720 721 722 723 724 725
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  blas.VSUB(t_latest->numel(), t_psrever.data<float>(),
            t_timestamp->data<float>(), t_delta->data<float>());
  blas.VADD(t_latest->numel(), t_latest->data<float>(), t_delta->data<float>(),
            t_latest->data<float>());
  blas.VCOPY(t_latest->numel(), t_psrever.data<float>(),
             t_timestamp->data<float>());
726 727
}

T
tangwei12 已提交
728
void GeoCommunicator::InitParams() {
729 730
  std::vector<std::future<void>> tasks;
  tasks.reserve(recv_varname_to_ctx_.size());
731

732 733 734
  for (auto &iter : recv_varname_to_ctx_) {
    auto &var_name = iter.first;
    auto &recv_ctx = iter.second;
735

736 737 738 739 740 741
    auto recv_task = [this, &var_name, &recv_ctx] {
      if (!recv_ctx.is_sparse) {
        InitDense(var_name);
      }
    };
    tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
742 743
  }

744 745
  for (auto &task : tasks) {
    task.wait();
746
  }
747
  InitSparse();
748
}
749

750 751 752
void GeoCommunicator::InitDense(const std::string varname) {
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
T
tangwei12 已提交
753 754 755 756 757 758 759 760 761
  recv(ctx, *recv_scope_);

  auto *global_var = recv_scope_->FindVar(varname);
  global_var->GetMutable<framework::LoDTensor>();

  auto *old_var = old_scope_->Var(varname);
  old_var->GetMutable<framework::LoDTensor>();

  framework::CopyVariable(*global_var, old_var);
762 763
  VLOG(1) << "init dense variable " << varname << " done";
}
T
tangwei12 已提交
764

765 766
void GeoCommunicator::InitSparse() {
  auto sparse_metas = string::split_string<std::string>(sparse_attrs_, "#");
T
tangwei12 已提交
767

768 769
  std::vector<distributed::SparseMeta> metas;
  std::vector<int64_t> dicts;
T
tangwei12 已提交
770

771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
  for (auto &sparse_meta : sparse_metas) {
    auto attrs = string::split_string<std::string>(sparse_meta, ":");

    auto meta = distributed::SparseMeta();
    meta.name = attrs[0];
    meta.value_names = {"Param"};

    auto dic = string::split_string<std::string>(attrs[1], ",");
    dicts.push_back(std::stoi(dic[0]));
    meta.value_dims = {std::stoi(dic[1])};
    meta.mode = distributed::Mode::training;
    meta.grad_name = "none";
    meta.cached_varnames = {};
    meta.initializer_attrs = string::split_string<std::string>(attrs[2]);
    meta.entry = "none";

    VLOG(3) << "add sparse meta: " << meta.ToString();
    metas.push_back(meta);
T
tangwei12 已提交
789 790
  }

791
  LargeScaleKV::Init(metas);
T
tangwei12 已提交
792

T
tangwei12 已提交
793 794 795
  for (auto &meta : metas) {
    auto &ctx = recv_varname_to_ctx_.at(meta.name);
    auto recv = distributed::ParameterRecv<float>();
T
tangwei12 已提交
796

T
tangwei12 已提交
797 798 799 800
    auto *global_var = recv_scope_->FindVar(meta.name);
    auto global_value = global_var->Get<framework::LoDTensor>();
    auto rows = global_value.dims()[0];
    auto dim1 = global_value.dims()[1];
T
tangwei12 已提交
801

T
tangwei12 已提交
802 803 804 805 806 807 808 809 810 811 812 813
    recv(ctx, *recv_scope_);
    VLOG(1) << "recv " << meta.name << " with global scope for init";

    auto n_rows = global_var->Get<framework::LoDTensor>().dims()[0];

    PADDLE_ENFORCE_EQ(
        rows, n_rows,
        platform::errors::InvalidArgument(
            "global var: %s origin dim must equal recved rows", meta.name));

    std::vector<int64_t> ids(rows);
    std::iota(ids.begin(), ids.end(), 0);
T
tangwei12 已提交
814

815
    auto *ins = distributed::LargeScaleKV::GetInstance();
T
tangwei12 已提交
816 817 818 819
    std::vector<std::vector<std::vector<float> *>> values;

    ins->Get(meta.name)->Init(ids);
    ins->Get(meta.name)->Get(ids, {"Param"}, &values);
820

T
tangwei12 已提交
821 822 823 824 825 826 827
    auto blas = math::GetBlas<platform::CPUDeviceContext, float>(
        paddle::platform::CPUDeviceContext());

    for (auto &id : ids) {
      blas.VCOPY(dim1, global_value.data<float>() + id * dim1,
                 values[id][0]->data());
    }
T
tangwei12 已提交
828 829
  }

830
  VLOG(3) << "init sparse variable done";
T
tangwei12 已提交
831 832
}

Q
Qiao Longfei 已提交
833 834 835
}  // namespace distributed
}  // namespace operators
}  // namespace paddle