brpc_ps_server.cc 30.6 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
T
tangwei12 已提交
16
#include <thread>  // NOLINT
T
Thunderbrook 已提交
17
#include "butil/object_pool.h"
18
#include "paddle/fluid/distributed/common/cost_timer.h"
19 20
#include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/ps/table/table.h"
T
tangwei12 已提交
21 22 23
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"

24 25 26 27 28 29 30
namespace google {
namespace protobuf {
class Closure;
class RpcController;
}  // namespace protobuf
}  // namespace google

Z
zhaocaibei123 已提交
31 32 33 34 35 36 37
DEFINE_int32(pserver_timeout_ms_s2s, 10000,
             "pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms_s2s, 10000,
             "pserver connect server timeout_ms");
DEFINE_string(pserver_connection_type_s2s, "pooled",
              "pserver connection_type[pooled:single]");

T
tangwei12 已提交
38 39 40
namespace paddle {
namespace distributed {

Z
zhaocaibei123 已提交
41
int32_t BrpcPsServer::Initialize() {
T
tangwei12 已提交
42 43 44 45 46
  auto &service_config = _config.downpour_server_param().service_param();
  if (!service_config.has_service_class()) {
    LOG(ERROR) << "miss service_class in ServerServiceParameter";
    return -1;
  }
T
tangwei12 已提交
47 48
  auto *service =
      CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
T
tangwei12 已提交
49 50 51 52 53 54 55
  if (service == NULL) {
    LOG(ERROR) << "service is unregistered, service_name:"
               << service_config.service_class();
    return -1;
  }

  _service.reset(service);
Z
zhaocaibei123 已提交
56
  if (service->Configure(this) != 0 || service->Initialize() != 0) {
T
tangwei12 已提交
57 58 59 60 61 62 63 64 65 66 67 68
    LOG(ERROR) << "service initialize failed, service_name:"
               << service_config.service_class();
    return -1;
  }
  if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
    LOG(ERROR) << "service add to brpc failed, service:"
               << service_config.service_class();
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
69
uint64_t BrpcPsServer::Start(const std::string &ip, uint32_t port) {
T
tangwei12 已提交
70 71 72
  std::unique_lock<std::mutex> lock(mutex_);

  std::string ip_port = ip + ":" + std::to_string(port);
T
tangwei12 已提交
73 74
  VLOG(0) << "running server with rank id: " << _rank
          << ", endpoint: " << ip_port;
T
tangwei12 已提交
75
  brpc::ServerOptions options;
T
tangwei12 已提交
76 77

  int num_threads = std::thread::hardware_concurrency();
Z
zhaocaibei123 已提交
78
  auto trainers = _environment->GetTrainers();
T
tangwei12 已提交
79
  options.num_threads = trainers > num_threads ? trainers : num_threads;
T
tangwei12 已提交
80 81

  if (_server.Start(ip_port.c_str(), &options) != 0) {
82 83 84 85 86 87 88 89 90
    VLOG(0) << "BrpcPsServer start failed, ip_port= " << ip_port
            << " , Try Again.";

    std::string int_ip_port = GetIntTypeEndpoint(ip, port);

    if (_server.Start(int_ip_port.c_str(), &options) != 0) {
      LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
      return 0;
    }
T
tangwei12 已提交
91
  }
92

Z
zhaocaibei123 已提交
93
  _environment->RegistePsServer(ip, port, _rank);
T
tangwei12 已提交
94 95 96 97 98 99 100 101 102
  cv_.wait(lock, [&] { return stoped_; });

  PSHost host;
  host.ip = ip;
  host.port = port;
  host.rank = _rank;
  return host.rank;
}

Z
zhaocaibei123 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
int32_t BrpcPsServer::StartS2S() {
  brpc::ChannelOptions options;
  options.protocol = "baidu_std";
  options.timeout_ms = FLAGS_pserver_timeout_ms_s2s;
  options.connection_type = FLAGS_pserver_connection_type_s2s;
  options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms_s2s;
  options.max_retry = 3;

  std::vector<PSHost> pserver_list = _environment->GetPsServers();
  _pserver_channels.resize(pserver_list.size());
  VLOG(2) << "pserver start s2s server_list size: " << _pserver_channels.size();

  std::ostringstream os;
  std::string server_ip_port;

  for (size_t i = 0; i < pserver_list.size(); ++i) {
    server_ip_port.assign(pserver_list[i].ip.c_str());
    server_ip_port.append(":");
    server_ip_port.append(std::to_string(pserver_list[i].port));
    _pserver_channels[i].reset(new brpc::Channel());
    if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
      LOG(ERROR) << "pserver connect to pserver:" << server_ip_port
                 << " Failed!";
    }
    os << server_ip_port << ",";
  }
  LOG(INFO) << "pserver connect success: " << os.str();
  return 0;
}

std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg(
    int msg_type, int to_pserver_id, const std::string &msg) {
  auto promise = std::make_shared<std::promise<int32_t>>();
  std::future<int> fut = promise->get_future();
  if (to_pserver_id >= _pserver_channels.size()) {
    LOG(FATAL) << "to_pserver_id is out of range pservers, which size is "
               << _pserver_channels.size();
    promise->set_value(-1);
    return fut;
  }
  auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) {
    auto *closure = (DownpourPServerBrpcClosure *)done;
    int32_t ret = closure->check_response(0, msg_type + 1000);
    closure->set_promise_value(ret);
  });

  closure->add_promise(promise);
  closure->request(0)->set_cmd_id(101);
  closure->request(0)->set_client_id(_rank);
  closure->request(0)->set_table_id(0);
  closure->request(0)->set_data(msg);
  PsService_Stub rpc_stub(_pserver_channels[to_pserver_id].get());
  rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
                   closure);
  return fut;
}

int32_t BrpcPsServer::ReceiveFromPServer(int msg_type, int pserver_id,
                                         const std::string &msg) {
  if (msg.length() == 0) {
    LOG(WARNING) << "SERVER>>RESPONSE>>msg = 0 Finish S2S Response";
    return 0;
  }
  paddle::framework::BinaryArchive ar;
  ar.SetReadBuffer(const_cast<char *>(msg.c_str()), msg.length(), nullptr);
  if (ar.Cursor() == ar.Finish()) {
    LOG(WARNING) << "SERVER>>RESPONSE ar = 0>> Finish S2S Response";
    return 0;
  }
  std::vector<std::pair<uint64_t, std::string>> data;
  while (ar.Cursor() < ar.Finish()) {
    data.push_back(ar.Get<std::pair<uint64_t, std::string>>());
  }
  CHECK(ar.Cursor() == ar.Finish());
  this->_shuffled_ins->Write(std::move(data));
  return 0;
}

Z
zhaocaibei123 已提交
181
int32_t BrpcPsServer::Port() { return _server.listen_address().port; }
T
tangwei12 已提交
182

Z
zhaocaibei123 已提交
183
int32_t BrpcPsService::Initialize() {
T
tangwei12 已提交
184
  _is_initialize_shard_info = false;
Z
zhaocaibei123 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
  _service_handler_map[PS_STOP_SERVER] = &BrpcPsService::StopServer;
  _service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::PullDense;
  _service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::PushDense;
  _service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::PullSparse;
  _service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::PushSparse;
  _service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::SaveOneTable;
  _service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::SaveAllTable;
  _service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::ShrinkTable;
  _service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::LoadOneTable;
  _service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::LoadAllTable;
  _service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::ClearOneTable;
  _service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::ClearAllTable;
  _service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::PushDenseParam;
  _service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::PrintTableStat;
  _service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::PullGeoParam;
  _service_handler_map[PS_PUSH_SPARSE_PARAM] = &BrpcPsService::PushSparseParam;
  _service_handler_map[PS_BARRIER] = &BrpcPsService::Barrier;
  _service_handler_map[PS_START_PROFILER] = &BrpcPsService::StartProfiler;
  _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
  _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
Z
zhaocaibei123 已提交
205 206 207 208 209 210 211 212
  // for save cache

  _service_handler_map[PS_SAVE_ONE_CACHE_TABLE] =
      &BrpcPsService::SaveCacheTable;
  _service_handler_map[PS_GET_CACHE_THRESHOLD] =
      &BrpcPsService::GetCacheThreshold;
  _service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle;

213 214 215 216 217
  auto &profiler = CostProfiler::instance();
  profiler.register_profiler("pserver_server_pull_dense");
  profiler.register_profiler("pserver_server_push_dense");
  profiler.register_profiler("pserver_server_pull_sparse");
  profiler.register_profiler("pserver_server_push_sparse");
T
tangwei12 已提交
218 219

  // shard初始化,server启动后才可从env获取到server_list的shard信息
Z
zhaocaibei123 已提交
220
  InitializeShardInfo();
T
tangwei12 已提交
221 222 223 224 225 226 227 228 229 230 231 232

  return 0;
}

#define CHECK_TABLE_EXIST(table, request, response)        \
  if (table == NULL) {                                     \
    std::string err_msg("table not found with table_id:"); \
    err_msg.append(std::to_string(request.table_id()));    \
    set_response_code(response, -1, err_msg.c_str());      \
    return -1;                                             \
  }

Z
zhaocaibei123 已提交
233
int32_t BrpcPsService::InitializeShardInfo() {
T
tangwei12 已提交
234 235 236 237 238
  if (!_is_initialize_shard_info) {
    std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
    if (_is_initialize_shard_info) {
      return 0;
    }
Z
zhaocaibei123 已提交
239 240
    size_t shard_num = _server->Environment()->GetPsServers().size();
    auto &table_map = *(_server->GetTable());
T
tangwei12 已提交
241
    for (auto itr : table_map) {
Z
zhaocaibei123 已提交
242
      itr.second->SetShard(_rank, shard_num);
T
tangwei12 已提交
243 244 245 246 247 248
    }
    _is_initialize_shard_info = true;
  }
  return 0;
}

T
tangwei12 已提交
249 250 251 252
void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
                            const PsRequestMessage *request,
                            PsResponseMessage *response,
                            google::protobuf::Closure *done) {
T
tangwei12 已提交
253 254 255 256 257 258 259 260 261
  brpc::ClosureGuard done_guard(done);
  std::string log_label("ReceiveCmd-");
  if (!request->has_table_id()) {
    set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
    return;
  }

  response->set_err_code(0);
  response->set_err_msg("");
Z
zhaocaibei123 已提交
262
  auto *table = _server->GetTable(request->table_id());
T
tangwei12 已提交
263
  brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
Z
zhaocaibei123 已提交
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286

  if (request->cmd_id() < 100) {
    auto itr = _service_handler_map.find(request->cmd_id());
    if (itr == _service_handler_map.end()) {
      std::string err_msg(
          "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
      err_msg.append(std::to_string(request->cmd_id()));
      set_response_code(*response, -1, err_msg.c_str());
      return;
    }
    serviceHandlerFunc handler_func = itr->second;
    int service_ret = (this->*handler_func)(table, *request, *response, cntl);
    if (service_ret != 0) {
      response->set_err_code(service_ret);
      response->set_err_msg("server internal error");
    }
  } else {
    int service_ret = _server->HandlePServer2PServerMsg(
        request->cmd_id(), request->client_id(), request->data());
    if (service_ret != 0) {
      response->set_err_code(-1);
      response->set_err_msg("handle_pserver2pserver_msg failed");
    }
T
tangwei12 已提交
287 288 289
  }
}

Z
zhaocaibei123 已提交
290 291 292
int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request,
                                 PsResponseMessage &response,
                                 brpc::Controller *cntl) {
293
  platform::RecordEvent record_event(
Z
zhaocaibei123 已提交
294
      "PsService->PullDense", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
295 296 297 298 299 300 301
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 1) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 1 for num of dense");
    return 0;
  }
302
  CostTimer timer("pserver_server_pull_dense");
T
tangwei12 已提交
303 304
  uint32_t num = *(const uint32_t *)request.params(0).c_str();

T
Thunderbrook 已提交
305
  auto res_data = butil::get_object<std::vector<float>>();
306
  res_data->resize(num * table->ValueAccesor()->GetAccessorInfo().select_size /
307
                   sizeof(float));
Z
zhaocaibei123 已提交
308

309 310 311 312 313
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.pull_context.values = res_data->data();
  table_context.num = num;
  table->Pull(table_context);
Z
zhaocaibei123 已提交
314
  // table->PullDense(res_data->data(), num);
T
tangwei12 已提交
315

T
Thunderbrook 已提交
316 317 318
  cntl->response_attachment().append((char *)(res_data->data()),
                                     res_data->size() * sizeof(float));
  butil::return_object(res_data);
T
tangwei12 已提交
319 320 321 322

  return 0;
}

Z
zhaocaibei123 已提交
323 324 325 326 327 328
int32_t BrpcPsService::PushDenseParam(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
  platform::RecordEvent record_event(
      "PsService->PushDenseParam", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
  CHECK_TABLE_EXIST(table, request, response)
  thread_local std::string push_buffer;
  auto &req_io_buffer = cntl->request_attachment();
  auto req_buffer_size = req_io_buffer.size();
  if (req_buffer_size < 1) {
    set_response_code(response, -1, "req attachment is empty");
    return 0;
  }
  push_buffer.resize(0);
  push_buffer.reserve(req_buffer_size);
  const char *data = (const char *)cntl->request_attachment().fetch(
      const_cast<char *>(push_buffer.data()), req_buffer_size);

  uint32_t num = *(const uint32_t *)data;

  const float *values = (const float *)(data + sizeof(uint32_t));
345 346 347 348 349 350 351 352
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = values;
  table_context.push_context.is_param = true;
  table_context.num = num;

  //  if (table->PushDenseParam(values, num) != 0) {
  if (table->Push(table_context) != 0) {
Z
zhaocaibei123 已提交
353
    set_response_code(response, -1, "PushDenseParam failed");
T
tangwei12 已提交
354 355 356 357
  }
  return 0;
}

Z
zhaocaibei123 已提交
358 359 360
int32_t BrpcPsService::PushDense(Table *table, const PsRequestMessage &request,
                                 PsResponseMessage &response,
                                 brpc::Controller *cntl) {
361
  platform::RecordEvent record_event(
Z
zhaocaibei123 已提交
362
      "PsService->PushDense", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
363 364 365 366 367 368 369
  CHECK_TABLE_EXIST(table, request, response)
  auto req_buffer_size = request.data().size();
  if (req_buffer_size < 1) {
    // set_response_code(response, 0, "push dense data is empty");
    return 0;
  }

370
  CostTimer timer("pserver_server_push_dense");
T
tangwei12 已提交
371 372 373 374 375 376
  /*
  Push Content:
  |--num--|---valuesData---|
  |--4B---|----------------|
  */
  uint32_t num = *(const uint32_t *)(request.data().data());
377 378 379
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values =
T
tangwei12 已提交
380
      (const float *)(request.data().data() + sizeof(uint32_t));
381 382 383 384
  table_context.num = num;
  // const float *values = (const float *)(request.data().data() +
  // sizeof(uint32_t));
  if (table->Push(table_context) != 0) {
Z
zhaocaibei123 已提交
385 386
    // if (table->PushDense(values, num) != 0) {
    set_response_code(response, -1, "PushDense failed");
T
tangwei12 已提交
387 388 389 390 391
  }

  return 0;
}

Z
zhaocaibei123 已提交
392
int32_t BrpcPsService::Barrier(Table *table, const PsRequestMessage &request,
T
tangwei12 已提交
393 394
                               PsResponseMessage &response,
                               brpc::Controller *cntl) {
T
tangwei12 已提交
395 396 397 398 399 400 401 402 403 404 405
  CHECK_TABLE_EXIST(table, request, response)

  if (request.params_size() < 1) {
    set_response_code(response, -1,
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }

  auto trainer_id = request.client_id();
  auto barrier_type = request.params(0);
Z
zhaocaibei123 已提交
406
  table->Barrier(trainer_id, barrier_type);
T
tangwei12 已提交
407 408 409
  return 0;
}

Z
zhaocaibei123 已提交
410 411 412 413 414
int32_t BrpcPsService::PushSparseParam(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
  platform::RecordEvent record_event("PsService->PushSparseParam",
415 416
                                     platform::TracerEventType::Communication,
                                     1);
T
tangwei12 已提交
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437
  CHECK_TABLE_EXIST(table, request, response)
  auto &push_data = request.data();
  if (push_data.size() < 1) {
    // set_response_code(response, 0, "push sparse data is empty");
    return 0;
  }
  if (request.params_size() < 1) {
    set_response_code(response, -1,
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }
  uint32_t num = *(uint32_t *)(request.params(0).c_str());
  /*
  Push Content:
  |---keysData---|---valuesData---|
  |---8*{num}B---|----------------|
  */
  const uint64_t *keys = (const uint64_t *)push_data.data();
  const float *values =
      (const float *)(push_data.data() + sizeof(uint64_t) * num);
438 439 440 441 442 443 444 445 446

  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = keys;
  table_context.push_context.values = values;
  table_context.push_context.is_param = true;
  table_context.num = num;
  //  if (table->PushSparseParam(keys, values, num) != 0) {
  if (table->Push(table_context) != 0) {
Z
zhaocaibei123 已提交
447
    set_response_code(response, -1, "PushSparseParam error");
T
tangwei12 已提交
448 449 450 451
  }
  return 0;
}

Z
zhaocaibei123 已提交
452 453 454 455
int32_t BrpcPsService::PullGeoParam(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
456 457
  platform::RecordEvent record_event(
      "PsService->pull_geo_param", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
458 459 460 461 462 463 464
  CHECK_TABLE_EXIST(table, request, response)
  thread_local std::string push_sparse_request_buffer;

  auto trainer_id = request.client_id();

  std::vector<float> values;
  std::vector<uint64_t> ids;
465 466 467 468 469 470 471 472

  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.pull_context.geo_pull_keys = &ids;
  table_context.pull_context.geo_pull_values = &values;
  table_context.trainer_id = trainer_id;
  table->Pull(table_context);
  //  table->PullGeoParam(trainer_id, &values, &ids);
T
tangwei12 已提交
473 474 475 476 477 478 479 480 481 482

  uint32_t num = ids.size();
  cntl->response_attachment().append((char *)(&num), sizeof(uint32_t));
  cntl->response_attachment().append((char *)ids.data(),
                                     ids.size() * sizeof(uint64_t));
  cntl->response_attachment().append((char *)values.data(),
                                     values.size() * sizeof(float));
  return 0;
}

Z
zhaocaibei123 已提交
483 484 485
int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl) {
486
  platform::RecordEvent record_event(
Z
zhaocaibei123 已提交
487
      "PsService->PullSparse", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
488
  CHECK_TABLE_EXIST(table, request, response)
489

T
tangwei12 已提交
490 491
  auto &req_io_buffer = cntl->request_attachment();
  auto req_buffer_size = req_io_buffer.size();
492

T
tangwei12 已提交
493 494 495 496
  if (req_buffer_size < 1) {
    set_response_code(response, -1, "req attachment is empty");
    return 0;
  }
497

T
tangwei12 已提交
498 499 500 501 502 503
  if (request.params_size() < 1) {
    set_response_code(response, -1,
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }
504

505
  CostTimer timer("pserver_server_pull_sparse");
T
tangwei12 已提交
506
  uint32_t num = *(uint32_t *)(request.params(0).c_str());
507
  auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim;
508 509 510 511 512 513 514 515 516 517 518

  thread_local std::string req_buffer;
  req_buffer.reserve(req_buffer_size);

  const void *data = cntl->request_attachment().fetch(
      const_cast<char *>(req_buffer.data()), req_buffer_size);

  auto value = PullSparseValue(num, dim);

  value.DeserializeFromBytes(const_cast<void *>(data));

T
Thunderbrook 已提交
519 520
  auto res_data = butil::get_object<std::vector<float>>();
  res_data->resize(num * dim);
521 522 523 524 525
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.pull_context.pull_value = value;
  table_context.pull_context.values = res_data->data();
  table->Pull(table_context);
Z
zhaocaibei123 已提交
526
  // table->PullSparse(res_data->data(), value);
527

T
Thunderbrook 已提交
528 529 530
  cntl->response_attachment().append((char *)(res_data->data()),
                                     res_data->size() * sizeof(float));
  butil::return_object(res_data);
T
tangwei12 已提交
531 532 533
  return 0;
}

Z
zhaocaibei123 已提交
534 535 536
int32_t BrpcPsService::PushSparse(Table *table, const PsRequestMessage &request,
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl) {
537
  platform::RecordEvent record_event(
Z
zhaocaibei123 已提交
538
      "PsService->PushSparse", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
539 540 541 542 543 544 545 546 547 548 549 550
  CHECK_TABLE_EXIST(table, request, response)
  auto &push_data = request.data();
  if (push_data.size() < 1) {
    // set_response_code(response, 0, "push sparse data is empty");
    return 0;
  }
  if (request.params_size() < 1) {
    set_response_code(response, -1,
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }
551
  CostTimer timer("pserver_server_push_sparse");
T
tangwei12 已提交
552 553 554 555 556 557
  uint32_t num = *(uint32_t *)(request.params(0).c_str());
  /*
  Push Content:
  |---keysData---|---valuesData---|
  |---8*{num}B---|----------------|
  */
558 559 560 561
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = (const uint64_t *)push_data.data();
  table_context.push_context.values =
T
tangwei12 已提交
562
      (const float *)(push_data.data() + sizeof(uint64_t) * num);
563 564 565 566 567
  table_context.num = num;
  // const uint64_t *keys = (const uint64_t *)push_data.data();
  // const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
  // num);
  if (table->Push(table_context) != 0) {
Z
zhaocaibei123 已提交
568 569
    // if (table->PushSparse(keys, values, num) != 0) {
    set_response_code(response, -1, "PushSparse error");
T
tangwei12 已提交
570 571 572 573
  }
  return 0;
}

Z
zhaocaibei123 已提交
574 575 576 577
int32_t BrpcPsService::PrintTableStat(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
T
tangwei12 已提交
578
  CHECK_TABLE_EXIST(table, request, response)
Z
zhaocaibei123 已提交
579
  std::pair<int64_t, int64_t> ret = table->PrintTableStat();
T
tangwei12 已提交
580 581 582 583 584 585 586 587
  paddle::framework::BinaryArchive ar;
  ar << ret.first << ret.second;
  std::string table_info(ar.Buffer(), ar.Length());
  response.set_data(table_info);

  return 0;
}

Z
zhaocaibei123 已提交
588 589 590 591
int32_t BrpcPsService::LoadOneTable(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
T
tangwei12 已提交
592 593 594 595 596 597 598
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 2) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 2 for path & load_param");
    return -1;
  }
Z
zhaocaibei123 已提交
599
  if (table->Load(request.params(0), request.params(1)) != 0) {
T
tangwei12 已提交
600 601 602 603 604 605
    set_response_code(response, -1, "table load failed");
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
606 607 608 609 610
int32_t BrpcPsService::LoadAllTable(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
  auto &table_map = *(_server->GetTable());
T
tangwei12 已提交
611
  for (auto &itr : table_map) {
Z
zhaocaibei123 已提交
612
    if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
T
tangwei12 已提交
613 614 615 616 617 618 619
      LOG(ERROR) << "load table[" << itr.first << "] failed";
      return -1;
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
620 621 622 623
int32_t BrpcPsService::SaveOneTable(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
T
tangwei12 已提交
624 625 626 627 628 629 630
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 2) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 2, path&mode");
    return -1;
  }
Z
zhaocaibei123 已提交
631
  table->Flush();
T
tangwei12 已提交
632 633

  int32_t feasign_size = 0;
634

635
  VLOG(3) << "save table " << request.params(0) << " " << request.params(1);
Z
zhaocaibei123 已提交
636
  feasign_size = table->Save(request.params(0), request.params(1));
T
tangwei12 已提交
637 638 639 640 641 642 643
  if (feasign_size < 0) {
    set_response_code(response, -1, "table save failed");
    return -1;
  }
  return feasign_size;
}

Z
zhaocaibei123 已提交
644 645 646 647 648
int32_t BrpcPsService::SaveAllTable(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
  auto &table_map = *(_server->GetTable());
T
tangwei12 已提交
649 650 651 652
  int32_t all_feasign_size = 0;
  int32_t feasign_size = 0;

  for (auto &itr : table_map) {
Z
zhaocaibei123 已提交
653
    feasign_size = SaveOneTable(itr.second.get(), request, response, cntl);
T
tangwei12 已提交
654 655 656 657 658 659 660 661
    if (feasign_size < 0) {
      LOG(ERROR) << "save table[" << itr.first << "] failed";
      return -1;
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
int32_t BrpcPsService::SaveCacheTable(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 2) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 3, path&mode");
    return -1;
  }
  table->Flush();
  int32_t feasign_size = 0;
  // if (_server->_shuffled_ins->size() <= 0) {
  //    LOG(WARNING) << "shuffled ins size <= 0";
  //}
  feasign_size = table->SaveCache(request.params(0), request.params(1),
                                  _server->_shuffled_ins);
  if (feasign_size < 0) {
    set_response_code(response, -1, "table save failed");
    return -1;
  }
  return feasign_size;
}

int32_t BrpcPsService::CacheShuffle(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
  // start cache shuffle
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 3) {
    set_response_code(response, -1,
                      "PsRequestMessage.datas is requeired at least 3, "
                      "path&mode&cache_threshold");
    return -1;
  }
  table->Flush();
  double cache_threshold = std::stod(request.params(2));
  LOG(INFO) << "cache threshold for cache shuffle: " << cache_threshold;
  //    auto shuffled_ins = paddle::ps::make_channel<std::pair<uint64_t,
  //    std::string>>();
  //    shuffled_ins->set_block_size(80000);
  _server->StartS2S();
  std::function<std::future<int32_t>(int msg_type, int to_pserver_id,
                                     const std::string &msg)>
      send_msg_func = [this](int msg_type, int to_pserver_id,
                             const std::string &msg) -> std::future<int32_t> {
    return this->_server->SendPServer2PServerMsg(msg_type, to_pserver_id, msg);
  };

  std::vector<Table *> table_ptrs;
  for (size_t i = 3; i < request.params_size(); ++i) {
    int table_id = std::stoi(request.params(i));
    Table *table_ptr = _server->GetTable(table_id);
    table_ptrs.push_back(table_ptr);
  }
  if (table_ptrs.empty()) {
    table_ptrs.push_back(table);
  }

  table->CacheShuffle(request.params(0), request.params(1), cache_threshold,
                      send_msg_func, _server->_shuffled_ins, table_ptrs);
  return 0;
}

int32_t BrpcPsService::GetCacheThreshold(Table *table,
                                         const PsRequestMessage &request,
                                         PsResponseMessage &response,
                                         brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
  table->Flush();
  double cache_threshold = 0.0;
  cache_threshold = table->GetCacheThreshold();
  if (cache_threshold < 0) {
    LOG(WARNING) << "wrong threshold: " << cache_threshold;
  }
  std::stringstream ss;
  ss << std::setprecision(15) << cache_threshold;
  std::string cache_threshold_str = ss.str();
  response.set_data(cache_threshold_str);
  return 0;
}

Z
zhaocaibei123 已提交
746 747 748 749
int32_t BrpcPsService::ShrinkTable(Table *table,
                                   const PsRequestMessage &request,
                                   PsResponseMessage &response,
                                   brpc::Controller *cntl) {
T
tangwei12 已提交
750
  CHECK_TABLE_EXIST(table, request, response)
751 752 753 754 755 756
  if (request.params_size() < 1) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 1, threshold");
    return -1;
  }
Z
zhaocaibei123 已提交
757 758
  table->Flush();
  if (table->Shrink(request.params(0)) != 0) {
T
tangwei12 已提交
759
    set_response_code(response, -1, "table shrink failed");
760
    return -1;
T
tangwei12 已提交
761
  }
762
  VLOG(3) << "Pserver Shrink Finished";
T
tangwei12 已提交
763 764 765
  return 0;
}

Z
zhaocaibei123 已提交
766 767 768 769
int32_t BrpcPsService::ClearOneTable(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
T
tangwei12 已提交
770
  CHECK_TABLE_EXIST(table, request, response)
Z
zhaocaibei123 已提交
771 772
  table->Flush();
  table->Clear();
T
tangwei12 已提交
773 774 775
  return 0;
}

Z
zhaocaibei123 已提交
776 777 778 779 780
int32_t BrpcPsService::ClearAllTable(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
  auto &table_map = *(_server->GetTable());
T
tangwei12 已提交
781
  for (auto &itr : table_map) {
Z
zhaocaibei123 已提交
782
    if (ClearOneTable(itr.second.get(), request, response, cntl) != 0) {
T
tangwei12 已提交
783 784 785 786 787 788
      return -1;
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
789 790 791
int32_t BrpcPsService::StopServer(Table *table, const PsRequestMessage &request,
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl) {
T
tangwei12 已提交
792 793
  auto *p_server = _server;
  std::thread t_stop([p_server]() {
Z
zhaocaibei123 已提交
794
    p_server->Stop();
T
tangwei12 已提交
795
    VLOG(3) << "Server Stoped";
T
tangwei12 已提交
796 797 798 799 800
  });
  t_stop.detach();
  return 0;
}

Z
zhaocaibei123 已提交
801 802 803 804
int32_t BrpcPsService::StopProfiler(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
T
tangwei12 已提交
805 806 807 808 809
  platform::DisableProfiler(platform::EventSortingKey::kDefault,
                            string::Sprintf("server_%s_profile", _rank));
  return 0;
}

Z
zhaocaibei123 已提交
810 811 812 813
int32_t BrpcPsService::StartProfiler(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
T
tangwei12 已提交
814 815 816 817
  platform::EnableProfiler(platform::ProfilerState::kCPU);
  return 0;
}

Z
zhaocaibei123 已提交
818 819 820 821
int32_t BrpcPsService::PushGlobalStep(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
822 823 824 825 826 827 828 829 830 831
  CHECK_TABLE_EXIST(table, request, response);
  auto req_buffer_size = request.data().size();
  if (req_buffer_size < 1) {
    set_response_code(response, 0, "run_program data is empty");
    return 0;
  }
  uint32_t num = *(const uint32_t *)(request.data().data());
  const int64_t *values =
      (const int64_t *)(request.data().data() + sizeof(uint32_t));
  auto trainer_id = request.client_id();
832 833 834 835 836 837 838

  TableContext context;
  context.trainer_id = trainer_id;
  context.push_context.push_steps = values;

  //  if (table->PushDense(values, trainer_id) != 0) {
  if (table->Push(context) != 0) {
839 840 841 842 843 844
    set_response_code(response, -1, "run_program failed");
  }

  return 0;
}

T
tangwei12 已提交
845 846
}  // namespace distributed
}  // namespace paddle