graph_brpc_server.cc 27.9 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
S
seemingwang 已提交
17 18

#include <thread>  // NOLINT
S
seemingwang 已提交
19
#include <utility>
S
seemingwang 已提交
20 21
#include "butil/endpoint.h"
#include "iomanip"
22
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
S
seemingwang 已提交
23 24 25 26 27
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {

28 29 30 31 32 33 34 35
#define CHECK_TABLE_EXIST(table, request, response)        \
  if (table == NULL) {                                     \
    std::string err_msg("table not found with table_id:"); \
    err_msg.append(std::to_string(request.table_id()));    \
    set_response_code(response, -1, err_msg.c_str());      \
    return -1;                                             \
  }

Z
zhaocaibei123 已提交
36
int32_t GraphBrpcServer::Initialize() {
S
seemingwang 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50
  auto &service_config = _config.downpour_server_param().service_param();
  if (!service_config.has_service_class()) {
    LOG(ERROR) << "miss service_class in ServerServiceParameter";
    return -1;
  }
  auto *service =
      CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
  if (service == NULL) {
    LOG(ERROR) << "service is unregistered, service_name:"
               << service_config.service_class();
    return -1;
  }

  _service.reset(service);
Z
zhaocaibei123 已提交
51
  if (service->Configure(this) != 0 || service->Initialize() != 0) {
S
seemingwang 已提交
52 53 54 55 56 57 58 59 60 61 62 63
    LOG(ERROR) << "service initialize failed, service_name:"
               << service_config.service_class();
    return -1;
  }
  if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
    LOG(ERROR) << "service add to brpc failed, service:"
               << service_config.service_class();
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
64
brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) {
S
seemingwang 已提交
65 66 67
  return _pserver_channels[server_index].get();
}

Z
zhaocaibei123 已提交
68
uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) {
S
seemingwang 已提交
69 70 71 72 73 74 75
  std::unique_lock<std::mutex> lock(mutex_);

  std::string ip_port = ip + ":" + std::to_string(port);
  VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
  brpc::ServerOptions options;

  int num_threads = std::thread::hardware_concurrency();
Z
zhaocaibei123 已提交
76
  auto trainers = _environment->GetTrainers();
S
seemingwang 已提交
77 78 79 80 81 82
  options.num_threads = trainers > num_threads ? trainers : num_threads;

  if (_server.Start(ip_port.c_str(), &options) != 0) {
    LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
    return 0;
  }
Z
zhaocaibei123 已提交
83
  _environment->RegistePsServer(ip, port, _rank);
S
seemingwang 已提交
84 85 86
  return 0;
}

S
seemingwang 已提交
87 88
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
  this->rank = rank;
Z
zhaocaibei123 已提交
89
  auto _env = Environment();
S
seemingwang 已提交
90 91 92 93 94 95 96
  brpc::ChannelOptions options;
  options.protocol = "baidu_std";
  options.timeout_ms = 500000;
  options.connection_type = "pooled";
  options.connect_timeout_ms = 10000;
  options.max_retry = 3;

Z
zhaocaibei123 已提交
97
  std::vector<PSHost> server_list = _env->GetPsServers();
S
seemingwang 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
  _pserver_channels.resize(server_list.size());
  std::ostringstream os;
  std::string server_ip_port;
  for (size_t i = 0; i < server_list.size(); ++i) {
    server_ip_port.assign(server_list[i].ip.c_str());
    server_ip_port.append(":");
    server_ip_port.append(std::to_string(server_list[i].port));
    _pserver_channels[i].reset(new brpc::Channel());
    if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
      VLOG(0) << "GraphServer connect to Server:" << server_ip_port
              << " Failed! Try again.";
      std::string int_ip_port =
          GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
      if (_pserver_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
        LOG(ERROR) << "GraphServer connect to Server:" << int_ip_port
                   << " Failed!";
        return -1;
      }
    }
    os << server_ip_port << ",";
  }
  LOG(INFO) << "servers peer2peer connection success:" << os.str();
  return 0;
}

123 124 125 126
int32_t GraphBrpcService::clear_nodes(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
127 128 129
  int type_id = *(int *)(request.params(0).c_str());
  int idx_ = *(int *)(request.params(1).c_str());
  ((GraphTable *)table)->clear_nodes(type_id, idx_);
130 131 132 133 134 135 136 137
  return 0;
}

int32_t GraphBrpcService::add_graph_node(Table *table,
                                         const PsRequestMessage &request,
                                         PsResponseMessage &response,
                                         brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
138 139 140
  if (request.params_size() < 2) {
    set_response_code(response, -1,
                      "add_graph_node request requires at least 2 arguments");
141 142 143
    return 0;
  }

144 145 146 147 148
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
149
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
150
  std::vector<bool> is_weighted_list;
151 152 153
  if (request.params_size() == 3) {
    size_t weight_list_size = request.params(2).size() / sizeof(bool);
    bool *is_weighted_buffer = (bool *)(request.params(2).c_str());
154 155 156
    is_weighted_list = std::vector<bool>(is_weighted_buffer,
                                         is_weighted_buffer + weight_list_size);
  }
157 158 159 160 161 162 163
  // if (request.params_size() == 2) {
  //   size_t weight_list_size = request.params(1).size() / sizeof(bool);
  //   bool *is_weighted_buffer = (bool *)(request.params(1).c_str());
  //   is_weighted_list = std::vector<bool>(is_weighted_buffer,
  //                                        is_weighted_buffer +
  //                                        weight_list_size);
  // }
164

165
  ((GraphTable *)table)->add_graph_node(idx_, node_ids, is_weighted_list);
166 167 168 169 170 171 172
  return 0;
}
int32_t GraphBrpcService::remove_graph_node(Table *table,
                                            const PsRequestMessage &request,
                                            PsResponseMessage &response,
                                            brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
173
  if (request.params_size() < 2) {
174 175
    set_response_code(
        response, -1,
176
        "remove_graph_node request requires at least 2 arguments");
177 178
    return 0;
  }
179 180 181 182 183
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
184
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
185

186
  ((GraphTable *)table)->remove_graph_node(idx_, node_ids);
187 188
  return 0;
}
Z
zhaocaibei123 已提交
189
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
S
seemingwang 已提交
190

Z
zhaocaibei123 已提交
191
int32_t GraphBrpcService::Initialize() {
S
seemingwang 已提交
192
  _is_initialize_shard_info = false;
Z
zhaocaibei123 已提交
193 194 195
  _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer;
  _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable;
  _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable;
S
seemingwang 已提交
196

Z
zhaocaibei123 已提交
197 198 199 200
  _service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat;
  _service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier;
  _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler;
  _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler;
S
seemingwang 已提交
201 202

  _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
203 204
  _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
      &GraphBrpcService::graph_random_sample_neighbors;
S
seemingwang 已提交
205 206 207 208
  _service_handler_map[PS_GRAPH_SAMPLE_NODES] =
      &GraphBrpcService::graph_random_sample_nodes;
  _service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
      &GraphBrpcService::graph_get_node_feat;
209 210 211 212 213
  _service_handler_map[PS_GRAPH_CLEAR] = &GraphBrpcService::clear_nodes;
  _service_handler_map[PS_GRAPH_ADD_GRAPH_NODE] =
      &GraphBrpcService::add_graph_node;
  _service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
      &GraphBrpcService::remove_graph_node;
S
seemingwang 已提交
214 215
  _service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
      &GraphBrpcService::graph_set_node_feat;
S
seemingwang 已提交
216
  _service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
217
      &GraphBrpcService::sample_neighbors_across_multi_servers;
218 219 220 221
  // _service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] =
  //     &GraphBrpcService::use_neighbors_sample_cache;
  // _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
  //     &GraphBrpcService::load_graph_split_config;
S
seemingwang 已提交
222
  // shard初始化,server启动后才可从env获取到server_list的shard信息
Z
zhaocaibei123 已提交
223
  InitializeShardInfo();
S
seemingwang 已提交
224 225 226 227

  return 0;
}

Z
zhaocaibei123 已提交
228
int32_t GraphBrpcService::InitializeShardInfo() {
S
seemingwang 已提交
229 230 231 232 233
  if (!_is_initialize_shard_info) {
    std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
    if (_is_initialize_shard_info) {
      return 0;
    }
Z
zhaocaibei123 已提交
234 235
    server_size = _server->Environment()->GetPsServers().size();
    auto &table_map = *(_server->GetTable());
S
seemingwang 已提交
236
    for (auto itr : table_map) {
Z
zhaocaibei123 已提交
237
      itr.second->SetShard(_rank, server_size);
S
seemingwang 已提交
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
    }
    _is_initialize_shard_info = true;
  }
  return 0;
}

void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
                               const PsRequestMessage *request,
                               PsResponseMessage *response,
                               google::protobuf::Closure *done) {
  brpc::ClosureGuard done_guard(done);
  std::string log_label("ReceiveCmd-");
  if (!request->has_table_id()) {
    set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
    return;
  }

  response->set_err_code(0);
  response->set_err_msg("");
Z
zhaocaibei123 已提交
257
  auto *table = _server->GetTable(request->table_id());
S
seemingwang 已提交
258 259 260 261 262 263 264 265 266 267 268 269 270
  brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
  auto itr = _service_handler_map.find(request->cmd_id());
  if (itr == _service_handler_map.end()) {
    std::string err_msg(
        "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
    err_msg.append(std::to_string(request->cmd_id()));
    set_response_code(*response, -1, err_msg.c_str());
    return;
  }
  serviceFunc handler_func = itr->second;
  int service_ret = (this->*handler_func)(table, *request, *response, cntl);
  if (service_ret != 0) {
    response->set_err_code(service_ret);
S
seemingwang 已提交
271 272 273
    if (!response->has_err_msg()) {
      response->set_err_msg("server internal error");
    }
S
seemingwang 已提交
274 275 276
  }
}

Z
zhaocaibei123 已提交
277
int32_t GraphBrpcService::Barrier(Table *table, const PsRequestMessage &request,
S
seemingwang 已提交
278 279 280 281 282 283 284 285 286 287 288 289 290
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)

  if (request.params_size() < 1) {
    set_response_code(response, -1,
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }

  auto trainer_id = request.client_id();
  auto barrier_type = request.params(0);
Z
zhaocaibei123 已提交
291
  table->Barrier(trainer_id, barrier_type);
S
seemingwang 已提交
292 293 294
  return 0;
}

Z
zhaocaibei123 已提交
295 296 297 298
int32_t GraphBrpcService::PrintTableStat(Table *table,
                                         const PsRequestMessage &request,
                                         PsResponseMessage &response,
                                         brpc::Controller *cntl) {
S
seemingwang 已提交
299
  CHECK_TABLE_EXIST(table, request, response)
Z
zhaocaibei123 已提交
300
  std::pair<int64_t, int64_t> ret = table->PrintTableStat();
S
seemingwang 已提交
301 302 303 304 305 306 307 308
  paddle::framework::BinaryArchive ar;
  ar << ret.first << ret.second;
  std::string table_info(ar.Buffer(), ar.Length());
  response.set_data(table_info);

  return 0;
}

Z
zhaocaibei123 已提交
309 310 311 312
int32_t GraphBrpcService::LoadOneTable(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
S
seemingwang 已提交
313 314 315 316 317 318 319
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 2) {
    set_response_code(
        response, -1,
        "PsRequestMessage.datas is requeired at least 2 for path & load_param");
    return -1;
  }
Z
zhaocaibei123 已提交
320
  if (table->Load(request.params(0), request.params(1)) != 0) {
S
seemingwang 已提交
321 322 323 324 325 326
    set_response_code(response, -1, "table load failed");
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
327 328 329 330 331
int32_t GraphBrpcService::LoadAllTable(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
  auto &table_map = *(_server->GetTable());
S
seemingwang 已提交
332
  for (auto &itr : table_map) {
Z
zhaocaibei123 已提交
333
    if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
S
seemingwang 已提交
334 335 336 337 338 339 340
      LOG(ERROR) << "load table[" << itr.first << "] failed";
      return -1;
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
341 342 343 344
int32_t GraphBrpcService::StopServer(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
S
seemingwang 已提交
345 346
  GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
  std::thread t_stop([p_server]() {
Z
zhaocaibei123 已提交
347
    p_server->Stop();
S
seemingwang 已提交
348 349 350 351 352 353 354
    LOG(INFO) << "Server Stoped";
  });
  p_server->export_cv()->notify_all();
  t_stop.detach();
  return 0;
}

Z
zhaocaibei123 已提交
355 356 357 358
int32_t GraphBrpcService::StopProfiler(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
S
seemingwang 已提交
359 360 361 362 363
  platform::DisableProfiler(platform::EventSortingKey::kDefault,
                            string::Sprintf("server_%s_profile", _rank));
  return 0;
}

Z
zhaocaibei123 已提交
364 365 366 367
int32_t GraphBrpcService::StartProfiler(Table *table,
                                        const PsRequestMessage &request,
                                        PsResponseMessage &response,
                                        brpc::Controller *cntl) {
S
seemingwang 已提交
368 369 370 371 372 373 374 375 376
  platform::EnableProfiler(platform::ProfilerState::kCPU);
  return 0;
}

int32_t GraphBrpcService::pull_graph_list(Table *table,
                                          const PsRequestMessage &request,
                                          PsResponseMessage &response,
                                          brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
377
  if (request.params_size() < 5) {
S
seemingwang 已提交
378
    set_response_code(response, -1,
379
                      "pull_graph_list request requires at least 5 arguments");
S
seemingwang 已提交
380 381
    return 0;
  }
382 383 384 385 386 387 388 389
  int type_id = *(int *)(request.params(0).c_str());
  int idx = *(int *)(request.params(1).c_str());
  int start = *(int *)(request.params(2).c_str());
  int size = *(int *)(request.params(3).c_str());
  int step = *(int *)(request.params(4).c_str());
  // int start = *(int *)(request.params(0).c_str());
  // int size = *(int *)(request.params(1).c_str());
  // int step = *(int *)(request.params(2).c_str());
S
seemingwang 已提交
390 391
  std::unique_ptr<char[]> buffer;
  int actual_size;
S
seemingwang 已提交
392
  ((GraphTable *)table)
393 394
      ->pull_graph_list(type_id, idx, start, size, buffer, actual_size, false,
                        step);
S
seemingwang 已提交
395 396 397
  cntl->response_attachment().append(buffer.get(), actual_size);
  return 0;
}
398
int32_t GraphBrpcService::graph_random_sample_neighbors(
S
seemingwang 已提交
399 400 401
    Table *table, const PsRequestMessage &request, PsResponseMessage &response,
    brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
402
  if (request.params_size() < 4) {
S
seemingwang 已提交
403 404
    set_response_code(
        response, -1,
405
        "graph_random_sample_neighbors request requires at least 3 arguments");
S
seemingwang 已提交
406 407
    return 0;
  }
408 409 410 411 412 413 414 415 416
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  int sample_size = *(int64_t *)(request.params(2).c_str());
  bool need_weight = *(bool *)(request.params(3).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  // int sample_size = *(int64_t *)(request.params(1).c_str());
  // bool need_weight = *(bool *)(request.params(2).c_str());
417
  std::vector<std::shared_ptr<char>> buffers(node_num);
S
seemingwang 已提交
418
  std::vector<int> actual_sizes(node_num, 0);
S
seemingwang 已提交
419
  ((GraphTable *)table)
420 421
      ->random_sample_neighbors(idx_, node_data, sample_size, buffers,
                                actual_sizes, need_weight);
S
seemingwang 已提交
422 423 424 425 426 427 428 429 430 431 432 433

  cntl->response_attachment().append(&node_num, sizeof(size_t));
  cntl->response_attachment().append(actual_sizes.data(),
                                     sizeof(int) * node_num);
  for (size_t idx = 0; idx < node_num; ++idx) {
    cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
  }
  return 0;
}
int32_t GraphBrpcService::graph_random_sample_nodes(
    Table *table, const PsRequestMessage &request, PsResponseMessage &response,
    brpc::Controller *cntl) {
434 435 436 437
  int type_id = *(int *)(request.params(0).c_str());
  int idx_ = *(int *)(request.params(1).c_str());
  size_t size = *(int64_t *)(request.params(2).c_str());
  // size_t size = *(int64_t *)(request.params(0).c_str());
S
seemingwang 已提交
438 439
  std::unique_ptr<char[]> buffer;
  int actual_size;
440 441
  if (((GraphTable *)table)
          ->random_sample_nodes(type_id, idx_, size, buffer, actual_size) ==
S
seemingwang 已提交
442
      0) {
S
seemingwang 已提交
443 444 445 446 447 448 449 450 451 452 453 454
    cntl->response_attachment().append(buffer.get(), actual_size);
  } else
    cntl->response_attachment().append(NULL, 0);

  return 0;
}

int32_t GraphBrpcService::graph_get_node_feat(Table *table,
                                              const PsRequestMessage &request,
                                              PsResponseMessage &response,
                                              brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
455
  if (request.params_size() < 3) {
S
seemingwang 已提交
456 457
    set_response_code(
        response, -1,
458
        "graph_get_node_feat request requires at least 3 arguments");
S
seemingwang 已提交
459 460
    return 0;
  }
461 462 463 464 465
  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
466
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
S
seemingwang 已提交
467 468

  std::vector<std::string> feature_names =
469
      paddle::string::split_string<std::string>(request.params(2), "\t");
S
seemingwang 已提交
470 471 472 473

  std::vector<std::vector<std::string>> feature(
      feature_names.size(), std::vector<std::string>(node_num));

474
  ((GraphTable *)table)->get_node_feat(idx_, node_ids, feature_names, feature);
S
seemingwang 已提交
475 476 477 478 479 480 481 482 483 484 485 486

  for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
    for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
      size_t feat_len = feature[feat_idx][node_idx].size();
      cntl->response_attachment().append(&feat_len, sizeof(size_t));
      cntl->response_attachment().append(feature[feat_idx][node_idx].data(),
                                         feat_len);
    }
  }

  return 0;
}
487
int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
S
seemingwang 已提交
488 489 490 491
    Table *table, const PsRequestMessage &request, PsResponseMessage &response,
    brpc::Controller *cntl) {
  // sleep(5);
  CHECK_TABLE_EXIST(table, request, response)
492
  if (request.params_size() < 4) {
493 494
    set_response_code(response, -1,
                      "sample_neighbors_across_multi_servers request requires "
495
                      "at least 4 arguments");
S
seemingwang 已提交
496 497
    return 0;
  }
498 499 500

  int idx_ = *(int *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t),
S
seemingwang 已提交
501
         size_of_size_t = sizeof(size_t);
502 503 504 505 506 507 508 509 510
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
  int sample_size = *(int64_t *)(request.params(2).c_str());
  bool need_weight = *(int64_t *)(request.params(3).c_str());

  // size_t node_num = request.params(0).size() / sizeof(int64_t),
  //        size_of_size_t = sizeof(size_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  // int sample_size = *(int64_t *)(request.params(1).c_str());
  // bool need_weight = *(int64_t *)(request.params(2).c_str());
511
  // std::vector<int64_t> res = ((GraphTable
S
seemingwang 已提交
512 513 514
  // *)table).filter_out_non_exist_nodes(node_data, sample_size);
  std::vector<int> request2server;
  std::vector<int> server2request(server_size, -1);
515
  std::vector<int64_t> local_id;
S
seemingwang 已提交
516
  std::vector<int> local_query_idx;
Z
zhaocaibei123 已提交
517
  size_t rank = GetRank();
S
seemingwang 已提交
518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
  for (int query_idx = 0; query_idx < node_num; ++query_idx) {
    int server_index =
        ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
    if (server2request[server_index] == -1) {
      server2request[server_index] = request2server.size();
      request2server.push_back(server_index);
    }
  }
  if (server2request[rank] != -1) {
    auto pos = server2request[rank];
    std::swap(request2server[pos],
              request2server[(int)request2server.size() - 1]);
    server2request[request2server[pos]] = pos;
    server2request[request2server[(int)request2server.size() - 1]] =
        request2server.size() - 1;
  }
  size_t request_call_num = request2server.size();
535
  std::vector<std::shared_ptr<char>> local_buffers;
S
seemingwang 已提交
536 537
  std::vector<int> local_actual_sizes;
  std::vector<size_t> seq;
538
  std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
S
seemingwang 已提交
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
  std::vector<std::vector<int>> query_idx_buckets(request_call_num);
  for (int query_idx = 0; query_idx < node_num; ++query_idx) {
    int server_index =
        ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
    int request_idx = server2request[server_index];
    node_id_buckets[request_idx].push_back(node_data[query_idx]);
    query_idx_buckets[request_idx].push_back(query_idx);
    seq.push_back(request_idx);
  }
  size_t remote_call_num = request_call_num;
  if (request2server.size() != 0 && request2server.back() == rank) {
    remote_call_num--;
    local_buffers.resize(node_id_buckets.back().size());
    local_actual_sizes.resize(node_id_buckets.back().size());
  }
  cntl->response_attachment().append(&node_num, sizeof(size_t));
  auto local_promise = std::make_shared<std::promise<int32_t>>();
  std::future<int> local_fut = local_promise->get_future();
  std::vector<bool> failed(server_size, false);
  std::function<void(void *)> func = [&, node_id_buckets, query_idx_buckets,
                                      request_call_num](void *done) {
    local_fut.get();
    std::vector<int> actual_size;
    auto *closure = (DownpourBrpcClosure *)done;
    std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
        remote_call_num);
    size_t fail_num = 0;
    for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
567
      if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
S
seemingwang 已提交
568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
          0) {
        ++fail_num;
        failed[request2server[request_idx]] = true;
      } else {
        auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
        size_t node_size;
        res[request_idx].reset(new butil::IOBufBytesIterator(res_io_buffer));
        size_t num;
        res[request_idx]->copy_and_forward(&num, sizeof(size_t));
      }
    }
    int size;
    int local_index = 0;
    for (size_t i = 0; i < node_num; i++) {
      if (fail_num > 0 && failed[seq[i]]) {
        size = 0;
      } else if (request2server[seq[i]] != rank) {
        res[seq[i]]->copy_and_forward(&size, sizeof(int));
      } else {
        size = local_actual_sizes[local_index++];
      }
      actual_size.push_back(size);
    }
    cntl->response_attachment().append(actual_size.data(),
                                       actual_size.size() * sizeof(int));

    local_index = 0;
    for (size_t i = 0; i < node_num; i++) {
      if (fail_num > 0 && failed[seq[i]]) {
        continue;
      } else if (request2server[seq[i]] != rank) {
        char temp[actual_size[i] + 1];
        res[seq[i]]->copy_and_forward(temp, actual_size[i]);
        cntl->response_attachment().append(temp, actual_size[i]);
      } else {
        char *temp = local_buffers[local_index++].get();
        cntl->response_attachment().append(temp, actual_size[i]);
      }
    }
    closure->set_promise_value(0);
  };

  DownpourBrpcClosure *closure = new DownpourBrpcClosure(remote_call_num, func);

  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

  for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) {
    int server_index = request2server[request_idx];
618
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
S
seemingwang 已提交
619 620 621 622
    closure->request(request_idx)->set_table_id(request.table_id());
    closure->request(request_idx)->set_client_id(rank);
    size_t node_num = node_id_buckets[request_idx].size();

623 624
    closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));

S
seemingwang 已提交
625 626
    closure->request(request_idx)
        ->add_params((char *)node_id_buckets[request_idx].data(),
627
                     sizeof(int64_t) * node_num);
S
seemingwang 已提交
628 629
    closure->request(request_idx)
        ->add_params((char *)&sample_size, sizeof(int));
630 631
    closure->request(request_idx)
        ->add_params((char *)&need_weight, sizeof(bool));
S
seemingwang 已提交
632
    PsService_Stub rpc_stub(
Z
zhaocaibei123 已提交
633
        ((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
S
seemingwang 已提交
634
    // GraphPsService_Stub rpc_stub =
Z
zhaocaibei123 已提交
635
    //     getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
636 637 638 639 640 641
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
    rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
                     closure->response(request_idx), closure);
  }
  if (server2request[rank] != -1) {
    ((GraphTable *)table)
642 643 644
        ->random_sample_neighbors(idx_, node_id_buckets.back().data(),
                                  sample_size, local_buffers,
                                  local_actual_sizes, need_weight);
S
seemingwang 已提交
645 646 647 648 649 650
  }
  local_promise.get()->set_value(0);
  if (remote_call_num == 0) func(closure);
  fut.get();
  return 0;
}
S
seemingwang 已提交
651 652 653 654 655
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
                                              const PsRequestMessage &request,
                                              PsResponseMessage &response,
                                              brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
656
  if (request.params_size() < 4) {
S
seemingwang 已提交
657 658
    set_response_code(
        response, -1,
S
seemingwang 已提交
659
        "graph_set_node_feat request requires at least 3 arguments");
S
seemingwang 已提交
660 661
    return 0;
  }
662 663 664 665 666 667
  int idx_ = *(int *)(request.params(0).c_str());

  // size_t node_num = request.params(0).size() / sizeof(int64_t);
  // int64_t *node_data = (int64_t *)(request.params(0).c_str());
  size_t node_num = request.params(1).size() / sizeof(int64_t);
  int64_t *node_data = (int64_t *)(request.params(1).c_str());
668
  std::vector<int64_t> node_ids(node_data, node_data + node_num);
S
seemingwang 已提交
669

670 671 672
  // std::vector<std::string> feature_names =
  //     paddle::string::split_string<std::string>(request.params(1), "\t");

S
seemingwang 已提交
673
  std::vector<std::string> feature_names =
674
      paddle::string::split_string<std::string>(request.params(2), "\t");
S
seemingwang 已提交
675 676 677 678

  std::vector<std::vector<std::string>> features(
      feature_names.size(), std::vector<std::string>(node_num));

679 680
  //  const char *buffer = request.params(2).c_str();
  const char *buffer = request.params(3).c_str();
S
seemingwang 已提交
681 682 683 684 685 686 687 688 689 690 691

  for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
    for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
      size_t feat_len = *(size_t *)(buffer);
      buffer += sizeof(size_t);
      auto feat = std::string(buffer, feat_len);
      features[feat_idx][node_idx] = feat;
      buffer += feat_len;
    }
  }

692
  ((GraphTable *)table)->set_node_feat(idx_, node_ids, feature_names, features);
S
seemingwang 已提交
693 694 695 696

  return 0;
}

S
seemingwang 已提交
697 698
}  // namespace distributed
}  // namespace paddle