brpc_ps_server.cc 23.8 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
T
tangwei12 已提交
16
#include <thread>  // NOLINT
T
Thunderbrook 已提交
17
#include "butil/object_pool.h"
18
#include "paddle/fluid/distributed/common/cost_timer.h"
19 20
#include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/ps/table/table.h"
T
tangwei12 已提交
21 22 23
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"

24 25 26 27 28 29 30
namespace google {
namespace protobuf {
class Closure;
class RpcController;
}  // namespace protobuf
}  // namespace google

T
tangwei12 已提交
31 32 33
namespace paddle {
namespace distributed {

Z
zhaocaibei123 已提交
34
int32_t BrpcPsServer::Initialize() {
T
tangwei12 已提交
35 36 37 38 39
  auto &service_config = _config.downpour_server_param().service_param();
  if (!service_config.has_service_class()) {
    LOG(ERROR) << "miss service_class in ServerServiceParameter";
    return -1;
  }
T
tangwei12 已提交
40 41
  auto *service =
      CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
T
tangwei12 已提交
42 43 44 45 46 47 48
  if (service == NULL) {
    LOG(ERROR) << "service is unregistered, service_name:"
               << service_config.service_class();
    return -1;
  }

  _service.reset(service);
Z
zhaocaibei123 已提交
49
  if (service->Configure(this) != 0 || service->Initialize() != 0) {
T
tangwei12 已提交
50 51 52 53 54 55 56 57 58 59 60 61
    LOG(ERROR) << "service initialize failed, service_name:"
               << service_config.service_class();
    return -1;
  }
  if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
    LOG(ERROR) << "service add to brpc failed, service:"
               << service_config.service_class();
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
62
uint64_t BrpcPsServer::Start(const std::string &ip, uint32_t port) {
T
tangwei12 已提交
63 64 65
  std::unique_lock<std::mutex> lock(mutex_);

  std::string ip_port = ip + ":" + std::to_string(port);
T
tangwei12 已提交
66 67
  VLOG(0) << "running server with rank id: " << _rank
          << ", endpoint: " << ip_port;
T
tangwei12 已提交
68
  brpc::ServerOptions options;
T
tangwei12 已提交
69 70

  int num_threads = std::thread::hardware_concurrency();
Z
zhaocaibei123 已提交
71
  auto trainers = _environment->GetTrainers();
T
tangwei12 已提交
72
  options.num_threads = trainers > num_threads ? trainers : num_threads;
T
tangwei12 已提交
73 74

  if (_server.Start(ip_port.c_str(), &options) != 0) {
75 76 77 78 79 80 81 82 83
    VLOG(0) << "BrpcPsServer start failed, ip_port= " << ip_port
            << " , Try Again.";

    std::string int_ip_port = GetIntTypeEndpoint(ip, port);

    if (_server.Start(int_ip_port.c_str(), &options) != 0) {
      LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
      return 0;
    }
T
tangwei12 已提交
84
  }
85

Z
zhaocaibei123 已提交
86
  _environment->RegistePsServer(ip, port, _rank);
T
tangwei12 已提交
87 88 89 90 91 92 93 94 95
  cv_.wait(lock, [&] { return stoped_; });

  PSHost host;
  host.ip = ip;
  host.port = port;
  host.rank = _rank;
  return host.rank;
}

Z
zhaocaibei123 已提交
96
int32_t BrpcPsServer::Port() { return _server.listen_address().port; }
T
tangwei12 已提交
97

Z
zhaocaibei123 已提交
98
int32_t BrpcPsService::Initialize() {
T
tangwei12 已提交
99
  _is_initialize_shard_info = false;
Z
zhaocaibei123 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
  _service_handler_map[PS_STOP_SERVER] = &BrpcPsService::StopServer;
  _service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::PullDense;
  _service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::PushDense;
  _service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::PullSparse;
  _service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::PushSparse;
  _service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::SaveOneTable;
  _service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::SaveAllTable;
  _service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::ShrinkTable;
  _service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::LoadOneTable;
  _service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::LoadAllTable;
  _service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::ClearOneTable;
  _service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::ClearAllTable;
  _service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::PushDenseParam;
  _service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::PrintTableStat;
  _service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::PullGeoParam;
  _service_handler_map[PS_PUSH_SPARSE_PARAM] = &BrpcPsService::PushSparseParam;
  _service_handler_map[PS_BARRIER] = &BrpcPsService::Barrier;
  _service_handler_map[PS_START_PROFILER] = &BrpcPsService::StartProfiler;
  _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
  _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
120 121 122 123 124
  auto &profiler = CostProfiler::instance();
  profiler.register_profiler("pserver_server_pull_dense");
  profiler.register_profiler("pserver_server_push_dense");
  profiler.register_profiler("pserver_server_pull_sparse");
  profiler.register_profiler("pserver_server_push_sparse");
T
tangwei12 已提交
125 126

  // shard初始化,server启动后才可从env获取到server_list的shard信息
Z
zhaocaibei123 已提交
127
  InitializeShardInfo();
T
tangwei12 已提交
128 129 130 131 132 133 134 135 136 137 138 139

  return 0;
}

#define CHECK_TABLE_EXIST(table, request, response)        \
  if (table == NULL) {                                     \
    std::string err_msg("table not found with table_id:"); \
    err_msg.append(std::to_string(request.table_id()));    \
    set_response_code(response, -1, err_msg.c_str());      \
    return -1;                                             \
  }

Z
zhaocaibei123 已提交
140
int32_t BrpcPsService::InitializeShardInfo() {
T
tangwei12 已提交
141 142 143 144 145
  if (!_is_initialize_shard_info) {
    std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
    if (_is_initialize_shard_info) {
      return 0;
    }
Z
zhaocaibei123 已提交
146 147
    size_t shard_num = _server->Environment()->GetPsServers().size();
    auto &table_map = *(_server->GetTable());
T
tangwei12 已提交
148
    for (auto itr : table_map) {
Z
zhaocaibei123 已提交
149
      itr.second->SetShard(_rank, shard_num);
T
tangwei12 已提交
150 151 152 153 154 155
    }
    _is_initialize_shard_info = true;
  }
  return 0;
}

T
tangwei12 已提交
156 157 158 159
void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
                            const PsRequestMessage *request,
                            PsResponseMessage *response,
                            google::protobuf::Closure *done) {
T
tangwei12 已提交
160 161 162 163 164 165 166 167 168
  brpc::ClosureGuard done_guard(done);
  std::string log_label("ReceiveCmd-");
  if (!request->has_table_id()) {
    set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
    return;
  }

  response->set_err_code(0);
  response->set_err_msg("");
Z
zhaocaibei123 已提交
169
  auto *table = _server->GetTable(request->table_id());
T
tangwei12 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
  brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
  auto itr = _service_handler_map.find(request->cmd_id());
  if (itr == _service_handler_map.end()) {
    std::string err_msg(
        "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
    err_msg.append(std::to_string(request->cmd_id()));
    set_response_code(*response, -1, err_msg.c_str());
    return;
  }
  serviceHandlerFunc handler_func = itr->second;
  int service_ret = (this->*handler_func)(table, *request, *response, cntl);
  if (service_ret != 0) {
    response->set_err_code(service_ret);
    response->set_err_msg("server internal error");
  }
}

Z
zhaocaibei123 已提交
187 188 189
int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request,
                                 PsResponseMessage &response,
                                 brpc::Controller *cntl) {
190
  platform::RecordEvent record_event(
Z
zhaocaibei123 已提交
191
      "PsService->PullDense", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
192 193 194 195 196 197 198
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 1) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 1 for num of dense");
    return 0;
  }
199
  CostTimer timer("pserver_server_pull_dense");
T
tangwei12 已提交
200 201 202 203 204 205 206
  uint32_t num = *(const uint32_t *)request.params(0).c_str();
  if (num < 0) {
    set_response_code(response, -1,
                      "PsRequestMessage.datas[0] is invalid, num must >= 0");
    return 0;
  }

T
Thunderbrook 已提交
207
  auto res_data = butil::get_object<std::vector<float>>();
208
  res_data->resize(num * table->ValueAccesor()->GetAccessorInfo().select_size /
209
                   sizeof(float));
Z
zhaocaibei123 已提交
210

211 212 213 214 215
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.pull_context.values = res_data->data();
  table_context.num = num;
  table->Pull(table_context);
Z
zhaocaibei123 已提交
216
  // table->PullDense(res_data->data(), num);
T
tangwei12 已提交
217

T
Thunderbrook 已提交
218 219 220
  cntl->response_attachment().append((char *)(res_data->data()),
                                     res_data->size() * sizeof(float));
  butil::return_object(res_data);
T
tangwei12 已提交
221 222 223 224

  return 0;
}

Z
zhaocaibei123 已提交
225 226 227 228 229 230
int32_t BrpcPsService::PushDenseParam(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
  platform::RecordEvent record_event(
      "PsService->PushDenseParam", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
  CHECK_TABLE_EXIST(table, request, response)
  thread_local std::string push_buffer;
  auto &req_io_buffer = cntl->request_attachment();
  auto req_buffer_size = req_io_buffer.size();
  if (req_buffer_size < 1) {
    set_response_code(response, -1, "req attachment is empty");
    return 0;
  }
  push_buffer.resize(0);
  push_buffer.reserve(req_buffer_size);
  const char *data = (const char *)cntl->request_attachment().fetch(
      const_cast<char *>(push_buffer.data()), req_buffer_size);

  uint32_t num = *(const uint32_t *)data;

  const float *values = (const float *)(data + sizeof(uint32_t));
247 248 249 250 251 252 253 254
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = values;
  table_context.push_context.is_param = true;
  table_context.num = num;

  //  if (table->PushDenseParam(values, num) != 0) {
  if (table->Push(table_context) != 0) {
Z
zhaocaibei123 已提交
255
    set_response_code(response, -1, "PushDenseParam failed");
T
tangwei12 已提交
256 257 258 259
  }
  return 0;
}

Z
zhaocaibei123 已提交
260 261 262
int32_t BrpcPsService::PushDense(Table *table, const PsRequestMessage &request,
                                 PsResponseMessage &response,
                                 brpc::Controller *cntl) {
263
  platform::RecordEvent record_event(
Z
zhaocaibei123 已提交
264
      "PsService->PushDense", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
265 266 267 268 269 270 271
  CHECK_TABLE_EXIST(table, request, response)
  auto req_buffer_size = request.data().size();
  if (req_buffer_size < 1) {
    // set_response_code(response, 0, "push dense data is empty");
    return 0;
  }

272
  CostTimer timer("pserver_server_push_dense");
T
tangwei12 已提交
273 274 275 276 277 278
  /*
  Push Content:
  |--num--|---valuesData---|
  |--4B---|----------------|
  */
  uint32_t num = *(const uint32_t *)(request.data().data());
279 280 281
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values =
T
tangwei12 已提交
282
      (const float *)(request.data().data() + sizeof(uint32_t));
283 284 285 286
  table_context.num = num;
  // const float *values = (const float *)(request.data().data() +
  // sizeof(uint32_t));
  if (table->Push(table_context) != 0) {
Z
zhaocaibei123 已提交
287 288
    // if (table->PushDense(values, num) != 0) {
    set_response_code(response, -1, "PushDense failed");
T
tangwei12 已提交
289 290 291 292 293
  }

  return 0;
}

Z
zhaocaibei123 已提交
294
int32_t BrpcPsService::Barrier(Table *table, const PsRequestMessage &request,
T
tangwei12 已提交
295 296
                               PsResponseMessage &response,
                               brpc::Controller *cntl) {
T
tangwei12 已提交
297 298 299 300 301 302 303 304 305 306 307
  CHECK_TABLE_EXIST(table, request, response)

  if (request.params_size() < 1) {
    set_response_code(response, -1,
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }

  auto trainer_id = request.client_id();
  auto barrier_type = request.params(0);
Z
zhaocaibei123 已提交
308
  table->Barrier(trainer_id, barrier_type);
T
tangwei12 已提交
309 310 311
  return 0;
}

Z
zhaocaibei123 已提交
312 313 314 315 316
int32_t BrpcPsService::PushSparseParam(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
  platform::RecordEvent record_event("PsService->PushSparseParam",
317 318
                                     platform::TracerEventType::Communication,
                                     1);
T
tangwei12 已提交
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
  CHECK_TABLE_EXIST(table, request, response)
  auto &push_data = request.data();
  if (push_data.size() < 1) {
    // set_response_code(response, 0, "push sparse data is empty");
    return 0;
  }
  if (request.params_size() < 1) {
    set_response_code(response, -1,
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }
  uint32_t num = *(uint32_t *)(request.params(0).c_str());
  /*
  Push Content:
  |---keysData---|---valuesData---|
  |---8*{num}B---|----------------|
  */
  const uint64_t *keys = (const uint64_t *)push_data.data();
  const float *values =
      (const float *)(push_data.data() + sizeof(uint64_t) * num);
340 341 342 343 344 345 346 347 348

  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = keys;
  table_context.push_context.values = values;
  table_context.push_context.is_param = true;
  table_context.num = num;
  //  if (table->PushSparseParam(keys, values, num) != 0) {
  if (table->Push(table_context) != 0) {
Z
zhaocaibei123 已提交
349
    set_response_code(response, -1, "PushSparseParam error");
T
tangwei12 已提交
350 351 352 353
  }
  return 0;
}

Z
zhaocaibei123 已提交
354 355 356 357
int32_t BrpcPsService::PullGeoParam(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
358 359
  platform::RecordEvent record_event(
      "PsService->pull_geo_param", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
360 361 362 363 364 365 366
  CHECK_TABLE_EXIST(table, request, response)
  thread_local std::string push_sparse_request_buffer;

  auto trainer_id = request.client_id();

  std::vector<float> values;
  std::vector<uint64_t> ids;
367 368 369 370 371 372 373 374

  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.pull_context.geo_pull_keys = &ids;
  table_context.pull_context.geo_pull_values = &values;
  table_context.trainer_id = trainer_id;
  table->Pull(table_context);
  //  table->PullGeoParam(trainer_id, &values, &ids);
T
tangwei12 已提交
375 376 377 378 379 380 381 382 383 384

  uint32_t num = ids.size();
  cntl->response_attachment().append((char *)(&num), sizeof(uint32_t));
  cntl->response_attachment().append((char *)ids.data(),
                                     ids.size() * sizeof(uint64_t));
  cntl->response_attachment().append((char *)values.data(),
                                     values.size() * sizeof(float));
  return 0;
}

Z
zhaocaibei123 已提交
385 386 387
int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl) {
388
  platform::RecordEvent record_event(
Z
zhaocaibei123 已提交
389
      "PsService->PullSparse", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
390
  CHECK_TABLE_EXIST(table, request, response)
391

T
tangwei12 已提交
392 393
  auto &req_io_buffer = cntl->request_attachment();
  auto req_buffer_size = req_io_buffer.size();
394

T
tangwei12 已提交
395 396 397 398
  if (req_buffer_size < 1) {
    set_response_code(response, -1, "req attachment is empty");
    return 0;
  }
399

T
tangwei12 已提交
400 401 402 403 404 405
  if (request.params_size() < 1) {
    set_response_code(response, -1,
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }
406

407
  CostTimer timer("pserver_server_pull_sparse");
T
tangwei12 已提交
408
  uint32_t num = *(uint32_t *)(request.params(0).c_str());
409
  auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim;
410 411 412 413 414 415 416 417 418 419 420

  thread_local std::string req_buffer;
  req_buffer.reserve(req_buffer_size);

  const void *data = cntl->request_attachment().fetch(
      const_cast<char *>(req_buffer.data()), req_buffer_size);

  auto value = PullSparseValue(num, dim);

  value.DeserializeFromBytes(const_cast<void *>(data));

T
Thunderbrook 已提交
421 422
  auto res_data = butil::get_object<std::vector<float>>();
  res_data->resize(num * dim);
423 424 425 426 427
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.pull_context.pull_value = value;
  table_context.pull_context.values = res_data->data();
  table->Pull(table_context);
Z
zhaocaibei123 已提交
428
  // table->PullSparse(res_data->data(), value);
429

T
Thunderbrook 已提交
430 431 432
  cntl->response_attachment().append((char *)(res_data->data()),
                                     res_data->size() * sizeof(float));
  butil::return_object(res_data);
T
tangwei12 已提交
433 434 435
  return 0;
}

Z
zhaocaibei123 已提交
436 437 438
int32_t BrpcPsService::PushSparse(Table *table, const PsRequestMessage &request,
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl) {
439
  platform::RecordEvent record_event(
Z
zhaocaibei123 已提交
440
      "PsService->PushSparse", platform::TracerEventType::Communication, 1);
T
tangwei12 已提交
441 442 443 444 445 446 447 448 449 450 451 452
  CHECK_TABLE_EXIST(table, request, response)
  auto &push_data = request.data();
  if (push_data.size() < 1) {
    // set_response_code(response, 0, "push sparse data is empty");
    return 0;
  }
  if (request.params_size() < 1) {
    set_response_code(response, -1,
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }
453
  CostTimer timer("pserver_server_push_sparse");
T
tangwei12 已提交
454 455 456 457 458 459
  uint32_t num = *(uint32_t *)(request.params(0).c_str());
  /*
  Push Content:
  |---keysData---|---valuesData---|
  |---8*{num}B---|----------------|
  */
460 461 462 463
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = (const uint64_t *)push_data.data();
  table_context.push_context.values =
T
tangwei12 已提交
464
      (const float *)(push_data.data() + sizeof(uint64_t) * num);
465 466 467 468 469
  table_context.num = num;
  // const uint64_t *keys = (const uint64_t *)push_data.data();
  // const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
  // num);
  if (table->Push(table_context) != 0) {
Z
zhaocaibei123 已提交
470 471
    // if (table->PushSparse(keys, values, num) != 0) {
    set_response_code(response, -1, "PushSparse error");
T
tangwei12 已提交
472 473 474 475
  }
  return 0;
}

Z
zhaocaibei123 已提交
476 477 478 479
int32_t BrpcPsService::PrintTableStat(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
T
tangwei12 已提交
480
  CHECK_TABLE_EXIST(table, request, response)
Z
zhaocaibei123 已提交
481
  std::pair<int64_t, int64_t> ret = table->PrintTableStat();
T
tangwei12 已提交
482 483 484 485 486 487 488 489
  paddle::framework::BinaryArchive ar;
  ar << ret.first << ret.second;
  std::string table_info(ar.Buffer(), ar.Length());
  response.set_data(table_info);

  return 0;
}

Z
zhaocaibei123 已提交
490 491 492 493
int32_t BrpcPsService::LoadOneTable(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
T
tangwei12 已提交
494 495 496 497 498 499 500
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 2) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 2 for path & load_param");
    return -1;
  }
Z
zhaocaibei123 已提交
501
  if (table->Load(request.params(0), request.params(1)) != 0) {
T
tangwei12 已提交
502 503 504 505 506 507
    set_response_code(response, -1, "table load failed");
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
508 509 510 511 512
int32_t BrpcPsService::LoadAllTable(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
  auto &table_map = *(_server->GetTable());
T
tangwei12 已提交
513
  for (auto &itr : table_map) {
Z
zhaocaibei123 已提交
514
    if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
T
tangwei12 已提交
515 516 517 518 519 520 521
      LOG(ERROR) << "load table[" << itr.first << "] failed";
      return -1;
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
522 523 524 525
int32_t BrpcPsService::SaveOneTable(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
T
tangwei12 已提交
526 527 528 529 530 531 532
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 2) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 2, path&mode");
    return -1;
  }
Z
zhaocaibei123 已提交
533
  table->Flush();
T
tangwei12 已提交
534 535

  int32_t feasign_size = 0;
536

537
  VLOG(3) << "save table " << request.params(0) << " " << request.params(1);
Z
zhaocaibei123 已提交
538
  feasign_size = table->Save(request.params(0), request.params(1));
T
tangwei12 已提交
539 540 541 542 543 544 545
  if (feasign_size < 0) {
    set_response_code(response, -1, "table save failed");
    return -1;
  }
  return feasign_size;
}

Z
zhaocaibei123 已提交
546 547 548 549 550
int32_t BrpcPsService::SaveAllTable(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
  auto &table_map = *(_server->GetTable());
T
tangwei12 已提交
551 552 553 554
  int32_t all_feasign_size = 0;
  int32_t feasign_size = 0;

  for (auto &itr : table_map) {
Z
zhaocaibei123 已提交
555
    feasign_size = SaveOneTable(itr.second.get(), request, response, cntl);
T
tangwei12 已提交
556 557 558 559 560 561 562 563
    if (feasign_size < 0) {
      LOG(ERROR) << "save table[" << itr.first << "] failed";
      return -1;
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
564 565 566 567
int32_t BrpcPsService::ShrinkTable(Table *table,
                                   const PsRequestMessage &request,
                                   PsResponseMessage &response,
                                   brpc::Controller *cntl) {
T
tangwei12 已提交
568
  CHECK_TABLE_EXIST(table, request, response)
569 570 571 572 573 574
  if (request.params_size() < 1) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 1, threshold");
    return -1;
  }
Z
zhaocaibei123 已提交
575 576
  table->Flush();
  if (table->Shrink(request.params(0)) != 0) {
T
tangwei12 已提交
577
    set_response_code(response, -1, "table shrink failed");
578
    return -1;
T
tangwei12 已提交
579
  }
580
  VLOG(3) << "Pserver Shrink Finished";
T
tangwei12 已提交
581 582 583
  return 0;
}

Z
zhaocaibei123 已提交
584 585 586 587
int32_t BrpcPsService::ClearOneTable(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
T
tangwei12 已提交
588
  CHECK_TABLE_EXIST(table, request, response)
Z
zhaocaibei123 已提交
589 590
  table->Flush();
  table->Clear();
T
tangwei12 已提交
591 592 593
  return 0;
}

Z
zhaocaibei123 已提交
594 595 596 597 598
int32_t BrpcPsService::ClearAllTable(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
  auto &table_map = *(_server->GetTable());
T
tangwei12 已提交
599
  for (auto &itr : table_map) {
Z
zhaocaibei123 已提交
600
    if (ClearOneTable(itr.second.get(), request, response, cntl) != 0) {
T
tangwei12 已提交
601 602 603 604 605 606
      return -1;
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
607 608 609
int32_t BrpcPsService::StopServer(Table *table, const PsRequestMessage &request,
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl) {
T
tangwei12 已提交
610 611
  auto *p_server = _server;
  std::thread t_stop([p_server]() {
Z
zhaocaibei123 已提交
612
    p_server->Stop();
T
tangwei12 已提交
613
    VLOG(3) << "Server Stoped";
T
tangwei12 已提交
614 615 616 617 618
  });
  t_stop.detach();
  return 0;
}

Z
zhaocaibei123 已提交
619 620 621 622
int32_t BrpcPsService::StopProfiler(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl) {
T
tangwei12 已提交
623 624 625 626 627
  platform::DisableProfiler(platform::EventSortingKey::kDefault,
                            string::Sprintf("server_%s_profile", _rank));
  return 0;
}

Z
zhaocaibei123 已提交
628 629 630 631
int32_t BrpcPsService::StartProfiler(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
T
tangwei12 已提交
632 633 634 635
  platform::EnableProfiler(platform::ProfilerState::kCPU);
  return 0;
}

Z
zhaocaibei123 已提交
636 637 638 639
int32_t BrpcPsService::PushGlobalStep(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
640 641 642 643 644 645 646 647 648 649
  CHECK_TABLE_EXIST(table, request, response);
  auto req_buffer_size = request.data().size();
  if (req_buffer_size < 1) {
    set_response_code(response, 0, "run_program data is empty");
    return 0;
  }
  uint32_t num = *(const uint32_t *)(request.data().data());
  const int64_t *values =
      (const int64_t *)(request.data().data() + sizeof(uint32_t));
  auto trainer_id = request.client_id();
650 651 652 653 654 655 656

  TableContext context;
  context.trainer_id = trainer_id;
  context.push_context.push_steps = values;

  //  if (table->PushDense(values, trainer_id) != 0) {
  if (table->Push(context) != 0) {
657 658 659 660 661 662
    set_response_code(response, -1, "run_program failed");
  }

  return 0;
}

T
tangwei12 已提交
663 664
}  // namespace distributed
}  // namespace paddle