common_graph_table.h 13.8 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
S
seemingwang 已提交
20 21 22 23 24 25
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <ctime>
#include <functional>
#include <iostream>
S
seemingwang 已提交
26
#include <list>
S
seemingwang 已提交
27
#include <map>
S
seemingwang 已提交
28 29
#include <memory>
#include <mutex>  // NOLINT
S
seemingwang 已提交
30 31 32
#include <numeric>
#include <queue>
#include <set>
S
seemingwang 已提交
33
#include <string>
S
seemingwang 已提交
34
#include <thread>
S
seemingwang 已提交
35
#include <unordered_map>
S
seemingwang 已提交
36
#include <unordered_set>
S
seemingwang 已提交
37 38
#include <utility>
#include <vector>
39 40 41
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
S
seemingwang 已提交
42
#include "paddle/fluid/string/string_helper.h"
43 44
#include "paddle/pten/core/utils/rw_lock.h"

S
seemingwang 已提交
45 46 47 48 49 50
namespace paddle {
namespace distributed {
class GraphShard {
 public:
  size_t get_size();
  GraphShard() {}
51
  ~GraphShard();
S
seemingwang 已提交
52 53 54 55
  std::vector<Node *> &get_bucket() { return bucket; }
  std::vector<Node *> get_batch(int start, int end, int step);
  std::vector<uint64_t> get_ids_by_range(int start, int end) {
    std::vector<uint64_t> res;
56
    for (int i = start; i < end && i < (int)bucket.size(); i++) {
S
seemingwang 已提交
57 58 59 60
      res.push_back(bucket[i]->get_id());
    }
    return res;
  }
S
seemingwang 已提交
61

S
seemingwang 已提交
62
  GraphNode *add_graph_node(uint64_t id);
63
  GraphNode *add_graph_node(Node *node);
S
seemingwang 已提交
64 65
  FeatureNode *add_feature_node(uint64_t id);
  Node *find_node(uint64_t id);
66 67
  void delete_node(uint64_t id);
  void clear();
S
seemingwang 已提交
68
  void add_neighbor(uint64_t id, uint64_t dst_id, float weight);
69
  std::unordered_map<uint64_t, int> &get_node_location() {
S
seemingwang 已提交
70 71 72 73 74 75 76
    return node_location;
  }

 private:
  std::unordered_map<uint64_t, int> node_location;
  std::vector<Node *> bucket;
};
S
seemingwang 已提交
77 78 79 80 81 82

enum LRUResponse { ok = 0, blocked = 1, err = 2 };

struct SampleKey {
  uint64_t node_key;
  size_t sample_size;
83 84 85 86 87
  bool is_weighted;
  SampleKey(uint64_t _node_key, size_t _sample_size, bool _is_weighted)
      : node_key(_node_key),
        sample_size(_sample_size),
        is_weighted(_is_weighted) {}
S
seemingwang 已提交
88
  bool operator==(const SampleKey &s) const {
89 90
    return node_key == s.node_key && sample_size == s.sample_size &&
           is_weighted == s.is_weighted;
S
seemingwang 已提交
91 92 93 94 95 96
  }
};

class SampleResult {
 public:
  size_t actual_size;
97 98 99 100 101 102 103
  std::shared_ptr<char> buffer;
  SampleResult(size_t _actual_size, std::shared_ptr<char> &_buffer)
      : actual_size(_actual_size), buffer(_buffer) {}
  SampleResult(size_t _actual_size, char *_buffer)
      : actual_size(_actual_size),
        buffer(_buffer, [](char *p) { delete[] p; }) {}
  ~SampleResult() {}
S
seemingwang 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117
};

template <typename K, typename V>
class LRUNode {
 public:
  LRUNode(K _key, V _data, size_t _ttl) : key(_key), data(_data), ttl(_ttl) {
    next = pre = NULL;
  }
  K key;
  V data;
  size_t ttl;
  // time to live
  LRUNode<K, V> *pre, *next;
};
118
template <typename K, typename V>
S
seemingwang 已提交
119 120
class ScaledLRU;

121
template <typename K, typename V>
S
seemingwang 已提交
122 123
class RandomSampleLRU {
 public:
124 125 126
  RandomSampleLRU(ScaledLRU<K, V> *_father) {
    father = _father;
    remove_count = 0;
S
seemingwang 已提交
127 128 129
    node_size = 0;
    node_head = node_end = NULL;
    global_ttl = father->ttl;
130
    total_diff = 0;
S
seemingwang 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143
  }

  ~RandomSampleLRU() {
    LRUNode<K, V> *p;
    while (node_head != NULL) {
      p = node_head->next;
      delete node_head;
      node_head = p;
    }
  }
  LRUResponse query(K *keys, size_t length, std::vector<std::pair<K, V>> &res) {
    if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
      return LRUResponse::blocked;
144 145 146 147 148 149 150 151 152 153 154 155 156 157
    // pthread_rwlock_rdlock(&father->rwlock);
    int init_size = node_size - remove_count;
    process_redundant(length * 3);

    for (size_t i = 0; i < length; i++) {
      auto iter = key_map.find(keys[i]);
      if (iter != key_map.end()) {
        res.emplace_back(keys[i], iter->second->data);
        iter->second->ttl--;
        if (iter->second->ttl == 0) {
          remove(iter->second);
          if (remove_count != 0) remove_count--;
        } else {
          move_to_tail(iter->second);
S
seemingwang 已提交
158 159
        }
      }
160 161 162 163 164
    }
    total_diff += node_size - remove_count - init_size;
    if (total_diff >= 500 || total_diff < -500) {
      father->handle_size_diff(total_diff);
      total_diff = 0;
S
seemingwang 已提交
165 166 167 168 169 170 171
    }
    pthread_rwlock_unlock(&father->rwlock);
    return LRUResponse::ok;
  }
  LRUResponse insert(K *keys, V *data, size_t length) {
    if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
      return LRUResponse::blocked;
172 173 174 175 176 177 178 179 180 181 182 183
    // pthread_rwlock_rdlock(&father->rwlock);
    int init_size = node_size - remove_count;
    process_redundant(length * 3);
    for (size_t i = 0; i < length; i++) {
      auto iter = key_map.find(keys[i]);
      if (iter != key_map.end()) {
        move_to_tail(iter->second);
        iter->second->ttl = global_ttl;
        iter->second->data = data[i];
      } else {
        LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
        add_new(temp);
S
seemingwang 已提交
184 185
      }
    }
186 187 188 189 190 191
    total_diff += node_size - remove_count - init_size;
    if (total_diff >= 500 || total_diff < -500) {
      father->handle_size_diff(total_diff);
      total_diff = 0;
    }

S
seemingwang 已提交
192 193 194
    pthread_rwlock_unlock(&father->rwlock);
    return LRUResponse::ok;
  }
195 196
  void remove(LRUNode<K, V> *node) {
    fetch(node);
S
seemingwang 已提交
197
    node_size--;
198 199
    key_map.erase(node->key);
    delete node;
200 201 202 203 204 205 206
  }

  void process_redundant(int process_size) {
    size_t length = std::min(remove_count, process_size);
    while (length--) {
      remove(node_head);
      remove_count--;
S
seemingwang 已提交
207
    }
208
    // std::cerr<<"after remove_count = "<<remove_count<<std::endl;
S
seemingwang 已提交
209 210
  }

211 212 213 214 215 216 217 218 219 220 221 222
  void move_to_tail(LRUNode<K, V> *node) {
    fetch(node);
    place_at_tail(node);
  }

  void add_new(LRUNode<K, V> *node) {
    node->ttl = global_ttl;
    place_at_tail(node);
    node_size++;
    key_map[node->key] = node;
  }
  void place_at_tail(LRUNode<K, V> *node) {
S
seemingwang 已提交
223 224 225 226 227 228 229 230 231 232 233
    if (node_end == NULL) {
      node_head = node_end = node;
      node->next = node->pre = NULL;
    } else {
      node_end->next = node;
      node->pre = node_end;
      node->next = NULL;
      node_end = node;
    }
  }

234 235 236 237 238 239 240 241 242 243 244 245 246
  void fetch(LRUNode<K, V> *node) {
    if (node->pre) {
      node->pre->next = node->next;
    } else {
      node_head = node->next;
    }
    if (node->next) {
      node->next->pre = node->pre;
    } else {
      node_end = node->pre;
    }
  }

S
seemingwang 已提交
247
 private:
248 249
  std::unordered_map<K, LRUNode<K, V> *> key_map;
  ScaledLRU<K, V> *father;
250
  size_t global_ttl, size_limit;
251
  int node_size, total_diff;
S
seemingwang 已提交
252
  LRUNode<K, V> *node_head, *node_end;
253
  friend class ScaledLRU<K, V>;
254
  int remove_count;
S
seemingwang 已提交
255 256
};

257
template <typename K, typename V>
S
seemingwang 已提交
258 259
class ScaledLRU {
 public:
260
  ScaledLRU(size_t _shard_num, size_t size_limit, size_t _ttl)
S
seemingwang 已提交
261
      : size_limit(size_limit), ttl(_ttl) {
262
    shard_num = _shard_num;
S
seemingwang 已提交
263 264 265 266
    pthread_rwlock_init(&rwlock, NULL);
    stop = false;
    thread_pool.reset(new ::ThreadPool(1));
    global_count = 0;
267 268
    lru_pool = std::vector<RandomSampleLRU<K, V>>(shard_num,
                                                  RandomSampleLRU<K, V>(this));
S
seemingwang 已提交
269 270 271 272
    shrink_job = std::thread([this]() -> void {
      while (true) {
        {
          std::unique_lock<std::mutex> lock(mutex_);
273
          cv_.wait_for(lock, std::chrono::milliseconds(20000));
S
seemingwang 已提交
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
          if (stop) {
            return;
          }
        }
        auto status =
            thread_pool->enqueue([this]() -> int { return shrink(); });
        status.wait();
      }
    });
    shrink_job.detach();
  }
  ~ScaledLRU() {
    std::unique_lock<std::mutex> lock(mutex_);
    stop = true;
    cv_.notify_one();
  }
  LRUResponse query(size_t index, K *keys, size_t length,
                    std::vector<std::pair<K, V>> &res) {
    return lru_pool[index].query(keys, length, res);
  }
  LRUResponse insert(size_t index, K *keys, V *data, size_t length) {
    return lru_pool[index].insert(keys, data, length);
  }
  int shrink() {
    int node_size = 0;
    for (size_t i = 0; i < lru_pool.size(); i++) {
300
      node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
S
seemingwang 已提交
301 302
    }

303
    if (node_size <= size_t(1.1 * size_limit) + 1) return 0;
S
seemingwang 已提交
304
    if (pthread_rwlock_wrlock(&rwlock) == 0) {
305 306 307 308 309 310 311 312 313 314 315 316 317 318
      // VLOG(0)<"in shrink\n";
      global_count = 0;
      for (size_t i = 0; i < lru_pool.size(); i++) {
        global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
      }
      // VLOG(0)<<"global_count "<<global_count<<"\n";
      if (global_count > size_limit) {
        size_t remove = global_count - size_limit;
        for (int i = 0; i < lru_pool.size(); i++) {
          lru_pool[i].total_diff = 0;
          lru_pool[i].remove_count +=
              1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
              global_count * remove;
          // VLOG(0)<<i<<" "<<lru_pool[i].remove_count<<std::endl;
S
seemingwang 已提交
319 320 321 322 323 324 325
        }
      }
      pthread_rwlock_unlock(&rwlock);
      return 0;
    }
    return 0;
  }
326

S
seemingwang 已提交
327 328 329
  void handle_size_diff(int diff) {
    if (diff != 0) {
      __sync_fetch_and_add(&global_count, diff);
330
      if (global_count > int(1.25 * size_limit)) {
S
seemingwang 已提交
331
        // VLOG(0)<<"global_count too large "<<global_count<<" enter start
S
seemingwang 已提交
332 333 334 335 336 337 338 339 340 341
        // shrink task\n";
        thread_pool->enqueue([this]() -> int { return shrink(); });
      }
    }
  }

  size_t get_ttl() { return ttl; }

 private:
  pthread_rwlock_t rwlock;
342
  size_t shard_num;
S
seemingwang 已提交
343
  int global_count;
344
  size_t size_limit, total, hit;
S
seemingwang 已提交
345 346 347
  size_t ttl;
  bool stop;
  std::thread shrink_job;
348
  std::vector<RandomSampleLRU<K, V>> lru_pool;
S
seemingwang 已提交
349 350 351
  mutable std::mutex mutex_;
  std::condition_variable cv_;
  std::shared_ptr<::ThreadPool> thread_pool;
352
  friend class RandomSampleLRU<K, V>;
S
seemingwang 已提交
353 354
};

S
seemingwang 已提交
355 356
class GraphTable : public SparseTable {
 public:
357
  GraphTable() { use_cache = false; }
358
  virtual ~GraphTable();
S
seemingwang 已提交
359 360 361 362 363
  virtual int32_t pull_graph_list(int start, int size,
                                  std::unique_ptr<char[]> &buffer,
                                  int &actual_size, bool need_feature,
                                  int step);

364
  virtual int32_t random_sample_neighbors(
S
seemingwang 已提交
365
      uint64_t *node_ids, int sample_size,
366
      std::vector<std::shared_ptr<char>> &buffers,
367
      std::vector<int> &actual_sizes, bool need_weight);
S
seemingwang 已提交
368 369 370 371 372 373 374 375 376

  int32_t random_sample_nodes(int sample_size, std::unique_ptr<char[]> &buffers,
                              int &actual_sizes);

  virtual int32_t get_nodes_ids_by_ranges(
      std::vector<std::pair<int, int>> ranges, std::vector<uint64_t> &res);
  virtual int32_t initialize();

  int32_t load(const std::string &path, const std::string &param);
377
  int32_t load_graph_split_config(const std::string &path);
S
seemingwang 已提交
378 379 380 381 382

  int32_t load_edges(const std::string &path, bool reverse);

  int32_t load_nodes(const std::string &path, std::string node_type);

383 384 385 386 387
  int32_t add_graph_node(std::vector<uint64_t> &id_list,
                         std::vector<bool> &is_weight_list);

  int32_t remove_graph_node(std::vector<uint64_t> &id_list);

S
seemingwang 已提交
388
  int32_t get_server_index_by_id(uint64_t id);
S
seemingwang 已提交
389 390
  Node *find_node(uint64_t id);

391 392
  virtual int32_t pull_sparse(float *values,
                              const PullSparseValue &pull_value) {
S
seemingwang 已提交
393 394
    return 0;
  }
395

S
seemingwang 已提交
396 397 398 399
  virtual int32_t push_sparse(const uint64_t *keys, const float *values,
                              size_t num) {
    return 0;
  }
400

401
  virtual int32_t clear_nodes();
S
seemingwang 已提交
402 403 404 405 406 407 408 409
  virtual void clear() {}
  virtual int32_t flush() { return 0; }
  virtual int32_t shrink(const std::string &param) { return 0; }
  //指定保存路径
  virtual int32_t save(const std::string &path, const std::string &converter) {
    return 0;
  }
  virtual int32_t initialize_shard() { return 0; }
410
  virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index);
S
seemingwang 已提交
411 412 413 414 415 416 417
  virtual uint32_t get_thread_pool_index(uint64_t node_id);
  virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str);

  virtual int32_t get_node_feat(const std::vector<uint64_t> &node_ids,
                                const std::vector<std::string> &feature_names,
                                std::vector<std::vector<std::string>> &res);

S
seemingwang 已提交
418 419 420 421 422
  virtual int32_t set_node_feat(
      const std::vector<uint64_t> &node_ids,
      const std::vector<std::string> &feature_names,
      const std::vector<std::vector<std::string>> &res);

S
seemingwang 已提交
423 424
  size_t get_server_num() { return server_num; }

425
  virtual int32_t make_neighbor_sample_cache(size_t size_limit, size_t ttl) {
426 427 428
    {
      std::unique_lock<std::mutex> lock(mutex_);
      if (use_cache == false) {
429
        scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult>(
430
            task_pool_size_, size_limit, ttl));
431 432 433 434 435 436
        use_cache = true;
      }
    }
    return 0;
  }

S
seemingwang 已提交
437
 protected:
438
  std::vector<GraphShard *> shards, extra_shards;
S
seemingwang 已提交
439
  size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
S
seemingwang 已提交
440
  const int task_pool_size_ = 24;
S
seemingwang 已提交
441 442 443 444 445 446 447 448 449 450
  const int random_sample_nodes_ranges = 3;

  std::vector<std::string> feat_name;
  std::vector<std::string> feat_dtype;
  std::vector<int32_t> feat_shape;
  std::unordered_map<std::string, int32_t> feat_id_map;
  std::string table_name;
  std::string table_type;

  std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
451
  std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
452
  std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
453 454 455
  std::unordered_set<uint64_t> extra_nodes;
  std::unordered_map<uint64_t, size_t> extra_nodes_to_thread_index;
  bool use_cache, use_duplicate_nodes;
456
  mutable std::mutex mutex_;
S
seemingwang 已提交
457
};
458
}  // namespace distributed
459

460
};  // namespace paddle
461 462 463 464 465 466 467 468 469 470

namespace std {

template <>
struct hash<paddle::distributed::SampleKey> {
  size_t operator()(const paddle::distributed::SampleKey &s) const {
    return s.node_key ^ s.sample_size;
  }
};
}