common_graph_table.h 28.0 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
20

S
seemingwang 已提交
21 22 23 24 25 26
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <ctime>
#include <functional>
#include <iostream>
S
seemingwang 已提交
27
#include <list>
S
seemingwang 已提交
28
#include <map>
S
seemingwang 已提交
29 30
#include <memory>
#include <mutex>  // NOLINT
S
seemingwang 已提交
31 32 33
#include <numeric>
#include <queue>
#include <set>
S
seemingwang 已提交
34
#include <string>
S
seemingwang 已提交
35
#include <thread>
S
seemingwang 已提交
36
#include <unordered_map>
S
seemingwang 已提交
37
#include <unordered_set>
S
seemingwang 已提交
38 39
#include <utility>
#include <vector>
40

41 42
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
43
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
44
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
S
seemingwang 已提交
45
#include "paddle/fluid/string/string_helper.h"
46
#include "paddle/phi/core/utils/rw_lock.h"
47

48
#ifdef PADDLE_WITH_HETERPS
Z
zhaocaibei123 已提交
49
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
50 51
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
S
seemingwang 已提交
52 53 54 55 56 57
namespace paddle {
namespace distributed {
class GraphShard {
 public:
  size_t get_size();
  GraphShard() {}
58
  ~GraphShard();
S
seemingwang 已提交
59 60
  std::vector<Node *> &get_bucket() { return bucket; }
  std::vector<Node *> get_batch(int start, int end, int step);
D
danleifeng 已提交
61 62
  void get_ids_by_range(int start, int end, std::vector<uint64_t> *res) {
    res->reserve(res->size() + end - start);
L
lxsbupt 已提交
63
    for (int i = start; i < end && i < static_cast<int>(bucket.size()); i++) {
D
danleifeng 已提交
64
      res->emplace_back(bucket[i]->get_id());
S
seemingwang 已提交
65 66
    }
  }
D
danleifeng 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
  size_t get_all_id(std::vector<std::vector<uint64_t>> *shard_keys,
                    int slice_num) {
    int bucket_num = bucket.size();
    shard_keys->resize(slice_num);
    for (int i = 0; i < slice_num; ++i) {
      (*shard_keys)[i].reserve(bucket_num / slice_num);
    }
    for (int i = 0; i < bucket_num; i++) {
      uint64_t k = bucket[i]->get_id();
      (*shard_keys)[k % slice_num].emplace_back(k);
    }
    return bucket_num;
  }
  size_t get_all_neighbor_id(std::vector<std::vector<uint64_t>> *total_res,
                             int slice_num) {
    std::vector<uint64_t> keys;
    for (size_t i = 0; i < bucket.size(); i++) {
      size_t neighbor_size = bucket[i]->get_neighbor_size();
      size_t n = keys.size();
      keys.resize(n + neighbor_size);
      for (size_t j = 0; j < neighbor_size; j++) {
        keys[n + j] = bucket[i]->get_neighbor_id(j);
      }
    }
    return dedup2shard_keys(&keys, total_res, slice_num);
  }
  size_t get_all_feature_ids(std::vector<std::vector<uint64_t>> *total_res,
                             int slice_num) {
    std::vector<uint64_t> keys;
L
lxsbupt 已提交
96
    for (size_t i = 0; i < bucket.size(); i++) {
D
danleifeng 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110
      bucket[i]->get_feature_ids(&keys);
    }
    return dedup2shard_keys(&keys, total_res, slice_num);
  }
  size_t dedup2shard_keys(std::vector<uint64_t> *keys,
                          std::vector<std::vector<uint64_t>> *total_res,
                          int slice_num) {
    size_t num = keys->size();
    uint64_t last_key = 0;
    // sort key insert to vector
    std::sort(keys->begin(), keys->end());
    total_res->resize(slice_num);
    for (int shard_id = 0; shard_id < slice_num; ++shard_id) {
      (*total_res)[shard_id].reserve(num / slice_num);
111
    }
D
danleifeng 已提交
112 113 114 115 116 117 118 119 120
    for (size_t i = 0; i < num; ++i) {
      const uint64_t &k = (*keys)[i];
      if (i > 0 && last_key == k) {
        continue;
      }
      last_key = k;
      (*total_res)[k % slice_num].push_back(k);
    }
    return num;
121
  }
D
danleifeng 已提交
122
  GraphNode *add_graph_node(uint64_t id);
123
  GraphNode *add_graph_node(Node *node);
D
danleifeng 已提交
124 125 126
  FeatureNode *add_feature_node(uint64_t id, bool is_overlap = true);
  Node *find_node(uint64_t id);
  void delete_node(uint64_t id);
127
  void clear();
D
danleifeng 已提交
128 129
  void add_neighbor(uint64_t id, uint64_t dst_id, float weight);
  std::unordered_map<uint64_t, int> &get_node_location() {
S
seemingwang 已提交
130 131 132
    return node_location;
  }

L
lxsbupt 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
  void shrink_to_fit() {
    bucket.shrink_to_fit();
    for (size_t i = 0; i < bucket.size(); i++) {
      bucket[i]->shrink_to_fit();
    }
  }

  void merge_shard(GraphShard *&shard) {  // NOLINT
    bucket.reserve(bucket.size() + shard->bucket.size());
    for (size_t i = 0; i < shard->bucket.size(); i++) {
      auto node_id = shard->bucket[i]->get_id();
      if (node_location.find(node_id) == node_location.end()) {
        node_location[node_id] = bucket.size();
        bucket.push_back(shard->bucket[i]);
      }
    }
    shard->node_location.clear();
    shard->bucket.clear();
    delete shard;
    shard = NULL;
  }

 public:
D
danleifeng 已提交
156
  std::unordered_map<uint64_t, int> node_location;
S
seemingwang 已提交
157 158
  std::vector<Node *> bucket;
};
S
seemingwang 已提交
159 160 161 162

enum LRUResponse { ok = 0, blocked = 1, err = 2 };

struct SampleKey {
163
  int idx;
D
danleifeng 已提交
164
  uint64_t node_key;
S
seemingwang 已提交
165
  size_t sample_size;
166
  bool is_weighted;
167
  SampleKey(int _idx,
D
danleifeng 已提交
168
            uint64_t _node_key,
169
            size_t _sample_size,
170 171 172 173 174 175
            bool _is_weighted) {
    idx = _idx;
    node_key = _node_key;
    sample_size = _sample_size;
    is_weighted = _is_weighted;
  }
S
seemingwang 已提交
176
  bool operator==(const SampleKey &s) const {
177 178
    return idx == s.idx && node_key == s.node_key &&
           sample_size == s.sample_size && is_weighted == s.is_weighted;
S
seemingwang 已提交
179 180 181 182 183 184
  }
};

class SampleResult {
 public:
  size_t actual_size;
185
  std::shared_ptr<char> buffer;
L
lxsbupt 已提交
186
  SampleResult(size_t _actual_size, std::shared_ptr<char> &_buffer)  // NOLINT
187 188 189 190 191
      : actual_size(_actual_size), buffer(_buffer) {}
  SampleResult(size_t _actual_size, char *_buffer)
      : actual_size(_actual_size),
        buffer(_buffer, [](char *p) { delete[] p; }) {}
  ~SampleResult() {}
S
seemingwang 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204 205
};

template <typename K, typename V>
class LRUNode {
 public:
  LRUNode(K _key, V _data, size_t _ttl) : key(_key), data(_data), ttl(_ttl) {
    next = pre = NULL;
  }
  K key;
  V data;
  size_t ttl;
  // time to live
  LRUNode<K, V> *pre, *next;
};
206
template <typename K, typename V>
S
seemingwang 已提交
207 208
class ScaledLRU;

209
template <typename K, typename V>
S
seemingwang 已提交
210 211
class RandomSampleLRU {
 public:
L
lxsbupt 已提交
212
  explicit RandomSampleLRU(ScaledLRU<K, V> *_father) {
213 214
    father = _father;
    remove_count = 0;
S
seemingwang 已提交
215 216 217
    node_size = 0;
    node_head = node_end = NULL;
    global_ttl = father->ttl;
218
    total_diff = 0;
S
seemingwang 已提交
219 220 221 222 223 224 225 226 227 228
  }

  ~RandomSampleLRU() {
    LRUNode<K, V> *p;
    while (node_head != NULL) {
      p = node_head->next;
      delete node_head;
      node_head = p;
    }
  }
L
lxsbupt 已提交
229 230 231
  LRUResponse query(K *keys,
                    size_t length,
                    std::vector<std::pair<K, V>> &res) {  // NOLINT
S
seemingwang 已提交
232 233
    if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
      return LRUResponse::blocked;
234 235 236 237 238 239 240 241 242 243 244 245 246 247
    // pthread_rwlock_rdlock(&father->rwlock);
    int init_size = node_size - remove_count;
    process_redundant(length * 3);

    for (size_t i = 0; i < length; i++) {
      auto iter = key_map.find(keys[i]);
      if (iter != key_map.end()) {
        res.emplace_back(keys[i], iter->second->data);
        iter->second->ttl--;
        if (iter->second->ttl == 0) {
          remove(iter->second);
          if (remove_count != 0) remove_count--;
        } else {
          move_to_tail(iter->second);
S
seemingwang 已提交
248 249
        }
      }
250 251 252 253 254
    }
    total_diff += node_size - remove_count - init_size;
    if (total_diff >= 500 || total_diff < -500) {
      father->handle_size_diff(total_diff);
      total_diff = 0;
S
seemingwang 已提交
255 256 257 258 259 260 261
    }
    pthread_rwlock_unlock(&father->rwlock);
    return LRUResponse::ok;
  }
  LRUResponse insert(K *keys, V *data, size_t length) {
    if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
      return LRUResponse::blocked;
262 263 264 265 266 267 268 269 270 271 272 273
    // pthread_rwlock_rdlock(&father->rwlock);
    int init_size = node_size - remove_count;
    process_redundant(length * 3);
    for (size_t i = 0; i < length; i++) {
      auto iter = key_map.find(keys[i]);
      if (iter != key_map.end()) {
        move_to_tail(iter->second);
        iter->second->ttl = global_ttl;
        iter->second->data = data[i];
      } else {
        LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
        add_new(temp);
S
seemingwang 已提交
274 275
      }
    }
276 277 278 279 280 281
    total_diff += node_size - remove_count - init_size;
    if (total_diff >= 500 || total_diff < -500) {
      father->handle_size_diff(total_diff);
      total_diff = 0;
    }

S
seemingwang 已提交
282 283 284
    pthread_rwlock_unlock(&father->rwlock);
    return LRUResponse::ok;
  }
285 286
  void remove(LRUNode<K, V> *node) {
    fetch(node);
S
seemingwang 已提交
287
    node_size--;
288 289
    key_map.erase(node->key);
    delete node;
290 291 292
  }

  void process_redundant(int process_size) {
293
    int length = std::min(remove_count, process_size);
294 295 296
    while (length--) {
      remove(node_head);
      remove_count--;
S
seemingwang 已提交
297 298 299
    }
  }

300 301 302 303 304 305 306 307 308 309 310 311
  void move_to_tail(LRUNode<K, V> *node) {
    fetch(node);
    place_at_tail(node);
  }

  void add_new(LRUNode<K, V> *node) {
    node->ttl = global_ttl;
    place_at_tail(node);
    node_size++;
    key_map[node->key] = node;
  }
  void place_at_tail(LRUNode<K, V> *node) {
S
seemingwang 已提交
312 313 314 315 316 317 318 319 320 321 322
    if (node_end == NULL) {
      node_head = node_end = node;
      node->next = node->pre = NULL;
    } else {
      node_end->next = node;
      node->pre = node_end;
      node->next = NULL;
      node_end = node;
    }
  }

323 324 325 326 327 328 329 330 331 332 333 334 335
  void fetch(LRUNode<K, V> *node) {
    if (node->pre) {
      node->pre->next = node->next;
    } else {
      node_head = node->next;
    }
    if (node->next) {
      node->next->pre = node->pre;
    } else {
      node_end = node->pre;
    }
  }

S
seemingwang 已提交
336
 private:
337 338
  std::unordered_map<K, LRUNode<K, V> *> key_map;
  ScaledLRU<K, V> *father;
339
  size_t global_ttl, size_limit;
340
  int node_size, total_diff;
S
seemingwang 已提交
341
  LRUNode<K, V> *node_head, *node_end;
342
  friend class ScaledLRU<K, V>;
343
  int remove_count;
S
seemingwang 已提交
344 345
};

346
template <typename K, typename V>
S
seemingwang 已提交
347 348
class ScaledLRU {
 public:
349
  ScaledLRU(size_t _shard_num, size_t size_limit, size_t _ttl)
S
seemingwang 已提交
350
      : size_limit(size_limit), ttl(_ttl) {
351
    shard_num = _shard_num;
S
seemingwang 已提交
352 353 354 355
    pthread_rwlock_init(&rwlock, NULL);
    stop = false;
    thread_pool.reset(new ::ThreadPool(1));
    global_count = 0;
356 357
    lru_pool = std::vector<RandomSampleLRU<K, V>>(shard_num,
                                                  RandomSampleLRU<K, V>(this));
S
seemingwang 已提交
358 359 360 361
    shrink_job = std::thread([this]() -> void {
      while (true) {
        {
          std::unique_lock<std::mutex> lock(mutex_);
362
          cv_.wait_for(lock, std::chrono::milliseconds(20000));
S
seemingwang 已提交
363 364 365 366 367
          if (stop) {
            return;
          }
        }
        auto status =
Z
zhaocaibei123 已提交
368
            thread_pool->enqueue([this]() -> int { return Shrink(); });
S
seemingwang 已提交
369 370 371 372 373 374 375 376 377 378
        status.wait();
      }
    });
    shrink_job.detach();
  }
  ~ScaledLRU() {
    std::unique_lock<std::mutex> lock(mutex_);
    stop = true;
    cv_.notify_one();
  }
379 380 381
  LRUResponse query(size_t index,
                    K *keys,
                    size_t length,
L
lxsbupt 已提交
382
                    std::vector<std::pair<K, V>> &res) {  // NOLINT
S
seemingwang 已提交
383 384 385 386 387
    return lru_pool[index].query(keys, length, res);
  }
  LRUResponse insert(size_t index, K *keys, V *data, size_t length) {
    return lru_pool[index].insert(keys, data, length);
  }
Z
zhaocaibei123 已提交
388
  int Shrink() {
L
lxsbupt 已提交
389
    size_t node_size = 0;
S
seemingwang 已提交
390
    for (size_t i = 0; i < lru_pool.size(); i++) {
391
      node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
S
seemingwang 已提交
392 393
    }

L
lxsbupt 已提交
394
    if (node_size <= static_cast<size_t>(1.1 * size_limit) + 1) return 0;
S
seemingwang 已提交
395
    if (pthread_rwlock_wrlock(&rwlock) == 0) {
396 397 398 399
      global_count = 0;
      for (size_t i = 0; i < lru_pool.size(); i++) {
        global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
      }
L
lxsbupt 已提交
400
      if (static_cast<size_t>(global_count) > size_limit) {
401
        size_t remove = global_count - size_limit;
402
        for (size_t i = 0; i < lru_pool.size(); i++) {
403 404 405 406
          lru_pool[i].total_diff = 0;
          lru_pool[i].remove_count +=
              1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
              global_count * remove;
S
seemingwang 已提交
407 408 409 410 411 412 413
        }
      }
      pthread_rwlock_unlock(&rwlock);
      return 0;
    }
    return 0;
  }
414

S
seemingwang 已提交
415 416 417
  void handle_size_diff(int diff) {
    if (diff != 0) {
      __sync_fetch_and_add(&global_count, diff);
L
lxsbupt 已提交
418
      if (global_count > static_cast<int>(1.25 * size_limit)) {
Z
zhaocaibei123 已提交
419
        thread_pool->enqueue([this]() -> int { return Shrink(); });
S
seemingwang 已提交
420 421 422 423 424 425 426 427
      }
    }
  }

  size_t get_ttl() { return ttl; }

 private:
  pthread_rwlock_t rwlock;
428
  size_t shard_num;
S
seemingwang 已提交
429
  int global_count;
430
  size_t size_limit, total, hit;
S
seemingwang 已提交
431 432 433
  size_t ttl;
  bool stop;
  std::thread shrink_job;
434
  std::vector<RandomSampleLRU<K, V>> lru_pool;
S
seemingwang 已提交
435 436 437
  mutable std::mutex mutex_;
  std::condition_variable cv_;
  std::shared_ptr<::ThreadPool> thread_pool;
438
  friend class RandomSampleLRU<K, V>;
S
seemingwang 已提交
439 440
};

441
/*
442 443 444 445 446 447 448 449 450 451 452 453
#ifdef PADDLE_WITH_HETERPS
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphTable;
class GraphSampler {
 public:
  GraphSampler() {
    status = GraphSamplerStatus::waiting;
    thread_pool.reset(new ::ThreadPool(1));
    callback = [](std::vector<paddle::framework::GpuPsCommGraph> &res) {
      return;
    };
  }
454 455 456
  virtual int loadData(const std::string &path){
    return 0;
  }
457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
  virtual int run_graph_sampling() = 0;
  virtual int start_graph_sampling() {
    if (status != GraphSamplerStatus::waiting) {
      return -1;
    }
    std::promise<int> prom;
    std::future<int> fut = prom.get_future();
    graph_sample_task_over = thread_pool->enqueue([&prom, this]() {
      prom.set_value(0);
      status = GraphSamplerStatus::running;
      return run_graph_sampling();
    });
    return fut.get();
  }
  virtual void init(size_t gpu_num, GraphTable *graph_table,
                    std::vector<std::string> args) = 0;
  virtual void set_graph_sample_callback(
      std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
          callback) {
    this->callback = callback;
  }

  virtual int end_graph_sampling() {
    if (status == GraphSamplerStatus::running) {
      status = GraphSamplerStatus::terminating;
      return graph_sample_task_over.get();
    }
    return -1;
  }
  virtual GraphSamplerStatus get_graph_sampler_status() { return status; }

 protected:
  std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
      callback;
  std::shared_ptr<::ThreadPool> thread_pool;
  GraphSamplerStatus status;
  std::future<int> graph_sample_task_over;
  std::vector<paddle::framework::GpuPsCommGraph> sample_res;
};
#endif
497
*/
498

499
class GraphTable : public Table {
S
seemingwang 已提交
500
 public:
501 502 503 504
  GraphTable() {
    use_cache = false;
    shard_num = 0;
    rw_lock.reset(new pthread_rwlock_t());
505 506 507 508
#ifdef PADDLE_WITH_HETERPS
    next_partition = 0;
    total_memory_cost = 0;
#endif
509
  }
510
  virtual ~GraphTable();
511 512 513 514 515 516 517 518 519 520 521 522

  virtual void *GetShard(size_t shard_idx) { return 0; }

  static int32_t sparse_local_shard_num(uint32_t shard_num,
                                        uint32_t server_num) {
    if (shard_num % server_num == 0) {
      return shard_num / server_num;
    }
    size_t local_shard_num = shard_num / server_num + 1;
    return local_shard_num;
  }

523 524
  static size_t get_sparse_shard(uint32_t shard_num,
                                 uint32_t server_num,
525 526 527 528
                                 uint64_t key) {
    return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
  }

529 530 531 532
  virtual int32_t pull_graph_list(int type_id,
                                  int idx,
                                  int start,
                                  int size,
L
lxsbupt 已提交
533 534
                                  std::unique_ptr<char[]> &buffer,  // NOLINT
                                  int &actual_size,                 // NOLINT
535
                                  bool need_feature,
S
seemingwang 已提交
536 537
                                  int step);

538
  virtual int32_t random_sample_neighbors(
539
      int idx,
D
danleifeng 已提交
540
      uint64_t *node_ids,
541
      int sample_size,
L
lxsbupt 已提交
542 543
      std::vector<std::shared_ptr<char>> &buffers,  // NOLINT
      std::vector<int> &actual_sizes,               // NOLINT
544
      bool need_weight);
S
seemingwang 已提交
545

546 547 548
  int32_t random_sample_nodes(int type_id,
                              int idx,
                              int sample_size,
L
lxsbupt 已提交
549 550
                              std::unique_ptr<char[]> &buffers,  // NOLINT
                              int &actual_sizes);                // NOLINT
S
seemingwang 已提交
551 552

  virtual int32_t get_nodes_ids_by_ranges(
553 554 555
      int type_id,
      int idx,
      std::vector<std::pair<int, int>> ranges,
L
lxsbupt 已提交
556
      std::vector<uint64_t> &res);  // NOLINT
Z
zhaocaibei123 已提交
557 558
  virtual int32_t Initialize() { return 0; }
  virtual int32_t Initialize(const TableParameter &config,
559
                             const FsClientParameter &fs_config);
Z
zhaocaibei123 已提交
560
  virtual int32_t Initialize(const GraphParameter &config);
L
lxsbupt 已提交
561
  void init_worker_poll(int gpu_num);
Z
zhaocaibei123 已提交
562
  int32_t Load(const std::string &path, const std::string &param);
L
lxsbupt 已提交
563 564 565
  int32_t load_node_and_edge_file(std::string etype2files,
                                  std::string ntype2files,
                                  std::string graph_data_local_path,
D
danleifeng 已提交
566 567
                                  int part_num,
                                  bool reverse);
L
lxsbupt 已提交
568 569 570 571 572 573 574 575 576 577 578 579 580
  int32_t parse_edge_and_load(std::string etype2files,
                              std::string graph_data_local_path,
                              int part_num,
                              bool reverse);
  int32_t parse_node_and_load(std::string ntype2files,
                              std::string graph_data_local_path,
                              int part_num);
  std::string get_inverse_etype(std::string &etype);  // NOLINT
  int32_t parse_type_to_typepath(
      std::string &type2files,  // NOLINT
      std::string graph_data_local_path,
      std::vector<std::string> &res_type,                            // NOLINT
      std::unordered_map<std::string, std::string> &res_type2path);  // NOLINT
581 582
  int32_t load_edges(const std::string &path,
                     bool reverse,
583
                     const std::string &edge_type);
D
danleifeng 已提交
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
  int get_all_id(int type,
                 int slice_num,
                 std::vector<std::vector<uint64_t>> *output);
  int get_all_neighbor_id(int type,
                          int slice_num,
                          std::vector<std::vector<uint64_t>> *output);
  int get_all_id(int type,
                 int idx,
                 int slice_num,
                 std::vector<std::vector<uint64_t>> *output);
  int get_all_neighbor_id(int type_id,
                          int id,
                          int slice_num,
                          std::vector<std::vector<uint64_t>> *output);
  int get_all_feature_ids(int type,
                          int idx,
                          int slice_num,
                          std::vector<std::vector<uint64_t>> *output);
L
lxsbupt 已提交
602 603
  int get_node_embedding_ids(int slice_num,
                             std::vector<std::vector<uint64_t>> *output);
D
danleifeng 已提交
604 605 606 607 608 609 610 611 612
  int32_t load_nodes(const std::string &path,
                     std::string node_type = std::string());
  std::pair<uint64_t, uint64_t> parse_edge_file(const std::string &path,
                                                int idx,
                                                bool reverse);
  std::pair<uint64_t, uint64_t> parse_node_file(const std::string &path,
                                                const std::string &node_type,
                                                int idx);
  std::pair<uint64_t, uint64_t> parse_node_file(const std::string &path);
613
  int32_t add_graph_node(int idx,
L
lxsbupt 已提交
614 615
                         std::vector<uint64_t> &id_list,      // NOLINT
                         std::vector<bool> &is_weight_list);  // NOLINT
616

L
lxsbupt 已提交
617
  int32_t remove_graph_node(int idx, std::vector<uint64_t> &id_list);  // NOLINT
618

D
danleifeng 已提交
619 620 621
  int32_t get_server_index_by_id(uint64_t id);
  Node *find_node(int type_id, int idx, uint64_t id);
  Node *find_node(int type_id, uint64_t id);
S
seemingwang 已提交
622

L
lxsbupt 已提交
623 624
  virtual int32_t Pull(TableContext &context) { return 0; }  // NOLINT
  virtual int32_t Push(TableContext &context) { return 0; }  // NOLINT
Y
yaoxuefeng 已提交
625

626
  virtual int32_t clear_nodes(int type, int idx);
Z
zhaocaibei123 已提交
627 628 629
  virtual void Clear() {}
  virtual int32_t Flush() { return 0; }
  virtual int32_t Shrink(const std::string &param) { return 0; }
L
lxsbupt 已提交
630
  // 指定保存路径
Z
zhaocaibei123 已提交
631
  virtual int32_t Save(const std::string &path, const std::string &converter) {
S
seemingwang 已提交
632 633
    return 0;
  }
Z
zhaocaibei123 已提交
634 635
  virtual int32_t InitializeShard() { return 0; }
  virtual int32_t SetShard(size_t shard_idx, size_t server_num) {
636 637 638 639 640 641 642 643 644 645
    _shard_idx = shard_idx;
    /*
    _shard_num is not used in graph_table, this following operation is for the
    purpose of
    being compatible with base class table.
    */
    _shard_num = server_num;
    this->server_num = server_num;
    return 0;
  }
D
danleifeng 已提交
646 647 648 649 650 651
  virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index);
  virtual uint32_t get_thread_pool_index(uint64_t node_id);
  virtual int parse_feature(int idx,
                            const char *feat_str,
                            size_t len,
                            FeatureNode *node);
S
seemingwang 已提交
652

L
lxsbupt 已提交
653
  virtual int32_t get_node_feat(
654
      int idx,
D
danleifeng 已提交
655
      const std::vector<uint64_t> &node_ids,
S
seemingwang 已提交
656
      const std::vector<std::string> &feature_names,
L
lxsbupt 已提交
657 658 659 660 661 662 663
      std::vector<std::vector<std::string>> &res);  // NOLINT

  virtual int32_t set_node_feat(
      int idx,
      const std::vector<uint64_t> &node_ids,              // NOLINT
      const std::vector<std::string> &feature_names,      // NOLINT
      const std::vector<std::vector<std::string>> &res);  // NOLINT
S
seemingwang 已提交
664

S
seemingwang 已提交
665
  size_t get_server_num() { return server_num; }
L
lxsbupt 已提交
666
  void clear_graph();
667
  void clear_graph(int idx);
L
lxsbupt 已提交
668 669 670 671 672 673 674
  void clear_edge_shard();
  void clear_feature_shard();
  void feature_shrink_to_fit();
  void merge_feature_shard();
  void release_graph();
  void release_graph_edge();
  void release_graph_node();
675
  virtual int32_t make_neighbor_sample_cache(size_t size_limit, size_t ttl) {
676 677 678
    {
      std::unique_lock<std::mutex> lock(mutex_);
      if (use_cache == false) {
679
        scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult>(
680
            task_pool_size_, size_limit, ttl));
681 682 683 684 685
        use_cache = true;
      }
    }
    return 0;
  }
686
  virtual void load_node_weight(int type_id, int idx, std::string path);
687
#ifdef PADDLE_WITH_HETERPS
688 689 690 691 692 693 694 695 696 697 698 699
  // virtual int32_t start_graph_sampling() {
  //   return this->graph_sampler->start_graph_sampling();
  // }
  // virtual int32_t end_graph_sampling() {
  //   return this->graph_sampler->end_graph_sampling();
  // }
  // virtual int32_t set_graph_sample_callback(
  //     std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
  //         callback) {
  //   graph_sampler->set_graph_sample_callback(callback);
  //   return 0;
  // }
700
  virtual void make_partitions(int idx, int64_t gb_size, int device_len);
701
  virtual void export_partition_files(int idx, std::string file_path);
702
  virtual char *random_sample_neighbor_from_ssd(
703
      int idx,
D
danleifeng 已提交
704
      uint64_t id,
705 706
      int sample_size,
      const std::shared_ptr<std::mt19937_64> rng,
L
lxsbupt 已提交
707
      int &actual_size);  // NOLINT
708
  virtual int32_t add_node_to_ssd(
D
danleifeng 已提交
709
      int type_id, int idx, uint64_t src_id, char *data, int len);
710
  virtual paddle::framework::GpuPsCommGraph make_gpu_ps_graph(
L
lxsbupt 已提交
711
      int idx, const std::vector<uint64_t> &ids);
D
danleifeng 已提交
712
  virtual paddle::framework::GpuPsCommGraphFea make_gpu_ps_graph_fea(
L
lxsbupt 已提交
713
      int gpu_id, std::vector<uint64_t> &node_ids, int slot_num);  // NOLINT
714
  int32_t Load_to_ssd(const std::string &path, const std::string &param);
L
lxsbupt 已提交
715 716
  int64_t load_graph_to_memory_from_ssd(int idx,
                                        std::vector<uint64_t> &ids);  // NOLINT
717 718 719
  int32_t make_complementary_graph(int idx, int64_t byte_size);
  int32_t dump_edges_to_ssd(int idx);
  int32_t get_partition_num(int idx) { return partitions[idx].size(); }
L
lxsbupt 已提交
720 721 722 723 724
  std::vector<int> slot_feature_num_map() const {
    return slot_feature_num_map_;
  }
  std::vector<uint64_t> get_partition(size_t idx, size_t index) {
    if (idx >= partitions.size() || index >= partitions[idx].size())
D
danleifeng 已提交
725
      return std::vector<uint64_t>();
726 727
    return partitions[idx][index];
  }
728 729
  int32_t load_edges_to_ssd(const std::string &path,
                            bool reverse_edge,
730 731 732
                            const std::string &edge_type);
  int32_t load_next_partition(int idx);
  void set_search_level(int search_level) { this->search_level = search_level; }
733
  int search_level;
734
  int64_t total_memory_cost;
D
danleifeng 已提交
735
  std::vector<std::vector<std::vector<uint64_t>>> partitions;
736
  int next_partition;
737
#endif
D
danleifeng 已提交
738
  virtual int32_t add_comm_edge(int idx, uint64_t src_id, uint64_t dst_id);
739
  virtual int32_t build_sampler(int idx, std::string sample_type = "random");
L
lxsbupt 已提交
740
  void set_slot_feature_separator(const std::string &ch);
D
danleifeng 已提交
741
  void set_feature_separator(const std::string &ch);
L
lxsbupt 已提交
742 743 744 745 746 747 748 749

  void build_graph_total_keys();
  void build_graph_type_keys();

  std::vector<uint64_t> graph_total_keys_;
  std::vector<std::vector<uint64_t>> graph_type_keys_;
  std::unordered_map<int, int> type_to_index_;

750
  std::vector<std::vector<GraphShard *>> edge_shards, feature_shards;
S
seemingwang 已提交
751
  size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
L
lxsbupt 已提交
752
  int task_pool_size_ = 64;
D
danleifeng 已提交
753 754
  int load_thread_num = 160;

S
seemingwang 已提交
755 756
  const int random_sample_nodes_ranges = 3;

D
danleifeng 已提交
757
  std::vector<std::vector<std::unordered_map<uint64_t, double>>> node_weight;
758 759 760 761 762 763
  std::vector<std::vector<std::string>> feat_name;
  std::vector<std::vector<std::string>> feat_dtype;
  std::vector<std::vector<int32_t>> feat_shape;
  std::vector<std::unordered_map<std::string, int32_t>> feat_id_map;
  std::unordered_map<std::string, int> feature_to_id, edge_to_id;
  std::vector<std::string> id_to_feature, id_to_edge;
S
seemingwang 已提交
764 765
  std::string table_name;
  std::string table_type;
L
lxsbupt 已提交
766
  std::vector<std::string> edge_type_size;
S
seemingwang 已提交
767 768

  std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
L
lxsbupt 已提交
769
  std::vector<std::shared_ptr<::ThreadPool>> _cpu_worker_pool;
770
  std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
D
danleifeng 已提交
771
  std::shared_ptr<::ThreadPool> load_node_edge_task_pool;
772
  std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
D
danleifeng 已提交
773 774
  std::unordered_set<uint64_t> extra_nodes;
  std::unordered_map<uint64_t, size_t> extra_nodes_to_thread_index;
775
  bool use_cache, use_duplicate_nodes;
776 777
  int cache_size_limit;
  int cache_ttl;
778
  mutable std::mutex mutex_;
D
danleifeng 已提交
779
  bool build_sampler_on_cpu;
L
lxsbupt 已提交
780
  bool is_load_reverse_edge = false;
781 782 783
  std::shared_ptr<pthread_rwlock_t> rw_lock;
#ifdef PADDLE_WITH_HETERPS
  // paddle::framework::GpuPsGraphTable gpu_graph_table;
784
  paddle::distributed::RocksDBHandler *_db;
D
danleifeng 已提交
785 786 787
  // std::shared_ptr<::ThreadPool> graph_sample_pool;
  // std::shared_ptr<GraphSampler> graph_sampler;
  // REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
788
#endif
L
lxsbupt 已提交
789
  std::string slot_feature_separator_ = std::string(" ");
D
danleifeng 已提交
790
  std::string feature_separator_ = std::string(" ");
L
lxsbupt 已提交
791
  std::vector<int> slot_feature_num_map_;
792 793
};

794
/*
795 796 797 798 799 800 801 802 803 804 805 806 807 808
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER(GraphSampler);
class CompleteGraphSampler : public GraphSampler {
 public:
  CompleteGraphSampler() {}
  ~CompleteGraphSampler() {}
  // virtual pthread_rwlock_t *export_rw_lock();
  virtual int run_graph_sampling();
  virtual void init(size_t gpu_num, GraphTable *graph_table,
                    std::vector<std::string> args_);

 protected:
  GraphTable *graph_table;
  std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
D
danleifeng 已提交
809
  std::vector<std::vector<uint64_t>> sample_neighbors;
810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827
  // std::vector<GpuPsCommGraph> sample_res;
  // std::shared_ptr<std::mt19937_64> random;
  int gpu_num;
};

class BasicBfsGraphSampler : public GraphSampler {
 public:
  BasicBfsGraphSampler() {}
  ~BasicBfsGraphSampler() {}
  // virtual pthread_rwlock_t *export_rw_lock();
  virtual int run_graph_sampling();
  virtual void init(size_t gpu_num, GraphTable *graph_table,
                    std::vector<std::string> args_);

 protected:
  GraphTable *graph_table;
  // std::vector<std::vector<GpuPsGraphNode>> sample_nodes;
  std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
D
danleifeng 已提交
828
  std::vector<std::vector<uint64_t>> sample_neighbors;
829
  size_t gpu_num;
830
  int init_search_size, node_num_for_each_shard, edge_num_for_each_node;
831
  int rounds, interval;
D
danleifeng 已提交
832
  std::vector<std::unordered_map<uint64_t, std::vector<uint64_t>>>
833
      sample_neighbors_map;
S
seemingwang 已提交
834
};
835
#endif
836
*/
837
}  // namespace distributed
838

839
};  // namespace paddle
840 841 842 843 844 845

namespace std {

template <>
struct hash<paddle::distributed::SampleKey> {
  size_t operator()(const paddle::distributed::SampleKey &s) const {
846
    return s.idx ^ s.node_key ^ s.sample_size;
847 848
  }
};
849
}  // namespace std