graph_brpc_server.cc 27.9 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
S
seemingwang 已提交
16 17

#include <thread>  // NOLINT
S
seemingwang 已提交
18
#include <utility>
19

S
seemingwang 已提交
20 21
#include "butil/endpoint.h"
#include "iomanip"
22
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
23
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
S
seemingwang 已提交
24 25 26 27 28
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {

29 30 31 32 33 34 35 36
#define CHECK_TABLE_EXIST(table, request, response)        \
  if (table == NULL) {                                     \
    std::string err_msg("table not found with table_id:"); \
    err_msg.append(std::to_string(request.table_id()));    \
    set_response_code(response, -1, err_msg.c_str());      \
    return -1;                                             \
  }

Z
zhaocaibei123 已提交
37
int32_t GraphBrpcServer::Initialize() {
S
seemingwang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51
  auto &service_config = _config.downpour_server_param().service_param();
  if (!service_config.has_service_class()) {
    LOG(ERROR) << "miss service_class in ServerServiceParameter";
    return -1;
  }
  auto *service =
      CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
  if (service == NULL) {
    LOG(ERROR) << "service is unregistered, service_name:"
               << service_config.service_class();
    return -1;
  }

  _service.reset(service);
Z
zhaocaibei123 已提交
52
  if (service->Configure(this) != 0 || service->Initialize() != 0) {
S
seemingwang 已提交
53 54 55 56 57 58 59 60 61 62 63 64
    LOG(ERROR) << "service initialize failed, service_name:"
               << service_config.service_class();
    return -1;
  }
  if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
    LOG(ERROR) << "service add to brpc failed, service:"
               << service_config.service_class();
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
65
brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) {
S
seemingwang 已提交
66 67 68
  return _pserver_channels[server_index].get();
}

Z
zhaocaibei123 已提交
69
uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) {
S
seemingwang 已提交
70 71 72 73 74 75 76
  std::unique_lock<std::mutex> lock(mutex_);

  std::string ip_port = ip + ":" + std::to_string(port);
  VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
  brpc::ServerOptions options;

  int num_threads = std::thread::hardware_concurrency();
Z
zhaocaibei123 已提交
77
  auto trainers = _environment->GetTrainers();
S
seemingwang 已提交
78 79 80 81 82 83
  options.num_threads = trainers > num_threads ? trainers : num_threads;

  if (_server.Start(ip_port.c_str(), &options) != 0) {
    LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
    return 0;
  }
Z
zhaocaibei123 已提交
84
  _environment->RegistePsServer(ip, port, _rank);
S
seemingwang 已提交
85 86 87
  return 0;
}

S
seemingwang 已提交
88 89
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
  this->rank = rank;
Z
zhaocaibei123 已提交
90
  auto _env = Environment();
S
seemingwang 已提交
91 92 93 94 95 96 97
  brpc::ChannelOptions options;
  options.protocol = "baidu_std";
  options.timeout_ms = 500000;
  options.connection_type = "pooled";
  options.connect_timeout_ms = 10000;
  options.max_retry = 3;

Z
zhaocaibei123 已提交
98
  std::vector<PSHost> server_list = _env->GetPsServers();
S
seemingwang 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
  _pserver_channels.resize(server_list.size());
  std::ostringstream os;
  std::string server_ip_port;
  for (size_t i = 0; i < server_list.size(); ++i) {
    server_ip_port.assign(server_list[i].ip.c_str());
    server_ip_port.append(":");
    server_ip_port.append(std::to_string(server_list[i].port));
    _pserver_channels[i].reset(new brpc::Channel());
    if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
      VLOG(0) << "GraphServer connect to Server:" << server_ip_port
              << " Failed! Try again.";
      std::string int_ip_port =
          GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
      if (_pserver_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
        LOG(ERROR) << "GraphServer connect to Server:" << int_ip_port
                   << " Failed!";
        return -1;
      }
    }
    os << server_ip_port << ",";
  }
  LOG(INFO) << "servers peer2peer connection success:" << os.str();
  return 0;
}

124 125 126 127
int32_t GraphBrpcService::clear_nodes(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
128 129 130
  int type_id = *(int *)(request.params(0).c_str());
  int idx_ = *(int *)(request.params(1).c_str());
  ((GraphTable *)table)->clear_nodes(type_id, idx_);
131 132 133 134 135 136 137 138
  return 0;
}

int32_t GraphBrpcService::add_graph_node(Table *table,
                                         const PsRequestMessage &request,
                                         PsResponseMessage &response,
                                         brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
139 140 141
  if (request.params_size() < 2) {
    set_response_code(response, -1,
                      "add_graph_node request requires at least 2 arguments");
142 143 144
    return 0;
  }

145 146 147 148 149
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
150
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
151
  std::vector<bool> is_weighted_list;
152 153 154
  if (request.params_size() == 3) {
    size_t weight_list_size = request.params(2).size() / sizeof(bool);
    bool *is_weighted_buffer = (bool *)(request.params(2).c_str());
155 156 157
    is_weighted_list = std::vector<bool>(is_weighted_buffer,
                                         is_weighted_buffer + weight_list_size);
  }
158 159 160 161 162 163 164
  // if (request.params_size() == 2) {
  //   size_t weight_list_size = request.params(1).size() / sizeof(bool);
  //   bool *is_weighted_buffer = (bool *)(request.params(1).c_str());
  //   is_weighted_list = std::vector<bool>(is_weighted_buffer,
  //                                        is_weighted_buffer +
  //                                        weight_list_size);
  // }
165

166
  ((GraphTable *)table)->add_graph_node(idx_, node_ids, is_weighted_list);
167 168 169 170 171 172 173
  return 0;
}
int32_t GraphBrpcService::remove_graph_node(Table *table,
                                            const PsRequestMessage &request,
                                            PsResponseMessage &response,
                                            brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
174
  if (request.params_size() < 2) {
175 176
    set_response_code(
        response, -1,
177
        "remove_graph_node request requires at least 2 arguments");
178 179
    return 0;
  }
180 181 182 183 184
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
185
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
186

187
  ((GraphTable *)table)->remove_graph_node(idx_, node_ids);
188 189
  return 0;
}
Z
zhaocaibei123 已提交
190
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
S
seemingwang 已提交
191

Z
zhaocaibei123 已提交
192
int32_t GraphBrpcService::Initialize() {
S
seemingwang 已提交
193
  _is_initialize_shard_info = false;
Z
zhaocaibei123 已提交
194 195 196
  _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer;
  _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable;
  _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable;
S
seemingwang 已提交
197

Z
zhaocaibei123 已提交
198 199 200 201
  _service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat;
  _service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier;
  _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler;
  _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler;
S
seemingwang 已提交
202 203

  _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
204 205
  _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
      &GraphBrpcService::graph_random_sample_neighbors;
S
seemingwang 已提交
206 207 208 209
  _service_handler_map[PS_GRAPH_SAMPLE_NODES] =
      &GraphBrpcService::graph_random_sample_nodes;
  _service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
      &GraphBrpcService::graph_get_node_feat;
210 211 212 213 214
  _service_handler_map[PS_GRAPH_CLEAR] = &GraphBrpcService::clear_nodes;
  _service_handler_map[PS_GRAPH_ADD_GRAPH_NODE] =
      &GraphBrpcService::add_graph_node;
  _service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
      &GraphBrpcService::remove_graph_node;
S
seemingwang 已提交
215 216
  _service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
      &GraphBrpcService::graph_set_node_feat;
S
seemingwang 已提交
217
  _service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
218
      &GraphBrpcService::sample_neighbors_across_multi_servers;
219 220 221 222
  // _service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] =
  //     &GraphBrpcService::use_neighbors_sample_cache;
  // _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
  //     &GraphBrpcService::load_graph_split_config;
S
seemingwang 已提交
223
  // shard初始化,server启动后才可从env获取到server_list的shard信息
Z
zhaocaibei123 已提交
224
  InitializeShardInfo();
S
seemingwang 已提交
225 226 227 228

  return 0;
}

Z
zhaocaibei123 已提交
229
int32_t GraphBrpcService::InitializeShardInfo() {
S
seemingwang 已提交
230 231 232 233 234
  if (!_is_initialize_shard_info) {
    std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
    if (_is_initialize_shard_info) {
      return 0;
    }
Z
zhaocaibei123 已提交
235 236
    server_size = _server->Environment()->GetPsServers().size();
    auto &table_map = *(_server->GetTable());
S
seemingwang 已提交
237
    for (auto itr : table_map) {
Z
zhaocaibei123 已提交
238
      itr.second->SetShard(_rank, server_size);
S
seemingwang 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    }
    _is_initialize_shard_info = true;
  }
  return 0;
}

void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
                               const PsRequestMessage *request,
                               PsResponseMessage *response,
                               google::protobuf::Closure *done) {
  brpc::ClosureGuard done_guard(done);
  std::string log_label("ReceiveCmd-");
  if (!request->has_table_id()) {
    set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
    return;
  }

  response->set_err_code(0);
  response->set_err_msg("");
Z
zhaocaibei123 已提交
258
  auto *table = _server->GetTable(request->table_id());
S
seemingwang 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271
  brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
  auto itr = _service_handler_map.find(request->cmd_id());
  if (itr == _service_handler_map.end()) {
    std::string err_msg(
        "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
    err_msg.append(std::to_string(request->cmd_id()));
    set_response_code(*response, -1, err_msg.c_str());
    return;
  }
  serviceFunc handler_func = itr->second;
  int service_ret = (this->*handler_func)(table, *request, *response, cntl);
  if (service_ret != 0) {
    response->set_err_code(service_ret);
S
seemingwang 已提交
272 273 274
    if (!response->has_err_msg()) {
      response->set_err_msg("server internal error");
    }
S
seemingwang 已提交
275 276 277
  }
}

Z
zhaocaibei123 已提交
278
int32_t GraphBrpcService::Barrier(Table *table, const PsRequestMessage &request,
S
seemingwang 已提交
279 280 281 282 283 284 285 286 287 288 289 290 291
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)

  if (request.params_size() < 1) {
    set_response_code(response, -1,
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }

  auto trainer_id = request.client_id();
  auto barrier_type = request.params(0);
Z
zhaocaibei123 已提交
292
  table->Barrier(trainer_id, barrier_type);
S
seemingwang 已提交
293 294 295
  return 0;
}

Z
zhaocaibei123 已提交
296 297 298 299
int32_t GraphBrpcService::PrintTableStat(Table *table,
                                         const PsRequestMessage &request,
                                         PsResponseMessage &response,
                                         brpc::Controller *cntl) {
S
seemingwang 已提交
300
  CHECK_TABLE_EXIST(table, request, response)
Z
zhaocaibei123 已提交
301
  std::pair<int64_t, int64_t> ret = table->PrintTableStat();
S
seemingwang 已提交
302 303 304 305 306 307 308 309
  paddle::framework::BinaryArchive ar;
  ar << ret.first << ret.second;
  std::string table_info(ar.Buffer(), ar.Length());
  response.set_data(table_info);

  return 0;
}

Z
zhaocaibei123 已提交
310 311 312 313
int32_t GraphBrpcService::LoadOneTable(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
S
seemingwang 已提交
314 315 316 317 318 319 320
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 2) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 2 for path & load_param");
    return -1;
  }
Z
zhaocaibei123 已提交
321
  if (table->Load(request.params(0), request.params(1)) != 0) {
S
seemingwang 已提交
322 323 324 325 326 327
    set_response_code(response, -1, "table load failed");
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
328 329 330 331 332
int32_t GraphBrpcService::LoadAllTable(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
  auto &table_map = *(_server->GetTable());
S
seemingwang 已提交
333
  for (auto &itr : table_map) {
Z
zhaocaibei123 已提交
334
    if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
S
seemingwang 已提交
335 336 337 338 339 340 341
      LOG(ERROR) << "load table[" << itr.first << "] failed";
      return -1;
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
342 343 344 345
int32_t GraphBrpcService::StopServer(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
S
seemingwang 已提交
346 347
  GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
  std::thread t_stop([p_server]() {
Z
zhaocaibei123 已提交
348
    p_server->Stop();
S
seemingwang 已提交
349 350 351 352 353 354 355
    LOG(INFO) << "Server Stoped";
  });
  p_server->export_cv()->notify_all();
  t_stop.detach();
  return 0;
}

Z
zhaocaibei123 已提交
356 357 358 359
int32_t GraphBrpcService::StopProfiler(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
S
seemingwang 已提交
360 361 362 363 364
  platform::DisableProfiler(platform::EventSortingKey::kDefault,
                            string::Sprintf("server_%s_profile", _rank));
  return 0;
}

Z
zhaocaibei123 已提交
365 366 367 368
int32_t GraphBrpcService::StartProfiler(Table *table,
                                        const PsRequestMessage &request,
                                        PsResponseMessage &response,
                                        brpc::Controller *cntl) {
S
seemingwang 已提交
369 370 371 372 373 374 375 376 377
  platform::EnableProfiler(platform::ProfilerState::kCPU);
  return 0;
}

int32_t GraphBrpcService::pull_graph_list(Table *table,
                                          const PsRequestMessage &request,
                                          PsResponseMessage &response,
                                          brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
378
  if (request.params_size() < 5) {
S
seemingwang 已提交
379
    set_response_code(response, -1,
380
                      "pull_graph_list request requires at least 5 arguments");
S
seemingwang 已提交
381 382
    return 0;
  }
383 384 385 386 387 388 389 390
  int type_id = *(int *)(request.params(0).c_str());
  int idx = *(int *)(request.params(1).c_str());
  int start = *(int *)(request.params(2).c_str());
  int size = *(int *)(request.params(3).c_str());
  int step = *(int *)(request.params(4).c_str());
  // int start = *(int *)(request.params(0).c_str());
  // int size = *(int *)(request.params(1).c_str());
  // int step = *(int *)(request.params(2).c_str());
S
seemingwang 已提交
391 392
  std::unique_ptr<char[]> buffer;
  int actual_size;
S
seemingwang 已提交
393
  ((GraphTable *)table)
394 395
      ->pull_graph_list(type_id, idx, start, size, buffer, actual_size, false,
                        step);
S
seemingwang 已提交
396 397 398
  cntl->response_attachment().append(buffer.get(), actual_size);
  return 0;
}
399
int32_t GraphBrpcService::graph_random_sample_neighbors(
S
seemingwang 已提交
400 401 402
    Table *table, const PsRequestMessage &request, PsResponseMessage &response,
    brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
403
  if (request.params_size() < 4) {
S
seemingwang 已提交
404 405
    set_response_code(
        response, -1,
406
        "graph_random_sample_neighbors request requires at least 3 arguments");
S
seemingwang 已提交
407 408
    return 0;
  }
409 410 411 412 413 414 415 416 417
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  int sample_size = *(int64_t *)(request.params(2).c_str());
  bool need_weight = *(bool *)(request.params(3).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  // int sample_size = *(int64_t *)(request.params(1).c_str());
  // bool need_weight = *(bool *)(request.params(2).c_str());
418
  std::vector<std::shared_ptr<char>> buffers(node_num);
S
seemingwang 已提交
419
  std::vector<int> actual_sizes(node_num, 0);
S
seemingwang 已提交
420
  ((GraphTable *)table)
421 422
      ->random_sample_neighbors(idx_, node_data, sample_size, buffers,
                                actual_sizes, need_weight);
S
seemingwang 已提交
423 424 425 426 427 428 429 430 431 432 433 434

  cntl->response_attachment().append(&node_num, sizeof(size_t));
  cntl->response_attachment().append(actual_sizes.data(),
                                     sizeof(int) * node_num);
  for (size_t idx = 0; idx < node_num; ++idx) {
    cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
  }
  return 0;
}
int32_t GraphBrpcService::graph_random_sample_nodes(
    Table *table, const PsRequestMessage &request, PsResponseMessage &response,
    brpc::Controller *cntl) {
435 436 437 438
  int type_id = *(int *)(request.params(0).c_str());
  int idx_ = *(int *)(request.params(1).c_str());
  size_t size = *(int64_t *)(request.params(2).c_str());
  // size_t size = *(int64_t *)(request.params(0).c_str());
S
seemingwang 已提交
439 440
  std::unique_ptr<char[]> buffer;
  int actual_size;
441 442
  if (((GraphTable *)table)
          ->random_sample_nodes(type_id, idx_, size, buffer, actual_size) ==
S
seemingwang 已提交
443
      0) {
S
seemingwang 已提交
444 445 446 447 448 449 450 451 452 453 454 455
    cntl->response_attachment().append(buffer.get(), actual_size);
  } else
    cntl->response_attachment().append(NULL, 0);

  return 0;
}

int32_t GraphBrpcService::graph_get_node_feat(Table *table,
                                              const PsRequestMessage &request,
                                              PsResponseMessage &response,
                                              brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
456
  if (request.params_size() < 3) {
S
seemingwang 已提交
457 458
    set_response_code(
        response, -1,
459
        "graph_get_node_feat request requires at least 3 arguments");
S
seemingwang 已提交
460 461
    return 0;
  }
462 463 464 465 466
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
467
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
S
seemingwang 已提交
468 469

  std::vector<std::string> feature_names =
470
      paddle::string::split_string<std::string>(request.params(2), "\t");
S
seemingwang 已提交
471 472 473 474

  std::vector<std::vector<std::string>> feature(
      feature_names.size(), std::vector<std::string>(node_num));

475
  ((GraphTable *)table)->get_node_feat(idx_, node_ids, feature_names, feature);
S
seemingwang 已提交
476 477 478 479 480 481 482 483 484 485 486 487

  for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
    for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
      size_t feat_len = feature[feat_idx][node_idx].size();
      cntl->response_attachment().append(&feat_len, sizeof(size_t));
      cntl->response_attachment().append(feature[feat_idx][node_idx].data(),
                                         feat_len);
    }
  }

  return 0;
}
488
int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
S
seemingwang 已提交
489 490 491 492
    Table *table, const PsRequestMessage &request, PsResponseMessage &response,
    brpc::Controller *cntl) {
  // sleep(5);
  CHECK_TABLE_EXIST(table, request, response)
493
  if (request.params_size() < 4) {
494 495
    set_response_code(response, -1,
                      "sample_neighbors_across_multi_servers request requires "
496
                      "at least 4 arguments");
S
seemingwang 已提交
497 498
    return 0;
  }
499 500

  int idx_ = *(int *)(request.params(0).c_str());
501
  size_t node_num = request.params(1).size() / sizeof(int64_t);
502 503 504 505 506 507 508 509 510
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  int sample_size = *(int64_t *)(request.params(2).c_str());
  bool need_weight = *(int64_t *)(request.params(3).c_str());

  // size_t node_num = request.params(0).size() / sizeof(int64_t),
  //        size_of_size_t = sizeof(size_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  // int sample_size = *(int64_t *)(request.params(1).c_str());
  // bool need_weight = *(int64_t *)(request.params(2).c_str());
511
  // std::vector<int64_t> res = ((GraphTable
S
seemingwang 已提交
512 513 514
  // *)table).filter_out_non_exist_nodes(node_data, sample_size);
  std::vector<int> request2server;
  std::vector<int> server2request(server_size, -1);
515
  std::vector<int64_t> local_id;
S
seemingwang 已提交
516
  std::vector<int> local_query_idx;
Z
zhaocaibei123 已提交
517
  size_t rank = GetRank();
518
  for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
S
seemingwang 已提交
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
    int server_index =
        ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
    if (server2request[server_index] == -1) {
      server2request[server_index] = request2server.size();
      request2server.push_back(server_index);
    }
  }
  if (server2request[rank] != -1) {
    auto pos = server2request[rank];
    std::swap(request2server[pos],
              request2server[(int)request2server.size() - 1]);
    server2request[request2server[pos]] = pos;
    server2request[request2server[(int)request2server.size() - 1]] =
        request2server.size() - 1;
  }
  size_t request_call_num = request2server.size();
535
  std::vector<std::shared_ptr<char>> local_buffers;
S
seemingwang 已提交
536 537
  std::vector<int> local_actual_sizes;
  std::vector<size_t> seq;
538
  std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
S
seemingwang 已提交
539
  std::vector<std::vector<int>> query_idx_buckets(request_call_num);
540
  for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
S
seemingwang 已提交
541 542 543 544 545 546 547 548
    int server_index =
        ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
    int request_idx = server2request[server_index];
    node_id_buckets[request_idx].push_back(node_data[query_idx]);
    query_idx_buckets[request_idx].push_back(query_idx);
    seq.push_back(request_idx);
  }
  size_t remote_call_num = request_call_num;
Z
zhangchunle 已提交
549 550
  if (request2server.size() != 0 &&
      static_cast<size_t>(request2server.back()) == rank) {
S
seemingwang 已提交
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
    remote_call_num--;
    local_buffers.resize(node_id_buckets.back().size());
    local_actual_sizes.resize(node_id_buckets.back().size());
  }
  cntl->response_attachment().append(&node_num, sizeof(size_t));
  auto local_promise = std::make_shared<std::promise<int32_t>>();
  std::future<int> local_fut = local_promise->get_future();
  std::vector<bool> failed(server_size, false);
  std::function<void(void *)> func = [&, node_id_buckets, query_idx_buckets,
                                      request_call_num](void *done) {
    local_fut.get();
    std::vector<int> actual_size;
    auto *closure = (DownpourBrpcClosure *)done;
    std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
        remote_call_num);
    size_t fail_num = 0;
    for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
568
      if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
S
seemingwang 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
          0) {
        ++fail_num;
        failed[request2server[request_idx]] = true;
      } else {
        auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
        res[request_idx].reset(new butil::IOBufBytesIterator(res_io_buffer));
        size_t num;
        res[request_idx]->copy_and_forward(&num, sizeof(size_t));
      }
    }
    int size;
    int local_index = 0;
    for (size_t i = 0; i < node_num; i++) {
      if (fail_num > 0 && failed[seq[i]]) {
        size = 0;
Z
zhangchunle 已提交
584
      } else if (static_cast<size_t>(request2server[seq[i]]) != rank) {
S
seemingwang 已提交
585 586 587 588 589 590 591 592 593 594 595 596 597
        res[seq[i]]->copy_and_forward(&size, sizeof(int));
      } else {
        size = local_actual_sizes[local_index++];
      }
      actual_size.push_back(size);
    }
    cntl->response_attachment().append(actual_size.data(),
                                       actual_size.size() * sizeof(int));

    local_index = 0;
    for (size_t i = 0; i < node_num; i++) {
      if (fail_num > 0 && failed[seq[i]]) {
        continue;
Z
zhangchunle 已提交
598
      } else if (static_cast<size_t>(request2server[seq[i]]) != rank) {
S
seemingwang 已提交
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
        char temp[actual_size[i] + 1];
        res[seq[i]]->copy_and_forward(temp, actual_size[i]);
        cntl->response_attachment().append(temp, actual_size[i]);
      } else {
        char *temp = local_buffers[local_index++].get();
        cntl->response_attachment().append(temp, actual_size[i]);
      }
    }
    closure->set_promise_value(0);
  };

  DownpourBrpcClosure *closure = new DownpourBrpcClosure(remote_call_num, func);

  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

616
  for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
S
seemingwang 已提交
617
    int server_index = request2server[request_idx];
618
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
S
seemingwang 已提交
619 620 621 622
    closure->request(request_idx)->set_table_id(request.table_id());
    closure->request(request_idx)->set_client_id(rank);
    size_t node_num = node_id_buckets[request_idx].size();

623 624
    closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));

S
seemingwang 已提交
625 626
    closure->request(request_idx)
        ->add_params((char *)node_id_buckets[request_idx].data(),
627
                     sizeof(int64_t) * node_num);
S
seemingwang 已提交
628 629
    closure->request(request_idx)
        ->add_params((char *)&sample_size, sizeof(int));
630 631
    closure->request(request_idx)
        ->add_params((char *)&need_weight, sizeof(bool));
S
seemingwang 已提交
632
    PsService_Stub rpc_stub(
Z
zhaocaibei123 已提交
633
        ((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
S
seemingwang 已提交
634
    // GraphPsService_Stub rpc_stub =
Z
zhaocaibei123 已提交
635
    //     getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
636 637 638 639 640 641
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
    rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
                     closure->response(request_idx), closure);
  }
  if (server2request[rank] != -1) {
    ((GraphTable *)table)
642 643 644
        ->random_sample_neighbors(idx_, node_id_buckets.back().data(),
                                  sample_size, local_buffers,
                                  local_actual_sizes, need_weight);
S
seemingwang 已提交
645 646 647 648 649 650
  }
  local_promise.get()->set_value(0);
  if (remote_call_num == 0) func(closure);
  fut.get();
  return 0;
}
S
seemingwang 已提交
651 652 653 654 655
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
                                              const PsRequestMessage &request,
                                              PsResponseMessage &response,
                                              brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
656
  if (request.params_size() < 4) {
S
seemingwang 已提交
657 658
    set_response_code(
        response, -1,
S
seemingwang 已提交
659
        "graph_set_node_feat request requires at least 3 arguments");
S
seemingwang 已提交
660 661
    return 0;
  }
662 663 664 665 666 667
  int idx_ = *(int *)(request.params(0).c_str());

  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
668
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
S
seemingwang 已提交
669

670 671 672
  // std::vector<std::string> feature_names =
  //     paddle::string::split_string<std::string>(request.params(1), "\t");

S
seemingwang 已提交
673
  std::vector<std::string> feature_names =
674
      paddle::string::split_string<std::string>(request.params(2), "\t");
S
seemingwang 已提交
675 676 677 678

  std::vector<std::vector<std::string>> features(
      feature_names.size(), std::vector<std::string>(node_num));

679 680
  //  const char *buffer = request.params(2).c_str();
  const char *buffer = request.params(3).c_str();
S
seemingwang 已提交
681 682 683 684 685 686 687 688 689 690 691

  for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
    for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
      size_t feat_len = *(size_t *)(buffer);
      buffer += sizeof(size_t);
      auto feat = std::string(buffer, feat_len);
      features[feat_idx][node_idx] = feat;
      buffer += feat_len;
    }
  }

692
  ((GraphTable *)table)->set_node_feat(idx_, node_ids, feature_names, features);
S
seemingwang 已提交
693 694 695 696

  return 0;
}

S
seemingwang 已提交
697 698
}  // namespace distributed
}  // namespace paddle