graph_brpc_client.cc 28.5 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
16

S
seemingwang 已提交
17 18 19 20 21 22
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
23

S
seemingwang 已提交
24
#include "Eigen/Dense"
25 26
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
S
seemingwang 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {

void GraphPsService_Stub::service(
    ::google::protobuf::RpcController *controller,
    const ::paddle::distributed::PsRequestMessage *request,
    ::paddle::distributed::PsResponseMessage *response,
    ::google::protobuf::Closure *done) {
  if (graph_service != NULL && local_channel == channel()) {
    // VLOG(0)<<"use local";
    task_pool->enqueue([this, controller, request, response, done]() -> int {
      this->graph_service->service(controller, request, response, done);
      return 0;
    });
  } else {
    // VLOG(0)<<"use server";
    PsService_Stub::service(controller, request, response, done);
  }
}

49
int GraphBrpcClient::get_server_index_by_id(int64_t id) {
S
seemingwang 已提交
50 51 52 53 54 55 56 57
  int shard_num = get_shard_num();
  int shard_per_server = shard_num % server_size == 0
                             ? shard_num / server_size
                             : shard_num / server_size + 1;
  return id % shard_num / shard_per_server;
}

std::future<int32_t> GraphBrpcClient::get_node_feat(
58 59 60
    const uint32_t &table_id,
    int idx_,
    const std::vector<int64_t> &node_ids,
S
seemingwang 已提交
61 62 63 64
    const std::vector<std::string> &feature_names,
    std::vector<std::vector<std::string>> &res) {
  std::vector<int> request2server;
  std::vector<int> server2request(server_size, -1);
65
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
66 67 68 69 70 71 72
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    if (server2request[server_index] == -1) {
      server2request[server_index] = request2server.size();
      request2server.push_back(server_index);
    }
  }
  size_t request_call_num = request2server.size();
73
  std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
S
seemingwang 已提交
74
  std::vector<std::vector<int>> query_idx_buckets(request_call_num);
75
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
76 77 78 79 80 81 82 83 84 85 86
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    int request_idx = server2request[server_index];
    node_id_buckets[request_idx].push_back(node_ids[query_idx]);
    query_idx_buckets[request_idx].push_back(query_idx);
  }

  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
      request_call_num,
      [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
        int ret = 0;
        auto *closure = (DownpourBrpcClosure *)done;
87
        size_t fail_num = 0;
88
        for (size_t request_idx = 0; request_idx < request_call_num;
S
seemingwang 已提交
89
             ++request_idx) {
90 91
          if (closure->check_response(request_idx, PS_GRAPH_GET_NODE_FEAT) !=
              0) {
S
seemingwang 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
            ++fail_num;
          } else {
            auto &res_io_buffer =
                closure->cntl(request_idx)->response_attachment();
            butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
            size_t bytes_size = io_buffer_itr.bytes_left();
            std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
            char *buffer = buffer_wrapper.get();
            io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);

            for (size_t feat_idx = 0; feat_idx < feature_names.size();
                 ++feat_idx) {
              for (size_t node_idx = 0;
                   node_idx < query_idx_buckets.at(request_idx).size();
                   ++node_idx) {
                int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
                size_t feat_len = *(size_t *)(buffer);
                buffer += sizeof(size_t);
                auto feature = std::string(buffer, feat_len);
                res[feat_idx][query_idx] = feature;
                buffer += feat_len;
              }
            }
          }
          if (fail_num == request_call_num) {
            ret = -1;
          }
        }
        closure->set_promise_value(ret);
      });

  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

127
  for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
S
seemingwang 已提交
128 129 130
    int server_index = request2server[request_idx];
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT);
    closure->request(request_idx)->set_table_id(table_id);
131

S
seemingwang 已提交
132 133 134
    closure->request(request_idx)->set_client_id(_client_id);
    size_t node_num = node_id_buckets[request_idx].size();

135
    closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
S
seemingwang 已提交
136 137
    closure->request(request_idx)
        ->add_params((char *)node_id_buckets[request_idx].data(),
138
                     sizeof(int64_t) * node_num);
S
seemingwang 已提交
139 140 141 142 143
    std::string joint_feature_name =
        paddle::string::join_strings(feature_names, '\t');
    closure->request(request_idx)
        ->add_params(joint_feature_name.c_str(), joint_feature_name.size());

Z
zhaocaibei123 已提交
144
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
145
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
146 147 148 149
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
S
seemingwang 已提交
150 151 152 153
  }

  return fut;
}
154

155
std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id,
156 157
                                                  int type_id,
                                                  int idx_) {
158
  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
159
      server_size, [&, server_size = this->server_size](void *done) {
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
        int ret = 0;
        auto *closure = (DownpourBrpcClosure *)done;
        size_t fail_num = 0;
        for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
          if (closure->check_response(request_idx, PS_GRAPH_CLEAR) != 0) {
            ++fail_num;
            break;
          }
        }
        ret = fail_num == 0 ? 0 : -1;
        closure->set_promise_value(ret);
      });
  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();
  for (size_t i = 0; i < server_size; i++) {
    int server_index = i;
    closure->request(server_index)->set_cmd_id(PS_GRAPH_CLEAR);
    closure->request(server_index)->set_table_id(table_id);
    closure->request(server_index)->set_client_id(_client_id);
180 181
    closure->request(server_index)->add_params((char *)&type_id, sizeof(int));
    closure->request(server_index)->add_params((char *)&idx_, sizeof(int));
Z
zhaocaibei123 已提交
182
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
183 184 185
    closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
    rpc_stub.service(closure->cntl(server_index),
                     closure->request(server_index),
186 187
                     closure->response(server_index),
                     closure);
188 189 190 191
  }
  return fut;
}
std::future<int32_t> GraphBrpcClient::add_graph_node(
192 193 194
    uint32_t table_id,
    int idx_,
    std::vector<int64_t> &node_id_list,
195
    std::vector<bool> &is_weighted_list) {
196
  std::vector<std::vector<int64_t>> request_bucket;
197 198 199 200 201 202 203 204 205
  std::vector<std::vector<bool>> is_weighted_bucket;
  bool add_weight = is_weighted_list.size() > 0;
  std::vector<int> server_index_arr;
  std::vector<int> index_mapping(server_size, -1);
  for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
    int server_index = get_server_index_by_id(node_id_list[query_idx]);
    if (index_mapping[server_index] == -1) {
      index_mapping[server_index] = request_bucket.size();
      server_index_arr.push_back(server_index);
206
      request_bucket.push_back(std::vector<int64_t>());
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
      if (add_weight) is_weighted_bucket.push_back(std::vector<bool>());
    }
    request_bucket[index_mapping[server_index]].push_back(
        node_id_list[query_idx]);
    if (add_weight)
      is_weighted_bucket[index_mapping[server_index]].push_back(
          query_idx < is_weighted_list.size() ? is_weighted_list[query_idx]
                                              : false);
  }
  size_t request_call_num = request_bucket.size();
  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
      request_call_num, [&, request_call_num](void *done) {
        int ret = 0;
        auto *closure = (DownpourBrpcClosure *)done;
        size_t fail_num = 0;
        for (size_t request_idx = 0; request_idx < request_call_num;
             ++request_idx) {
          if (closure->check_response(request_idx, PS_GRAPH_ADD_GRAPH_NODE) !=
              0) {
            ++fail_num;
          }
        }
        ret = fail_num == request_call_num ? -1 : 0;
        closure->set_promise_value(ret);
      });
  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

  for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
    int server_index = server_index_arr[request_idx];
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_ADD_GRAPH_NODE);
    closure->request(request_idx)->set_table_id(table_id);
    closure->request(request_idx)->set_client_id(_client_id);
    size_t node_num = request_bucket[request_idx].size();
242
    closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
243 244
    closure->request(request_idx)
        ->add_params((char *)request_bucket[request_idx].data(),
245
                     sizeof(int64_t) * node_num);
246 247 248 249 250 251 252 253
    if (add_weight) {
      bool weighted[is_weighted_bucket[request_idx].size() + 1];
      for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++)
        weighted[j] = is_weighted_bucket[request_idx][j];
      closure->request(request_idx)
          ->add_params((char *)weighted,
                       sizeof(bool) * is_weighted_bucket[request_idx].size());
    }
Z
zhaocaibei123 已提交
254 255
    // PsService_Stub rpc_stub(GetCmdChannel(server_index));
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
256
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
257 258 259 260
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
261 262 263 264
  }
  return fut;
}
std::future<int32_t> GraphBrpcClient::remove_graph_node(
265
    uint32_t table_id, int idx_, std::vector<int64_t> &node_id_list) {
266
  std::vector<std::vector<int64_t>> request_bucket;
267 268 269 270 271 272 273
  std::vector<int> server_index_arr;
  std::vector<int> index_mapping(server_size, -1);
  for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
    int server_index = get_server_index_by_id(node_id_list[query_idx]);
    if (index_mapping[server_index] == -1) {
      index_mapping[server_index] = request_bucket.size();
      server_index_arr.push_back(server_index);
274
      request_bucket.push_back(std::vector<int64_t>());
275 276 277 278 279 280 281 282 283
    }
    request_bucket[index_mapping[server_index]].push_back(
        node_id_list[query_idx]);
  }
  size_t request_call_num = request_bucket.size();
  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
      request_call_num, [&, request_call_num](void *done) {
        int ret = 0;
        auto *closure = (DownpourBrpcClosure *)done;
284
        size_t fail_num = 0;
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
        for (size_t request_idx = 0; request_idx < request_call_num;
             ++request_idx) {
          if (closure->check_response(request_idx,
                                      PS_GRAPH_REMOVE_GRAPH_NODE) != 0) {
            ++fail_num;
          }
        }
        ret = fail_num == request_call_num ? -1 : 0;
        closure->set_promise_value(ret);
      });
  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

  for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
    int server_index = server_index_arr[request_idx];
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_REMOVE_GRAPH_NODE);
    closure->request(request_idx)->set_table_id(table_id);
    closure->request(request_idx)->set_client_id(_client_id);
    size_t node_num = request_bucket[request_idx].size();

306
    closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
307 308
    closure->request(request_idx)
        ->add_params((char *)request_bucket[request_idx].data(),
309
                     sizeof(int64_t) * node_num);
Z
zhaocaibei123 已提交
310 311
    // PsService_Stub rpc_stub(GetCmdChannel(server_index));
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
312
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
313 314 315 316
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
317 318 319
  }
  return fut;
}
S
seemingwang 已提交
320
// char* &buffer,int &actual_size
321
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
322 323 324 325
    uint32_t table_id,
    int idx_,
    std::vector<int64_t> node_ids,
    int sample_size,
326 327
    // std::vector<std::vector<std::pair<int64_t, float>>> &res,
    std::vector<std::vector<int64_t>> &res,
328 329
    std::vector<std::vector<float>> &res_weight,
    bool need_weight,
S
seemingwang 已提交
330 331 332
    int server_index) {
  if (server_index != -1) {
    res.resize(node_ids.size());
333 334 335
    if (need_weight) {
      res_weight.resize(node_ids.size());
    }
S
seemingwang 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
    DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
      int ret = 0;
      auto *closure = (DownpourBrpcClosure *)done;
      if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER) !=
          0) {
        ret = -1;
      } else {
        auto &res_io_buffer = closure->cntl(0)->response_attachment();
        butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
        size_t bytes_size = io_buffer_itr.bytes_left();
        std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
        char *buffer = buffer_wrapper.get();
        io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);

        size_t node_num = *(size_t *)buffer;
        int *actual_sizes = (int *)(buffer + sizeof(size_t));
        char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;

        int offset = 0;
        for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
          int actual_size = actual_sizes[node_idx];
          int start = 0;
          while (start < actual_size) {
359
            res[node_idx].emplace_back(
360
                *(int64_t *)(node_buffer + offset + start));
361 362 363 364 365 366
            start += GraphNode::id_size;
            if (need_weight) {
              res_weight[node_idx].emplace_back(
                  *(float *)(node_buffer + offset + start));
              start += GraphNode::weight_size;
            }
S
seemingwang 已提交
367 368 369 370 371 372 373 374 375 376 377 378 379
          }
          offset += actual_size;
        }
      }
      closure->set_promise_value(ret);
    });
    auto promise = std::make_shared<std::promise<int32_t>>();
    closure->add_promise(promise);
    std::future<int> fut = promise->get_future();
    ;
    closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER);
    closure->request(0)->set_table_id(table_id);
    closure->request(0)->set_client_id(_client_id);
380
    closure->request(0)->add_params((char *)&idx_, sizeof(int));
S
seemingwang 已提交
381
    closure->request(0)->add_params((char *)node_ids.data(),
382
                                    sizeof(int64_t) * node_ids.size());
S
seemingwang 已提交
383
    closure->request(0)->add_params((char *)&sample_size, sizeof(int));
384
    closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
S
seemingwang 已提交
385
    ;
Z
zhaocaibei123 已提交
386 387
    // PsService_Stub rpc_stub(GetCmdChannel(server_index));
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
388
    closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
389 390
    rpc_stub.service(
        closure->cntl(0), closure->request(0), closure->response(0), closure);
S
seemingwang 已提交
391 392
    return fut;
  }
S
seemingwang 已提交
393 394 395
  std::vector<int> request2server;
  std::vector<int> server2request(server_size, -1);
  res.clear();
396
  res_weight.clear();
397
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
398 399 400 401 402
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    if (server2request[server_index] == -1) {
      server2request[server_index] = request2server.size();
      request2server.push_back(server_index);
    }
403
    // res.push_back(std::vector<std::pair<int64_t, float>>());
404 405 406 407
    res.push_back({});
    if (need_weight) {
      res_weight.push_back({});
    }
S
seemingwang 已提交
408 409
  }
  size_t request_call_num = request2server.size();
410
  std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
S
seemingwang 已提交
411
  std::vector<std::vector<int>> query_idx_buckets(request_call_num);
412
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
413 414 415 416 417 418 419 420 421 422 423
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    int request_idx = server2request[server_index];
    node_id_buckets[request_idx].push_back(node_ids[query_idx]);
    query_idx_buckets[request_idx].push_back(query_idx);
  }

  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
      request_call_num,
      [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
        int ret = 0;
        auto *closure = (DownpourBrpcClosure *)done;
424 425
        size_t fail_num = 0;
        for (size_t request_idx = 0; request_idx < request_call_num;
S
seemingwang 已提交
426
             ++request_idx) {
427 428
          if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
              0) {
S
seemingwang 已提交
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
            ++fail_num;
          } else {
            auto &res_io_buffer =
                closure->cntl(request_idx)->response_attachment();
            butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
            size_t bytes_size = io_buffer_itr.bytes_left();
            std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
            char *buffer = buffer_wrapper.get();
            io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);

            size_t node_num = *(size_t *)buffer;
            int *actual_sizes = (int *)(buffer + sizeof(size_t));
            char *node_buffer =
                buffer + sizeof(size_t) + sizeof(int) * node_num;

            int offset = 0;
            for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
              int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
              int actual_size = actual_sizes[node_idx];
              int start = 0;
              while (start < actual_size) {
450
                res[query_idx].emplace_back(
451
                    *(int64_t *)(node_buffer + offset + start));
452 453 454 455 456 457
                start += GraphNode::id_size;
                if (need_weight) {
                  res_weight[query_idx].emplace_back(
                      *(float *)(node_buffer + offset + start));
                  start += GraphNode::weight_size;
                }
S
seemingwang 已提交
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
              }
              offset += actual_size;
            }
          }
          if (fail_num == request_call_num) {
            ret = -1;
          }
        }
        closure->set_promise_value(ret);
      });

  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

473
  for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
S
seemingwang 已提交
474
    int server_index = request2server[request_idx];
475
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
S
seemingwang 已提交
476 477 478 479
    closure->request(request_idx)->set_table_id(table_id);
    closure->request(request_idx)->set_client_id(_client_id);
    size_t node_num = node_id_buckets[request_idx].size();

480
    closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
S
seemingwang 已提交
481 482
    closure->request(request_idx)
        ->add_params((char *)node_id_buckets[request_idx].data(),
483
                     sizeof(int64_t) * node_num);
S
seemingwang 已提交
484 485
    closure->request(request_idx)
        ->add_params((char *)&sample_size, sizeof(int));
486 487
    closure->request(request_idx)
        ->add_params((char *)&need_weight, sizeof(bool));
Z
zhaocaibei123 已提交
488 489
    // PsService_Stub rpc_stub(GetCmdChannel(server_index));
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
490
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
491 492 493 494
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
S
seemingwang 已提交
495 496 497 498 499
  }

  return fut;
}
std::future<int32_t> GraphBrpcClient::random_sample_nodes(
500 501 502 503 504
    uint32_t table_id,
    int type_id,
    int idx_,
    int server_index,
    int sample_size,
505
    std::vector<int64_t> &ids) {
S
seemingwang 已提交
506 507 508 509 510 511 512 513 514
  DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
    int ret = 0;
    auto *closure = (DownpourBrpcClosure *)done;
    if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES) != 0) {
      ret = -1;
    } else {
      auto &res_io_buffer = closure->cntl(0)->response_attachment();
      butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
      size_t bytes_size = io_buffer_itr.bytes_left();
515
      char *buffer = new char[bytes_size];
516
      size_t index = 0;
S
seemingwang 已提交
517
      while (index < bytes_size) {
518
        ids.push_back(*(int64_t *)(buffer + index));
S
seemingwang 已提交
519 520
        index += GraphNode::id_size;
      }
521
      delete[] buffer;
S
seemingwang 已提交
522 523 524 525 526 527 528 529 530 531
    }
    closure->set_promise_value(ret);
  });
  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();
  ;
  closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES);
  closure->request(0)->set_table_id(table_id);
  closure->request(0)->set_client_id(_client_id);
532 533
  closure->request(0)->add_params((char *)&type_id, sizeof(int));
  closure->request(0)->add_params((char *)&idx_, sizeof(int));
S
seemingwang 已提交
534 535
  closure->request(0)->add_params((char *)&sample_size, sizeof(int));
  ;
Z
zhaocaibei123 已提交
536 537
  // PsService_Stub rpc_stub(GetCmdChannel(server_index));
  GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
538
  closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
539 540
  rpc_stub.service(
      closure->cntl(0), closure->request(0), closure->response(0), closure);
S
seemingwang 已提交
541 542
  return fut;
}
543

S
seemingwang 已提交
544
std::future<int32_t> GraphBrpcClient::pull_graph_list(
545 546 547 548 549 550 551 552
    uint32_t table_id,
    int type_id,
    int idx_,
    int server_index,
    int start,
    int size,
    int step,
    std::vector<FeatureNode> &res) {
S
seemingwang 已提交
553 554 555 556 557 558 559 560 561
  DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
    int ret = 0;
    auto *closure = (DownpourBrpcClosure *)done;
    if (closure->check_response(0, PS_PULL_GRAPH_LIST) != 0) {
      ret = -1;
    } else {
      auto &res_io_buffer = closure->cntl(0)->response_attachment();
      butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
      size_t bytes_size = io_buffer_itr.bytes_left();
562
      char *buffer = new char[bytes_size];
S
seemingwang 已提交
563
      io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
564
      size_t index = 0;
S
seemingwang 已提交
565 566 567 568 569 570
      while (index < bytes_size) {
        FeatureNode node;
        node.recover_from_buffer(buffer + index);
        index += node.get_size(false);
        res.push_back(node);
      }
571
      delete[] buffer;
S
seemingwang 已提交
572 573 574 575 576 577 578 579 580
    }
    closure->set_promise_value(ret);
  });
  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();
  closure->request(0)->set_cmd_id(PS_PULL_GRAPH_LIST);
  closure->request(0)->set_table_id(table_id);
  closure->request(0)->set_client_id(_client_id);
581 582
  closure->request(0)->add_params((char *)&type_id, sizeof(int));
  closure->request(0)->add_params((char *)&idx_, sizeof(int));
S
seemingwang 已提交
583 584 585
  closure->request(0)->add_params((char *)&start, sizeof(int));
  closure->request(0)->add_params((char *)&size, sizeof(int));
  closure->request(0)->add_params((char *)&step, sizeof(int));
Z
zhaocaibei123 已提交
586 587
  // PsService_Stub rpc_stub(GetCmdChannel(server_index));
  GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
588
  closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
589 590
  rpc_stub.service(
      closure->cntl(0), closure->request(0), closure->response(0), closure);
S
seemingwang 已提交
591 592
  return fut;
}
S
seemingwang 已提交
593 594

std::future<int32_t> GraphBrpcClient::set_node_feat(
595 596 597
    const uint32_t &table_id,
    int idx_,
    const std::vector<int64_t> &node_ids,
S
seemingwang 已提交
598 599 600 601
    const std::vector<std::string> &feature_names,
    const std::vector<std::vector<std::string>> &features) {
  std::vector<int> request2server;
  std::vector<int> server2request(server_size, -1);
602
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
603 604 605 606 607 608 609
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    if (server2request[server_index] == -1) {
      server2request[server_index] = request2server.size();
      request2server.push_back(server_index);
    }
  }
  size_t request_call_num = request2server.size();
610
  std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
S
seemingwang 已提交
611 612 613
  std::vector<std::vector<int>> query_idx_buckets(request_call_num);
  std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
      request_call_num);
614
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
615 616 617 618 619 620 621
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    int request_idx = server2request[server_index];
    node_id_buckets[request_idx].push_back(node_ids[query_idx]);
    query_idx_buckets[request_idx].push_back(query_idx);
    if (features_idx_buckets[request_idx].size() == 0) {
      features_idx_buckets[request_idx].resize(feature_names.size());
    }
622
    for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
S
seemingwang 已提交
623 624 625 626 627 628 629 630 631 632 633
      features_idx_buckets[request_idx][feat_idx].push_back(
          features[feat_idx][query_idx]);
    }
  }

  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
      request_call_num,
      [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
        int ret = 0;
        auto *closure = (DownpourBrpcClosure *)done;
        size_t fail_num = 0;
634
        for (size_t request_idx = 0; request_idx < request_call_num;
S
seemingwang 已提交
635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
             ++request_idx) {
          if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) !=
              0) {
            ++fail_num;
          }
          if (fail_num == request_call_num) {
            ret = -1;
          }
        }
        closure->set_promise_value(ret);
      });

  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

651
  for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
S
seemingwang 已提交
652 653 654 655 656 657
    int server_index = request2server[request_idx];
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT);
    closure->request(request_idx)->set_table_id(table_id);
    closure->request(request_idx)->set_client_id(_client_id);
    size_t node_num = node_id_buckets[request_idx].size();

658
    closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
S
seemingwang 已提交
659 660
    closure->request(request_idx)
        ->add_params((char *)node_id_buckets[request_idx].data(),
661
                     sizeof(int64_t) * node_num);
S
seemingwang 已提交
662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681
    std::string joint_feature_name =
        paddle::string::join_strings(feature_names, '\t');
    closure->request(request_idx)
        ->add_params(joint_feature_name.c_str(), joint_feature_name.size());

    // set features
    std::string set_feature = "";
    for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
      for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
        size_t feat_len =
            features_idx_buckets[request_idx][feat_idx][node_idx].size();
        set_feature.append((char *)&feat_len, sizeof(size_t));
        set_feature.append(
            features_idx_buckets[request_idx][feat_idx][node_idx].data(),
            feat_len);
      }
    }
    closure->request(request_idx)
        ->add_params(set_feature.c_str(), set_feature.size());

Z
zhaocaibei123 已提交
682
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
683
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
684 685 686 687
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
S
seemingwang 已提交
688 689 690 691 692
  }

  return fut;
}

Z
zhaocaibei123 已提交
693
int32_t GraphBrpcClient::Initialize() {
S
seemingwang 已提交
694
  // set_shard_num(_config.shard_num());
Z
zhaocaibei123 已提交
695 696
  BrpcPsClient::Initialize();
  server_size = GetServerNums();
S
seemingwang 已提交
697 698 699 700
  graph_service = NULL;
  local_channel = NULL;
  return 0;
}
701 702
}  // namespace distributed
}  // namespace paddle