communicator.cc 25.2 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/distributed/communicator.h"
Q
Qiao Longfei 已提交
16
#include <gflags/gflags.h>
17
#include <paddle/fluid/framework/program_desc.h>
18
#include <algorithm>
Q
Qiao Longfei 已提交
19
#include <chrono>  // NOLINT
20
#include <map>
Q
Qiao Longfei 已提交
21
#include <thread>  // NOLINT
22
#include <unordered_set>
Q
Qiao Longfei 已提交
23
#include "paddle/fluid/framework/eigen.h"
Q
Qiao Longfei 已提交
24 25
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
26
#include "paddle/fluid/framework/threadpool.h"
Q
Qiao Longfei 已提交
27
#include "paddle/fluid/framework/variable_helper.h"
C
Chengmo 已提交
28
#include "paddle/fluid/operators/distributed/distributed.h"
Q
Qiao Longfei 已提交
29 30
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
T
tangwei12 已提交
31
#include "paddle/fluid/string/printf.h"
32
#include "paddle/fluid/string/split.h"
Q
Qiao Longfei 已提交
33

Q
Qiao Longfei 已提交
34 35 36 37
namespace paddle {
namespace operators {
namespace distributed {

38 39 40 41
using Tree =
    std::map<std::string, std::map<std::string, std::vector<std::string>>>;
using RpcCtxMap = operators::distributed::RpcCtxMap;

Q
Qiao Longfei 已提交
42 43 44 45 46 47
inline double GetCurrentUS() {
  struct timeval time;
  gettimeofday(&time, NULL);
  return 1e+6 * time.tv_sec + time.tv_usec;
}

48
Communicator::Communicator() {}
1
123malin 已提交
49

T
tangwei12 已提交
50
std::once_flag Communicator::init_flag_;
51
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
Q
can run  
Qiao Longfei 已提交
52

T
tangwei12 已提交
53 54 55 56 57 58 59
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
                                 const RpcCtxMap &recv_varname_to_ctx,
                                 Scope *recv_scope) {
  send_varname_to_ctx_ = std::move(send_varname_to_ctx);
  recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
  recv_scope_ = std::move(recv_scope);

60 61 62 63 64 65 66
  if (send_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be send, will not start send_thread";
  } else {
    send_scope_.reset(new Scope());
    for (auto &iter : send_varname_to_ctx_) {
      send_varname_to_queue_[iter.first] =
          std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
67
              send_queue_size_);
68
    }
69
    send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
70 71 72 73 74
  }

  if (recv_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be received, will not start recv_thread";
  } else {
75
    recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
Q
Qiao Longfei 已提交
76 77 78
  }
}

79 80 81 82 83 84 85 86
AsyncCommunicator::~AsyncCommunicator() {
  running_ = false;
  if (main_thread_) main_thread_->join();
}

void AsyncCommunicator::SendGlobalStep(int batches) {
  if (!need_global_step_) {
    return;
T
tangwei12 已提交
87 88
  }

89 90
  if (batches == 0) {
    return;
T
tangwei12 已提交
91 92
  }

93 94 95 96 97
  auto &var_name = STEP_COUNTER;
  auto *out_var = send_scope_->Var(var_name);
  auto *out_t = out_var->GetMutable<framework::LoDTensor>();
  auto *data = out_t->mutable_data<int64_t>({1}, platform::CPUPlace());
  data[0] = static_cast<int64_t>(batches);
T
tangwei12 已提交
98

99 100 101
  auto &ctx = send_varname_to_ctx_.at(var_name);
  auto send_functor = distributed::ParameterSend<float>();
  send_functor(ctx, *send_scope_, true, 1);
Q
Qiao Longfei 已提交
102 103
}

104 105 106 107 108 109 110 111 112 113 114 115
void AsyncCommunicator::SendByCommunicator(int batches) {
  std::vector<std::future<void>> task_futures;
  task_futures.reserve(send_varname_to_ctx_.size());
  VLOG(3) << "run send graph";
  auto before_run_send_graph = GetCurrentUS();
  for (auto &iter : send_varname_to_queue_) {
    auto &var_name = iter.first;
    auto &var_queue = iter.second;

    auto send_task = [this, batches, &var_name, &var_queue] {
      if (var_name == STEP_COUNTER) {
        return;
Q
Qiao Longfei 已提交
116
      }
117

118 119 120
      VLOG(3) << var_name << " merge and send";
      std::vector<std::shared_ptr<Variable>> vars;
      vars.reserve(batches);
Q
Qiao Longfei 已提交
121

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
      for (int i = 0; i < batches; ++i) {
        vars.push_back(var_queue->Pop());
      }

      auto &ctx = send_varname_to_ctx_.at(var_name);

      auto before_merge = GetCurrentUS();
      MergeVars<float>(var_name, vars, send_scope_.get(), ctx.merge_add);
      auto after_merge = GetCurrentUS();
      VLOG(3) << "merge " << batches << " " << var_name << " use time "
              << after_merge - before_merge;

      auto send_functor = distributed::ParameterSend<float>();
      send_functor(ctx, *send_scope_, true, 1);
      auto after_send = GetCurrentUS();
      VLOG(3) << "send " << var_name << " use time "
              << after_send - after_merge;
    };
    task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
Q
Qiao Longfei 已提交
141
  }
142 143 144 145 146 147 148
  for (auto &task_f : task_futures) {
    task_f.wait();
  }
  auto after_run_send_graph = GetCurrentUS();

  VLOG(3) << "run send graph use time "
          << after_run_send_graph - before_run_send_graph;
Q
Qiao Longfei 已提交
149 150
}

151 152 153 154 155 156
void AsyncCommunicator::MainThread() {
  VLOG(3) << "MainThread start and wait";

  while (waiting_ && running_) {
    std::this_thread::sleep_for(std::chrono::milliseconds(100));
    VLOG(3) << "wait for running";
157 158
  }

159 160 161 162 163 164 165 166 167 168 169
  while (running_) {
    int meet = Meet();

    VLOG(1) << "async_meet: " << meet;

    SendGlobalStep(meet);
    SendByCommunicator(meet);
    BarrierSend();
    RecvByCommunicator();
    BarrierRecv();
    BarrierWeakUp();
170
  }
171
  VLOG(1) << "communicator stopped, send thread exit";
172 173
}

174
void AsyncCommunicator::RecvByCommunicator() {
T
tangwei12 已提交
175 176
  VLOG(3) << "parallel run recv graph";
  if (!running_) return;
177 178 179 180 181
  RecvNoBarrier();
  VLOG(3) << "run recv graph use time";
}

void AsyncCommunicator::RecvNoBarrier() {
T
tangwei12 已提交
182 183
  std::vector<std::future<void>> task_futures;
  task_futures.reserve(recv_varname_to_ctx_.size());
184

T
tangwei12 已提交
185 186 187 188 189
  for (auto &iter : recv_varname_to_ctx_) {
    auto recv_task = [this, &iter] {
      auto &var_name = iter.first;
      VLOG(4) << "recv var " << var_name;
      auto recv_functor = distributed::ParameterRecv<float>();
190
      recv_functor(iter.second, *recv_scope_, false);
T
tangwei12 已提交
191 192 193
    };
    task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
  }
194

T
tangwei12 已提交
195 196 197
  for (auto &task : task_futures) {
    task.wait();
  }
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
}

int AsyncCommunicator::Meet() {
  auto &step_queue = send_varname_to_queue_.at(STEP_COUNTER);

  size_t merged_var_num = 0;
  size_t wait_times = 0;

  while (merged_var_num < static_cast<size_t>(max_merge_var_num_)) {
    if (step_queue->Size() == 0) {
      VLOG(3) << "wait_times -> " << wait_times;
      if (wait_times >= static_cast<size_t>(send_wait_times_)) {
        break;
      }
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
      wait_times++;
      continue;
    } else {
      step_queue->Pop();
      wait_times = 0;
      merged_var_num++;
    }
  }

  return merged_var_num;
223 224
}

T
tangwei12 已提交
225
void AsyncCommunicator::Start() {
226
  VLOG(1) << "Communicator start";
227 228 229 230
  if (!communicator_) {
    VLOG(0) << "Communicator is not inited, do nothing";
  } else {
    VLOG(1) << "start send thread and recv thread";
231
    waiting_ = true;
232
    running_ = true;
233
    BarrierTriggerReset(max_merge_var_num_);
234
    // start send and recv thread
235 236
    main_thread_.reset(
        new std::thread(std::bind(&AsyncCommunicator::MainThread, this)));
237 238 239
  }
}

T
tangwei12 已提交
240
void AsyncCommunicator::Stop() {
241
  VLOG(1) << "Communicator stop";
242 243 244 245
  running_ = false;
  if (!communicator_) {
    VLOG(0) << "Communicator is not inited, do nothing";
  } else {
246
    if (main_thread_) {
247
      VLOG(1) << "stop send thread";
248 249
      main_thread_->join();
      main_thread_.reset(nullptr);
250
    }
Q
Qiao Longfei 已提交
251
  }
252
  VLOG(1) << "Communicator stop done";
Q
Qiao Longfei 已提交
253 254
}

255 256 257
void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
                             const std::vector<std::string> &var_tables,
                             const framework::Scope &scope) {
258 259
  waiting_ = false;

260
  PADDLE_ENFORCE_EQ(
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
      var_tables.size(), 1,
      platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));

  auto table_name = var_tables[0];
  auto &queue = send_varname_to_queue_.at(table_name);

  if (table_name == STEP_COUNTER) {
    auto tmp_var = std::make_shared<Variable>();
    auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
    tensor->Resize(framework::make_ddim({1}));
    auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
    out_d[0] = 1;
    VLOG(3) << "send to " << table_name << " with queue size " << queue->Size();
    queue->Push(tmp_var);
  } else {
    PADDLE_ENFORCE_GE(var_names.size(), 1,
                      platform::errors::InvalidArgument(
                          "var_names.size() >= 1 is permitted"));

    auto *var = scope.FindVar(var_names[0]);

    PADDLE_ENFORCE_EQ(
        var->IsInitialized(), true,
        platform::errors::InvalidArgument("grad var should be inited"));

    auto tmp_var = std::make_shared<Variable>();
    if (var->IsType<framework::SelectedRows>()) {
      framework::CopyVariable(*var, tmp_var.get());
      VLOG(3) << "send to " << table_name << " with queue size "
              << queue->Size();
      queue->Push(tmp_var);
    } else if (var->IsType<framework::LoDTensor>()) {
      // push var into send queue by var_name
      auto var_name = var_names[0];
      framework::CopyVariable(*var, tmp_var.get());
      VLOG(3) << "send to " << table_name << " with queue size "
              << queue->Size();
      queue->Push(tmp_var);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "unknown var type to copy, only support LoDTensor/SelectedRows"));
    }
  }
304
}
305

306 307 308 309
void HalfAsyncCommunicator::Clean() {
  for (auto &iter : send_varname_to_queue_) {
    auto &var_name = iter.first;
    auto &var_queue = iter.second;
310

311 312
    while (var_queue->Size() > 0) {
      var_queue->Pop();
313 314
    }

315 316 317 318 319 320 321 322 323 324 325
    VLOG(3) << "clean var: " << var_name << " done";
  }
}

int HalfAsyncCommunicator::Meet() {
  while (running_) {
    if (barrier_counter_.load() >= barrier_trigger_.load() &&
        barrier_trigger_.load() != 0) {
      break;
    } else {
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
326
    }
327
  }
328

329 330
  return barrier_counter_.load();
}
C
Chengmo 已提交
331

332 333
void HalfAsyncCommunicator::Barrier() {
  barrier_counter_++;
334

335 336 337
  if (!running_) {
    VLOG(3) << "Communicator is not running, release barrier";
    return;
338 339
  }

340 341 342
  {
    std::unique_lock<std::mutex> lk(barrier_mutex_);
    barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); });
343
  }
344
}
345

346 347 348 349
void HalfAsyncCommunicator::BarrierTriggerDecrement() {
  barrier_trigger_--;
  VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to "
          << barrier_trigger_.load();
350 351
}

352 353 354 355 356
void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) {
  barrier_trigger_.store(initial_val);

  VLOG(3) << "BarrierTriggerReset reset barrier trigger to "
          << barrier_trigger_.load();
357 358
}

359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
void HalfAsyncCommunicator::BarrierWeakUp() {
  barrier_counter_.store(0);
  barrier_cond_.notify_all();
}

void SyncCommunicator::BarrierSend() {
  if (!running_) return;

  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);

  std::vector<distributed::VarHandlePtr> rets;

  for (auto &ep : pserver_endpoints_) {
    rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
374
  }
375 376 377 378 379 380 381

  for (size_t i = 0; i < rets.size(); i++) {
    PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
                                               "internal error in RPCClient"));
  }

  VLOG(4) << "BarrierSend with SyncCommunicator";
382 383
}

384 385 386 387 388 389 390 391 392
void SyncCommunicator::BarrierRecv() {
  if (!running_) return;

  distributed::RPCClient *rpc_client =
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);

  std::vector<distributed::VarHandlePtr> rets;
  for (auto &ep : pserver_endpoints_) {
    rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
393 394
  }

395 396 397
  for (size_t i = 0; i < rets.size(); i++) {
    PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
                                               "internal error in RPCClient"));
398
  }
399 400

  VLOG(4) << "BarrierRecv with SyncCommunicator";
401 402
}

403 404 405 406 407 408
void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
                               const RpcCtxMap &recv_varname_to_ctx,
                               Scope *recv_scope) {
  send_varname_to_ctx_ = std::move(send_varname_to_ctx);
  recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
  recv_scope_ = std::move(recv_scope);
409

410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
  PADDLE_ENFORCE_GT(
      send_varname_to_ctx.size(), 0,
      platform::errors::InvalidArgument("send var contexts can not be zero"));

  send_scope_.reset(new Scope());
  for (auto &iter : send_varname_to_ctx_) {
    auto &varname = iter.first;

    if (varname == STEP_COUNTER) {
      send_varname_to_queue_[varname] =
          std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
              send_queue_size_);
    } else {
      auto &send_ctx = iter.second;

      if (!send_ctx.is_sparse) {
C
Chengmo 已提交
426
        continue;
427 428
      }

429 430 431
      send_ids_to_queue_[varname] =
          std::make_shared<BlockingQueue<std::vector<int64_t>>>(
              send_queue_size_);
432 433
    }
  }
434
  send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
435

436 437 438 439
  if (recv_varname_to_ctx.size() == 0) {
    VLOG(0) << "nothing need to be received, will not start recv_thread";
  } else {
    recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
440 441
  }

442 443 444
  delta_scope_.reset(new Scope());
  old_scope_.reset(new Scope());
  pserver_scope_.reset(new Scope());
445

446
  Init();
447 448
}

449 450 451 452
void GeoCommunicator::Send(const std::vector<std::string> &var_names,
                           const std::vector<std::string> &var_tables,
                           const framework::Scope &scope) {
  waiting_ = false;
453

454 455 456
  PADDLE_ENFORCE_EQ(
      var_tables.size(), 1,
      platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
C
Chengmo 已提交
457

458
  auto table_name = var_tables[0];
459

460 461
  if (table_name == STEP_COUNTER) {
    auto &queue = send_varname_to_queue_.at(table_name);
462

463 464 465 466 467 468 469 470 471 472 473 474
    auto tmp_var = std::make_shared<Variable>();
    auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
    tensor->Resize(framework::make_ddim({1}));
    auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
    out_d[0] = 1;
    VLOG(3) << "send to " << table_name << " with queue size " << queue->Size();
    queue->Push(tmp_var);
  } else {
    auto &queue = send_ids_to_queue_.at(table_name);
    PADDLE_ENFORCE_EQ(var_names.size(), 1,
                      platform::errors::InvalidArgument(
                          "var_names.size() == 1 is permitted"));
475

476
    auto *var = scope.FindVar(var_names[0]);
477

478 479 480
    PADDLE_ENFORCE_EQ(
        var->IsInitialized(), true,
        platform::errors::InvalidArgument("grad var should be inited"));
481

482 483 484 485
    if (!var->IsType<framework::SelectedRows>()) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Only LodTensor can be send in GeoCommunicator::Send"));
    }
C
Chengmo 已提交
486

487 488 489 490
    std::vector<int64_t> ids;
    auto &rows = var->Get<framework::SelectedRows>().rows();
    ids.assign(rows.begin(), rows.end());
    queue->Push(ids);
491
  }
492 493 494 495 496 497 498 499 500
}

void GeoCommunicator::SendByCommunicator(int batches) {
  std::vector<std::future<void>> tasks;
  tasks.reserve(send_varname_to_ctx_.size());

  for (auto &iter : send_varname_to_ctx_) {
    auto &var_name = iter.first;
    auto &send_ctx = iter.second;
C
Chengmo 已提交
501

502 503 504 505
    auto send_task = [this, batches, &var_name, &send_ctx] {
      if (var_name == STEP_COUNTER) {
        return;
      }
C
Chengmo 已提交
506

507 508 509 510 511 512 513 514 515
      if (send_ctx.is_sparse) {
        SendSparse(var_name, batches);
      } else {
        VLOG(1) << "send dense " << var_name << " begin";
        SendDense(var_name);
        VLOG(1) << "send dense " << var_name << " done";
      }
    };
    tasks.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
C
Chengmo 已提交
516
  }
517

518 519 520 521
  for (auto &task : tasks) {
    task.wait();
  }
}
C
Chengmo 已提交
522

523 524 525
void GeoCommunicator::SendSparse(const std::string &varname, int batches) {
  std::vector<int64_t> ids;
  auto &ids_queue = send_ids_to_queue_.at(varname);
C
Chengmo 已提交
526

527 528 529
  for (int i = 0; i < batches; ++i) {
    auto pop_ids = ids_queue->Pop();
    std::copy(pop_ids.begin(), pop_ids.end(), back_inserter(ids));
530 531
  }

532 533 534 535 536 537
  auto size = ids.size();

  std::set<int64_t> st(ids.begin(), ids.end());
  ids.assign(st.begin(), st.end());
  VLOG(1) << "SendSparse receive var: " << varname << " unset: " << size
          << " set: " << ids.size();
C
Chengmo 已提交
538

539 540 541
  if (ids.empty()) {
    LOG(WARNING) << "WARNING: GEO has nothing to send, return directly ";
    return;
542 543
  }

544 545 546 547 548 549 550 551
  auto *var_latest = recv_scope_->FindVar(varname);

  PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
  auto &t_latest = var_latest->Get<framework::LoDTensor>();

  auto dims1 = t_latest.dims()[1];
C
Chengmo 已提交
552 553

  auto cpu_ctx = paddle::platform::CPUDeviceContext();
554 555 556 557 558 559 560 561
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
  t_delta->set_height(ids.size());
  t_delta->mutable_rows()->assign(ids.begin(), ids.end());
  auto *t_value = t_delta->mutable_value();
  t_value->mutable_data<float>(
      framework::make_ddim({static_cast<int64_t>(ids.size()), dims1}),
      cpu_ctx.GetPlace());
C
Chengmo 已提交
562

563 564 565
  std::vector<std::vector<std::vector<float> *>> values;
  auto *ins = distributed::LargeScaleKV::GetInstance();
  ins->Get(varname)->Get(ids, {"Param"}, &values);
C
Chengmo 已提交
566

567 568
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  float coefficient = 1.0 / static_cast<float>(trainers_);
C
Chengmo 已提交
569

570 571 572 573 574 575
  for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
    blas.VSUB(dims1, t_latest.data<float>() + ids[j] * dims1,
              values[j][0]->data(), t_value->data<float>() + j * dims1);
    blas.SCAL(dims1, coefficient, t_value->data<float>() + j * dims1);
    blas.VADD(dims1, values[j][0]->data(), t_value->data<float>() + j * dims1,
              values[j][0]->data());
C
Chengmo 已提交
576 577
  }

578 579 580
  auto &ctx = send_varname_to_ctx_.at(varname);
  auto send = distributed::ParameterSend<float>();
  send(ctx, *delta_scope_, true, 1);
581 582
}

583 584 585
void GeoCommunicator::SendDense(const std::string &varname) {
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_timestamp = old_scope_->FindVar(varname);
586

587 588 589 590 591 592
  PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
  PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true,
                    platform::errors::Unavailable(
                        "%s is not initialized, please check", varname));
593

594 595
  auto &t_latest = var_latest->Get<framework::LoDTensor>();
  auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
C
Chengmo 已提交
596

597 598 599 600
  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
  t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace());
C
Chengmo 已提交
601

602 603 604
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  blas.VSUB(t_latest.numel(), t_latest.data<float>(),
            t_timestamp->data<float>(), t_delta->data<float>());
C
Chengmo 已提交
605

606 607
  float coefficient = 1.0 / static_cast<float>(trainers_);
  blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>());
C
Chengmo 已提交
608

609 610
  blas.VADD(t_latest.numel(), t_timestamp->data<float>(),
            t_delta->data<float>(), t_timestamp->data<float>());
611

612 613 614 615
  auto &ctx = send_varname_to_ctx_.at(varname);
  auto send = distributed::ParameterSend<float>();
  send(ctx, *delta_scope_, true, 1);
}
616

617 618 619
void GeoCommunicator::RecvByCommunicator() {
  std::vector<std::future<void>> tasks;
  tasks.reserve(recv_varname_to_ctx_.size());
620

621 622 623
  for (auto &iter : recv_varname_to_ctx_) {
    auto &var_name = iter.first;
    auto &recv_ctx = iter.second;
624

625 626 627 628 629 630 631 632 633 634
    auto recv_task = [this, &var_name, &recv_ctx] {
      if (recv_ctx.is_sparse) {
        RecvSparse(var_name);
      } else {
        VLOG(1) << "recv dense " << var_name << " begin";
        RecvDense(var_name);
        VLOG(1) << "recv dense " << var_name << " done";
      }
    };
    tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
635
  }
636 637
  for (auto &task : tasks) {
    task.wait();
638 639 640
  }
}

641 642
void GeoCommunicator::RecvSparse(const std::string &varname) {
  VLOG(1) << "RecvSparse receive var: " << varname;
643

644 645
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_psrever = pserver_scope_->Var(varname);
646

647 648 649
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
  recv(ctx, *pserver_scope_, true);
650

651 652 653 654
  PADDLE_ENFORCE_EQ(
      var_psrever->IsInitialized(), true,
      platform::errors::Unavailable(
          "%s in pserver scope is not initialized, please check", varname));
655

656 657 658
  std::vector<int64_t> ids;
  ids.assign(var_psrever->Get<framework::SelectedRows>().rows().begin(),
             var_psrever->Get<framework::SelectedRows>().rows().end());
659

660 661
  VLOG(1) << "RecvSparse receive var: " << varname
          << " ids Size: " << ids.size();
662

663
  auto t_psrever = var_psrever->Get<framework::SelectedRows>().value();
664

665
  std::vector<std::vector<std::vector<float> *>> old_values;
666

667 668
  auto *ins = distributed::LargeScaleKV::GetInstance();
  ins->Get(varname)->Get(ids, {"Param"}, &old_values);
669

670
  auto *t_latest = var_latest->GetMutable<framework::LoDTensor>();
671

672 673
  auto dims1 = t_latest->dims()[1];
  auto numel = ids.size() * dims1;
674

675 676 677 678 679 680 681 682 683 684 685 686 687 688
  std::vector<float> v_delta;
  v_delta.resize(numel);

  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);

  for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
    blas.VSUB(dims1, t_psrever.data<float>() + j * dims1,
              old_values[j][0]->data(), v_delta.data() + j * dims1);
    blas.VADD(dims1, t_latest->data<float>() + ids[j] * dims1,
              v_delta.data() + j * dims1,
              t_latest->data<float>() + ids[j] * dims1);
    blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1,
               old_values[j][0]->data());
689 690 691
  }
}

692 693 694 695
void GeoCommunicator::RecvDense(const std::string &varname) {
  auto *var_latest = recv_scope_->FindVar(varname);
  auto *var_timestamp = old_scope_->FindVar(varname);
  auto *var_psrever = pserver_scope_->Var(varname);
696

697 698 699
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
  recv(ctx, *pserver_scope_, true);
700

701 702 703 704
  PADDLE_ENFORCE_EQ(
      var_psrever->IsInitialized(), true,
      platform::errors::Unavailable(
          "%s in pserver scope is not initialized, please check", varname));
705

706 707 708
  auto t_psrever = var_psrever->Get<framework::LoDTensor>();
  auto t_latest = var_latest->GetMutable<framework::LoDTensor>();
  auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
709

710 711 712 713
  auto cpu_ctx = paddle::platform::CPUDeviceContext();
  auto *var_delta = delta_scope_->Var(varname);
  auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
  t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace());
714

715 716 717 718 719 720 721
  auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
  blas.VSUB(t_latest->numel(), t_psrever.data<float>(),
            t_timestamp->data<float>(), t_delta->data<float>());
  blas.VADD(t_latest->numel(), t_latest->data<float>(), t_delta->data<float>(),
            t_latest->data<float>());
  blas.VCOPY(t_latest->numel(), t_psrever.data<float>(),
             t_timestamp->data<float>());
722 723
}

724 725 726
void GeoCommunicator::Init() {
  std::vector<std::future<void>> tasks;
  tasks.reserve(recv_varname_to_ctx_.size());
727

728 729 730
  for (auto &iter : recv_varname_to_ctx_) {
    auto &var_name = iter.first;
    auto &recv_ctx = iter.second;
731

732 733 734 735 736 737
    auto recv_task = [this, &var_name, &recv_ctx] {
      if (!recv_ctx.is_sparse) {
        InitDense(var_name);
      }
    };
    tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
738 739
  }

740 741
  for (auto &task : tasks) {
    task.wait();
742
  }
743
  InitSparse();
744
}
745

746 747 748
void GeoCommunicator::InitDense(const std::string varname) {
  auto *var = old_scope_->Var(varname);
  var->GetMutable<framework::LoDTensor>();
T
tangwei12 已提交
749

750 751 752 753 754
  auto &ctx = recv_varname_to_ctx_.at(varname);
  auto recv = distributed::ParameterRecv<float>();
  recv(ctx, *old_scope_);
  VLOG(1) << "init dense variable " << varname << " done";
}
T
tangwei12 已提交
755

756 757
void GeoCommunicator::InitSparse() {
  auto sparse_metas = string::split_string<std::string>(sparse_attrs_, "#");
T
tangwei12 已提交
758

759 760
  std::vector<distributed::SparseMeta> metas;
  std::vector<int64_t> dicts;
T
tangwei12 已提交
761

762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779
  for (auto &sparse_meta : sparse_metas) {
    auto attrs = string::split_string<std::string>(sparse_meta, ":");

    auto meta = distributed::SparseMeta();
    meta.name = attrs[0];
    meta.value_names = {"Param"};

    auto dic = string::split_string<std::string>(attrs[1], ",");
    dicts.push_back(std::stoi(dic[0]));
    meta.value_dims = {std::stoi(dic[1])};
    meta.mode = distributed::Mode::training;
    meta.grad_name = "none";
    meta.cached_varnames = {};
    meta.initializer_attrs = string::split_string<std::string>(attrs[2]);
    meta.entry = "none";

    VLOG(3) << "add sparse meta: " << meta.ToString();
    metas.push_back(meta);
T
tangwei12 已提交
780 781
  }

782
  LargeScaleKV::Init(metas);
T
tangwei12 已提交
783

784 785 786
  for (size_t i = 0; i < metas.size(); i++) {
    auto &varname = metas[i].name;
    auto &dict = dicts[i];
T
tangwei12 已提交
787

788 789
    std::vector<int64_t> ids;
    ids.reserve(dict);
T
tangwei12 已提交
790

791 792 793
    for (auto j = 0; j < dict; ++j) {
      ids.push_back(j);
    }
T
tangwei12 已提交
794

795 796 797 798 799
    auto *ins = distributed::LargeScaleKV::GetInstance();
    ins->Get(varname)->Init(ids);

    VLOG(3) << "GeoCommunicator init sparse " << varname << " with size "
            << ids.size();
T
tangwei12 已提交
800 801
  }

802
  VLOG(3) << "init sparse variable done";
T
tangwei12 已提交
803 804
}

Q
Qiao Longfei 已提交
805 806 807
}  // namespace distributed
}  // namespace operators
}  // namespace paddle