fleet.cc 32.2 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"

17 18
#include <google/protobuf/text_format.h>

19 20
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include "paddle/fluid/distributed/ps/table/table.h"
D
danleifeng 已提交
21 22 23
#if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#endif
T
tangwei12 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36

namespace paddle {
namespace distributed {

using framework::ProgramDesc;
using framework::VarDesc;
using framework::Variable;

const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false;

std::shared_ptr<paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ = NULL;
37 38 39 40 41 42 43 44 45 46 47 48 49 50
std::shared_ptr<paddle::distributed::PSClient> FleetWrapper::worker_ptr_ = NULL;

int FleetWrapper::RegisterHeterCallback(HeterCallBackFunc handler) {
  VLOG(0) << "RegisterHeterCallback support later";
  return 0;
}

int32_t FleetWrapper::CopyTable(const uint64_t src_table_id,
                                const uint64_t dest_table_id) {
  VLOG(0) << "CopyTable support later";
  return 0;
}

int32_t FleetWrapper::CopyTableByFeasign(
51 52
    const uint64_t src_table_id,
    const uint64_t dest_table_id,
53 54 55 56
    const std::vector<uint64_t>& feasign_list) {
  VLOG(0) << "CopyTableByFeasign support later";
  return 0;
}
T
tangwei12 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70

void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms,
                                          int connect_timeout_ms,
                                          int max_retry) {
  client2client_request_timeout_ms_ = request_timeout_ms;
  client2client_connect_timeout_ms_ = connect_timeout_ms;
  client2client_max_retry_ = max_retry;
}

void FleetWrapper::LoadSparseOnServer(const std::string& path,
                                      const std::string& meta,
                                      uint32_t table_id) {
  VLOG(3) << "load sparse table " << table_id << " with " << path << " meta "
          << meta;
Z
zhaocaibei123 已提交
71
  pserver_ptr_->_server_ptr->GetTable(table_id)->Load(path, meta);
T
tangwei12 已提交
72 73
}

74 75
void FleetWrapper::InitServer(
    const std::string& dist_desc,
76 77 78
    const std::vector<std::string>& host_sign_list,
    int index,
    int trainers,
79
    const std::vector<framework::ProgramDesc>& server_sub_program) {
T
tangwei12 已提交
80 81 82 83
  if (!is_initialized_) {
    VLOG(3) << "Going to init server";
    pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>(
        new paddle::distributed::PSCore());
84 85 86 87 88 89
    pserver_ptr_->InitServer(dist_desc,
                             &host_sign_list,
                             host_sign_list.size(),
                             index,
                             trainers,
                             server_sub_program);
T
tangwei12 已提交
90 91 92 93 94 95
    is_initialized_ = true;
  } else {
    VLOG(3) << "Server can be initialized only once";
  }
}

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
void FleetWrapper::InitGFlag(const std::string& gflags) {
  VLOG(3) << "Init With Gflags:" << gflags;
  std::vector<std::string> flags = paddle::string::split_string(gflags);
  if (flags.size() < 1) {
    flags.push_back("-max_body_size=314217728");
    flags.push_back("-bthread_concurrency=40");
    flags.push_back("-socket_max_unwritten_bytes=2048000000");
    flags.push_back("-max_connection_pool_size=1950");
  }
  auto it = flags.begin();
  flags.insert(it, "exe default");
  char* flags_ptr[flags.size()];
  for (size_t i = 0; i < flags.size(); ++i) {
    flags_ptr[i] = (char*)(flags[i].c_str());  // NOLINT
  }
  int params_cnt = flags.size();
  char** params_ptr = &(flags_ptr[0]);
  ::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
}
T
tangwei12 已提交
115

116 117 118 119 120 121 122 123 124 125 126 127 128 129
void FleetWrapper::InitWorker(const std::string& dist_desc,
                              const std::vector<std::string>& host_sign_list,
                              int index) {
  if (!is_initialized_) {
    // not used, just for psclient's init
    // TODO(zhaocaibei123): remove this later
    std::map<uint64_t, std::vector<paddle::distributed::Region>>
        dense_pull_regions;

    if (worker_ptr_.get() == nullptr) {
      paddle::distributed::PSParameter ps_param;
      google::protobuf::TextFormat::ParseFromString(dist_desc, &ps_param);
      InitGFlag(ps_param.init_gflags());
      int servers = host_sign_list.size();
Z
zhaocaibei123 已提交
130
      ps_env_.SetPsServers(&host_sign_list, servers);
131
      worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>(
Z
zhaocaibei123 已提交
132 133
          paddle::distributed::PSClientFactory::Create(ps_param));
      worker_ptr_->Configure(ps_param, dense_pull_regions, ps_env_, index);
D
danleifeng 已提交
134 135 136 137 138 139 140
#if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE
      VLOG(3) << "FleetWrapper::InitWorker InitializeGPUServer";
      auto* accessor = worker_ptr_->GetTableAccessor(0);
      auto ps_gpu_wrapper = paddle::framework::PSGPUWrapper::GetInstance();
      ps_gpu_wrapper->InitializeGPUServer(ps_param);
      ps_gpu_wrapper->SetTableAccessor(accessor);
#endif
141
    }
T
tangwei12 已提交
142
  } else {
143
    VLOG(3) << "Client can be initialized only once";
T
tangwei12 已提交
144 145 146
  }
}

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
void FleetWrapper::InitFlWorker(const std::vector<std::string>& host_list,
                                int index,
                                const std::string& self_endpoint) {
  assert(worker_ptr_.get() != nullptr);
  uint32_t coordinator_num = host_list.size();
  ps_env_.SetCoordinators(&host_list, coordinator_num);
  auto ptr = dynamic_cast<BrpcPsClient*>(worker_ptr_.get());
  ptr->InitializeFlWorker(self_endpoint);
  return;
}

void FleetWrapper::PushFLClientInfoSync(const std::string& fl_client_info) {
  // FLClientInfo fci;
  // google::protobuf::TextFormat::ParseFromString(fl_client_info, &fci);
  // InitGFlag(fci.init_gflags());
  auto ptr = dynamic_cast<BrpcPsClient*>(worker_ptr_.get());
  VLOG(0) << "fl-ps > PushFLClientInfoSync: " << typeid(worker_ptr_).name()
          << ", " << typeid(ptr).name() << ", " << typeid(BrpcPsClient).name();
  ptr->PushFLClientInfoSync(fl_client_info);
  return;
}

std::string FleetWrapper::PullFlStrategy() {
  auto ptr = dynamic_cast<BrpcPsClient*>(worker_ptr_.get());
  std::string str = ptr->PullFlStrategy();
  return str;
}

T
tangwei12 已提交
175 176
void FleetWrapper::StopServer() {
  VLOG(3) << "Going to stop server";
Z
zhaocaibei123 已提交
177
  auto status = worker_ptr_->StopServer();
T
tangwei12 已提交
178 179 180 181 182
  status.wait();
}

void FleetWrapper::FinalizeWorker() {
  VLOG(3) << "Going to finalize worker";
Z
zhaocaibei123 已提交
183
  worker_ptr_->FinalizeWorker();
T
tangwei12 已提交
184 185 186 187 188 189 190 191 192 193
}

void FleetWrapper::BarrierWithTable(uint32_t barrier_type) {
  VLOG(3) << "Going to Barrier worker";
  auto* communicator = Communicator::GetInstance();
  communicator->BarrierWithTable(barrier_type);
}

uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) {
  VLOG(3) << "Going to run server with ip " << ip << " port " << port;
Z
zhaocaibei123 已提交
194
  auto ret = pserver_ptr_->RunServer(ip, port);
T
tangwei12 已提交
195 196 197 198 199
  return ret;
}

std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
  VLOG(3) << "Going to get client info";
Z
zhaocaibei123 已提交
200
  std::vector<uint64_t> res = ps_env_.GetClientInfo();
201 202 203
  for (auto rr : res) {
    VLOG(2) << "FleetWrapper::GetClientInfo " << rr;
  }
Z
zhaocaibei123 已提交
204
  return res;
T
tangwei12 已提交
205 206
}

207 208
int FleetWrapper::SetClients(std::vector<uint64_t>& host_sign_list) {
  int node = host_sign_list.size();
Z
zhaocaibei123 已提交
209
  return ps_env_.SetPsClients(host_sign_list.data(), node);
210 211
}

T
tangwei12 已提交
212
void FleetWrapper::CreateClient2ClientConnection() {
Z
zhaocaibei123 已提交
213
  VLOG(1) << "Going to create client2client connection";
Z
zhaocaibei123 已提交
214 215 216
  worker_ptr_->CreateClient2ClientConnection(client2client_request_timeout_ms_,
                                             client2client_connect_timeout_ms_,
                                             client2client_max_retry_);
T
tangwei12 已提交
217 218
}

219
std::future<int32_t> FleetWrapper::PullSparseVarsAsync(
220 221 222 223 224 225
    const Scope& scope,
    const uint64_t table_id,
    const std::vector<std::string>& var_names,
    std::vector<uint64_t>* fea_keys,
    std::vector<std::vector<float>>* fea_values,
    int fea_value_dim) {
226 227 228 229 230 231 232 233
  fea_keys->clear();
  fea_keys->resize(0);
  fea_keys->reserve(MAX_FEASIGN_NUM);
  for (auto name : var_names) {
    Variable* var = scope.FindVar(name);
    if (var == nullptr) {
      continue;
    }
234
    phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
    CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
    int64_t* ids = tensor->data<int64_t>();
    size_t len = tensor->numel();
    for (auto i = 0u; i < len; ++i) {
      if (ids[i] == 0u) {
        continue;
      }
      fea_keys->push_back(static_cast<uint64_t>(ids[i]));
    }
  }
  fea_values->resize(fea_keys->size() + 1);
  for (auto& t : *fea_values) {
    t.resize(fea_value_dim);
  }
  std::vector<float*> pull_result_ptr;
  for (auto& t : *fea_values) {
    pull_result_ptr.push_back(t.data());
  }

  bool training = true;
255 256
  return pserver_ptr_->_worker_ptr->PullSparse(pull_result_ptr.data(),
                                               table_id,
Z
zhaocaibei123 已提交
257
                                               fea_keys->data(),
258 259
                                               fea_keys->size(),
                                               training);
260 261
}

T
tangwei12 已提交
262
void FleetWrapper::PullSparseVarsSync(
263 264 265 266 267 268
    const Scope& scope,
    const uint64_t table_id,
    const std::vector<std::string>& var_names,
    std::vector<uint64_t>* fea_keys,
    std::vector<std::vector<float>>* fea_values,
    int fea_value_dim,
T
tangwei12 已提交
269 270 271 272 273 274 275 276 277 278 279 280
    const std::vector<std::string>& var_emb_names) {
  std::vector<std::future<int32_t>> pull_sparse_status;
  pull_sparse_status.resize(0);
  fea_keys->clear();
  fea_keys->resize(0);
  fea_keys->reserve(MAX_FEASIGN_NUM);
  for (size_t var_index = 0; var_index < var_names.size(); ++var_index) {
    const std::string& name = var_names[var_index];
    Variable* var = scope.FindVar(name);
    if (var == nullptr) {
      continue;
    }
281
    phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
T
tangwei12 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
    CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
    int64_t* ids = tensor->data<int64_t>();
    size_t len = tensor->numel();

    // skip slots which do not have embedding
    const std::string& emb_name = var_emb_names[var_index];
    Variable* emb_var = scope.FindVar(emb_name);
    if (emb_var == nullptr) {
      continue;
    }

    for (auto i = 0u; i < len; ++i) {
      if (ids[i] == 0u) {
        continue;
      }
      fea_keys->push_back(static_cast<uint64_t>(ids[i]));
    }
  }
  fea_values->resize(fea_keys->size() + 1);
  for (auto& t : *fea_values) {
    t.resize(fea_value_dim);
  }
  std::vector<float*> pull_result_ptr;
  for (auto& t : *fea_values) {
    pull_result_ptr.push_back(t.data());
  }
308
  bool training = true;
309 310 311 312 313
  auto status = pserver_ptr_->_worker_ptr->PullSparse(pull_result_ptr.data(),
                                                      table_id,
                                                      fea_keys->data(),
                                                      fea_keys->size(),
                                                      training);
T
tangwei12 已提交
314 315 316 317 318 319 320 321 322 323 324 325
  pull_sparse_status.push_back(std::move(status));
  for (auto& t : pull_sparse_status) {
    t.wait();
    auto status = t.get();
    if (status != 0) {
      LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
      sleep(sleep_seconds_before_fail_exit_);
      exit(-1);
    }
  }
}

326 327 328
// is_training is true means training, false means inference, the behavior is
// different on pserver

329 330 331 332 333 334 335 336
void FleetWrapper::PullSparseToTensorSync(
    const uint64_t table_id,
    int fea_dim,
    uint64_t padding_id,
    platform::Place place,
    bool is_training,
    std::vector<const phi::DenseTensor*>* inputs,
    std::vector<phi::DenseTensor*>* outputs) {
T
tangwei12 已提交
337 338 339 340 341
  std::vector<uint64_t> fea_keys;
  std::vector<float*> pull_result_ptr;
  fea_keys.reserve(MAX_FEASIGN_NUM / 100);
  pull_result_ptr.reserve(MAX_FEASIGN_NUM / 100);
  std::vector<float> init_value(fea_dim, 0);
342
  phi::DenseTensor* output = nullptr;
T
tangwei12 已提交
343 344 345 346
  float* output_data = nullptr;
  size_t output_index = -1;
  size_t output_len = 0;
  for (size_t index = 0; index < inputs->size(); ++index) {
347
    const phi::DenseTensor* tensor = inputs->at(index);
T
tangwei12 已提交
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
    const int64_t* ids = tensor->data<int64_t>();
    size_t len = tensor->numel();
    for (size_t i = 0; i < len; ++i, output_len += fea_dim) {
      if (!output || output_len == size_t(output->numel())) {
        ++output_index;
        CHECK(output_index < outputs->size());  // NOLINT
        output = outputs->at(output_index);
        output->set_lod(tensor->lod());
        output_data = output->mutable_data<float>(place);
        output_len = 0;
        CHECK(output->numel() % fea_dim == 0);  // NOLINT
        CHECK(output_data != nullptr);          // NOLINT
      }
      uint64_t real_id = static_cast<uint64_t>(ids[i]);
      if (real_id == padding_id) {
363 364
        memcpy(output_data + output_len,
               init_value.data(),
T
tangwei12 已提交
365 366 367 368 369 370 371
               sizeof(float) * fea_dim);
        continue;
      }
      fea_keys.push_back(real_id);
      pull_result_ptr.push_back(output_data + output_len);
    }
  }
Z
zhaocaibei123 已提交
372

373 374 375 376 377
  auto status = worker_ptr_->PullSparse(pull_result_ptr.data(),
                                        table_id,
                                        fea_keys.data(),
                                        fea_keys.size(),
                                        is_training);
T
tangwei12 已提交
378 379 380 381 382 383 384 385 386
  status.wait();
  auto ret = status.get();
  if (ret != 0) {
    LOG(ERROR) << "fleet pull sparse failed, status[" << ret << "]";
    sleep(sleep_seconds_before_fail_exit_);
  }
}

void FleetWrapper::PullDenseVarsAsync(
387 388
    const Scope& scope,
    const uint64_t tid,
T
tangwei12 已提交
389
    const std::vector<std::string>& var_names,
390 391
    std::vector<std::future<int32_t>>* pull_dense_status,
    bool in_cpu) {
Z
zhaocaibei123 已提交
392
  auto& regions = regions_[tid];
T
tangwei12 已提交
393 394 395 396 397 398 399 400
  regions.clear();
  regions.resize(var_names.size());
  for (auto i = 0u; i < var_names.size(); ++i) {
    std::string varname = var_names[i];
    if (!in_cpu) {
      varname = var_names[i] + "pin";
    }
    Variable* var = scope.FindVar(varname);
401
    phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
T
tangwei12 已提交
402 403 404 405
    float* w = tensor->data<float>();
    paddle::distributed::Region reg(w, tensor->numel());
    regions[i] = std::move(reg);
  }
Z
zhaocaibei123 已提交
406 407

  auto status = worker_ptr_->PullDense(regions.data(), regions.size(), tid);
T
tangwei12 已提交
408 409 410 411
  pull_dense_status->push_back(std::move(status));
}

void FleetWrapper::PullDenseVarsSync(
412 413
    const Scope& scope,
    const uint64_t tid,
T
tangwei12 已提交
414
    const std::vector<std::string>& var_names) {
Z
zhaocaibei123 已提交
415
  auto& regions = regions_[tid];
T
tangwei12 已提交
416 417 418 419
  regions.clear();
  regions.reserve(var_names.size());
  for (auto& t : var_names) {
    Variable* var = scope.FindVar(t);
420
    phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
421 422 423 424 425
    if (!platform::is_gpu_place(tensor->place())) {
      float* w = tensor->data<float>();
      paddle::distributed::Region reg(w, tensor->numel());
      regions.emplace_back(std::move(reg));
    }
T
tangwei12 已提交
426
  }
Z
zhaocaibei123 已提交
427
  auto status = worker_ptr_->PullDense(regions.data(), regions.size(), tid);
T
tangwei12 已提交
428 429 430 431
  status.wait();
}

void FleetWrapper::PushDenseParamSync(
432 433
    const Scope& scope,
    const uint64_t table_id,
T
tangwei12 已提交
434 435 436 437 438 439
    const std::vector<std::string>& var_names) {
  auto place = platform::CPUPlace();
  std::vector<paddle::distributed::Region> regions;
  for (auto& t : var_names) {
    Variable* var = scope.FindVar(t);
    CHECK(var != nullptr) << "var[" << t << "] not found";
440
    phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
441 442 443 444 445
    if (!platform::is_gpu_place(tensor->place())) {
      float* g = tensor->mutable_data<float>(place);
      paddle::distributed::Region reg(g, tensor->numel());
      regions.emplace_back(std::move(reg));
    }
T
tangwei12 已提交
446
  }
447
  auto push_status =
Z
zhaocaibei123 已提交
448
      worker_ptr_->PushDenseParam(regions.data(), regions.size(), table_id);
T
tangwei12 已提交
449 450 451 452 453 454
  push_status.wait();
  auto status = push_status.get();
  CHECK(status == 0) << "push dense param failed, status[" << status << "]";
}

void FleetWrapper::PushDenseVarsSync(
455 456
    Scope* scope,
    const uint64_t table_id,
T
tangwei12 已提交
457 458 459
    const std::vector<std::string>& var_names) {}

void FleetWrapper::PushDenseVarsAsync(
460 461
    const Scope& scope,
    const uint64_t table_id,
T
tangwei12 已提交
462
    const std::vector<std::string>& var_names,
463 464
    std::vector<std::future<int32_t>>* push_sparse_status,
    float scale_datanorm,
T
tangwei12 已提交
465
    int batch_size) {
Z
zhaocaibei123 已提交
466 467 468 469 470
  auto place = platform::CPUPlace();
  std::vector<paddle::distributed::Region> regions;
  for (auto& t : var_names) {
    Variable* var = scope.FindVar(t);
    CHECK(var != nullptr) << "var[" << t << "] not found";
471
    phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
472
    int count = tensor->numel();
Z
zhaocaibei123 已提交
473
    float* g = tensor->mutable_data<float>(place);
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
    // TODO(zhaocaibei123): how to get batch_size in op?
    if (scale_datanorm >= 0) {
      if (t.find(".batch_size@GRAD") != std::string::npos ||
          t.find(".batch_sum@GRAD") != std::string::npos) {
        Eigen::Map<Eigen::MatrixXf> mat(g, 1, count);
        float scale = 1.0 / batch_size;
        mat *= scale;
      } else if (t.find(".batch_square_sum@GRAD") != std::string::npos) {
        VLOG(3) << "epsilon: " << scale_datanorm;
        for (int i = 0; i < count; ++i) {
          g[i] = (g[i] - batch_size * scale_datanorm) / batch_size +
                 batch_size * scale_datanorm;
        }
      }
    }

Z
zhaocaibei123 已提交
490 491 492 493 494 495 496
    paddle::distributed::Region reg(g, tensor->numel());
    regions.emplace_back(std::move(reg));
    VLOG(3) << "FleetWrapper::PushDenseVarsAsync Var " << t << " talbe_id "
            << table_id << " Temp_data[0] " << g[0] << " Temp_data[-1] "
            << g[tensor->numel() - 1];
  }

Z
zhaocaibei123 已提交
497 498
  auto push_status =
      worker_ptr_->PushDense(regions.data(), regions.size(), table_id);
T
tangwei12 已提交
499 500 501
}

void FleetWrapper::PushSparseVarsAsync(
502 503
    const Scope& scope,
    const uint64_t table_id,
T
tangwei12 已提交
504 505 506 507 508 509 510
    const std::string& grad_varname,
    std::vector<std::future<int32_t>>* push_sparse_status) {
  std::vector<std::string> varnames;
  varnames.push_back(grad_varname);

  auto* communicator = Communicator::GetInstance();
  PADDLE_ENFORCE_EQ(
511 512
      communicator->Check(table_id),
      true,
T
tangwei12 已提交
513 514 515 516 517 518
      platform::errors::InvalidArgument(
          "can not find table: %s, please check your config", table_id));
  communicator->Send(varnames, scope);
}

void FleetWrapper::PushSparseVarsWithLabelAsync(
519 520 521 522
    const Scope& scope,
    const uint64_t table_id,
    const std::vector<uint64_t>& fea_keys,
    const std::vector<float>& fea_labels,
T
tangwei12 已提交
523
    const std::vector<std::string>& sparse_key_names,
524 525
    const std::vector<std::string>& sparse_grad_names,
    const int emb_dim,
T
tangwei12 已提交
526
    std::vector<std::vector<float>>* push_values,
527 528 529 530 531 532
    std::vector<std::future<int32_t>>* push_sparse_status,
    const int batch_size,
    const bool use_cvm,
    const bool dump_slot,
    std::vector<uint64_t>* sparse_push_keys,
    const bool no_cvm) {
T
tangwei12 已提交
533 534 535 536 537
  // not support
  return;
}

void FleetWrapper::PushSparseFromTensorWithLabelAsync(
538 539 540 541 542 543 544 545
    const Scope& scope,
    const uint64_t table_id,
    int fea_dim,
    uint64_t padding_id,
    bool scale_sparse,
    const std::string& accesor,
    const std::string& click_name,
    platform::Place place,
T
tangwei12 已提交
546
    const std::vector<std::string>& input_names,
547 548
    std::vector<const phi::DenseTensor*>* inputs,
    std::vector<const phi::DenseTensor*>* outputs) {
T
tangwei12 已提交
549 550 551 552
  // not support
  return;
}

Z
zhaocaibei123 已提交
553
void FleetWrapper::PushSparseFromTensorAsync(
554 555 556 557
    const uint64_t table_id,
    int fea_dim,
    uint64_t padding_id,
    platform::Place place,
558
    std::vector<const phi::DenseTensor*>* inputs,
559
    std::vector<int>& slots,
560 561 562
    const phi::DenseTensor* shows,
    const phi::DenseTensor* clks,
    std::vector<phi::DenseTensor*>* outputs,
563
    bool use_cvm_op) {
564
  CHECK(slots.size() == inputs->size());
Z
zhaocaibei123 已提交
565
  int batch_size = -1;
Z
zhaocaibei123 已提交
566
  bool batch_size_consist = true;
Z
zhaocaibei123 已提交
567
  for (auto* input : *inputs) {
D
danleifeng 已提交
568
    size_t cur_batch_size =
Z
zhaocaibei123 已提交
569 570
        input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0];
    if (batch_size == -1) {
571 572
      batch_size = static_cast<int>(cur_batch_size);
    } else if (batch_size != static_cast<int>(cur_batch_size)) {
Z
zhaocaibei123 已提交
573 574 575
      // CHECK(batch_size == cur_batch_size);  // NOLINT
      batch_size_consist = false;
      break;
Z
zhaocaibei123 已提交
576 577 578 579
    }
  }
  CHECK(batch_size > 0);  // NOLINT

D
danleifeng 已提交
580
  size_t show_size =
Z
zhaocaibei123 已提交
581
      shows->lod().size() ? shows->lod()[0].size() - 1 : shows->dims()[0];
D
danleifeng 已提交
582 583
  CHECK(show_size == size_t(batch_size) || show_size == 1);
  size_t clk_size =
Z
zhaocaibei123 已提交
584
      clks->lod().size() ? clks->lod()[0].size() - 1 : clks->dims()[0];
D
danleifeng 已提交
585
  CHECK(clk_size == size_t(batch_size) || clk_size == 1);
Z
zhaocaibei123 已提交
586

587
  CHECK(outputs->size() == inputs->size());
Z
zhaocaibei123 已提交
588 589 590 591 592 593 594 595 596 597 598 599
  std::vector<uint64_t> push_keys;
  push_keys.reserve(MAX_FEASIGN_NUM / 100);
  std::vector<std::vector<float>> push_values;
  push_values.reserve(MAX_FEASIGN_NUM / 100);
  size_t output_len = 0;
  size_t input_idx = 0;

  VLOG(2) << "fleet.cc::emb_dim: " << fea_dim;

  // TODO(zhaocaibei123): check type of show/clk is int? float? uint64?
  // const long int* show_tensor = shows->data<int64_t>();
  // const long int* clk_tensor = clks->data<int64_t>();
600 601
  const float* show_tensor = shows->data<float>();
  const float* clk_tensor = clks->data<float>();
Z
zhaocaibei123 已提交
602 603

  for (size_t index = 0; index < inputs->size(); ++index) {
604
    phi::DenseTensor* g_tensor = outputs->at(index);
605 606 607 608 609 610 611
    float* g = g_tensor->data<float>();
    // no cvm
    if (batch_size_consist) {  // TODO(zhaocaibei123): add config
                               // scale_sparse_gradient_with_batch_size_
      Eigen::Map<
          Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
          g_mat(g, g_tensor->numel() / fea_dim, fea_dim);
612 613 614 615 616
      if (use_cvm_op) {
        g_mat.rightCols(fea_dim - 2) *= batch_size;
      } else {
        g_mat.rightCols(fea_dim) *= batch_size;
      }
617 618
    }

619
    const phi::DenseTensor* tensor = inputs->at(index);
Z
zhaocaibei123 已提交
620 621
    const int64_t* ids = tensor->data<int64_t>();
    size_t len = tensor->numel();
622
    output_len = 0;
Z
zhaocaibei123 已提交
623 624

    if (tensor->lod().size() > 0) {
Z
zhangchunle 已提交
625
      for (size_t i = 0; i < tensor->lod()[0].size() - 1; ++i) {
626
        for (size_t j = tensor->lod()[0][i]; j < tensor->lod()[0][i + 1];
Z
zhaocaibei123 已提交
627 628 629 630 631 632
             ++j, output_len += fea_dim) {
          uint64_t real_id = static_cast<uint64_t>(ids[j]);
          if (real_id == padding_id) {
            continue;
          }
          push_keys.emplace_back(real_id);
633 634
          if (use_cvm_op) {
            push_values.emplace_back(fea_dim + 1);
635
            push_values.back()[0] = static_cast<float>(slots[index]);
636 637 638 639 640
            float* data = push_values.back().data() + 1;
            memcpy(data, g + output_len, sizeof(float) * fea_dim);
          } else {
            push_values.emplace_back(fea_dim + 3);
            // slot show clk grad... consistent with CtrCommonPushValue defined
641 642
            // in ctr_accessor.h
            push_values.back()[0] = static_cast<float>(slots[index]);
D
danleifeng 已提交
643 644 645 646
            push_values.back()[1] =
                (i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
            push_values.back()[2] =
                (i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
647 648 649 650 651 652 653
            float* data = push_values.back().data() + 3;
            memcpy(data, g + output_len, sizeof(float) * fea_dim);
          }
          ++input_idx;
        }
      }
    } else {
Z
zhangchunle 已提交
654
      for (size_t i = 0; i < len; ++i, output_len += fea_dim) {
655 656 657 658 659 660 661
        uint64_t real_id = static_cast<uint64_t>(ids[i]);
        if (real_id == padding_id) {
          continue;
        }
        push_keys.emplace_back(real_id);
        if (use_cvm_op) {
          push_values.emplace_back(fea_dim + 1);
662
          push_values.back()[0] = static_cast<float>(slots[index]);
663 664 665
          float* data = push_values.back().data() + 1;
          memcpy(data, g + output_len, sizeof(float) * fea_dim);
        } else {
Z
zhaocaibei123 已提交
666 667 668
          push_values.emplace_back(fea_dim + 3);
          // slot show clk grad... consistent with CtrCommonPushValue defined in
          // ctr_accessor.h
669 670 671
          push_values.back()[0] = static_cast<float>(slots[index]);
          push_values.back()[1] = (i >= show_size ? 1 : show_tensor[i]);
          push_values.back()[2] = (i >= clk_size ? 0 : clk_tensor[i]);
Z
zhaocaibei123 已提交
672
          float* data = push_values.back().data() + 3;
673
          memcpy(data, g + output_len, sizeof(float) * fea_dim);
Z
zhaocaibei123 已提交
674 675 676 677
        }
        ++input_idx;
      }
    }
Z
zhangchunle 已提交
678
    CHECK(static_cast<int64_t>(output_len) == g_tensor->numel());
Z
zhaocaibei123 已提交
679 680 681 682 683 684 685 686
  }

  std::vector<float*> push_g_vec(input_idx, nullptr);

  for (auto i = 0u; i < push_keys.size(); ++i) {
    push_g_vec[i] = push_values.at(i).data();
  }

687 688
  auto status = worker_ptr_->PushSparse(table_id,
                                        push_keys.data(),
Z
zhaocaibei123 已提交
689 690
                                        (const float**)push_g_vec.data(),
                                        push_keys.size());
Z
zhaocaibei123 已提交
691 692 693
}

void FleetWrapper::LoadModel(const std::string& path, const int mode) {
Z
zhaocaibei123 已提交
694
  auto ret = worker_ptr_->Load(path, std::to_string(mode));
T
tangwei12 已提交
695 696 697 698 699 700 701
  ret.wait();
  if (ret.get() != 0) {
    LOG(ERROR) << "load model from path:" << path << " failed";
  }
}

void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
702 703
                                     const std::string& path,
                                     const int mode) {
Z
zhaocaibei123 已提交
704
  auto ret = worker_ptr_->Load(table_id, path, std::to_string(mode));
T
tangwei12 已提交
705 706 707 708 709 710 711 712
  ret.wait();
  if (ret.get() != 0) {
    LOG(ERROR) << "load model of table id: " << table_id
               << ", from path: " << path << " failed";
  }
}

void FleetWrapper::SaveModel(const std::string& path, const int mode) {
Z
zhaocaibei123 已提交
713
  auto ret = worker_ptr_->Save(path, std::to_string(mode));
T
tangwei12 已提交
714 715 716 717 718 719 720 721
  ret.wait();
  int32_t feasign_cnt = ret.get();
  if (feasign_cnt == -1) {
    LOG(ERROR) << "save model failed";
  }
}

void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
722 723
                                     const std::string& path,
                                     const int mode) {
Z
zhaocaibei123 已提交
724
  auto ret = worker_ptr_->Save(table_id, path, std::to_string(mode));
T
tangwei12 已提交
725 726 727 728 729 730 731
  ret.wait();
  if (ret.get() != 0) {
    LOG(ERROR) << "save model of table id: " << table_id
               << ", to path: " << path << " failed";
  }
}

732 733
void FleetWrapper::RecvAndSaveTable(const uint64_t table_id,
                                    const std::string& path) {
Z
zhaocaibei123 已提交
734
  auto ret = worker_ptr_->RecvAndSaveTable(table_id, path);
735 736 737 738 739 740
  if (ret != 0) {
    LOG(ERROR) << "save model of table id: " << table_id
               << ", to path: " << path << " failed";
  }
}

T
tangwei12 已提交
741
void FleetWrapper::PrintTableStat(const uint64_t table_id) {
Z
zhaocaibei123 已提交
742
  auto ret = worker_ptr_->PrintTableStat(table_id);
T
tangwei12 已提交
743 744 745 746 747 748 749
  ret.wait();
  int32_t err_code = ret.get();
  if (err_code == -1) {
    LOG(ERROR) << "print table stat failed";
  }
}

750
void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) {
Z
zhaocaibei123 已提交
751
  auto ret = worker_ptr_->Shrink(table_id, std::to_string(threshold));
T
tangwei12 已提交
752
  ret.wait();
753 754 755 756
  int32_t err_code = ret.get();
  if (err_code == -1) {
    LOG(ERROR) << "shrink sparse table stat failed";
  }
T
tangwei12 已提交
757 758 759
}

void FleetWrapper::ClearModel() {
Z
zhaocaibei123 已提交
760
  auto ret = pserver_ptr_->_worker_ptr->Clear();
T
tangwei12 已提交
761 762 763 764
  ret.wait();
}

void FleetWrapper::ClearOneTable(const uint64_t table_id) {
Z
zhaocaibei123 已提交
765
  auto ret = pserver_ptr_->_worker_ptr->Clear(table_id);
T
tangwei12 已提交
766 767 768
  ret.wait();
}

769 770
void FleetWrapper::ShrinkDenseTable(int table_id,
                                    Scope* scope,
T
tangwei12 已提交
771
                                    std::vector<std::string> var_list,
772 773
                                    float decay,
                                    int emb_dim) {
T
tangwei12 已提交
774 775 776 777 778
  std::vector<paddle::distributed::Region> regions;
  for (std::string& name : var_list) {
    if (name.find("batch_sum") != std::string::npos) {
      Variable* var = scope->FindVar(name);
      CHECK(var != nullptr) << "var[" << name << "] not found";
779
      VLOG(3) << "prepare shrink dense batch_sum";
780
      phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
T
tangwei12 已提交
781 782 783 784
      float* g = tensor->data<float>();

      // show_batch_sum += N * log(decay)
      std::string size_name = name;
785 786
      size_name.replace(
          size_name.find("batch_sum"), size_name.length(), "batch_size");
T
tangwei12 已提交
787 788 789
      Variable* var_size = scope->FindVar(size_name);
      CHECK(var_size != nullptr) << "var[" << size_name << "] not found";
      VLOG(3) << "shrink dense batch_sum: " << name << ", " << size_name;
790
      float* g_size = var_size->GetMutable<phi::DenseTensor>()->data<float>();
T
tangwei12 已提交
791 792 793 794 795 796 797 798 799

      for (int k = 0; k < tensor->numel(); k += emb_dim) {
        g[k] = g[k] + g_size[k] * log(decay);
      }
      paddle::distributed::Region reg(g, tensor->numel());
      regions.emplace_back(std::move(reg));
    } else {
      Variable* var = scope->FindVar(name);
      CHECK(var != nullptr) << "var[" << name << "] not found";
800
      phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
T
tangwei12 已提交
801 802 803 804 805
      float* g = tensor->data<float>();
      paddle::distributed::Region reg(g, tensor->numel());
      regions.emplace_back(std::move(reg));
    }
  }
Z
zhaocaibei123 已提交
806
  auto push_status = pserver_ptr_->_worker_ptr->PushDenseParam(
T
tangwei12 已提交
807 808 809 810 811 812 813 814 815 816 817 818
      regions.data(), regions.size(), table_id);
  push_status.wait();
  auto status = push_status.get();
  if (status != 0) {
    // PADDLE_THORW(platform::errors::Fatal(
    //    "push shrink dense param failed, status is [%d].", status));
    sleep(sleep_seconds_before_fail_exit_);
    exit(-1);
  }
}

void FleetWrapper::ClientFlush() {
819 820 821 822
  if (worker_ptr_.get() == nullptr) {
    VLOG(0) << "worker_ptr null, do nothing";
    return;
  }
Z
zhaocaibei123 已提交
823
  auto ret = worker_ptr_->Flush();
T
tangwei12 已提交
824
  ret.wait();
825 826 827 828
  int32_t err_code = ret.get();
  if (err_code == -1) {
    LOG(ERROR) << "Client Flush failed";
  }
T
tangwei12 已提交
829 830 831 832
}

int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
                                                   MsgHandlerFunc handler) {
833 834
  if (worker_ptr_.get() == nullptr) {
    VLOG(0) << "FleetWrapper::Client is null";
Z
zhaocaibei123 已提交
835 836
    return -1;
  } else {
Z
zhaocaibei123 已提交
837
    return worker_ptr_->RegisteClient2ClientMsgHandler(msg_type, handler);
Z
zhaocaibei123 已提交
838
  }
T
tangwei12 已提交
839 840 841 842
}

std::future<int32_t> FleetWrapper::SendClientToClientMsg(
    int msg_type, int to_client_id, const std::string& msg) {
Z
zhaocaibei123 已提交
843
  return worker_ptr_->SendClient2ClientMsg(msg_type, to_client_id, msg);
T
tangwei12 已提交
844 845
}

Z
zhaocaibei123 已提交
846 847 848 849 850 851 852 853 854 855 856 857 858 859
double FleetWrapper::GetCacheThreshold(int table_id) {
  double cache_threshold = 0.0;
  auto ret = worker_ptr_->Flush();
  ret.wait();
  ret = worker_ptr_->GetCacheThreshold(table_id, cache_threshold);
  ret.wait();
  if (cache_threshold < 0) {
    LOG(ERROR) << "get cache threshold failed";
    sleep(sleep_seconds_before_fail_exit_);
    exit(-1);
  }
  return cache_threshold;
}

860 861 862 863 864 865
void FleetWrapper::CacheShuffle(int table_id,
                                const std::string& path,
                                const int mode,
                                const double cache_threshold) {
  auto ret = worker_ptr_->CacheShuffle(
      table_id, path, std::to_string(mode), std::to_string(cache_threshold));
Z
zhaocaibei123 已提交
866 867 868 869 870 871 872 873 874
  ret.wait();
  int32_t feasign_cnt = ret.get();
  if (feasign_cnt == -1) {
    LOG(ERROR) << "cache shuffle failed";
    sleep(sleep_seconds_before_fail_exit_);
    exit(-1);
  }
}

875 876
int32_t FleetWrapper::SaveCache(int table_id,
                                const std::string& path,
Z
zhaocaibei123 已提交
877 878 879 880 881 882 883 884 885 886 887 888
                                const int mode) {
  auto ret = worker_ptr_->SaveCache(table_id, path, std::to_string(mode));
  ret.wait();
  int32_t feasign_cnt = ret.get();
  if (feasign_cnt == -1) {
    LOG(ERROR) << "table save cache failed";
    sleep(sleep_seconds_before_fail_exit_);
    exit(-1);
  }
  return feasign_cnt;
}

Z
zhaocaibei123 已提交
889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906
void FleetWrapper::Revert() {
  auto ret = worker_ptr_->Revert();
  ret.wait();
  if (ret.get() == -1) {
    LOG(ERROR) << "table revert failed";
    exit(-1);
  }
}

void FleetWrapper::CheckSavePrePatchDone() {
  auto ret = worker_ptr_->CheckSavePrePatchDone();
  ret.wait();
  if (ret.get() == -1) {
    LOG(ERROR) << "table revert failed";
    exit(-1);
  }
}

T
tangwei12 已提交
907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
  struct engine_wrapper_t {
    std::default_random_engine engine;

    engine_wrapper_t() {
      struct timespec tp;
      clock_gettime(CLOCK_REALTIME, &tp);
      double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
      static std::atomic<uint64_t> x(0);
      std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)};
      engine.seed(sseq);
    }
  };
  thread_local engine_wrapper_t r;
  return r.engine;
}

924 925 926
size_t FleetWrapper::GetAbsoluteSum(size_t start,
                                    size_t end,
                                    size_t level,
T
tangwei12 已提交
927 928 929 930 931 932 933 934 935 936 937 938 939 940 941
                                    const framework::LoD& lod) {
  if (level >= lod.size() - 1) {
    return end - start;
  }
  size_t ret = 0;
  for (size_t i = start; i < end - 1; ++i) {
    size_t pos1 = lod[level][i];
    size_t pos2 = lod[level][i + 1];
    ret += GetAbsoluteSum(pos1, pos2, level + 1, lod);
  }
  return ret;
}

}  // end namespace distributed
}  // end namespace paddle