graph_brpc_server.cc 27.3 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
S
seemingwang 已提交
16

17
#include <string>
S
seemingwang 已提交
18
#include <thread>  // NOLINT
S
seemingwang 已提交
19
#include <utility>
20

S
seemingwang 已提交
21 22
#include "butil/endpoint.h"
#include "iomanip"
23
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
24
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
S
seemingwang 已提交
25 26 27 28 29
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {

30 31 32 33 34 35 36 37
#define CHECK_TABLE_EXIST(table, request, response)        \
  if (table == NULL) {                                     \
    std::string err_msg("table not found with table_id:"); \
    err_msg.append(std::to_string(request.table_id()));    \
    set_response_code(response, -1, err_msg.c_str());      \
    return -1;                                             \
  }

Z
zhaocaibei123 已提交
38
int32_t GraphBrpcServer::Initialize() {
S
seemingwang 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52
  auto &service_config = _config.downpour_server_param().service_param();
  if (!service_config.has_service_class()) {
    LOG(ERROR) << "miss service_class in ServerServiceParameter";
    return -1;
  }
  auto *service =
      CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
  if (service == NULL) {
    LOG(ERROR) << "service is unregistered, service_name:"
               << service_config.service_class();
    return -1;
  }

  _service.reset(service);
Z
zhaocaibei123 已提交
53
  if (service->Configure(this) != 0 || service->Initialize() != 0) {
S
seemingwang 已提交
54 55 56 57 58 59 60 61 62 63 64 65
    LOG(ERROR) << "service initialize failed, service_name:"
               << service_config.service_class();
    return -1;
  }
  if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
    LOG(ERROR) << "service add to brpc failed, service:"
               << service_config.service_class();
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
66
brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) {
S
seemingwang 已提交
67 68 69
  return _pserver_channels[server_index].get();
}

Z
zhaocaibei123 已提交
70
uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) {
S
seemingwang 已提交
71 72 73 74 75 76 77
  std::unique_lock<std::mutex> lock(mutex_);

  std::string ip_port = ip + ":" + std::to_string(port);
  VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
  brpc::ServerOptions options;

  int num_threads = std::thread::hardware_concurrency();
Z
zhaocaibei123 已提交
78
  auto trainers = _environment->GetTrainers();
S
seemingwang 已提交
79 80 81 82 83 84
  options.num_threads = trainers > num_threads ? trainers : num_threads;

  if (_server.Start(ip_port.c_str(), &options) != 0) {
    LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
    return 0;
  }
Z
zhaocaibei123 已提交
85
  _environment->RegistePsServer(ip, port, _rank);
S
seemingwang 已提交
86 87 88
  return 0;
}

S
seemingwang 已提交
89 90
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
  this->rank = rank;
Z
zhaocaibei123 已提交
91
  auto _env = Environment();
S
seemingwang 已提交
92 93 94 95 96 97 98
  brpc::ChannelOptions options;
  options.protocol = "baidu_std";
  options.timeout_ms = 500000;
  options.connection_type = "pooled";
  options.connect_timeout_ms = 10000;
  options.max_retry = 3;

Z
zhaocaibei123 已提交
99
  std::vector<PSHost> server_list = _env->GetPsServers();
S
seemingwang 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
  _pserver_channels.resize(server_list.size());
  std::ostringstream os;
  std::string server_ip_port;
  for (size_t i = 0; i < server_list.size(); ++i) {
    server_ip_port.assign(server_list[i].ip.c_str());
    server_ip_port.append(":");
    server_ip_port.append(std::to_string(server_list[i].port));
    _pserver_channels[i].reset(new brpc::Channel());
    if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
      VLOG(0) << "GraphServer connect to Server:" << server_ip_port
              << " Failed! Try again.";
      std::string int_ip_port =
          GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
      if (_pserver_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
        LOG(ERROR) << "GraphServer connect to Server:" << int_ip_port
                   << " Failed!";
        return -1;
      }
    }
    os << server_ip_port << ",";
  }
  LOG(INFO) << "servers peer2peer connection success:" << os.str();
  return 0;
}

125 126 127 128
int32_t GraphBrpcService::clear_nodes(Table *table,
                                      const PsRequestMessage &request,
                                      PsResponseMessage &response,
                                      brpc::Controller *cntl) {
129 130 131
  int type_id = std::stoi(request.params(0).c_str());
  int idx_ = std::stoi(request.params(1).c_str());
  (reinterpret_cast<GraphTable *>(table))->clear_nodes(type_id, idx_);
132 133 134 135 136 137 138 139
  return 0;
}

int32_t GraphBrpcService::add_graph_node(Table *table,
                                         const PsRequestMessage &request,
                                         PsResponseMessage &response,
                                         brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
140
  if (request.params_size() < 2) {
141 142
    set_response_code(
        response, -1, "add_graph_node request requires at least 2 arguments");
143 144 145
    return 0;
  }

146
  int idx_ = std::stoi(request.params(0).c_str());
147
  size_t node_num = request.params(1).size() / sizeof(int64_t);
148 149
  const uint64_t *node_data =
      reinterpret_cast<const uint64_t *>(request.params(1).c_str());
D
danleifeng 已提交
150
  std::vector<uint64_t> node_ids(node_data, node_data + node_num);
151
  std::vector<bool> is_weighted_list;
152 153
  if (request.params_size() == 3) {
    size_t weight_list_size = request.params(2).size() / sizeof(bool);
154 155
    const bool *is_weighted_buffer =
        reinterpret_cast<const bool *>(request.params(2).c_str());
156 157 158
    is_weighted_list = std::vector<bool>(is_weighted_buffer,
                                         is_weighted_buffer + weight_list_size);
  }
159 160 161 162 163 164 165
  // if (request.params_size() == 2) {
  //   size_t weight_list_size = request.params(1).size() / sizeof(bool);
  //   bool *is_weighted_buffer = (bool *)(request.params(1).c_str());
  //   is_weighted_list = std::vector<bool>(is_weighted_buffer,
  //                                        is_weighted_buffer +
  //                                        weight_list_size);
  // }
166

167 168
  (reinterpret_cast<GraphTable *>(table))
      ->add_graph_node(idx_, node_ids, is_weighted_list);
169 170 171 172 173 174 175
  return 0;
}
int32_t GraphBrpcService::remove_graph_node(Table *table,
                                            const PsRequestMessage &request,
                                            PsResponseMessage &response,
                                            brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
176
  if (request.params_size() < 2) {
177
    set_response_code(
178 179
        response,
        -1,
180
        "remove_graph_node request requires at least 2 arguments");
181 182
    return 0;
  }
183
  int idx_ = std::stoi(request.params(0).c_str());
D
danleifeng 已提交
184
  size_t node_num = request.params(1).size() / sizeof(uint64_t);
185 186
  const uint64_t *node_data =
      reinterpret_cast<const uint64_t *>(request.params(1).c_str());
D
danleifeng 已提交
187
  std::vector<uint64_t> node_ids(node_data, node_data + node_num);
188

189
  (reinterpret_cast<GraphTable *>(table))->remove_graph_node(idx_, node_ids);
190 191
  return 0;
}
Z
zhaocaibei123 已提交
192
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
S
seemingwang 已提交
193

Z
zhaocaibei123 已提交
194
int32_t GraphBrpcService::Initialize() {
S
seemingwang 已提交
195
  _is_initialize_shard_info = false;
Z
zhaocaibei123 已提交
196 197 198
  _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer;
  _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable;
  _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable;
S
seemingwang 已提交
199

Z
zhaocaibei123 已提交
200 201 202 203
  _service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat;
  _service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier;
  _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler;
  _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler;
S
seemingwang 已提交
204 205

  _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
206 207
  _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
      &GraphBrpcService::graph_random_sample_neighbors;
S
seemingwang 已提交
208 209 210 211
  _service_handler_map[PS_GRAPH_SAMPLE_NODES] =
      &GraphBrpcService::graph_random_sample_nodes;
  _service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
      &GraphBrpcService::graph_get_node_feat;
212 213 214 215 216
  _service_handler_map[PS_GRAPH_CLEAR] = &GraphBrpcService::clear_nodes;
  _service_handler_map[PS_GRAPH_ADD_GRAPH_NODE] =
      &GraphBrpcService::add_graph_node;
  _service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
      &GraphBrpcService::remove_graph_node;
S
seemingwang 已提交
217 218
  _service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
      &GraphBrpcService::graph_set_node_feat;
S
seemingwang 已提交
219
  _service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
220
      &GraphBrpcService::sample_neighbors_across_multi_servers;
Z
zhaocaibei123 已提交
221
  InitializeShardInfo();
S
seemingwang 已提交
222 223 224 225

  return 0;
}

Z
zhaocaibei123 已提交
226
int32_t GraphBrpcService::InitializeShardInfo() {
S
seemingwang 已提交
227 228 229 230 231
  if (!_is_initialize_shard_info) {
    std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
    if (_is_initialize_shard_info) {
      return 0;
    }
Z
zhaocaibei123 已提交
232 233
    server_size = _server->Environment()->GetPsServers().size();
    auto &table_map = *(_server->GetTable());
S
seemingwang 已提交
234
    for (auto itr : table_map) {
Z
zhaocaibei123 已提交
235
      itr.second->SetShard(_rank, server_size);
S
seemingwang 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
    }
    _is_initialize_shard_info = true;
  }
  return 0;
}

void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
                               const PsRequestMessage *request,
                               PsResponseMessage *response,
                               google::protobuf::Closure *done) {
  brpc::ClosureGuard done_guard(done);
  std::string log_label("ReceiveCmd-");
  if (!request->has_table_id()) {
    set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
    return;
  }

  response->set_err_code(0);
  response->set_err_msg("");
Z
zhaocaibei123 已提交
255
  auto *table = _server->GetTable(request->table_id());
S
seemingwang 已提交
256 257 258 259 260 261 262 263 264 265 266 267 268
  brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
  auto itr = _service_handler_map.find(request->cmd_id());
  if (itr == _service_handler_map.end()) {
    std::string err_msg(
        "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
    err_msg.append(std::to_string(request->cmd_id()));
    set_response_code(*response, -1, err_msg.c_str());
    return;
  }
  serviceFunc handler_func = itr->second;
  int service_ret = (this->*handler_func)(table, *request, *response, cntl);
  if (service_ret != 0) {
    response->set_err_code(service_ret);
S
seemingwang 已提交
269 270 271
    if (!response->has_err_msg()) {
      response->set_err_msg("server internal error");
    }
S
seemingwang 已提交
272 273 274
  }
}

275 276
int32_t GraphBrpcService::Barrier(Table *table,
                                  const PsRequestMessage &request,
S
seemingwang 已提交
277 278 279 280 281
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)

  if (request.params_size() < 1) {
282 283
    set_response_code(response,
                      -1,
S
seemingwang 已提交
284 285 286 287 288 289 290
                      "PsRequestMessage.params is requeired at "
                      "least 1 for num of sparse_key");
    return 0;
  }

  auto trainer_id = request.client_id();
  auto barrier_type = request.params(0);
Z
zhaocaibei123 已提交
291
  table->Barrier(trainer_id, barrier_type);
S
seemingwang 已提交
292 293 294
  return 0;
}

Z
zhaocaibei123 已提交
295 296 297 298
int32_t GraphBrpcService::PrintTableStat(Table *table,
                                         const PsRequestMessage &request,
                                         PsResponseMessage &response,
                                         brpc::Controller *cntl) {
S
seemingwang 已提交
299
  CHECK_TABLE_EXIST(table, request, response)
Z
zhaocaibei123 已提交
300
  std::pair<int64_t, int64_t> ret = table->PrintTableStat();
S
seemingwang 已提交
301 302 303 304 305 306 307 308
  paddle::framework::BinaryArchive ar;
  ar << ret.first << ret.second;
  std::string table_info(ar.Buffer(), ar.Length());
  response.set_data(table_info);

  return 0;
}

Z
zhaocaibei123 已提交
309 310 311 312
int32_t GraphBrpcService::LoadOneTable(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
S
seemingwang 已提交
313 314 315
  CHECK_TABLE_EXIST(table, request, response)
  if (request.params_size() < 2) {
    set_response_code(
316 317
        response,
        -1,
S
seemingwang 已提交
318 319 320
        "PsRequestMessage.datas is requeired at least 2 for path & load_param");
    return -1;
  }
Z
zhaocaibei123 已提交
321
  if (table->Load(request.params(0), request.params(1)) != 0) {
S
seemingwang 已提交
322 323 324 325 326 327
    set_response_code(response, -1, "table load failed");
    return -1;
  }
  return 0;
}

Z
zhaocaibei123 已提交
328 329 330 331 332
int32_t GraphBrpcService::LoadAllTable(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
  auto &table_map = *(_server->GetTable());
S
seemingwang 已提交
333
  for (auto &itr : table_map) {
Z
zhaocaibei123 已提交
334
    if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
S
seemingwang 已提交
335 336 337 338 339 340 341
      LOG(ERROR) << "load table[" << itr.first << "] failed";
      return -1;
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
342 343 344 345
int32_t GraphBrpcService::StopServer(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
346
  GraphBrpcServer *p_server = reinterpret_cast<GraphBrpcServer *>(_server);
S
seemingwang 已提交
347
  std::thread t_stop([p_server]() {
Z
zhaocaibei123 已提交
348
    p_server->Stop();
S
seemingwang 已提交
349 350 351 352 353 354 355
    LOG(INFO) << "Server Stoped";
  });
  p_server->export_cv()->notify_all();
  t_stop.detach();
  return 0;
}

Z
zhaocaibei123 已提交
356 357 358 359
int32_t GraphBrpcService::StopProfiler(Table *table,
                                       const PsRequestMessage &request,
                                       PsResponseMessage &response,
                                       brpc::Controller *cntl) {
S
seemingwang 已提交
360 361 362 363 364
  platform::DisableProfiler(platform::EventSortingKey::kDefault,
                            string::Sprintf("server_%s_profile", _rank));
  return 0;
}

Z
zhaocaibei123 已提交
365 366 367 368
int32_t GraphBrpcService::StartProfiler(Table *table,
                                        const PsRequestMessage &request,
                                        PsResponseMessage &response,
                                        brpc::Controller *cntl) {
S
seemingwang 已提交
369 370 371 372 373 374 375 376 377
  platform::EnableProfiler(platform::ProfilerState::kCPU);
  return 0;
}

int32_t GraphBrpcService::pull_graph_list(Table *table,
                                          const PsRequestMessage &request,
                                          PsResponseMessage &response,
                                          brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
378
  if (request.params_size() < 5) {
379 380
    set_response_code(
        response, -1, "pull_graph_list request requires at least 5 arguments");
S
seemingwang 已提交
381 382
    return 0;
  }
383 384 385 386 387
  int type_id = std::stoi(request.params(0).c_str());
  int idx = std::stoi(request.params(1).c_str());
  int start = std::stoi(request.params(2).c_str());
  int size = std::stoi(request.params(3).c_str());
  int step = std::stoi(request.params(4).c_str());
S
seemingwang 已提交
388 389
  std::unique_ptr<char[]> buffer;
  int actual_size;
390
  (reinterpret_cast<GraphTable *>(table))
391 392
      ->pull_graph_list(
          type_id, idx, start, size, buffer, actual_size, false, step);
S
seemingwang 已提交
393 394 395
  cntl->response_attachment().append(buffer.get(), actual_size);
  return 0;
}
396
int32_t GraphBrpcService::graph_random_sample_neighbors(
397 398 399
    Table *table,
    const PsRequestMessage &request,
    PsResponseMessage &response,
S
seemingwang 已提交
400 401
    brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
402
  if (request.params_size() < 4) {
S
seemingwang 已提交
403
    set_response_code(
404 405
        response,
        -1,
406
        "graph_random_sample_neighbors request requires at least 3 arguments");
S
seemingwang 已提交
407 408
    return 0;
  }
409
  int idx_ = std::stoi(request.params(0).c_str());
D
danleifeng 已提交
410
  size_t node_num = request.params(1).size() / sizeof(uint64_t);
411 412 413 414 415
  uint64_t *node_data = (uint64_t *)(request.params(1).c_str());  // NOLINT
  const int sample_size =
      *reinterpret_cast<const int *>(request.params(2).c_str());
  const bool need_weight =
      *reinterpret_cast<const bool *>(request.params(3).c_str());
416
  std::vector<std::shared_ptr<char>> buffers(node_num);
S
seemingwang 已提交
417
  std::vector<int> actual_sizes(node_num, 0);
418
  (reinterpret_cast<GraphTable *>(table))
419 420
      ->random_sample_neighbors(
          idx_, node_data, sample_size, buffers, actual_sizes, need_weight);
S
seemingwang 已提交
421 422 423 424 425 426 427 428 429 430

  cntl->response_attachment().append(&node_num, sizeof(size_t));
  cntl->response_attachment().append(actual_sizes.data(),
                                     sizeof(int) * node_num);
  for (size_t idx = 0; idx < node_num; ++idx) {
    cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
  }
  return 0;
}
int32_t GraphBrpcService::graph_random_sample_nodes(
431 432 433
    Table *table,
    const PsRequestMessage &request,
    PsResponseMessage &response,
S
seemingwang 已提交
434
    brpc::Controller *cntl) {
435 436 437
  int type_id = std::stoi(request.params(0).c_str());
  int idx_ = std::stoi(request.params(1).c_str());
  size_t size = std::stoull(request.params(2).c_str());
438
  // size_t size = *(int64_t *)(request.params(0).c_str());
S
seemingwang 已提交
439 440
  std::unique_ptr<char[]> buffer;
  int actual_size;
441 442
  if (reinterpret_cast<GraphTable *>(table)->random_sample_nodes(
          type_id, idx_, size, buffer, actual_size) == 0) {
S
seemingwang 已提交
443
    cntl->response_attachment().append(buffer.get(), actual_size);
444
  } else {
S
seemingwang 已提交
445
    cntl->response_attachment().append(NULL, 0);
446
  }
S
seemingwang 已提交
447 448 449 450 451 452 453 454 455

  return 0;
}

int32_t GraphBrpcService::graph_get_node_feat(Table *table,
                                              const PsRequestMessage &request,
                                              PsResponseMessage &response,
                                              brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
456
  if (request.params_size() < 3) {
S
seemingwang 已提交
457
    set_response_code(
458 459
        response,
        -1,
460
        "graph_get_node_feat request requires at least 3 arguments");
S
seemingwang 已提交
461 462
    return 0;
  }
463
  int idx_ = std::stoi(request.params(0).c_str());
D
danleifeng 已提交
464
  size_t node_num = request.params(1).size() / sizeof(uint64_t);
465 466
  const uint64_t *node_data =
      reinterpret_cast<const uint64_t *>(request.params(1).c_str());
D
danleifeng 已提交
467
  std::vector<uint64_t> node_ids(node_data, node_data + node_num);
S
seemingwang 已提交
468 469

  std::vector<std::string> feature_names =
470
      paddle::string::split_string<std::string>(request.params(2), "\t");
S
seemingwang 已提交
471 472 473 474

  std::vector<std::vector<std::string>> feature(
      feature_names.size(), std::vector<std::string>(node_num));

475 476
  (reinterpret_cast<GraphTable *>(table))
      ->get_node_feat(idx_, node_ids, feature_names, feature);
S
seemingwang 已提交
477 478 479 480 481 482 483 484 485 486 487 488

  for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
    for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
      size_t feat_len = feature[feat_idx][node_idx].size();
      cntl->response_attachment().append(&feat_len, sizeof(size_t));
      cntl->response_attachment().append(feature[feat_idx][node_idx].data(),
                                         feat_len);
    }
  }

  return 0;
}
489
int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
490 491 492
    Table *table,
    const PsRequestMessage &request,
    PsResponseMessage &response,
S
seemingwang 已提交
493 494 495
    brpc::Controller *cntl) {
  // sleep(5);
  CHECK_TABLE_EXIST(table, request, response)
496
  if (request.params_size() < 4) {
497 498
    set_response_code(response,
                      -1,
499
                      "sample_neighbors_across_multi_servers request requires "
500
                      "at least 4 arguments");
S
seemingwang 已提交
501 502
    return 0;
  }
503

504
  int idx_ = std::stoi(request.params(0).c_str());
D
danleifeng 已提交
505
  size_t node_num = request.params(1).size() / sizeof(uint64_t);
506 507 508 509
  const uint64_t *node_data =
      reinterpret_cast<const uint64_t *>(request.params(1).c_str());
  int sample_size = std::stoi(request.params(2).c_str());
  bool need_weight = std::stoi(request.params(3).c_str());
D
danleifeng 已提交
510

S
seemingwang 已提交
511 512
  std::vector<int> request2server;
  std::vector<int> server2request(server_size, -1);
D
danleifeng 已提交
513
  std::vector<uint64_t> local_id;
S
seemingwang 已提交
514
  std::vector<int> local_query_idx;
Z
zhaocaibei123 已提交
515
  size_t rank = GetRank();
516
  for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
517 518
    int server_index = (reinterpret_cast<GraphTable *>(table))
                           ->get_server_index_by_id(node_data[query_idx]);
S
seemingwang 已提交
519 520 521 522 523 524 525 526
    if (server2request[server_index] == -1) {
      server2request[server_index] = request2server.size();
      request2server.push_back(server_index);
    }
  }
  if (server2request[rank] != -1) {
    auto pos = server2request[rank];
    std::swap(request2server[pos],
527
              request2server[static_cast<int>(request2server.size()) - 1]);
S
seemingwang 已提交
528
    server2request[request2server[pos]] = pos;
529 530
    server2request[request2server[static_cast<int>(request2server.size()) -
                                  1]] = request2server.size() - 1;
S
seemingwang 已提交
531 532
  }
  size_t request_call_num = request2server.size();
533
  std::vector<std::shared_ptr<char>> local_buffers;
S
seemingwang 已提交
534 535
  std::vector<int> local_actual_sizes;
  std::vector<size_t> seq;
D
danleifeng 已提交
536
  std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
S
seemingwang 已提交
537
  std::vector<std::vector<int>> query_idx_buckets(request_call_num);
538
  for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
539 540
    int server_index = (reinterpret_cast<GraphTable *>(table))
                           ->get_server_index_by_id(node_data[query_idx]);
S
seemingwang 已提交
541 542 543 544 545 546
    int request_idx = server2request[server_index];
    node_id_buckets[request_idx].push_back(node_data[query_idx]);
    query_idx_buckets[request_idx].push_back(query_idx);
    seq.push_back(request_idx);
  }
  size_t remote_call_num = request_call_num;
Z
zhangchunle 已提交
547 548
  if (request2server.size() != 0 &&
      static_cast<size_t>(request2server.back()) == rank) {
S
seemingwang 已提交
549 550 551 552 553 554 555 556
    remote_call_num--;
    local_buffers.resize(node_id_buckets.back().size());
    local_actual_sizes.resize(node_id_buckets.back().size());
  }
  cntl->response_attachment().append(&node_num, sizeof(size_t));
  auto local_promise = std::make_shared<std::promise<int32_t>>();
  std::future<int> local_fut = local_promise->get_future();
  std::vector<bool> failed(server_size, false);
557 558 559
  std::function<void(void *)> func = [&,
                                      node_id_buckets,
                                      query_idx_buckets,
S
seemingwang 已提交
560 561 562
                                      request_call_num](void *done) {
    local_fut.get();
    std::vector<int> actual_size;
563
    auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
S
seemingwang 已提交
564 565 566 567
    std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
        remote_call_num);
    size_t fail_num = 0;
    for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
568
      if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
S
seemingwang 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
          0) {
        ++fail_num;
        failed[request2server[request_idx]] = true;
      } else {
        auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
        res[request_idx].reset(new butil::IOBufBytesIterator(res_io_buffer));
        size_t num;
        res[request_idx]->copy_and_forward(&num, sizeof(size_t));
      }
    }
    int size;
    int local_index = 0;
    for (size_t i = 0; i < node_num; i++) {
      if (fail_num > 0 && failed[seq[i]]) {
        size = 0;
Z
zhangchunle 已提交
584
      } else if (static_cast<size_t>(request2server[seq[i]]) != rank) {
S
seemingwang 已提交
585 586 587 588 589 590 591 592 593 594 595 596 597
        res[seq[i]]->copy_and_forward(&size, sizeof(int));
      } else {
        size = local_actual_sizes[local_index++];
      }
      actual_size.push_back(size);
    }
    cntl->response_attachment().append(actual_size.data(),
                                       actual_size.size() * sizeof(int));

    local_index = 0;
    for (size_t i = 0; i < node_num; i++) {
      if (fail_num > 0 && failed[seq[i]]) {
        continue;
Z
zhangchunle 已提交
598
      } else if (static_cast<size_t>(request2server[seq[i]]) != rank) {
S
seemingwang 已提交
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
        char temp[actual_size[i] + 1];
        res[seq[i]]->copy_and_forward(temp, actual_size[i]);
        cntl->response_attachment().append(temp, actual_size[i]);
      } else {
        char *temp = local_buffers[local_index++].get();
        cntl->response_attachment().append(temp, actual_size[i]);
      }
    }
    closure->set_promise_value(0);
  };

  DownpourBrpcClosure *closure = new DownpourBrpcClosure(remote_call_num, func);

  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

616
  for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
S
seemingwang 已提交
617
    int server_index = request2server[request_idx];
618
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
S
seemingwang 已提交
619 620 621 622
    closure->request(request_idx)->set_table_id(request.table_id());
    closure->request(request_idx)->set_client_id(rank);
    size_t node_num = node_id_buckets[request_idx].size();

623 624
    closure->request(request_idx)
        ->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));
625

S
seemingwang 已提交
626
    closure->request(request_idx)
627 628 629
        ->add_params(
            reinterpret_cast<char *>(node_id_buckets[request_idx].data()),
            sizeof(uint64_t) * node_num);
S
seemingwang 已提交
630
    closure->request(request_idx)
631
        ->add_params(reinterpret_cast<char *>(&sample_size), sizeof(int));
632
    closure->request(request_idx)
633 634 635
        ->add_params(reinterpret_cast<char *>(&need_weight), sizeof(bool));
    PsService_Stub rpc_stub((reinterpret_cast<GraphBrpcServer *>(GetServer())
                                 ->GetCmdChannel(server_index)));
S
seemingwang 已提交
636
    // GraphPsService_Stub rpc_stub =
Z
zhaocaibei123 已提交
637
    //     getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
638
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
639 640 641 642
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
S
seemingwang 已提交
643 644
  }
  if (server2request[rank] != -1) {
645
    (reinterpret_cast<GraphTable *>(table))
646 647 648 649 650 651
        ->random_sample_neighbors(idx_,
                                  node_id_buckets.back().data(),
                                  sample_size,
                                  local_buffers,
                                  local_actual_sizes,
                                  need_weight);
S
seemingwang 已提交
652 653 654 655 656 657
  }
  local_promise.get()->set_value(0);
  if (remote_call_num == 0) func(closure);
  fut.get();
  return 0;
}
S
seemingwang 已提交
658 659 660 661 662
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
                                              const PsRequestMessage &request,
                                              PsResponseMessage &response,
                                              brpc::Controller *cntl) {
  CHECK_TABLE_EXIST(table, request, response)
663
  if (request.params_size() < 4) {
S
seemingwang 已提交
664
    set_response_code(
665 666
        response,
        -1,
S
seemingwang 已提交
667
        "graph_set_node_feat request requires at least 3 arguments");
S
seemingwang 已提交
668 669
    return 0;
  }
670
  int idx_ = std::stoi(request.params(0).c_str());
671

D
danleifeng 已提交
672
  size_t node_num = request.params(1).size() / sizeof(uint64_t);
673 674
  const uint64_t *node_data =
      reinterpret_cast<const uint64_t *>(request.params(1).c_str());
D
danleifeng 已提交
675
  std::vector<uint64_t> node_ids(node_data, node_data + node_num);
S
seemingwang 已提交
676

677 678 679
  // std::vector<std::string> feature_names =
  //     paddle::string::split_string<std::string>(request.params(1), "\t");

S
seemingwang 已提交
680
  std::vector<std::string> feature_names =
681
      paddle::string::split_string<std::string>(request.params(2), "\t");
S
seemingwang 已提交
682 683 684 685

  std::vector<std::vector<std::string>> features(
      feature_names.size(), std::vector<std::string>(node_num));

686 687
  //  const char *buffer = request.params(2).c_str();
  const char *buffer = request.params(3).c_str();
S
seemingwang 已提交
688 689 690

  for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
    for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
691
      const size_t feat_len = *reinterpret_cast<const size_t *>(buffer);
S
seemingwang 已提交
692 693 694 695 696 697 698
      buffer += sizeof(size_t);
      auto feat = std::string(buffer, feat_len);
      features[feat_idx][node_idx] = feat;
      buffer += feat_len;
    }
  }

699 700
  (reinterpret_cast<GraphTable *>(table))
      ->set_node_feat(idx_, node_ids, feature_names, features);
S
seemingwang 已提交
701 702 703 704

  return 0;
}

S
seemingwang 已提交
705 706
}  // namespace distributed
}  // namespace paddle