graph_brpc_client.cc 29.9 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
16

S
seemingwang 已提交
17 18 19 20 21 22
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
23

S
seemingwang 已提交
24
#include "Eigen/Dense"
25 26
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
S
seemingwang 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {

void GraphPsService_Stub::service(
    ::google::protobuf::RpcController *controller,
    const ::paddle::distributed::PsRequestMessage *request,
    ::paddle::distributed::PsResponseMessage *response,
    ::google::protobuf::Closure *done) {
  if (graph_service != NULL && local_channel == channel()) {
    // VLOG(0)<<"use local";
    task_pool->enqueue([this, controller, request, response, done]() -> int {
      this->graph_service->service(controller, request, response, done);
      return 0;
    });
  } else {
    // VLOG(0)<<"use server";
    PsService_Stub::service(controller, request, response, done);
  }
}

49
int GraphBrpcClient::get_server_index_by_id(int64_t id) {
S
seemingwang 已提交
50 51 52 53 54 55 56 57
  int shard_num = get_shard_num();
  int shard_per_server = shard_num % server_size == 0
                             ? shard_num / server_size
                             : shard_num / server_size + 1;
  return id % shard_num / shard_per_server;
}

std::future<int32_t> GraphBrpcClient::get_node_feat(
58 59 60
    const uint32_t &table_id,
    int idx_,
    const std::vector<int64_t> &node_ids,
S
seemingwang 已提交
61 62 63 64
    const std::vector<std::string> &feature_names,
    std::vector<std::vector<std::string>> &res) {
  std::vector<int> request2server;
  std::vector<int> server2request(server_size, -1);
65
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
66 67 68 69 70 71 72
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    if (server2request[server_index] == -1) {
      server2request[server_index] = request2server.size();
      request2server.push_back(server_index);
    }
  }
  size_t request_call_num = request2server.size();
73
  std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
S
seemingwang 已提交
74
  std::vector<std::vector<int>> query_idx_buckets(request_call_num);
75
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
76 77 78 79 80 81 82 83 84 85
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    int request_idx = server2request[server_index];
    node_id_buckets[request_idx].push_back(node_ids[query_idx]);
    query_idx_buckets[request_idx].push_back(query_idx);
  }

  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
      request_call_num,
      [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
        int ret = 0;
86
        auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
87
        size_t fail_num = 0;
88
        for (size_t request_idx = 0; request_idx < request_call_num;
S
seemingwang 已提交
89
             ++request_idx) {
90 91
          if (closure->check_response(request_idx, PS_GRAPH_GET_NODE_FEAT) !=
              0) {
S
seemingwang 已提交
92 93 94 95 96 97 98 99
            ++fail_num;
          } else {
            auto &res_io_buffer =
                closure->cntl(request_idx)->response_attachment();
            butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
            size_t bytes_size = io_buffer_itr.bytes_left();
            std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
            char *buffer = buffer_wrapper.get();
100 101
            io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(buffer),
                                           bytes_size);
S
seemingwang 已提交
102 103 104 105 106 107 108

            for (size_t feat_idx = 0; feat_idx < feature_names.size();
                 ++feat_idx) {
              for (size_t node_idx = 0;
                   node_idx < query_idx_buckets.at(request_idx).size();
                   ++node_idx) {
                int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
109
                size_t feat_len = *reinterpret_cast<size_t *>(buffer);
S
seemingwang 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
                buffer += sizeof(size_t);
                auto feature = std::string(buffer, feat_len);
                res[feat_idx][query_idx] = feature;
                buffer += feat_len;
              }
            }
          }
          if (fail_num == request_call_num) {
            ret = -1;
          }
        }
        closure->set_promise_value(ret);
      });

  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

128
  for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
S
seemingwang 已提交
129 130 131
    int server_index = request2server[request_idx];
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT);
    closure->request(request_idx)->set_table_id(table_id);
132

S
seemingwang 已提交
133 134 135 136
    closure->request(request_idx)->set_client_id(_client_id);
    size_t node_num = node_id_buckets[request_idx].size();

    closure->request(request_idx)
137 138 139 140 141
        ->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));
    closure->request(request_idx)
        ->add_params(
            reinterpret_cast<char *>(node_id_buckets[request_idx].data()),
            sizeof(int64_t) * node_num);
S
seemingwang 已提交
142 143 144 145 146
    std::string joint_feature_name =
        paddle::string::join_strings(feature_names, '\t');
    closure->request(request_idx)
        ->add_params(joint_feature_name.c_str(), joint_feature_name.size());

Z
zhaocaibei123 已提交
147
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
148
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
149 150 151 152
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
S
seemingwang 已提交
153 154 155 156
  }

  return fut;
}
157

158
std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id,
159 160
                                                  int type_id,
                                                  int idx_) {
161
  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
162
      server_size, [&, server_size = this->server_size](void *done) {
163
        int ret = 0;
164
        auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
        size_t fail_num = 0;
        for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
          if (closure->check_response(request_idx, PS_GRAPH_CLEAR) != 0) {
            ++fail_num;
            break;
          }
        }
        ret = fail_num == 0 ? 0 : -1;
        closure->set_promise_value(ret);
      });
  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();
  for (size_t i = 0; i < server_size; i++) {
    int server_index = i;
    closure->request(server_index)->set_cmd_id(PS_GRAPH_CLEAR);
    closure->request(server_index)->set_table_id(table_id);
    closure->request(server_index)->set_client_id(_client_id);
183 184 185 186
    closure->request(server_index)
        ->add_params(reinterpret_cast<char *>(&type_id), sizeof(int));
    closure->request(server_index)
        ->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));
Z
zhaocaibei123 已提交
187
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
188 189 190
    closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
    rpc_stub.service(closure->cntl(server_index),
                     closure->request(server_index),
191 192
                     closure->response(server_index),
                     closure);
193 194 195 196
  }
  return fut;
}
std::future<int32_t> GraphBrpcClient::add_graph_node(
197 198 199
    uint32_t table_id,
    int idx_,
    std::vector<int64_t> &node_id_list,
200
    std::vector<bool> &is_weighted_list) {
201
  std::vector<std::vector<int64_t>> request_bucket;
202 203 204 205 206 207 208 209 210
  std::vector<std::vector<bool>> is_weighted_bucket;
  bool add_weight = is_weighted_list.size() > 0;
  std::vector<int> server_index_arr;
  std::vector<int> index_mapping(server_size, -1);
  for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
    int server_index = get_server_index_by_id(node_id_list[query_idx]);
    if (index_mapping[server_index] == -1) {
      index_mapping[server_index] = request_bucket.size();
      server_index_arr.push_back(server_index);
211
      request_bucket.push_back(std::vector<int64_t>());
212 213 214 215 216 217 218 219 220 221 222 223 224
      if (add_weight) is_weighted_bucket.push_back(std::vector<bool>());
    }
    request_bucket[index_mapping[server_index]].push_back(
        node_id_list[query_idx]);
    if (add_weight)
      is_weighted_bucket[index_mapping[server_index]].push_back(
          query_idx < is_weighted_list.size() ? is_weighted_list[query_idx]
                                              : false);
  }
  size_t request_call_num = request_bucket.size();
  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
      request_call_num, [&, request_call_num](void *done) {
        int ret = 0;
225
        auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
        size_t fail_num = 0;
        for (size_t request_idx = 0; request_idx < request_call_num;
             ++request_idx) {
          if (closure->check_response(request_idx, PS_GRAPH_ADD_GRAPH_NODE) !=
              0) {
            ++fail_num;
          }
        }
        ret = fail_num == request_call_num ? -1 : 0;
        closure->set_promise_value(ret);
      });
  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

  for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
    int server_index = server_index_arr[request_idx];
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_ADD_GRAPH_NODE);
    closure->request(request_idx)->set_table_id(table_id);
    closure->request(request_idx)->set_client_id(_client_id);
    size_t node_num = request_bucket[request_idx].size();
    closure->request(request_idx)
248 249 250 251 252
        ->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));
    closure->request(request_idx)
        ->add_params(
            reinterpret_cast<char *>(request_bucket[request_idx].data()),
            sizeof(int64_t) * node_num);
253 254 255 256 257
    if (add_weight) {
      bool weighted[is_weighted_bucket[request_idx].size() + 1];
      for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++)
        weighted[j] = is_weighted_bucket[request_idx][j];
      closure->request(request_idx)
258
          ->add_params(reinterpret_cast<char *>(weighted),
259 260
                       sizeof(bool) * is_weighted_bucket[request_idx].size());
    }
Z
zhaocaibei123 已提交
261 262
    // PsService_Stub rpc_stub(GetCmdChannel(server_index));
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
263
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
264 265 266 267
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
268 269 270 271
  }
  return fut;
}
std::future<int32_t> GraphBrpcClient::remove_graph_node(
272
    uint32_t table_id, int idx_, std::vector<int64_t> &node_id_list) {
273
  std::vector<std::vector<int64_t>> request_bucket;
274 275 276 277 278 279 280
  std::vector<int> server_index_arr;
  std::vector<int> index_mapping(server_size, -1);
  for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
    int server_index = get_server_index_by_id(node_id_list[query_idx]);
    if (index_mapping[server_index] == -1) {
      index_mapping[server_index] = request_bucket.size();
      server_index_arr.push_back(server_index);
281
      request_bucket.push_back(std::vector<int64_t>());
282 283 284 285 286 287 288 289
    }
    request_bucket[index_mapping[server_index]].push_back(
        node_id_list[query_idx]);
  }
  size_t request_call_num = request_bucket.size();
  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
      request_call_num, [&, request_call_num](void *done) {
        int ret = 0;
290
        auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
291
        size_t fail_num = 0;
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
        for (size_t request_idx = 0; request_idx < request_call_num;
             ++request_idx) {
          if (closure->check_response(request_idx,
                                      PS_GRAPH_REMOVE_GRAPH_NODE) != 0) {
            ++fail_num;
          }
        }
        ret = fail_num == request_call_num ? -1 : 0;
        closure->set_promise_value(ret);
      });
  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

  for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
    int server_index = server_index_arr[request_idx];
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_REMOVE_GRAPH_NODE);
    closure->request(request_idx)->set_table_id(table_id);
    closure->request(request_idx)->set_client_id(_client_id);
    size_t node_num = request_bucket[request_idx].size();

    closure->request(request_idx)
314 315 316 317 318
        ->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));
    closure->request(request_idx)
        ->add_params(
            reinterpret_cast<char *>(request_bucket[request_idx].data()),
            sizeof(int64_t) * node_num);
Z
zhaocaibei123 已提交
319 320
    // PsService_Stub rpc_stub(GetCmdChannel(server_index));
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
321
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
322 323 324 325
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
326 327 328
  }
  return fut;
}
S
seemingwang 已提交
329
// char* &buffer,int &actual_size
330
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
331 332 333 334
    uint32_t table_id,
    int idx_,
    std::vector<int64_t> node_ids,
    int sample_size,
335 336
    // std::vector<std::vector<std::pair<int64_t, float>>> &res,
    std::vector<std::vector<int64_t>> &res,
337 338
    std::vector<std::vector<float>> &res_weight,
    bool need_weight,
S
seemingwang 已提交
339 340 341
    int server_index) {
  if (server_index != -1) {
    res.resize(node_ids.size());
342 343 344
    if (need_weight) {
      res_weight.resize(node_ids.size());
    }
S
seemingwang 已提交
345 346
    DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
      int ret = 0;
347
      auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
S
seemingwang 已提交
348 349 350 351 352 353 354 355 356
      if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER) !=
          0) {
        ret = -1;
      } else {
        auto &res_io_buffer = closure->cntl(0)->response_attachment();
        butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
        size_t bytes_size = io_buffer_itr.bytes_left();
        std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
        char *buffer = buffer_wrapper.get();
357 358
        io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(buffer),
                                       bytes_size);
S
seemingwang 已提交
359

360 361
        size_t node_num = *reinterpret_cast<size_t *>(buffer);
        int *actual_sizes = reinterpret_cast<int *>(buffer + sizeof(size_t));
S
seemingwang 已提交
362 363 364 365 366 367 368
        char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;

        int offset = 0;
        for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
          int actual_size = actual_sizes[node_idx];
          int start = 0;
          while (start < actual_size) {
369
            res[node_idx].emplace_back(
370
                *reinterpret_cast<int64_t *>(node_buffer + offset + start));
371 372 373
            start += GraphNode::id_size;
            if (need_weight) {
              res_weight[node_idx].emplace_back(
374
                  *reinterpret_cast<float *>(node_buffer + offset + start));
375 376
              start += GraphNode::weight_size;
            }
S
seemingwang 已提交
377 378 379 380 381 382 383 384 385
          }
          offset += actual_size;
        }
      }
      closure->set_promise_value(ret);
    });
    auto promise = std::make_shared<std::promise<int32_t>>();
    closure->add_promise(promise);
    std::future<int> fut = promise->get_future();
386

S
seemingwang 已提交
387 388 389
    closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER);
    closure->request(0)->set_table_id(table_id);
    closure->request(0)->set_client_id(_client_id);
390 391 392
    closure->request(0)->add_params(reinterpret_cast<char *>(&idx_),
                                    sizeof(int));
    closure->request(0)->add_params(reinterpret_cast<char *>(node_ids.data()),
393
                                    sizeof(int64_t) * node_ids.size());
394 395 396 397 398
    closure->request(0)->add_params(reinterpret_cast<char *>(&sample_size),
                                    sizeof(int));
    closure->request(0)->add_params(reinterpret_cast<char *>(&need_weight),
                                    sizeof(bool));

Z
zhaocaibei123 已提交
399 400
    // PsService_Stub rpc_stub(GetCmdChannel(server_index));
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
401
    closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
402 403
    rpc_stub.service(
        closure->cntl(0), closure->request(0), closure->response(0), closure);
S
seemingwang 已提交
404 405
    return fut;
  }
S
seemingwang 已提交
406 407 408
  std::vector<int> request2server;
  std::vector<int> server2request(server_size, -1);
  res.clear();
409
  res_weight.clear();
410
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
411 412 413 414 415
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    if (server2request[server_index] == -1) {
      server2request[server_index] = request2server.size();
      request2server.push_back(server_index);
    }
416
    // res.push_back(std::vector<std::pair<int64_t, float>>());
417 418 419 420
    res.push_back({});
    if (need_weight) {
      res_weight.push_back({});
    }
S
seemingwang 已提交
421 422
  }
  size_t request_call_num = request2server.size();
423
  std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
S
seemingwang 已提交
424
  std::vector<std::vector<int>> query_idx_buckets(request_call_num);
425
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
426 427 428 429 430 431 432 433 434 435
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    int request_idx = server2request[server_index];
    node_id_buckets[request_idx].push_back(node_ids[query_idx]);
    query_idx_buckets[request_idx].push_back(query_idx);
  }

  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
      request_call_num,
      [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
        int ret = 0;
436
        auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
437 438
        size_t fail_num = 0;
        for (size_t request_idx = 0; request_idx < request_call_num;
S
seemingwang 已提交
439
             ++request_idx) {
440 441
          if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
              0) {
S
seemingwang 已提交
442 443 444 445 446 447 448 449
            ++fail_num;
          } else {
            auto &res_io_buffer =
                closure->cntl(request_idx)->response_attachment();
            butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
            size_t bytes_size = io_buffer_itr.bytes_left();
            std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
            char *buffer = buffer_wrapper.get();
450 451
            io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(buffer),
                                           bytes_size);
S
seemingwang 已提交
452

453 454 455
            size_t node_num = *reinterpret_cast<size_t *>(buffer);
            int *actual_sizes =
                reinterpret_cast<int *>(buffer + sizeof(size_t));
S
seemingwang 已提交
456 457 458 459 460 461 462 463 464
            char *node_buffer =
                buffer + sizeof(size_t) + sizeof(int) * node_num;

            int offset = 0;
            for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
              int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
              int actual_size = actual_sizes[node_idx];
              int start = 0;
              while (start < actual_size) {
465
                res[query_idx].emplace_back(
466
                    *reinterpret_cast<int64_t *>(node_buffer + offset + start));
467 468 469
                start += GraphNode::id_size;
                if (need_weight) {
                  res_weight[query_idx].emplace_back(
470
                      *reinterpret_cast<float *>(node_buffer + offset + start));
471 472
                  start += GraphNode::weight_size;
                }
S
seemingwang 已提交
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
              }
              offset += actual_size;
            }
          }
          if (fail_num == request_call_num) {
            ret = -1;
          }
        }
        closure->set_promise_value(ret);
      });

  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

488
  for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
S
seemingwang 已提交
489
    int server_index = request2server[request_idx];
490
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
S
seemingwang 已提交
491 492 493 494 495
    closure->request(request_idx)->set_table_id(table_id);
    closure->request(request_idx)->set_client_id(_client_id);
    size_t node_num = node_id_buckets[request_idx].size();

    closure->request(request_idx)
496
        ->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));
S
seemingwang 已提交
497
    closure->request(request_idx)
498 499 500
        ->add_params(
            reinterpret_cast<char *>(node_id_buckets[request_idx].data()),
            sizeof(int64_t) * node_num);
501
    closure->request(request_idx)
502 503 504
        ->add_params(reinterpret_cast<char *>(&sample_size), sizeof(int));
    closure->request(request_idx)
        ->add_params(reinterpret_cast<char *>(&need_weight), sizeof(bool));
Z
zhaocaibei123 已提交
505 506
    // PsService_Stub rpc_stub(GetCmdChannel(server_index));
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
507
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
508 509 510 511
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
S
seemingwang 已提交
512 513 514 515 516
  }

  return fut;
}
std::future<int32_t> GraphBrpcClient::random_sample_nodes(
517 518 519 520 521
    uint32_t table_id,
    int type_id,
    int idx_,
    int server_index,
    int sample_size,
522
    std::vector<int64_t> &ids) {
S
seemingwang 已提交
523 524
  DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
    int ret = 0;
525
    auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
S
seemingwang 已提交
526 527 528 529 530 531
    if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES) != 0) {
      ret = -1;
    } else {
      auto &res_io_buffer = closure->cntl(0)->response_attachment();
      butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
      size_t bytes_size = io_buffer_itr.bytes_left();
532
      char *buffer = new char[bytes_size];
533
      size_t index = 0;
S
seemingwang 已提交
534
      while (index < bytes_size) {
535
        ids.push_back(*reinterpret_cast<int64_t *>(buffer + index));
S
seemingwang 已提交
536 537
        index += GraphNode::id_size;
      }
538
      delete[] buffer;
S
seemingwang 已提交
539 540 541 542 543 544
    }
    closure->set_promise_value(ret);
  });
  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();
545

S
seemingwang 已提交
546 547 548
  closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES);
  closure->request(0)->set_table_id(table_id);
  closure->request(0)->set_client_id(_client_id);
549 550 551 552 553 554
  closure->request(0)->add_params(reinterpret_cast<char *>(&type_id),
                                  sizeof(int));
  closure->request(0)->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));
  closure->request(0)->add_params(reinterpret_cast<char *>(&sample_size),
                                  sizeof(int));

Z
zhaocaibei123 已提交
555 556
  // PsService_Stub rpc_stub(GetCmdChannel(server_index));
  GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
557
  closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
558 559
  rpc_stub.service(
      closure->cntl(0), closure->request(0), closure->response(0), closure);
S
seemingwang 已提交
560 561
  return fut;
}
562

S
seemingwang 已提交
563
std::future<int32_t> GraphBrpcClient::pull_graph_list(
564 565 566 567 568 569 570 571
    uint32_t table_id,
    int type_id,
    int idx_,
    int server_index,
    int start,
    int size,
    int step,
    std::vector<FeatureNode> &res) {
S
seemingwang 已提交
572 573
  DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
    int ret = 0;
574
    auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
S
seemingwang 已提交
575 576 577 578 579 580
    if (closure->check_response(0, PS_PULL_GRAPH_LIST) != 0) {
      ret = -1;
    } else {
      auto &res_io_buffer = closure->cntl(0)->response_attachment();
      butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
      size_t bytes_size = io_buffer_itr.bytes_left();
581
      char *buffer = new char[bytes_size];
582 583
      io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(buffer),
                                     bytes_size);
584
      size_t index = 0;
S
seemingwang 已提交
585 586 587 588 589 590
      while (index < bytes_size) {
        FeatureNode node;
        node.recover_from_buffer(buffer + index);
        index += node.get_size(false);
        res.push_back(node);
      }
591
      delete[] buffer;
S
seemingwang 已提交
592 593 594 595 596 597 598 599 600
    }
    closure->set_promise_value(ret);
  });
  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();
  closure->request(0)->set_cmd_id(PS_PULL_GRAPH_LIST);
  closure->request(0)->set_table_id(table_id);
  closure->request(0)->set_client_id(_client_id);
601 602 603 604 605 606 607
  closure->request(0)->add_params(reinterpret_cast<char *>(&type_id),
                                  sizeof(int));
  closure->request(0)->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));
  closure->request(0)->add_params(reinterpret_cast<char *>(&start),
                                  sizeof(int));
  closure->request(0)->add_params(reinterpret_cast<char *>(&size), sizeof(int));
  closure->request(0)->add_params(reinterpret_cast<char *>(&step), sizeof(int));
Z
zhaocaibei123 已提交
608 609
  // PsService_Stub rpc_stub(GetCmdChannel(server_index));
  GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
610
  closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
611 612
  rpc_stub.service(
      closure->cntl(0), closure->request(0), closure->response(0), closure);
S
seemingwang 已提交
613 614
  return fut;
}
S
seemingwang 已提交
615 616

std::future<int32_t> GraphBrpcClient::set_node_feat(
617 618 619
    const uint32_t &table_id,
    int idx_,
    const std::vector<int64_t> &node_ids,
S
seemingwang 已提交
620 621 622 623
    const std::vector<std::string> &feature_names,
    const std::vector<std::vector<std::string>> &features) {
  std::vector<int> request2server;
  std::vector<int> server2request(server_size, -1);
624
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
625 626 627 628 629 630 631
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    if (server2request[server_index] == -1) {
      server2request[server_index] = request2server.size();
      request2server.push_back(server_index);
    }
  }
  size_t request_call_num = request2server.size();
632
  std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
S
seemingwang 已提交
633 634 635
  std::vector<std::vector<int>> query_idx_buckets(request_call_num);
  std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
      request_call_num);
636
  for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
S
seemingwang 已提交
637 638 639 640 641 642 643
    int server_index = get_server_index_by_id(node_ids[query_idx]);
    int request_idx = server2request[server_index];
    node_id_buckets[request_idx].push_back(node_ids[query_idx]);
    query_idx_buckets[request_idx].push_back(query_idx);
    if (features_idx_buckets[request_idx].size() == 0) {
      features_idx_buckets[request_idx].resize(feature_names.size());
    }
644
    for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
S
seemingwang 已提交
645 646 647 648 649 650 651 652 653
      features_idx_buckets[request_idx][feat_idx].push_back(
          features[feat_idx][query_idx]);
    }
  }

  DownpourBrpcClosure *closure = new DownpourBrpcClosure(
      request_call_num,
      [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
        int ret = 0;
654
        auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
S
seemingwang 已提交
655
        size_t fail_num = 0;
656
        for (size_t request_idx = 0; request_idx < request_call_num;
S
seemingwang 已提交
657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
             ++request_idx) {
          if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) !=
              0) {
            ++fail_num;
          }
          if (fail_num == request_call_num) {
            ret = -1;
          }
        }
        closure->set_promise_value(ret);
      });

  auto promise = std::make_shared<std::promise<int32_t>>();
  closure->add_promise(promise);
  std::future<int> fut = promise->get_future();

673
  for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
S
seemingwang 已提交
674 675 676 677 678 679 680
    int server_index = request2server[request_idx];
    closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT);
    closure->request(request_idx)->set_table_id(table_id);
    closure->request(request_idx)->set_client_id(_client_id);
    size_t node_num = node_id_buckets[request_idx].size();

    closure->request(request_idx)
681 682 683 684 685
        ->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));
    closure->request(request_idx)
        ->add_params(
            reinterpret_cast<char *>(node_id_buckets[request_idx].data()),
            sizeof(int64_t) * node_num);
S
seemingwang 已提交
686 687 688 689 690 691 692 693 694 695 696
    std::string joint_feature_name =
        paddle::string::join_strings(feature_names, '\t');
    closure->request(request_idx)
        ->add_params(joint_feature_name.c_str(), joint_feature_name.size());

    // set features
    std::string set_feature = "";
    for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
      for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
        size_t feat_len =
            features_idx_buckets[request_idx][feat_idx][node_idx].size();
697
        set_feature.append(reinterpret_cast<char *>(&feat_len), sizeof(size_t));
S
seemingwang 已提交
698 699 700 701 702 703 704 705
        set_feature.append(
            features_idx_buckets[request_idx][feat_idx][node_idx].data(),
            feat_len);
      }
    }
    closure->request(request_idx)
        ->add_params(set_feature.c_str(), set_feature.size());

Z
zhaocaibei123 已提交
706
    GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
S
seemingwang 已提交
707
    closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
708 709 710 711
    rpc_stub.service(closure->cntl(request_idx),
                     closure->request(request_idx),
                     closure->response(request_idx),
                     closure);
S
seemingwang 已提交
712 713 714 715 716
  }

  return fut;
}

Z
zhaocaibei123 已提交
717
int32_t GraphBrpcClient::Initialize() {
S
seemingwang 已提交
718
  // set_shard_num(_config.shard_num());
Z
zhaocaibei123 已提交
719 720
  BrpcPsClient::Initialize();
  server_size = GetServerNums();
S
seemingwang 已提交
721 722 723 724
  graph_service = NULL;
  local_channel = NULL;
  return 0;
}
725 726
}  // namespace distributed
}  // namespace paddle