common_graph_table.h 28.4 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
20

S
seemingwang 已提交
21 22 23 24 25 26
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <ctime>
#include <functional>
#include <iostream>
S
seemingwang 已提交
27
#include <list>
S
seemingwang 已提交
28
#include <map>
S
seemingwang 已提交
29 30
#include <memory>
#include <mutex>  // NOLINT
S
seemingwang 已提交
31 32 33
#include <numeric>
#include <queue>
#include <set>
S
seemingwang 已提交
34
#include <string>
S
seemingwang 已提交
35
#include <thread>
S
seemingwang 已提交
36
#include <unordered_map>
S
seemingwang 已提交
37
#include <unordered_set>
S
seemingwang 已提交
38 39
#include <utility>
#include <vector>
40

41 42
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
43
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
44
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
S
seemingwang 已提交
45
#include "paddle/fluid/string/string_helper.h"
46
#include "paddle/phi/core/utils/rw_lock.h"
47

48
#ifdef PADDLE_WITH_HETERPS
Z
zhaocaibei123 已提交
49
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
50 51
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
S
seemingwang 已提交
52 53 54 55 56 57
namespace paddle {
namespace distributed {
class GraphShard {
 public:
  size_t get_size();
  GraphShard() {}
58
  ~GraphShard();
S
seemingwang 已提交
59 60
  std::vector<Node *> &get_bucket() { return bucket; }
  std::vector<Node *> get_batch(int start, int end, int step);
D
danleifeng 已提交
61 62
  void get_ids_by_range(int start, int end, std::vector<uint64_t> *res) {
    res->reserve(res->size() + end - start);
L
lxsbupt 已提交
63
    for (int i = start; i < end && i < static_cast<int>(bucket.size()); i++) {
D
danleifeng 已提交
64
      res->emplace_back(bucket[i]->get_id());
S
seemingwang 已提交
65 66
    }
  }
D
danleifeng 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
  size_t get_all_id(std::vector<std::vector<uint64_t>> *shard_keys,
                    int slice_num) {
    int bucket_num = bucket.size();
    shard_keys->resize(slice_num);
    for (int i = 0; i < slice_num; ++i) {
      (*shard_keys)[i].reserve(bucket_num / slice_num);
    }
    for (int i = 0; i < bucket_num; i++) {
      uint64_t k = bucket[i]->get_id();
      (*shard_keys)[k % slice_num].emplace_back(k);
    }
    return bucket_num;
  }
  size_t get_all_neighbor_id(std::vector<std::vector<uint64_t>> *total_res,
                             int slice_num) {
    std::vector<uint64_t> keys;
    for (size_t i = 0; i < bucket.size(); i++) {
      size_t neighbor_size = bucket[i]->get_neighbor_size();
      size_t n = keys.size();
      keys.resize(n + neighbor_size);
      for (size_t j = 0; j < neighbor_size; j++) {
        keys[n + j] = bucket[i]->get_neighbor_id(j);
      }
    }
    return dedup2shard_keys(&keys, total_res, slice_num);
  }
  size_t get_all_feature_ids(std::vector<std::vector<uint64_t>> *total_res,
                             int slice_num) {
    std::vector<uint64_t> keys;
L
lxsbupt 已提交
96
    for (size_t i = 0; i < bucket.size(); i++) {
D
danleifeng 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110
      bucket[i]->get_feature_ids(&keys);
    }
    return dedup2shard_keys(&keys, total_res, slice_num);
  }
  size_t dedup2shard_keys(std::vector<uint64_t> *keys,
                          std::vector<std::vector<uint64_t>> *total_res,
                          int slice_num) {
    size_t num = keys->size();
    uint64_t last_key = 0;
    // sort key insert to vector
    std::sort(keys->begin(), keys->end());
    total_res->resize(slice_num);
    for (int shard_id = 0; shard_id < slice_num; ++shard_id) {
      (*total_res)[shard_id].reserve(num / slice_num);
111
    }
D
danleifeng 已提交
112 113 114 115 116 117 118 119 120
    for (size_t i = 0; i < num; ++i) {
      const uint64_t &k = (*keys)[i];
      if (i > 0 && last_key == k) {
        continue;
      }
      last_key = k;
      (*total_res)[k % slice_num].push_back(k);
    }
    return num;
121
  }
D
danleifeng 已提交
122
  GraphNode *add_graph_node(uint64_t id);
123
  GraphNode *add_graph_node(Node *node);
D
danleifeng 已提交
124 125 126
  FeatureNode *add_feature_node(uint64_t id, bool is_overlap = true);
  Node *find_node(uint64_t id);
  void delete_node(uint64_t id);
127
  void clear();
D
danleifeng 已提交
128 129
  void add_neighbor(uint64_t id, uint64_t dst_id, float weight);
  std::unordered_map<uint64_t, int> &get_node_location() {
S
seemingwang 已提交
130 131 132
    return node_location;
  }

L
lxsbupt 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
  void shrink_to_fit() {
    bucket.shrink_to_fit();
    for (size_t i = 0; i < bucket.size(); i++) {
      bucket[i]->shrink_to_fit();
    }
  }

  void merge_shard(GraphShard *&shard) {  // NOLINT
    bucket.reserve(bucket.size() + shard->bucket.size());
    for (size_t i = 0; i < shard->bucket.size(); i++) {
      auto node_id = shard->bucket[i]->get_id();
      if (node_location.find(node_id) == node_location.end()) {
        node_location[node_id] = bucket.size();
        bucket.push_back(shard->bucket[i]);
      }
    }
    shard->node_location.clear();
    shard->bucket.clear();
    delete shard;
    shard = NULL;
  }

 public:
D
danleifeng 已提交
156
  std::unordered_map<uint64_t, int> node_location;
S
seemingwang 已提交
157 158
  std::vector<Node *> bucket;
};
S
seemingwang 已提交
159 160 161 162

enum LRUResponse { ok = 0, blocked = 1, err = 2 };

struct SampleKey {
163
  int idx;
D
danleifeng 已提交
164
  uint64_t node_key;
S
seemingwang 已提交
165
  size_t sample_size;
166
  bool is_weighted;
167
  SampleKey(int _idx,
D
danleifeng 已提交
168
            uint64_t _node_key,
169
            size_t _sample_size,
170 171 172 173 174 175
            bool _is_weighted) {
    idx = _idx;
    node_key = _node_key;
    sample_size = _sample_size;
    is_weighted = _is_weighted;
  }
S
seemingwang 已提交
176
  bool operator==(const SampleKey &s) const {
177 178
    return idx == s.idx && node_key == s.node_key &&
           sample_size == s.sample_size && is_weighted == s.is_weighted;
S
seemingwang 已提交
179 180 181 182 183 184
  }
};

class SampleResult {
 public:
  size_t actual_size;
185
  std::shared_ptr<char> buffer;
L
lxsbupt 已提交
186
  SampleResult(size_t _actual_size, std::shared_ptr<char> &_buffer)  // NOLINT
187 188 189 190 191
      : actual_size(_actual_size), buffer(_buffer) {}
  SampleResult(size_t _actual_size, char *_buffer)
      : actual_size(_actual_size),
        buffer(_buffer, [](char *p) { delete[] p; }) {}
  ~SampleResult() {}
S
seemingwang 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204 205
};

template <typename K, typename V>
class LRUNode {
 public:
  LRUNode(K _key, V _data, size_t _ttl) : key(_key), data(_data), ttl(_ttl) {
    next = pre = NULL;
  }
  K key;
  V data;
  size_t ttl;
  // time to live
  LRUNode<K, V> *pre, *next;
};
206
template <typename K, typename V>
S
seemingwang 已提交
207 208
class ScaledLRU;

209
template <typename K, typename V>
S
seemingwang 已提交
210 211
class RandomSampleLRU {
 public:
L
lxsbupt 已提交
212
  explicit RandomSampleLRU(ScaledLRU<K, V> *_father) {
213 214
    father = _father;
    remove_count = 0;
S
seemingwang 已提交
215 216 217
    node_size = 0;
    node_head = node_end = NULL;
    global_ttl = father->ttl;
218
    total_diff = 0;
S
seemingwang 已提交
219 220 221 222 223 224 225 226 227 228
  }

  ~RandomSampleLRU() {
    LRUNode<K, V> *p;
    while (node_head != NULL) {
      p = node_head->next;
      delete node_head;
      node_head = p;
    }
  }
L
lxsbupt 已提交
229 230 231
  LRUResponse query(K *keys,
                    size_t length,
                    std::vector<std::pair<K, V>> &res) {  // NOLINT
S
seemingwang 已提交
232 233
    if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
      return LRUResponse::blocked;
234 235 236 237 238 239 240 241 242 243 244 245 246 247
    // pthread_rwlock_rdlock(&father->rwlock);
    int init_size = node_size - remove_count;
    process_redundant(length * 3);

    for (size_t i = 0; i < length; i++) {
      auto iter = key_map.find(keys[i]);
      if (iter != key_map.end()) {
        res.emplace_back(keys[i], iter->second->data);
        iter->second->ttl--;
        if (iter->second->ttl == 0) {
          remove(iter->second);
          if (remove_count != 0) remove_count--;
        } else {
          move_to_tail(iter->second);
S
seemingwang 已提交
248 249
        }
      }
250 251 252 253 254
    }
    total_diff += node_size - remove_count - init_size;
    if (total_diff >= 500 || total_diff < -500) {
      father->handle_size_diff(total_diff);
      total_diff = 0;
S
seemingwang 已提交
255 256 257 258 259 260 261
    }
    pthread_rwlock_unlock(&father->rwlock);
    return LRUResponse::ok;
  }
  LRUResponse insert(K *keys, V *data, size_t length) {
    if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
      return LRUResponse::blocked;
262 263 264 265 266 267 268 269 270 271 272 273
    // pthread_rwlock_rdlock(&father->rwlock);
    int init_size = node_size - remove_count;
    process_redundant(length * 3);
    for (size_t i = 0; i < length; i++) {
      auto iter = key_map.find(keys[i]);
      if (iter != key_map.end()) {
        move_to_tail(iter->second);
        iter->second->ttl = global_ttl;
        iter->second->data = data[i];
      } else {
        LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
        add_new(temp);
S
seemingwang 已提交
274 275
      }
    }
276 277 278 279 280 281
    total_diff += node_size - remove_count - init_size;
    if (total_diff >= 500 || total_diff < -500) {
      father->handle_size_diff(total_diff);
      total_diff = 0;
    }

S
seemingwang 已提交
282 283 284
    pthread_rwlock_unlock(&father->rwlock);
    return LRUResponse::ok;
  }
285 286
  void remove(LRUNode<K, V> *node) {
    fetch(node);
S
seemingwang 已提交
287
    node_size--;
288 289
    key_map.erase(node->key);
    delete node;
290 291 292
  }

  void process_redundant(int process_size) {
293
    int length = std::min(remove_count, process_size);
294 295 296
    while (length--) {
      remove(node_head);
      remove_count--;
S
seemingwang 已提交
297 298 299
    }
  }

300 301 302 303 304 305 306 307 308 309 310 311
  void move_to_tail(LRUNode<K, V> *node) {
    fetch(node);
    place_at_tail(node);
  }

  void add_new(LRUNode<K, V> *node) {
    node->ttl = global_ttl;
    place_at_tail(node);
    node_size++;
    key_map[node->key] = node;
  }
  void place_at_tail(LRUNode<K, V> *node) {
S
seemingwang 已提交
312 313 314 315 316 317 318 319 320 321 322
    if (node_end == NULL) {
      node_head = node_end = node;
      node->next = node->pre = NULL;
    } else {
      node_end->next = node;
      node->pre = node_end;
      node->next = NULL;
      node_end = node;
    }
  }

323 324 325 326 327 328 329 330 331 332 333 334 335
  void fetch(LRUNode<K, V> *node) {
    if (node->pre) {
      node->pre->next = node->next;
    } else {
      node_head = node->next;
    }
    if (node->next) {
      node->next->pre = node->pre;
    } else {
      node_end = node->pre;
    }
  }

S
seemingwang 已提交
336
 private:
337 338
  std::unordered_map<K, LRUNode<K, V> *> key_map;
  ScaledLRU<K, V> *father;
339
  size_t global_ttl, size_limit;
340
  int node_size, total_diff;
S
seemingwang 已提交
341
  LRUNode<K, V> *node_head, *node_end;
342
  friend class ScaledLRU<K, V>;
343
  int remove_count;
S
seemingwang 已提交
344 345
};

346
template <typename K, typename V>
S
seemingwang 已提交
347 348
class ScaledLRU {
 public:
349
  ScaledLRU(size_t _shard_num, size_t size_limit, size_t _ttl)
S
seemingwang 已提交
350
      : size_limit(size_limit), ttl(_ttl) {
351
    shard_num = _shard_num;
S
seemingwang 已提交
352 353 354 355
    pthread_rwlock_init(&rwlock, NULL);
    stop = false;
    thread_pool.reset(new ::ThreadPool(1));
    global_count = 0;
356 357
    lru_pool = std::vector<RandomSampleLRU<K, V>>(shard_num,
                                                  RandomSampleLRU<K, V>(this));
S
seemingwang 已提交
358 359 360 361
    shrink_job = std::thread([this]() -> void {
      while (true) {
        {
          std::unique_lock<std::mutex> lock(mutex_);
362
          cv_.wait_for(lock, std::chrono::milliseconds(20000));
S
seemingwang 已提交
363 364 365 366 367
          if (stop) {
            return;
          }
        }
        auto status =
Z
zhaocaibei123 已提交
368
            thread_pool->enqueue([this]() -> int { return Shrink(); });
S
seemingwang 已提交
369 370 371 372 373 374 375 376 377 378
        status.wait();
      }
    });
    shrink_job.detach();
  }
  ~ScaledLRU() {
    std::unique_lock<std::mutex> lock(mutex_);
    stop = true;
    cv_.notify_one();
  }
379 380 381
  LRUResponse query(size_t index,
                    K *keys,
                    size_t length,
L
lxsbupt 已提交
382
                    std::vector<std::pair<K, V>> &res) {  // NOLINT
S
seemingwang 已提交
383 384 385 386 387
    return lru_pool[index].query(keys, length, res);
  }
  LRUResponse insert(size_t index, K *keys, V *data, size_t length) {
    return lru_pool[index].insert(keys, data, length);
  }
Z
zhaocaibei123 已提交
388
  int Shrink() {
L
lxsbupt 已提交
389
    size_t node_size = 0;
S
seemingwang 已提交
390
    for (size_t i = 0; i < lru_pool.size(); i++) {
391
      node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
S
seemingwang 已提交
392 393
    }

L
lxsbupt 已提交
394
    if (node_size <= static_cast<size_t>(1.1 * size_limit) + 1) return 0;
S
seemingwang 已提交
395
    if (pthread_rwlock_wrlock(&rwlock) == 0) {
396 397 398 399
      global_count = 0;
      for (size_t i = 0; i < lru_pool.size(); i++) {
        global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
      }
L
lxsbupt 已提交
400
      if (static_cast<size_t>(global_count) > size_limit) {
401
        size_t remove = global_count - size_limit;
402
        for (size_t i = 0; i < lru_pool.size(); i++) {
403 404 405 406
          lru_pool[i].total_diff = 0;
          lru_pool[i].remove_count +=
              1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
              global_count * remove;
S
seemingwang 已提交
407 408 409 410 411 412 413
        }
      }
      pthread_rwlock_unlock(&rwlock);
      return 0;
    }
    return 0;
  }
414

S
seemingwang 已提交
415 416 417
  void handle_size_diff(int diff) {
    if (diff != 0) {
      __sync_fetch_and_add(&global_count, diff);
L
lxsbupt 已提交
418
      if (global_count > static_cast<int>(1.25 * size_limit)) {
Z
zhaocaibei123 已提交
419
        thread_pool->enqueue([this]() -> int { return Shrink(); });
S
seemingwang 已提交
420 421 422 423 424 425 426 427
      }
    }
  }

  size_t get_ttl() { return ttl; }

 private:
  pthread_rwlock_t rwlock;
428
  size_t shard_num;
S
seemingwang 已提交
429
  int global_count;
430
  size_t size_limit, total, hit;
S
seemingwang 已提交
431 432 433
  size_t ttl;
  bool stop;
  std::thread shrink_job;
434
  std::vector<RandomSampleLRU<K, V>> lru_pool;
S
seemingwang 已提交
435 436 437
  mutable std::mutex mutex_;
  std::condition_variable cv_;
  std::shared_ptr<::ThreadPool> thread_pool;
438
  friend class RandomSampleLRU<K, V>;
S
seemingwang 已提交
439 440
};

441
/*
442 443 444 445 446 447 448 449 450 451 452 453
#ifdef PADDLE_WITH_HETERPS
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphTable;
class GraphSampler {
 public:
  GraphSampler() {
    status = GraphSamplerStatus::waiting;
    thread_pool.reset(new ::ThreadPool(1));
    callback = [](std::vector<paddle::framework::GpuPsCommGraph> &res) {
      return;
    };
  }
454 455 456
  virtual int loadData(const std::string &path){
    return 0;
  }
457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
  virtual int run_graph_sampling() = 0;
  virtual int start_graph_sampling() {
    if (status != GraphSamplerStatus::waiting) {
      return -1;
    }
    std::promise<int> prom;
    std::future<int> fut = prom.get_future();
    graph_sample_task_over = thread_pool->enqueue([&prom, this]() {
      prom.set_value(0);
      status = GraphSamplerStatus::running;
      return run_graph_sampling();
    });
    return fut.get();
  }
  virtual void init(size_t gpu_num, GraphTable *graph_table,
                    std::vector<std::string> args) = 0;
  virtual void set_graph_sample_callback(
      std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
          callback) {
    this->callback = callback;
  }

  virtual int end_graph_sampling() {
    if (status == GraphSamplerStatus::running) {
      status = GraphSamplerStatus::terminating;
      return graph_sample_task_over.get();
    }
    return -1;
  }
  virtual GraphSamplerStatus get_graph_sampler_status() { return status; }

 protected:
  std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
      callback;
  std::shared_ptr<::ThreadPool> thread_pool;
  GraphSamplerStatus status;
  std::future<int> graph_sample_task_over;
  std::vector<paddle::framework::GpuPsCommGraph> sample_res;
};
#endif
497
*/
498

499 500
enum GraphTableType { EDGE_TABLE, FEATURE_TABLE };

501
class GraphTable : public Table {
S
seemingwang 已提交
502
 public:
503 504 505 506
  GraphTable() {
    use_cache = false;
    shard_num = 0;
    rw_lock.reset(new pthread_rwlock_t());
507 508 509 510
#ifdef PADDLE_WITH_HETERPS
    next_partition = 0;
    total_memory_cost = 0;
#endif
511
  }
512
  virtual ~GraphTable();
513 514 515 516 517 518 519 520 521 522 523 524

  virtual void *GetShard(size_t shard_idx) { return 0; }

  static int32_t sparse_local_shard_num(uint32_t shard_num,
                                        uint32_t server_num) {
    if (shard_num % server_num == 0) {
      return shard_num / server_num;
    }
    size_t local_shard_num = shard_num / server_num + 1;
    return local_shard_num;
  }

525 526
  static size_t get_sparse_shard(uint32_t shard_num,
                                 uint32_t server_num,
527 528 529 530
                                 uint64_t key) {
    return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
  }

531
  virtual int32_t pull_graph_list(GraphTableType table_type,
532 533 534
                                  int idx,
                                  int start,
                                  int size,
L
lxsbupt 已提交
535 536
                                  std::unique_ptr<char[]> &buffer,  // NOLINT
                                  int &actual_size,                 // NOLINT
537
                                  bool need_feature,
S
seemingwang 已提交
538 539
                                  int step);

540
  virtual int32_t random_sample_neighbors(
541
      int idx,
D
danleifeng 已提交
542
      uint64_t *node_ids,
543
      int sample_size,
L
lxsbupt 已提交
544 545
      std::vector<std::shared_ptr<char>> &buffers,  // NOLINT
      std::vector<int> &actual_sizes,               // NOLINT
546
      bool need_weight);
S
seemingwang 已提交
547

548
  int32_t random_sample_nodes(GraphTableType table_type,
549 550
                              int idx,
                              int sample_size,
L
lxsbupt 已提交
551 552
                              std::unique_ptr<char[]> &buffers,  // NOLINT
                              int &actual_sizes);                // NOLINT
S
seemingwang 已提交
553 554

  virtual int32_t get_nodes_ids_by_ranges(
555
      GraphTableType table_type,
556 557
      int idx,
      std::vector<std::pair<int, int>> ranges,
L
lxsbupt 已提交
558
      std::vector<uint64_t> &res);  // NOLINT
Z
zhaocaibei123 已提交
559 560
  virtual int32_t Initialize() { return 0; }
  virtual int32_t Initialize(const TableParameter &config,
561
                             const FsClientParameter &fs_config);
Z
zhaocaibei123 已提交
562
  virtual int32_t Initialize(const GraphParameter &config);
L
lxsbupt 已提交
563
  void init_worker_poll(int gpu_num);
Z
zhaocaibei123 已提交
564
  int32_t Load(const std::string &path, const std::string &param);
L
lxsbupt 已提交
565 566 567
  int32_t load_node_and_edge_file(std::string etype2files,
                                  std::string ntype2files,
                                  std::string graph_data_local_path,
D
danleifeng 已提交
568
                                  int part_num,
569 570
                                  bool reverse,
                                  const std::vector<bool> &is_reverse_edge_map);
L
lxsbupt 已提交
571 572 573
  int32_t parse_edge_and_load(std::string etype2files,
                              std::string graph_data_local_path,
                              int part_num,
574 575
                              bool reverse,
                              const std::vector<bool> &is_reverse_edge_map);
L
lxsbupt 已提交
576 577 578 579 580 581 582 583 584
  int32_t parse_node_and_load(std::string ntype2files,
                              std::string graph_data_local_path,
                              int part_num);
  std::string get_inverse_etype(std::string &etype);  // NOLINT
  int32_t parse_type_to_typepath(
      std::string &type2files,  // NOLINT
      std::string graph_data_local_path,
      std::vector<std::string> &res_type,                            // NOLINT
      std::unordered_map<std::string, std::string> &res_type2path);  // NOLINT
585 586
  int32_t load_edges(const std::string &path,
                     bool reverse,
587
                     const std::string &edge_type);
588
  int get_all_id(GraphTableType table_type,
D
danleifeng 已提交
589 590
                 int slice_num,
                 std::vector<std::vector<uint64_t>> *output);
591
  int get_all_neighbor_id(GraphTableType table_type,
D
danleifeng 已提交
592 593
                          int slice_num,
                          std::vector<std::vector<uint64_t>> *output);
594
  int get_all_id(GraphTableType table_type,
D
danleifeng 已提交
595 596 597
                 int idx,
                 int slice_num,
                 std::vector<std::vector<uint64_t>> *output);
598
  int get_all_neighbor_id(GraphTableType table_type,
D
danleifeng 已提交
599 600 601
                          int id,
                          int slice_num,
                          std::vector<std::vector<uint64_t>> *output);
602
  int get_all_feature_ids(GraphTableType table_type,
D
danleifeng 已提交
603 604 605
                          int idx,
                          int slice_num,
                          std::vector<std::vector<uint64_t>> *output);
L
lxsbupt 已提交
606 607
  int get_node_embedding_ids(int slice_num,
                             std::vector<std::vector<uint64_t>> *output);
D
danleifeng 已提交
608 609 610 611 612 613 614 615 616
  int32_t load_nodes(const std::string &path,
                     std::string node_type = std::string());
  std::pair<uint64_t, uint64_t> parse_edge_file(const std::string &path,
                                                int idx,
                                                bool reverse);
  std::pair<uint64_t, uint64_t> parse_node_file(const std::string &path,
                                                const std::string &node_type,
                                                int idx);
  std::pair<uint64_t, uint64_t> parse_node_file(const std::string &path);
617
  int32_t add_graph_node(int idx,
L
lxsbupt 已提交
618 619
                         std::vector<uint64_t> &id_list,      // NOLINT
                         std::vector<bool> &is_weight_list);  // NOLINT
620

L
lxsbupt 已提交
621
  int32_t remove_graph_node(int idx, std::vector<uint64_t> &id_list);  // NOLINT
622

D
danleifeng 已提交
623
  int32_t get_server_index_by_id(uint64_t id);
624 625
  Node *find_node(GraphTableType table_type, int idx, uint64_t id);
  Node *find_node(GraphTableType table_type, uint64_t id);
S
seemingwang 已提交
626

L
lxsbupt 已提交
627 628
  virtual int32_t Pull(TableContext &context) { return 0; }  // NOLINT
  virtual int32_t Push(TableContext &context) { return 0; }  // NOLINT
Y
yaoxuefeng 已提交
629

630
  virtual int32_t clear_nodes(GraphTableType table_type, int idx);
Z
zhaocaibei123 已提交
631 632 633
  virtual void Clear() {}
  virtual int32_t Flush() { return 0; }
  virtual int32_t Shrink(const std::string &param) { return 0; }
L
lxsbupt 已提交
634
  // 指定保存路径
Z
zhaocaibei123 已提交
635
  virtual int32_t Save(const std::string &path, const std::string &converter) {
S
seemingwang 已提交
636 637
    return 0;
  }
Z
zhaocaibei123 已提交
638 639
  virtual int32_t InitializeShard() { return 0; }
  virtual int32_t SetShard(size_t shard_idx, size_t server_num) {
640 641 642 643 644 645 646 647 648 649
    _shard_idx = shard_idx;
    /*
    _shard_num is not used in graph_table, this following operation is for the
    purpose of
    being compatible with base class table.
    */
    _shard_num = server_num;
    this->server_num = server_num;
    return 0;
  }
D
danleifeng 已提交
650 651 652 653 654 655
  virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index);
  virtual uint32_t get_thread_pool_index(uint64_t node_id);
  virtual int parse_feature(int idx,
                            const char *feat_str,
                            size_t len,
                            FeatureNode *node);
S
seemingwang 已提交
656

L
lxsbupt 已提交
657
  virtual int32_t get_node_feat(
658
      int idx,
D
danleifeng 已提交
659
      const std::vector<uint64_t> &node_ids,
S
seemingwang 已提交
660
      const std::vector<std::string> &feature_names,
L
lxsbupt 已提交
661 662 663 664 665 666 667
      std::vector<std::vector<std::string>> &res);  // NOLINT

  virtual int32_t set_node_feat(
      int idx,
      const std::vector<uint64_t> &node_ids,              // NOLINT
      const std::vector<std::string> &feature_names,      // NOLINT
      const std::vector<std::vector<std::string>> &res);  // NOLINT
S
seemingwang 已提交
668

S
seemingwang 已提交
669
  size_t get_server_num() { return server_num; }
L
lxsbupt 已提交
670
  void clear_graph();
671
  void clear_graph(int idx);
L
lxsbupt 已提交
672 673 674 675 676 677 678
  void clear_edge_shard();
  void clear_feature_shard();
  void feature_shrink_to_fit();
  void merge_feature_shard();
  void release_graph();
  void release_graph_edge();
  void release_graph_node();
679
  virtual int32_t make_neighbor_sample_cache(size_t size_limit, size_t ttl) {
680 681 682
    {
      std::unique_lock<std::mutex> lock(mutex_);
      if (use_cache == false) {
683
        scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult>(
684
            task_pool_size_, size_limit, ttl));
685 686 687 688 689
        use_cache = true;
      }
    }
    return 0;
  }
690
  virtual void load_node_weight(int type_id, int idx, std::string path);
691
#ifdef PADDLE_WITH_HETERPS
692 693 694 695 696 697 698 699 700 701 702 703
  // virtual int32_t start_graph_sampling() {
  //   return this->graph_sampler->start_graph_sampling();
  // }
  // virtual int32_t end_graph_sampling() {
  //   return this->graph_sampler->end_graph_sampling();
  // }
  // virtual int32_t set_graph_sample_callback(
  //     std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
  //         callback) {
  //   graph_sampler->set_graph_sample_callback(callback);
  //   return 0;
  // }
704
  virtual void make_partitions(int idx, int64_t gb_size, int device_len);
705
  virtual void export_partition_files(int idx, std::string file_path);
706
  virtual char *random_sample_neighbor_from_ssd(
707
      int idx,
D
danleifeng 已提交
708
      uint64_t id,
709 710
      int sample_size,
      const std::shared_ptr<std::mt19937_64> rng,
L
lxsbupt 已提交
711
      int &actual_size);  // NOLINT
712
  virtual int32_t add_node_to_ssd(
D
danleifeng 已提交
713
      int type_id, int idx, uint64_t src_id, char *data, int len);
714
  virtual paddle::framework::GpuPsCommGraph make_gpu_ps_graph(
L
lxsbupt 已提交
715
      int idx, const std::vector<uint64_t> &ids);
D
danleifeng 已提交
716
  virtual paddle::framework::GpuPsCommGraphFea make_gpu_ps_graph_fea(
L
lxsbupt 已提交
717
      int gpu_id, std::vector<uint64_t> &node_ids, int slot_num);  // NOLINT
718
  int32_t Load_to_ssd(const std::string &path, const std::string &param);
L
lxsbupt 已提交
719 720
  int64_t load_graph_to_memory_from_ssd(int idx,
                                        std::vector<uint64_t> &ids);  // NOLINT
721 722 723
  int32_t make_complementary_graph(int idx, int64_t byte_size);
  int32_t dump_edges_to_ssd(int idx);
  int32_t get_partition_num(int idx) { return partitions[idx].size(); }
L
lxsbupt 已提交
724 725 726 727 728
  std::vector<int> slot_feature_num_map() const {
    return slot_feature_num_map_;
  }
  std::vector<uint64_t> get_partition(size_t idx, size_t index) {
    if (idx >= partitions.size() || index >= partitions[idx].size())
D
danleifeng 已提交
729
      return std::vector<uint64_t>();
730 731
    return partitions[idx][index];
  }
732 733
  int32_t load_edges_to_ssd(const std::string &path,
                            bool reverse_edge,
734 735 736
                            const std::string &edge_type);
  int32_t load_next_partition(int idx);
  void set_search_level(int search_level) { this->search_level = search_level; }
737
  int search_level;
738
  int64_t total_memory_cost;
D
danleifeng 已提交
739
  std::vector<std::vector<std::vector<uint64_t>>> partitions;
740
  int next_partition;
741
#endif
D
danleifeng 已提交
742
  virtual int32_t add_comm_edge(int idx, uint64_t src_id, uint64_t dst_id);
743
  virtual int32_t build_sampler(int idx, std::string sample_type = "random");
L
lxsbupt 已提交
744
  void set_slot_feature_separator(const std::string &ch);
D
danleifeng 已提交
745
  void set_feature_separator(const std::string &ch);
L
lxsbupt 已提交
746 747 748 749 750 751 752 753

  void build_graph_total_keys();
  void build_graph_type_keys();

  std::vector<uint64_t> graph_total_keys_;
  std::vector<std::vector<uint64_t>> graph_type_keys_;
  std::unordered_map<int, int> type_to_index_;

754
  std::vector<std::vector<GraphShard *>> edge_shards, feature_shards;
S
seemingwang 已提交
755
  size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
L
lxsbupt 已提交
756
  int task_pool_size_ = 64;
D
danleifeng 已提交
757 758
  int load_thread_num = 160;

S
seemingwang 已提交
759 760
  const int random_sample_nodes_ranges = 3;

D
danleifeng 已提交
761
  std::vector<std::vector<std::unordered_map<uint64_t, double>>> node_weight;
762 763 764 765 766 767
  std::vector<std::vector<std::string>> feat_name;
  std::vector<std::vector<std::string>> feat_dtype;
  std::vector<std::vector<int32_t>> feat_shape;
  std::vector<std::unordered_map<std::string, int32_t>> feat_id_map;
  std::unordered_map<std::string, int> feature_to_id, edge_to_id;
  std::vector<std::string> id_to_feature, id_to_edge;
S
seemingwang 已提交
768 769
  std::string table_name;
  std::string table_type;
L
lxsbupt 已提交
770
  std::vector<std::string> edge_type_size;
S
seemingwang 已提交
771 772

  std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
L
lxsbupt 已提交
773
  std::vector<std::shared_ptr<::ThreadPool>> _cpu_worker_pool;
774
  std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
D
danleifeng 已提交
775
  std::shared_ptr<::ThreadPool> load_node_edge_task_pool;
776
  std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
D
danleifeng 已提交
777 778
  std::unordered_set<uint64_t> extra_nodes;
  std::unordered_map<uint64_t, size_t> extra_nodes_to_thread_index;
779
  bool use_cache, use_duplicate_nodes;
780 781
  int cache_size_limit;
  int cache_ttl;
782
  mutable std::mutex mutex_;
D
danleifeng 已提交
783
  bool build_sampler_on_cpu;
L
lxsbupt 已提交
784
  bool is_load_reverse_edge = false;
785 786 787
  std::shared_ptr<pthread_rwlock_t> rw_lock;
#ifdef PADDLE_WITH_HETERPS
  // paddle::framework::GpuPsGraphTable gpu_graph_table;
788
  paddle::distributed::RocksDBHandler *_db;
D
danleifeng 已提交
789 790 791
  // std::shared_ptr<::ThreadPool> graph_sample_pool;
  // std::shared_ptr<GraphSampler> graph_sampler;
  // REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
792
#endif
L
lxsbupt 已提交
793
  std::string slot_feature_separator_ = std::string(" ");
D
danleifeng 已提交
794
  std::string feature_separator_ = std::string(" ");
L
lxsbupt 已提交
795
  std::vector<int> slot_feature_num_map_;
796
  bool is_parse_node_fail_ = false;
797 798
};

799
/*
800 801 802 803 804 805 806 807 808 809 810 811 812 813
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER(GraphSampler);
class CompleteGraphSampler : public GraphSampler {
 public:
  CompleteGraphSampler() {}
  ~CompleteGraphSampler() {}
  // virtual pthread_rwlock_t *export_rw_lock();
  virtual int run_graph_sampling();
  virtual void init(size_t gpu_num, GraphTable *graph_table,
                    std::vector<std::string> args_);

 protected:
  GraphTable *graph_table;
  std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
D
danleifeng 已提交
814
  std::vector<std::vector<uint64_t>> sample_neighbors;
815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
  // std::vector<GpuPsCommGraph> sample_res;
  // std::shared_ptr<std::mt19937_64> random;
  int gpu_num;
};

class BasicBfsGraphSampler : public GraphSampler {
 public:
  BasicBfsGraphSampler() {}
  ~BasicBfsGraphSampler() {}
  // virtual pthread_rwlock_t *export_rw_lock();
  virtual int run_graph_sampling();
  virtual void init(size_t gpu_num, GraphTable *graph_table,
                    std::vector<std::string> args_);

 protected:
  GraphTable *graph_table;
  // std::vector<std::vector<GpuPsGraphNode>> sample_nodes;
  std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
D
danleifeng 已提交
833
  std::vector<std::vector<uint64_t>> sample_neighbors;
834
  size_t gpu_num;
835
  int init_search_size, node_num_for_each_shard, edge_num_for_each_node;
836
  int rounds, interval;
D
danleifeng 已提交
837
  std::vector<std::unordered_map<uint64_t, std::vector<uint64_t>>>
838
      sample_neighbors_map;
S
seemingwang 已提交
839
};
840
#endif
841
*/
842
}  // namespace distributed
843

844
};  // namespace paddle
845 846 847 848 849 850

namespace std {

template <>
struct hash<paddle::distributed::SampleKey> {
  size_t operator()(const paddle::distributed::SampleKey &s) const {
851
    return s.idx ^ s.node_key ^ s.sample_size;
852 853
  }
};
854
}  // namespace std