graph_brpc_server.cc 28.2 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
S
seemingwang 已提交
16 17

#include <thread>  // NOLINT
S
seemingwang 已提交
18
#include <utility>
19

S
seemingwang 已提交
20 21
#include "butil/endpoint.h"
#include "iomanip"
22
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
23
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
S
seemingwang 已提交
24 25 26 27 28
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {

29 30 31 32 33 34 35 36
#define CHECK_TABLE_EXIST(table, request, response)        \
  if (table == NULL) {                                     \
    std::string err_msg("table not found with table_id:"); \
    err_msg.append(std::to_string(request.table_id()));    \
    set_response_code(response, -1, err_msg.c_str());      \
    return -1;                                             \
  }

Z
zhaocaibei123 已提交
37
int32_t GraphBrpcServer::Initialize() {
S
seemingwang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51
  auto &service_config = _config.downpour_server_param().service_param();
  if (!service_config.has_service_class()) {
    LOG(ERROR) << "miss service_class in ServerServiceParameter";
    return -1;
  }
  auto *service =
      CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
  if (service == NULL) {
    LOG(ERROR) << "service is unregistered, service_name:"
               << service_config.service_class();
    return -1;
  }

  _service.reset(service);
Z
zhaocaibei123 已提交
52
  if (service->Configure(this) != 0 || service->Initialize() != 0) {
S
seemingwang 已提交
53 54 55 56 57 58 59 60 61 62 63 64
    LOG(ERROR) << "service initialize failed, service_name:"
               << service_config.service_class();
    return -1;
  }
  if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
    LOG(ERROR) << "service add to brpc failed, service:"
               << service_config.service_class();
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
65
brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) {
S
seemingwang 已提交
66 67 68
  return _pserver_channels[server_index].get();
}

Z
zhaocaibei123 已提交
69
uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) {
S
seemingwang 已提交
70 71 72 73 74 75 76
  std::unique_lock<std::mutex> lock(mutex_);

  std::string ip_port = ip + ":" + std::to_string(port);
  VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
  brpc::ServerOptions options;

  int num_threads = std::thread::hardware_concurrency();
Z
zhaocaibei123 已提交
77
  auto trainers = _environment->GetTrainers();
S
seemingwang 已提交
78 79 80 81 82 83
  options.num_threads = trainers > num_threads ? trainers : num_threads;

  if (_server.Start(ip_port.c_str(), &options) != 0) {
    LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
    return 0;
  }
Z
zhaocaibei123 已提交
84
  _environment->RegistePsServer(ip, port, _rank);
S
seemingwang 已提交
85 86 87
  return 0;
}

S
seemingwang 已提交
88 89
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
  this->rank = rank;
Z
zhaocaibei123 已提交
90
  auto _env = Environment();
S
seemingwang 已提交
91 92 93 94 95 96 97
  brpc::ChannelOptions options;
  options.protocol = "baidu_std";
  options.timeout_ms = 500000;
  options.connection_type = "pooled";
  options.connect_timeout_ms = 10000;
  options.max_retry = 3;

Z
zhaocaibei123 已提交
98
  std::vector<PSHost> server_list = _env->GetPsServers();
S
seemingwang 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
  _pserver_channels.resize(server_list.size());
  std::ostringstream os;
  std::string server_ip_port;
  for (size_t i = 0; i < server_list.size(); ++i) {
    server_ip_port.assign(server_list[i].ip.c_str());
    server_ip_port.append(":");
    server_ip_port.append(std::to_string(server_list[i].port));
    _pserver_channels[i].reset(new brpc::Channel());
    if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
      VLOG(0) << "GraphServer connect to Server:" << server_ip_port
              << " Failed! Try again.";
      std::string int_ip_port =
          GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
      if (_pserver_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
        LOG(ERROR) << "GraphServer connect to Server:" << int_ip_port
                   << " Failed!";
        return -1;
      }
    }
    os << server_ip_port << ",";
  }
  LOG(INFO) << "servers peer2peer connection success:" << os.str();
  return 0;
}

124 125 126 127
int32_t GraphBrpcService::clear_nodes(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
128 129 130
  int type_id = *(int *)(request.params(0).c_str());
  int idx_ = *(int *)(request.params(1).c_str());
  ((GraphTable *)table)->clear_nodes(type_id, idx_);
131 132 133 134 135 136 137 138
  return 0;
}

int32_t GraphBrpcService::add_graph_node(Table *table,
                                         const PsRequestMessage &request,
                                         PsResponseMessage &response,
                                         brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
139
  if (request.params_size() < 2) {
140 141
    set_response_code(
        response, -1, "add_graph_node request requires at least 2 arguments");
142 143 144
    return 0;
  }

145 146 147 148 149
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
150
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
151
  std::vector<bool> is_weighted_list;
152 153 154
  if (request.params_size() == 3) {
    size_t weight_list_size = request.params(2).size() / sizeof(bool);
    bool *is_weighted_buffer = (bool *)(request.params(2).c_str());
155 156 157
    is_weighted_list = std::vector<bool>(is_weighted_buffer,
                                         is_weighted_buffer + weight_list_size);
  }
158 159 160 161 162 163 164
  // if (request.params_size() == 2) {
  //   size_t weight_list_size = request.params(1).size() / sizeof(bool);
  //   bool *is_weighted_buffer = (bool *)(request.params(1).c_str());
  //   is_weighted_list = std::vector<bool>(is_weighted_buffer,
  //                                        is_weighted_buffer +
  //                                        weight_list_size);
  // }
165

166
  ((GraphTable *)table)->add_graph_node(idx_, node_ids, is_weighted_list);
167 168 169 170 171 172 173
  return 0;
}
int32_t GraphBrpcService::remove_graph_node(Table *table,
                                            const PsRequestMessage &request,
                                            PsResponseMessage &response,
                                            brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
174
  if (request.params_size() < 2) {
175
    set_response_code(
176 177
        response,
        -1,
178
        "remove_graph_node request requires at least 2 arguments");
179 180
    return 0;
  }
181 182 183 184 185
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
186
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
187

188
  ((GraphTable *)table)->remove_graph_node(idx_, node_ids);
189 190
  return 0;
}
Z
zhaocaibei123 已提交
191
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
S
seemingwang 已提交
192

Z
zhaocaibei123 已提交
193
int32_t GraphBrpcService::Initialize() {
S
seemingwang 已提交
194
  _is_initialize_shard_info = false;
Z
zhaocaibei123 已提交
195 196 197
  _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer;
  _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable;
  _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable;
S
seemingwang 已提交
198

Z
zhaocaibei123 已提交
199 200 201 202
  _service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat;
  _service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier;
  _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler;
  _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler;
S
seemingwang 已提交
203 204

  _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
205 206
  _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
      &GraphBrpcService::graph_random_sample_neighbors;
S
seemingwang 已提交
207 208 209 210
  _service_handler_map[PS_GRAPH_SAMPLE_NODES] =
      &GraphBrpcService::graph_random_sample_nodes;
  _service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
      &GraphBrpcService::graph_get_node_feat;
211 212 213 214 215
  _service_handler_map[PS_GRAPH_CLEAR] = &GraphBrpcService::clear_nodes;
  _service_handler_map[PS_GRAPH_ADD_GRAPH_NODE] =
      &GraphBrpcService::add_graph_node;
  _service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
      &GraphBrpcService::remove_graph_node;
S
seemingwang 已提交
216 217
  _service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
      &GraphBrpcService::graph_set_node_feat;
S
seemingwang 已提交
218
  _service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
219
      &GraphBrpcService::sample_neighbors_across_multi_servers;
220 221 222 223
  // _service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] =
  //     &GraphBrpcService::use_neighbors_sample_cache;
  // _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
  //     &GraphBrpcService::load_graph_split_config;
S
seemingwang 已提交
224
  // shard初始化,server启动后才可从env获取到server_list的shard信息
Z
zhaocaibei123 已提交
225
  InitializeShardInfo();
S
seemingwang 已提交
226 227 228 229

  return 0;
}

Z
zhaocaibei123 已提交
230
int32_t GraphBrpcService::InitializeShardInfo() {
S
seemingwang 已提交
231 232 233 234 235
  if (!_is_initialize_shard_info) {
    std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
    if (_is_initialize_shard_info) {
      return 0;
    }
Z
zhaocaibei123 已提交
236 237
    server_size = _server->Environment()->GetPsServers().size();
    auto &table_map = *(_server->GetTable());
S
seemingwang 已提交
238
    for (auto itr : table_map) {
Z
zhaocaibei123 已提交
239
      itr.second->SetShard(_rank, server_size);
S
seemingwang 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
    }
    _is_initialize_shard_info = true;
  }
  return 0;
}

void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
                               const PsRequestMessage *request,
                               PsResponseMessage *response,
                               google::protobuf::Closure *done) {
  brpc::ClosureGuard done_guard(done);
  std::string log_label("ReceiveCmd-");
  if (!request->has_table_id()) {
    set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
    return;
  }

  response->set_err_code(0);
  response->set_err_msg("");
Z
zhaocaibei123 已提交
259
  auto *table = _server->GetTable(request->table_id());
S
seemingwang 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272
  brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
  auto itr = _service_handler_map.find(request->cmd_id());
  if (itr == _service_handler_map.end()) {
    std::string err_msg(
        "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
    err_msg.append(std::to_string(request->cmd_id()));
    set_response_code(*response, -1, err_msg.c_str());
    return;
  }
  serviceFunc handler_func = itr->second;
  int service_ret = (this->*handler_func)(table, *request, *response, cntl);
  if (service_ret != 0) {
    response->set_err_code(service_ret);
S
seemingwang 已提交
273 274 275
    if (!response->has_err_msg()) {
      response->set_err_msg("server internal error");
    }
S
seemingwang 已提交
276 277 278
  }
}

279 280
int32_t GraphBrpcService::Barrier(Table *table,
                                  const PsRequestMessage &request,
S
seemingwang 已提交
281 282 283 284 285
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)

  if (request.params_size() < 1) {
286 287
    set_response_code(response,
                      -1,
S
seemingwang 已提交
288 289 290 291 292 293 294
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }

  auto trainer_id = request.client_id();
  auto barrier_type = request.params(0);
Z
zhaocaibei123 已提交
295
  table->Barrier(trainer_id, barrier_type);
S
seemingwang 已提交
296 297 298
  return 0;
}

Z
zhaocaibei123 已提交
299 300 301 302
int32_t GraphBrpcService::PrintTableStat(Table *table,
                                         const PsRequestMessage &request,
                                         PsResponseMessage &response,
                                         brpc::Controller *cntl) {
S
seemingwang 已提交
303
  CHECK_TABLE_EXIST(table, request, response)
Z
zhaocaibei123 已提交
304
  std::pair<int64_t, int64_t> ret = table->PrintTableStat();
S
seemingwang 已提交
305 306 307 308 309 310 311 312
  paddle::framework::BinaryArchive ar;
  ar << ret.first << ret.second;
  std::string table_info(ar.Buffer(), ar.Length());
  response.set_data(table_info);

  return 0;
}

Z
zhaocaibei123 已提交
313 314 315 316
int32_t GraphBrpcService::LoadOneTable(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
S
seemingwang 已提交
317 318 319
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 2) {
    set_response_code(
320 321
        response,
        -1,
S
seemingwang 已提交
322 323 324
        "PsRequestMessage.datas is requeired at least 2 for path & load_param");
    return -1;
  }
Z
zhaocaibei123 已提交
325
  if (table->Load(request.params(0), request.params(1)) != 0) {
S
seemingwang 已提交
326 327 328 329 330 331
    set_response_code(response, -1, "table load failed");
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
332 333 334 335 336
int32_t GraphBrpcService::LoadAllTable(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
  auto &table_map = *(_server->GetTable());
S
seemingwang 已提交
337
  for (auto &itr : table_map) {
Z
zhaocaibei123 已提交
338
    if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
S
seemingwang 已提交
339 340 341 342 343 344 345
      LOG(ERROR) << "load table[" << itr.first << "] failed";
      return -1;
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
346 347 348 349
int32_t GraphBrpcService::StopServer(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
S
seemingwang 已提交
350 351
  GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
  std::thread t_stop([p_server]() {
Z
zhaocaibei123 已提交
352
    p_server->Stop();
S
seemingwang 已提交
353 354 355 356 357 358 359
    LOG(INFO) << "Server Stoped";
  });
  p_server->export_cv()->notify_all();
  t_stop.detach();
  return 0;
}

Z
zhaocaibei123 已提交
360 361 362 363
int32_t GraphBrpcService::StopProfiler(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
S
seemingwang 已提交
364 365 366 367 368
  platform::DisableProfiler(platform::EventSortingKey::kDefault,
                            string::Sprintf("server_%s_profile", _rank));
  return 0;
}

Z
zhaocaibei123 已提交
369 370 371 372
int32_t GraphBrpcService::StartProfiler(Table *table,
                                        const PsRequestMessage &request,
                                        PsResponseMessage &response,
                                        brpc::Controller *cntl) {
S
seemingwang 已提交
373 374 375 376 377 378 379 380 381
  platform::EnableProfiler(platform::ProfilerState::kCPU);
  return 0;
}

int32_t GraphBrpcService::pull_graph_list(Table *table,
                                          const PsRequestMessage &request,
                                          PsResponseMessage &response,
                                          brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
382
  if (request.params_size() < 5) {
383 384
    set_response_code(
        response, -1, "pull_graph_list request requires at least 5 arguments");
S
seemingwang 已提交
385 386
    return 0;
  }
387 388 389 390 391 392 393 394
  int type_id = *(int *)(request.params(0).c_str());
  int idx = *(int *)(request.params(1).c_str());
  int start = *(int *)(request.params(2).c_str());
  int size = *(int *)(request.params(3).c_str());
  int step = *(int *)(request.params(4).c_str());
  // int start = *(int *)(request.params(0).c_str());
  // int size = *(int *)(request.params(1).c_str());
  // int step = *(int *)(request.params(2).c_str());
S
seemingwang 已提交
395 396
  std::unique_ptr<char[]> buffer;
  int actual_size;
S
seemingwang 已提交
397
  ((GraphTable *)table)
398 399
      ->pull_graph_list(
          type_id, idx, start, size, buffer, actual_size, false, step);
S
seemingwang 已提交
400 401 402
  cntl->response_attachment().append(buffer.get(), actual_size);
  return 0;
}
403
int32_t GraphBrpcService::graph_random_sample_neighbors(
404 405 406
    Table *table,
    const PsRequestMessage &request,
    PsResponseMessage &response,
S
seemingwang 已提交
407 408
    brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
409
  if (request.params_size() < 4) {
S
seemingwang 已提交
410
    set_response_code(
411 412
        response,
        -1,
413
        "graph_random_sample_neighbors request requires at least 3 arguments");
S
seemingwang 已提交
414 415
    return 0;
  }
416 417 418 419 420 421 422 423 424
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  int sample_size = *(int64_t *)(request.params(2).c_str());
  bool need_weight = *(bool *)(request.params(3).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  // int sample_size = *(int64_t *)(request.params(1).c_str());
  // bool need_weight = *(bool *)(request.params(2).c_str());
425
  std::vector<std::shared_ptr<char>> buffers(node_num);
S
seemingwang 已提交
426
  std::vector<int> actual_sizes(node_num, 0);
S
seemingwang 已提交
427
  ((GraphTable *)table)
428 429
      ->random_sample_neighbors(
          idx_, node_data, sample_size, buffers, actual_sizes, need_weight);
S
seemingwang 已提交
430 431 432 433 434 435 436 437 438 439

  cntl->response_attachment().append(&node_num, sizeof(size_t));
  cntl->response_attachment().append(actual_sizes.data(),
                                     sizeof(int) * node_num);
  for (size_t idx = 0; idx < node_num; ++idx) {
    cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
  }
  return 0;
}
int32_t GraphBrpcService::graph_random_sample_nodes(
440 441 442
    Table *table,
    const PsRequestMessage &request,
    PsResponseMessage &response,
S
seemingwang 已提交
443
    brpc::Controller *cntl) {
444 445 446 447
  int type_id = *(int *)(request.params(0).c_str());
  int idx_ = *(int *)(request.params(1).c_str());
  size_t size = *(int64_t *)(request.params(2).c_str());
  // size_t size = *(int64_t *)(request.params(0).c_str());
S
seemingwang 已提交
448 449
  std::unique_ptr<char[]> buffer;
  int actual_size;
450 451
  if (((GraphTable *)table)
          ->random_sample_nodes(type_id, idx_, size, buffer, actual_size) ==
S
seemingwang 已提交
452
      0) {
S
seemingwang 已提交
453 454 455 456 457 458 459 460 461 462 463 464
    cntl->response_attachment().append(buffer.get(), actual_size);
  } else
    cntl->response_attachment().append(NULL, 0);

  return 0;
}

int32_t GraphBrpcService::graph_get_node_feat(Table *table,
                                              const PsRequestMessage &request,
                                              PsResponseMessage &response,
                                              brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
465
  if (request.params_size() < 3) {
S
seemingwang 已提交
466
    set_response_code(
467 468
        response,
        -1,
469
        "graph_get_node_feat request requires at least 3 arguments");
S
seemingwang 已提交
470 471
    return 0;
  }
472 473 474 475 476
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
477
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
S
seemingwang 已提交
478 479

  std::vector<std::string> feature_names =
480
      paddle::string::split_string<std::string>(request.params(2), "\t");
S
seemingwang 已提交
481 482 483 484

  std::vector<std::vector<std::string>> feature(
      feature_names.size(), std::vector<std::string>(node_num));

485
  ((GraphTable *)table)->get_node_feat(idx_, node_ids, feature_names, feature);
S
seemingwang 已提交
486 487 488 489 490 491 492 493 494 495 496 497

  for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
    for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
      size_t feat_len = feature[feat_idx][node_idx].size();
      cntl->response_attachment().append(&feat_len, sizeof(size_t));
      cntl->response_attachment().append(feature[feat_idx][node_idx].data(),
                                         feat_len);
    }
  }

  return 0;
}
498
int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
499 500 501
    Table *table,
    const PsRequestMessage &request,
    PsResponseMessage &response,
S
seemingwang 已提交
502 503 504
    brpc::Controller *cntl) {
  // sleep(5);
  CHECK_TABLE_EXIST(table, request, response)
505
  if (request.params_size() < 4) {
506 507
    set_response_code(response,
                      -1,
508
                      "sample_neighbors_across_multi_servers request requires "
509
                      "at least 4 arguments");
S
seemingwang 已提交
510 511
    return 0;
  }
512 513

  int idx_ = *(int *)(request.params(0).c_str());
514
  size_t node_num = request.params(1).size() / sizeof(int64_t);
515 516 517 518 519 520 521 522 523
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  int sample_size = *(int64_t *)(request.params(2).c_str());
  bool need_weight = *(int64_t *)(request.params(3).c_str());

  // size_t node_num = request.params(0).size() / sizeof(int64_t),
  //        size_of_size_t = sizeof(size_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  // int sample_size = *(int64_t *)(request.params(1).c_str());
  // bool need_weight = *(int64_t *)(request.params(2).c_str());
524
  // std::vector<int64_t> res = ((GraphTable
S
seemingwang 已提交
525 526 527
  // *)table).filter_out_non_exist_nodes(node_data, sample_size);
  std::vector<int> request2server;
  std::vector<int> server2request(server_size, -1);
528
  std::vector<int64_t> local_id;
S
seemingwang 已提交
529
  std::vector<int> local_query_idx;
Z
zhaocaibei123 已提交
530
  size_t rank = GetRank();
531
  for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
S
seemingwang 已提交
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
    int server_index =
        ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
    if (server2request[server_index] == -1) {
      server2request[server_index] = request2server.size();
      request2server.push_back(server_index);
    }
  }
  if (server2request[rank] != -1) {
    auto pos = server2request[rank];
    std::swap(request2server[pos],
              request2server[(int)request2server.size() - 1]);
    server2request[request2server[pos]] = pos;
    server2request[request2server[(int)request2server.size() - 1]] =
        request2server.size() - 1;
  }
  size_t request_call_num = request2server.size();
548
  std::vector<std::shared_ptr<char>> local_buffers;
S
seemingwang 已提交
549 550
  std::vector<int> local_actual_sizes;
  std::vector<size_t> seq;
551
  std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
S
seemingwang 已提交
552
  std::vector<std::vector<int>> query_idx_buckets(request_call_num);
553
  for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
S
seemingwang 已提交
554 555 556 557 558 559 560 561
    int server_index =
        ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
    int request_idx = server2request[server_index];
    node_id_buckets[request_idx].push_back(node_data[query_idx]);
    query_idx_buckets[request_idx].push_back(query_idx);
    seq.push_back(request_idx);
  }
  size_t remote_call_num = request_call_num;
Z
zhangchunle 已提交
562 563
  if (request2server.size() != 0 &&
      static_cast<size_t>(request2server.back()) == rank) {
S
seemingwang 已提交
564 565 566 567 568 569 570 571
    remote_call_num--;
    local_buffers.resize(node_id_buckets.back().size());
    local_actual_sizes.resize(node_id_buckets.back().size());
  }
  cntl->response_attachment().append(&node_num, sizeof(size_t));
  auto local_promise = std::make_shared<std::promise<int32_t>>();
  std::future<int> local_fut = local_promise->get_future();
  std::vector<bool> failed(server_size, false);
572 573 574
  std::function<void(void *)> func = [&,
                                      node_id_buckets,
                                      query_idx_buckets,
S
seemingwang 已提交
575 576 577 578 579 580 581 582
                                      request_call_num](void *done) {
    local_fut.get();
    std::vector<int> actual_size;
    auto *closure = (DownpourBrpcClosure *)done;
    std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
        remote_call_num);
    size_t fail_num = 0;
    for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
583
      if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
S
seemingwang 已提交
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
          0) {
        ++fail_num;
        failed[request2server[request_idx]] = true;
      } else {
        auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
        res[request_idx].reset(new butil::IOBufBytesIterator(res_io_buffer));
        size_t num;
        res[request_idx]->copy_and_forward(&num, sizeof(size_t));
      }
    }
    int size;
    int local_index = 0;
    for (size_t i = 0; i < node_num; i++) {
      if (fail_num > 0 && failed[seq[i]]) {
        size = 0;
Z
zhangchunle 已提交
599
      } else if (static_cast<size_t>(request2server[seq[i]]) != rank) {
S
seemingwang 已提交
600 601 602 603 604 605 606 607 608 609 610 611 612
        res[seq[i]]->copy_and_forward(&size, sizeof(int));
      } else {
        size = local_actual_sizes[local_index++];
      }
      actual_size.push_back(size);
    }
    cntl->response_attachment().append(actual_size.data(),
                                       actual_size.size() * sizeof(int));

    local_index = 0;
    for (size_t i = 0; i < node_num; i++) {
      if (fail_num > 0 && failed[seq[i]]) {
        continue;
Z
zhangchunle 已提交
613
      } else if (static_cast<size_t>(request2server[seq[i]]) != rank) {
S
seemingwang 已提交
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630
        char temp[actual_size[i] + 1];
        res[seq[i]]->copy_and_forward(temp, actual_size[i]);
        cntl->response_attachment().append(temp, actual_size[i]);
      } else {
        char *temp = local_buffers[local_index++].get();
        cntl->response_attachment().append(temp, actual_size[i]);
      }
    }
    closure->set_promise_value(0);
  };

  DownpourBrpcClosure *closure = new DownpourBrpcClosure(remote_call_num, func);

  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

631
  for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
S
seemingwang 已提交
632
    int server_index = request2server[request_idx];
633
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
S
seemingwang 已提交
634 635 636 637
    closure->request(request_idx)->set_table_id(request.table_id());
    closure->request(request_idx)->set_client_id(rank);
    size_t node_num = node_id_buckets[request_idx].size();

638 639
    closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));

S
seemingwang 已提交
640 641
    closure->request(request_idx)
        ->add_params((char *)node_id_buckets[request_idx].data(),
642
                     sizeof(int64_t) * node_num);
S
seemingwang 已提交
643 644
    closure->request(request_idx)
        ->add_params((char *)&sample_size, sizeof(int));
645 646
    closure->request(request_idx)
        ->add_params((char *)&need_weight, sizeof(bool));
S
seemingwang 已提交
647
    PsService_Stub rpc_stub(
Z
zhaocaibei123 已提交
648
        ((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
S
seemingwang 已提交
649
    // GraphPsService_Stub rpc_stub =
Z
zhaocaibei123 已提交
650
    //     getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
651
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
652 653 654 655
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
S
seemingwang 已提交
656 657 658
  }
  if (server2request[rank] != -1) {
    ((GraphTable *)table)
659 660 661 662 663 664
        ->random_sample_neighbors(idx_,
                                  node_id_buckets.back().data(),
                                  sample_size,
                                  local_buffers,
                                  local_actual_sizes,
                                  need_weight);
S
seemingwang 已提交
665 666 667 668 669 670
  }
  local_promise.get()->set_value(0);
  if (remote_call_num == 0) func(closure);
  fut.get();
  return 0;
}
S
seemingwang 已提交
671 672 673 674 675
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
                                              const PsRequestMessage &request,
                                              PsResponseMessage &response,
                                              brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
676
  if (request.params_size() < 4) {
S
seemingwang 已提交
677
    set_response_code(
678 679
        response,
        -1,
S
seemingwang 已提交
680
        "graph_set_node_feat request requires at least 3 arguments");
S
seemingwang 已提交
681 682
    return 0;
  }
683 684 685 686 687 688
  int idx_ = *(int *)(request.params(0).c_str());

  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
689
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
S
seemingwang 已提交
690

691 692 693
  // std::vector<std::string> feature_names =
  //     paddle::string::split_string<std::string>(request.params(1), "\t");

S
seemingwang 已提交
694
  std::vector<std::string> feature_names =
695
      paddle::string::split_string<std::string>(request.params(2), "\t");
S
seemingwang 已提交
696 697 698 699

  std::vector<std::vector<std::string>> features(
      feature_names.size(), std::vector<std::string>(node_num));

700 701
  //  const char *buffer = request.params(2).c_str();
  const char *buffer = request.params(3).c_str();
S
seemingwang 已提交
702 703 704 705 706 707 708 709 710 711 712

  for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
    for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
      size_t feat_len = *(size_t *)(buffer);
      buffer += sizeof(size_t);
      auto feat = std::string(buffer, feat_len);
      features[feat_idx][node_idx] = feat;
      buffer += feat_len;
    }
  }

713
  ((GraphTable *)table)->set_node_feat(idx_, node_ids, feature_names, features);
S
seemingwang 已提交
714 715 716 717

  return 0;
}

S
seemingwang 已提交
718 719
}  // namespace distributed
}  // namespace paddle