graph_brpc_server.cc 27.9 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
S
seemingwang 已提交
16 17

#include <thread>  // NOLINT
S
seemingwang 已提交
18
#include <utility>
19

S
seemingwang 已提交
20 21
#include "butil/endpoint.h"
#include "iomanip"
22
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
23
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
S
seemingwang 已提交
24 25 26 27 28
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {

29 30 31 32 33 34 35 36
#define CHECK_TABLE_EXIST(table, request, response)        \
  if (table == NULL) {                                     \
    std::string err_msg("table not found with table_id:"); \
    err_msg.append(std::to_string(request.table_id()));    \
    set_response_code(response, -1, err_msg.c_str());      \
    return -1;                                             \
  }

Z
zhaocaibei123 已提交
37
int32_t GraphBrpcServer::Initialize() {
S
seemingwang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51
  auto &service_config = _config.downpour_server_param().service_param();
  if (!service_config.has_service_class()) {
    LOG(ERROR) << "miss service_class in ServerServiceParameter";
    return -1;
  }
  auto *service =
      CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
  if (service == NULL) {
    LOG(ERROR) << "service is unregistered, service_name:"
               << service_config.service_class();
    return -1;
  }

  _service.reset(service);
Z
zhaocaibei123 已提交
52
  if (service->Configure(this) != 0 || service->Initialize() != 0) {
S
seemingwang 已提交
53 54 55 56 57 58 59 60 61 62 63 64
    LOG(ERROR) << "service initialize failed, service_name:"
               << service_config.service_class();
    return -1;
  }
  if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
    LOG(ERROR) << "service add to brpc failed, service:"
               << service_config.service_class();
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
65
brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) {
S
seemingwang 已提交
66 67 68
  return _pserver_channels[server_index].get();
}

Z
zhaocaibei123 已提交
69
uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) {
S
seemingwang 已提交
70 71 72 73 74 75 76
  std::unique_lock<std::mutex> lock(mutex_);

  std::string ip_port = ip + ":" + std::to_string(port);
  VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
  brpc::ServerOptions options;

  int num_threads = std::thread::hardware_concurrency();
Z
zhaocaibei123 已提交
77
  auto trainers = _environment->GetTrainers();
S
seemingwang 已提交
78 79 80 81 82 83
  options.num_threads = trainers > num_threads ? trainers : num_threads;

  if (_server.Start(ip_port.c_str(), &options) != 0) {
    LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
    return 0;
  }
Z
zhaocaibei123 已提交
84
  _environment->RegistePsServer(ip, port, _rank);
S
seemingwang 已提交
85 86 87
  return 0;
}

S
seemingwang 已提交
88 89
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
  this->rank = rank;
Z
zhaocaibei123 已提交
90
  auto _env = Environment();
S
seemingwang 已提交
91 92 93 94 95 96 97
  brpc::ChannelOptions options;
  options.protocol = "baidu_std";
  options.timeout_ms = 500000;
  options.connection_type = "pooled";
  options.connect_timeout_ms = 10000;
  options.max_retry = 3;

Z
zhaocaibei123 已提交
98
  std::vector<PSHost> server_list = _env->GetPsServers();
S
seemingwang 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
  _pserver_channels.resize(server_list.size());
  std::ostringstream os;
  std::string server_ip_port;
  for (size_t i = 0; i < server_list.size(); ++i) {
    server_ip_port.assign(server_list[i].ip.c_str());
    server_ip_port.append(":");
    server_ip_port.append(std::to_string(server_list[i].port));
    _pserver_channels[i].reset(new brpc::Channel());
    if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
      VLOG(0) << "GraphServer connect to Server:" << server_ip_port
              << " Failed! Try again.";
      std::string int_ip_port =
          GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
      if (_pserver_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
        LOG(ERROR) << "GraphServer connect to Server:" << int_ip_port
                   << " Failed!";
        return -1;
      }
    }
    os << server_ip_port << ",";
  }
  LOG(INFO) << "servers peer2peer connection success:" << os.str();
  return 0;
}

124 125 126 127
int32_t GraphBrpcService::clear_nodes(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
128 129 130
  int type_id = *(int *)(request.params(0).c_str());
  int idx_ = *(int *)(request.params(1).c_str());
  ((GraphTable *)table)->clear_nodes(type_id, idx_);
131 132 133 134 135 136 137 138
  return 0;
}

int32_t GraphBrpcService::add_graph_node(Table *table,
                                         const PsRequestMessage &request,
                                         PsResponseMessage &response,
                                         brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
139 140 141
  if (request.params_size() < 2) {
    set_response_code(response, -1,
                      "add_graph_node request requires at least 2 arguments");
142 143 144
    return 0;
  }

145 146 147 148 149
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
150
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
151
  std::vector<bool> is_weighted_list;
152 153 154
  if (request.params_size() == 3) {
    size_t weight_list_size = request.params(2).size() / sizeof(bool);
    bool *is_weighted_buffer = (bool *)(request.params(2).c_str());
155 156 157
    is_weighted_list = std::vector<bool>(is_weighted_buffer,
                                         is_weighted_buffer + weight_list_size);
  }
158 159 160 161 162 163 164
  // if (request.params_size() == 2) {
  //   size_t weight_list_size = request.params(1).size() / sizeof(bool);
  //   bool *is_weighted_buffer = (bool *)(request.params(1).c_str());
  //   is_weighted_list = std::vector<bool>(is_weighted_buffer,
  //                                        is_weighted_buffer +
  //                                        weight_list_size);
  // }
165

166
  ((GraphTable *)table)->add_graph_node(idx_, node_ids, is_weighted_list);
167 168 169 170 171 172 173
  return 0;
}
int32_t GraphBrpcService::remove_graph_node(Table *table,
                                            const PsRequestMessage &request,
                                            PsResponseMessage &response,
                                            brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
174
  if (request.params_size() < 2) {
175 176
    set_response_code(
        response, -1,
177
        "remove_graph_node request requires at least 2 arguments");
178 179
    return 0;
  }
180 181 182 183 184
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
185
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
186

187
  ((GraphTable *)table)->remove_graph_node(idx_, node_ids);
188 189
  return 0;
}
Z
zhaocaibei123 已提交
190
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
S
seemingwang 已提交
191

Z
zhaocaibei123 已提交
192
int32_t GraphBrpcService::Initialize() {
S
seemingwang 已提交
193
  _is_initialize_shard_info = false;
Z
zhaocaibei123 已提交
194 195 196
  _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer;
  _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable;
  _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable;
S
seemingwang 已提交
197

Z
zhaocaibei123 已提交
198 199 200 201
  _service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat;
  _service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier;
  _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler;
  _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler;
S
seemingwang 已提交
202 203

  _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
204 205
  _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
      &GraphBrpcService::graph_random_sample_neighbors;
S
seemingwang 已提交
206 207 208 209
  _service_handler_map[PS_GRAPH_SAMPLE_NODES] =
      &GraphBrpcService::graph_random_sample_nodes;
  _service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
      &GraphBrpcService::graph_get_node_feat;
210 211 212 213 214
  _service_handler_map[PS_GRAPH_CLEAR] = &GraphBrpcService::clear_nodes;
  _service_handler_map[PS_GRAPH_ADD_GRAPH_NODE] =
      &GraphBrpcService::add_graph_node;
  _service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
      &GraphBrpcService::remove_graph_node;
S
seemingwang 已提交
215 216
  _service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
      &GraphBrpcService::graph_set_node_feat;
S
seemingwang 已提交
217
  _service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
218
      &GraphBrpcService::sample_neighbors_across_multi_servers;
219 220 221 222
  // _service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] =
  //     &GraphBrpcService::use_neighbors_sample_cache;
  // _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
  //     &GraphBrpcService::load_graph_split_config;
S
seemingwang 已提交
223
  // shard初始化,server启动后才可从env获取到server_list的shard信息
Z
zhaocaibei123 已提交
224
  InitializeShardInfo();
S
seemingwang 已提交
225 226 227 228

  return 0;
}

Z
zhaocaibei123 已提交
229
int32_t GraphBrpcService::InitializeShardInfo() {
S
seemingwang 已提交
230 231 232 233 234
  if (!_is_initialize_shard_info) {
    std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
    if (_is_initialize_shard_info) {
      return 0;
    }
Z
zhaocaibei123 已提交
235 236
    server_size = _server->Environment()->GetPsServers().size();
    auto &table_map = *(_server->GetTable());
S
seemingwang 已提交
237
    for (auto itr : table_map) {
Z
zhaocaibei123 已提交
238
      itr.second->SetShard(_rank, server_size);
S
seemingwang 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    }
    _is_initialize_shard_info = true;
  }
  return 0;
}

void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
                               const PsRequestMessage *request,
                               PsResponseMessage *response,
                               google::protobuf::Closure *done) {
  brpc::ClosureGuard done_guard(done);
  std::string log_label("ReceiveCmd-");
  if (!request->has_table_id()) {
    set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
    return;
  }

  response->set_err_code(0);
  response->set_err_msg("");
Z
zhaocaibei123 已提交
258
  auto *table = _server->GetTable(request->table_id());
S
seemingwang 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271
  brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
  auto itr = _service_handler_map.find(request->cmd_id());
  if (itr == _service_handler_map.end()) {
    std::string err_msg(
        "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
    err_msg.append(std::to_string(request->cmd_id()));
    set_response_code(*response, -1, err_msg.c_str());
    return;
  }
  serviceFunc handler_func = itr->second;
  int service_ret = (this->*handler_func)(table, *request, *response, cntl);
  if (service_ret != 0) {
    response->set_err_code(service_ret);
S
seemingwang 已提交
272 273 274
    if (!response->has_err_msg()) {
      response->set_err_msg("server internal error");
    }
S
seemingwang 已提交
275 276 277
  }
}

Z
zhaocaibei123 已提交
278
int32_t GraphBrpcService::Barrier(Table *table, const PsRequestMessage &request,
S
seemingwang 已提交
279 280 281 282 283 284 285 286 287 288 289 290 291
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)

  if (request.params_size() < 1) {
    set_response_code(response, -1,
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }

  auto trainer_id = request.client_id();
  auto barrier_type = request.params(0);
Z
zhaocaibei123 已提交
292
  table->Barrier(trainer_id, barrier_type);
S
seemingwang 已提交
293 294 295
  return 0;
}

Z
zhaocaibei123 已提交
296 297 298 299
int32_t GraphBrpcService::PrintTableStat(Table *table,
                                         const PsRequestMessage &request,
                                         PsResponseMessage &response,
                                         brpc::Controller *cntl) {
S
seemingwang 已提交
300
  CHECK_TABLE_EXIST(table, request, response)
Z
zhaocaibei123 已提交
301
  std::pair<int64_t, int64_t> ret = table->PrintTableStat();
S
seemingwang 已提交
302 303 304 305 306 307 308 309
  paddle::framework::BinaryArchive ar;
  ar << ret.first << ret.second;
  std::string table_info(ar.Buffer(), ar.Length());
  response.set_data(table_info);

  return 0;
}

Z
zhaocaibei123 已提交
310 311 312 313
int32_t GraphBrpcService::LoadOneTable(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
S
seemingwang 已提交
314 315 316 317 318 319 320
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 2) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 2 for path & load_param");
    return -1;
  }
Z
zhaocaibei123 已提交
321
  if (table->Load(request.params(0), request.params(1)) != 0) {
S
seemingwang 已提交
322 323 324 325 326 327
    set_response_code(response, -1, "table load failed");
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
328 329 330 331 332
int32_t GraphBrpcService::LoadAllTable(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
  auto &table_map = *(_server->GetTable());
S
seemingwang 已提交
333
  for (auto &itr : table_map) {
Z
zhaocaibei123 已提交
334
    if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
S
seemingwang 已提交
335 336 337 338 339 340 341
      LOG(ERROR) << "load table[" << itr.first << "] failed";
      return -1;
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
342 343 344 345
int32_t GraphBrpcService::StopServer(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
S
seemingwang 已提交
346 347
  GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
  std::thread t_stop([p_server]() {
Z
zhaocaibei123 已提交
348
    p_server->Stop();
S
seemingwang 已提交
349 350 351 352 353 354 355
    LOG(INFO) << "Server Stoped";
  });
  p_server->export_cv()->notify_all();
  t_stop.detach();
  return 0;
}

Z
zhaocaibei123 已提交
356 357 358 359
int32_t GraphBrpcService::StopProfiler(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
S
seemingwang 已提交
360 361 362 363 364
  platform::DisableProfiler(platform::EventSortingKey::kDefault,
                            string::Sprintf("server_%s_profile", _rank));
  return 0;
}

Z
zhaocaibei123 已提交
365 366 367 368
int32_t GraphBrpcService::StartProfiler(Table *table,
                                        const PsRequestMessage &request,
                                        PsResponseMessage &response,
                                        brpc::Controller *cntl) {
S
seemingwang 已提交
369 370 371 372 373 374 375 376 377
  platform::EnableProfiler(platform::ProfilerState::kCPU);
  return 0;
}

int32_t GraphBrpcService::pull_graph_list(Table *table,
                                          const PsRequestMessage &request,
                                          PsResponseMessage &response,
                                          brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
378
  if (request.params_size() < 5) {
S
seemingwang 已提交
379
    set_response_code(response, -1,
380
                      "pull_graph_list request requires at least 5 arguments");
S
seemingwang 已提交
381 382
    return 0;
  }
383 384 385 386 387 388 389 390
  int type_id = *(int *)(request.params(0).c_str());
  int idx = *(int *)(request.params(1).c_str());
  int start = *(int *)(request.params(2).c_str());
  int size = *(int *)(request.params(3).c_str());
  int step = *(int *)(request.params(4).c_str());
  // int start = *(int *)(request.params(0).c_str());
  // int size = *(int *)(request.params(1).c_str());
  // int step = *(int *)(request.params(2).c_str());
S
seemingwang 已提交
391 392
  std::unique_ptr<char[]> buffer;
  int actual_size;
S
seemingwang 已提交
393
  ((GraphTable *)table)
394 395
      ->pull_graph_list(type_id, idx, start, size, buffer, actual_size, false,
                        step);
S
seemingwang 已提交
396 397 398
  cntl->response_attachment().append(buffer.get(), actual_size);
  return 0;
}
399
int32_t GraphBrpcService::graph_random_sample_neighbors(
S
seemingwang 已提交
400 401 402
    Table *table, const PsRequestMessage &request, PsResponseMessage &response,
    brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
403
  if (request.params_size() < 4) {
S
seemingwang 已提交
404 405
    set_response_code(
        response, -1,
406
        "graph_random_sample_neighbors request requires at least 3 arguments");
S
seemingwang 已提交
407 408
    return 0;
  }
409 410 411 412 413 414 415 416 417
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  int sample_size = *(int64_t *)(request.params(2).c_str());
  bool need_weight = *(bool *)(request.params(3).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  // int sample_size = *(int64_t *)(request.params(1).c_str());
  // bool need_weight = *(bool *)(request.params(2).c_str());
418
  std::vector<std::shared_ptr<char>> buffers(node_num);
S
seemingwang 已提交
419
  std::vector<int> actual_sizes(node_num, 0);
S
seemingwang 已提交
420
  ((GraphTable *)table)
421 422
      ->random_sample_neighbors(idx_, node_data, sample_size, buffers,
                                actual_sizes, need_weight);
S
seemingwang 已提交
423 424 425 426 427 428 429 430 431 432 433 434

  cntl->response_attachment().append(&node_num, sizeof(size_t));
  cntl->response_attachment().append(actual_sizes.data(),
                                     sizeof(int) * node_num);
  for (size_t idx = 0; idx < node_num; ++idx) {
    cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
  }
  return 0;
}
int32_t GraphBrpcService::graph_random_sample_nodes(
    Table *table, const PsRequestMessage &request, PsResponseMessage &response,
    brpc::Controller *cntl) {
435 436 437 438
  int type_id = *(int *)(request.params(0).c_str());
  int idx_ = *(int *)(request.params(1).c_str());
  size_t size = *(int64_t *)(request.params(2).c_str());
  // size_t size = *(int64_t *)(request.params(0).c_str());
S
seemingwang 已提交
439 440
  std::unique_ptr<char[]> buffer;
  int actual_size;
441 442
  if (((GraphTable *)table)
          ->random_sample_nodes(type_id, idx_, size, buffer, actual_size) ==
S
seemingwang 已提交
443
      0) {
S
seemingwang 已提交
444 445 446 447 448 449 450 451 452 453 454 455
    cntl->response_attachment().append(buffer.get(), actual_size);
  } else
    cntl->response_attachment().append(NULL, 0);

  return 0;
}

int32_t GraphBrpcService::graph_get_node_feat(Table *table,
                                              const PsRequestMessage &request,
                                              PsResponseMessage &response,
                                              brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
456
  if (request.params_size() < 3) {
S
seemingwang 已提交
457 458
    set_response_code(
        response, -1,
459
        "graph_get_node_feat request requires at least 3 arguments");
S
seemingwang 已提交
460 461
    return 0;
  }
462 463 464 465 466
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
467
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
S
seemingwang 已提交
468 469

  std::vector<std::string> feature_names =
470
      paddle::string::split_string<std::string>(request.params(2), "\t");
S
seemingwang 已提交
471 472 473 474

  std::vector<std::vector<std::string>> feature(
      feature_names.size(), std::vector<std::string>(node_num));

475
  ((GraphTable *)table)->get_node_feat(idx_, node_ids, feature_names, feature);
S
seemingwang 已提交
476 477 478 479 480 481 482 483 484 485 486 487

  for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
    for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
      size_t feat_len = feature[feat_idx][node_idx].size();
      cntl->response_attachment().append(&feat_len, sizeof(size_t));
      cntl->response_attachment().append(feature[feat_idx][node_idx].data(),
                                         feat_len);
    }
  }

  return 0;
}
488
int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
S
seemingwang 已提交
489 490 491 492
    Table *table, const PsRequestMessage &request, PsResponseMessage &response,
    brpc::Controller *cntl) {
  // sleep(5);
  CHECK_TABLE_EXIST(table, request, response)
493
  if (request.params_size() < 4) {
494 495
    set_response_code(response, -1,
                      "sample_neighbors_across_multi_servers request requires "
496
                      "at least 4 arguments");
S
seemingwang 已提交
497 498
    return 0;
  }
499 500 501

  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t),
S
seemingwang 已提交
502
         size_of_size_t = sizeof(size_t);
503 504 505 506 507 508 509 510 511
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  int sample_size = *(int64_t *)(request.params(2).c_str());
  bool need_weight = *(int64_t *)(request.params(3).c_str());

  // size_t node_num = request.params(0).size() / sizeof(int64_t),
  //        size_of_size_t = sizeof(size_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  // int sample_size = *(int64_t *)(request.params(1).c_str());
  // bool need_weight = *(int64_t *)(request.params(2).c_str());
512
  // std::vector<int64_t> res = ((GraphTable
S
seemingwang 已提交
513 514 515
  // *)table).filter_out_non_exist_nodes(node_data, sample_size);
  std::vector<int> request2server;
  std::vector<int> server2request(server_size, -1);
516
  std::vector<int64_t> local_id;
S
seemingwang 已提交
517
  std::vector<int> local_query_idx;
Z
zhaocaibei123 已提交
518
  size_t rank = GetRank();
S
seemingwang 已提交
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535
  for (int query_idx = 0; query_idx < node_num; ++query_idx) {
    int server_index =
        ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
    if (server2request[server_index] == -1) {
      server2request[server_index] = request2server.size();
      request2server.push_back(server_index);
    }
  }
  if (server2request[rank] != -1) {
    auto pos = server2request[rank];
    std::swap(request2server[pos],
              request2server[(int)request2server.size() - 1]);
    server2request[request2server[pos]] = pos;
    server2request[request2server[(int)request2server.size() - 1]] =
        request2server.size() - 1;
  }
  size_t request_call_num = request2server.size();
536
  std::vector<std::shared_ptr<char>> local_buffers;
S
seemingwang 已提交
537 538
  std::vector<int> local_actual_sizes;
  std::vector<size_t> seq;
539
  std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
S
seemingwang 已提交
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
  std::vector<std::vector<int>> query_idx_buckets(request_call_num);
  for (int query_idx = 0; query_idx < node_num; ++query_idx) {
    int server_index =
        ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
    int request_idx = server2request[server_index];
    node_id_buckets[request_idx].push_back(node_data[query_idx]);
    query_idx_buckets[request_idx].push_back(query_idx);
    seq.push_back(request_idx);
  }
  size_t remote_call_num = request_call_num;
  if (request2server.size() != 0 && request2server.back() == rank) {
    remote_call_num--;
    local_buffers.resize(node_id_buckets.back().size());
    local_actual_sizes.resize(node_id_buckets.back().size());
  }
  cntl->response_attachment().append(&node_num, sizeof(size_t));
  auto local_promise = std::make_shared<std::promise<int32_t>>();
  std::future<int> local_fut = local_promise->get_future();
  std::vector<bool> failed(server_size, false);
  std::function<void(void *)> func = [&, node_id_buckets, query_idx_buckets,
                                      request_call_num](void *done) {
    local_fut.get();
    std::vector<int> actual_size;
    auto *closure = (DownpourBrpcClosure *)done;
    std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
        remote_call_num);
    size_t fail_num = 0;
    for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
568
      if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
S
seemingwang 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
          0) {
        ++fail_num;
        failed[request2server[request_idx]] = true;
      } else {
        auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
        size_t node_size;
        res[request_idx].reset(new butil::IOBufBytesIterator(res_io_buffer));
        size_t num;
        res[request_idx]->copy_and_forward(&num, sizeof(size_t));
      }
    }
    int size;
    int local_index = 0;
    for (size_t i = 0; i < node_num; i++) {
      if (fail_num > 0 && failed[seq[i]]) {
        size = 0;
      } else if (request2server[seq[i]] != rank) {
        res[seq[i]]->copy_and_forward(&size, sizeof(int));
      } else {
        size = local_actual_sizes[local_index++];
      }
      actual_size.push_back(size);
    }
    cntl->response_attachment().append(actual_size.data(),
                                       actual_size.size() * sizeof(int));

    local_index = 0;
    for (size_t i = 0; i < node_num; i++) {
      if (fail_num > 0 && failed[seq[i]]) {
        continue;
      } else if (request2server[seq[i]] != rank) {
        char temp[actual_size[i] + 1];
        res[seq[i]]->copy_and_forward(temp, actual_size[i]);
        cntl->response_attachment().append(temp, actual_size[i]);
      } else {
        char *temp = local_buffers[local_index++].get();
        cntl->response_attachment().append(temp, actual_size[i]);
      }
    }
    closure->set_promise_value(0);
  };

  DownpourBrpcClosure *closure = new DownpourBrpcClosure(remote_call_num, func);

  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

  for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) {
    int server_index = request2server[request_idx];
619
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
S
seemingwang 已提交
620 621 622 623
    closure->request(request_idx)->set_table_id(request.table_id());
    closure->request(request_idx)->set_client_id(rank);
    size_t node_num = node_id_buckets[request_idx].size();

624 625
    closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));

S
seemingwang 已提交
626 627
    closure->request(request_idx)
        ->add_params((char *)node_id_buckets[request_idx].data(),
628
                     sizeof(int64_t) * node_num);
S
seemingwang 已提交
629 630
    closure->request(request_idx)
        ->add_params((char *)&sample_size, sizeof(int));
631 632
    closure->request(request_idx)
        ->add_params((char *)&need_weight, sizeof(bool));
S
seemingwang 已提交
633
    PsService_Stub rpc_stub(
Z
zhaocaibei123 已提交
634
        ((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
S
seemingwang 已提交
635
    // GraphPsService_Stub rpc_stub =
Z
zhaocaibei123 已提交
636
    //     getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
637 638 639 640 641 642
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
    rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
                     closure->response(request_idx), closure);
  }
  if (server2request[rank] != -1) {
    ((GraphTable *)table)
643 644 645
        ->random_sample_neighbors(idx_, node_id_buckets.back().data(),
                                  sample_size, local_buffers,
                                  local_actual_sizes, need_weight);
S
seemingwang 已提交
646 647 648 649 650 651
  }
  local_promise.get()->set_value(0);
  if (remote_call_num == 0) func(closure);
  fut.get();
  return 0;
}
S
seemingwang 已提交
652 653 654 655 656
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
                                              const PsRequestMessage &request,
                                              PsResponseMessage &response,
                                              brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
657
  if (request.params_size() < 4) {
S
seemingwang 已提交
658 659
    set_response_code(
        response, -1,
S
seemingwang 已提交
660
        "graph_set_node_feat request requires at least 3 arguments");
S
seemingwang 已提交
661 662
    return 0;
  }
663 664 665 666 667 668
  int idx_ = *(int *)(request.params(0).c_str());

  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
669
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
S
seemingwang 已提交
670

671 672 673
  // std::vector<std::string> feature_names =
  //     paddle::string::split_string<std::string>(request.params(1), "\t");

S
seemingwang 已提交
674
  std::vector<std::string> feature_names =
675
      paddle::string::split_string<std::string>(request.params(2), "\t");
S
seemingwang 已提交
676 677 678 679

  std::vector<std::vector<std::string>> features(
      feature_names.size(), std::vector<std::string>(node_num));

680 681
  //  const char *buffer = request.params(2).c_str();
  const char *buffer = request.params(3).c_str();
S
seemingwang 已提交
682 683 684 685 686 687 688 689 690 691 692

  for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
    for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
      size_t feat_len = *(size_t *)(buffer);
      buffer += sizeof(size_t);
      auto feat = std::string(buffer, feat_len);
      features[feat_idx][node_idx] = feat;
      buffer += feat_len;
    }
  }

693
  ((GraphTable *)table)->set_node_feat(idx_, node_ids, feature_names, features);
S
seemingwang 已提交
694 695 696 697

  return 0;
}

S
seemingwang 已提交
698 699
}  // namespace distributed
}  // namespace paddle