common_graph_table.h 13.5 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
S
seemingwang 已提交
20 21 22 23 24 25
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <ctime>
#include <functional>
#include <iostream>
S
seemingwang 已提交
26
#include <list>
S
seemingwang 已提交
27
#include <map>
S
seemingwang 已提交
28 29
#include <memory>
#include <mutex>  // NOLINT
S
seemingwang 已提交
30 31 32
#include <numeric>
#include <queue>
#include <set>
S
seemingwang 已提交
33
#include <string>
S
seemingwang 已提交
34
#include <thread>
S
seemingwang 已提交
35
#include <unordered_map>
S
seemingwang 已提交
36
#include <unordered_set>
S
seemingwang 已提交
37 38 39 40
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/common_table.h"
S
seemingwang 已提交
41
#include "paddle/fluid/distributed/table/graph/graph_node.h"
S
seemingwang 已提交
42 43 44 45 46 47 48 49
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
class GraphShard {
 public:
  size_t get_size();
  GraphShard() {}
S
seemingwang 已提交
50
  GraphShard(int shard_num) { this->shard_num = shard_num; }
51
  ~GraphShard();
S
seemingwang 已提交
52 53 54 55
  std::vector<Node *> &get_bucket() { return bucket; }
  std::vector<Node *> get_batch(int start, int end, int step);
  std::vector<uint64_t> get_ids_by_range(int start, int end) {
    std::vector<uint64_t> res;
56
    for (int i = start; i < end && i < (int)bucket.size(); i++) {
S
seemingwang 已提交
57 58 59 60
      res.push_back(bucket[i]->get_id());
    }
    return res;
  }
S
seemingwang 已提交
61

S
seemingwang 已提交
62 63 64
  GraphNode *add_graph_node(uint64_t id);
  FeatureNode *add_feature_node(uint64_t id);
  Node *find_node(uint64_t id);
65 66
  void delete_node(uint64_t id);
  void clear();
S
seemingwang 已提交
67
  void add_neighbor(uint64_t id, uint64_t dst_id, float weight);
S
seemingwang 已提交
68 69 70 71 72 73 74 75 76
  std::unordered_map<uint64_t, int> get_node_location() {
    return node_location;
  }

 private:
  std::unordered_map<uint64_t, int> node_location;
  int shard_num;
  std::vector<Node *> bucket;
};
S
seemingwang 已提交
77 78 79 80 81 82

enum LRUResponse { ok = 0, blocked = 1, err = 2 };

struct SampleKey {
  uint64_t node_key;
  size_t sample_size;
83
  SampleKey(uint64_t _node_key, size_t _sample_size)
S
seemingwang 已提交
84
      : node_key(_node_key), sample_size(_sample_size) {}
S
seemingwang 已提交
85 86 87 88 89 90 91 92
  bool operator==(const SampleKey &s) const {
    return node_key == s.node_key && sample_size == s.sample_size;
  }
};

class SampleResult {
 public:
  size_t actual_size;
93 94 95 96 97 98 99
  std::shared_ptr<char> buffer;
  SampleResult(size_t _actual_size, std::shared_ptr<char> &_buffer)
      : actual_size(_actual_size), buffer(_buffer) {}
  SampleResult(size_t _actual_size, char *_buffer)
      : actual_size(_actual_size),
        buffer(_buffer, [](char *p) { delete[] p; }) {}
  ~SampleResult() {}
S
seemingwang 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113
};

template <typename K, typename V>
class LRUNode {
 public:
  LRUNode(K _key, V _data, size_t _ttl) : key(_key), data(_data), ttl(_ttl) {
    next = pre = NULL;
  }
  K key;
  V data;
  size_t ttl;
  // time to live
  LRUNode<K, V> *pre, *next;
};
114
template <typename K, typename V>
S
seemingwang 已提交
115 116
class ScaledLRU;

117
template <typename K, typename V>
S
seemingwang 已提交
118 119
class RandomSampleLRU {
 public:
120 121 122
  RandomSampleLRU(ScaledLRU<K, V> *_father) {
    father = _father;
    remove_count = 0;
S
seemingwang 已提交
123 124 125
    node_size = 0;
    node_head = node_end = NULL;
    global_ttl = father->ttl;
126
    total_diff = 0;
S
seemingwang 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139
  }

  ~RandomSampleLRU() {
    LRUNode<K, V> *p;
    while (node_head != NULL) {
      p = node_head->next;
      delete node_head;
      node_head = p;
    }
  }
  LRUResponse query(K *keys, size_t length, std::vector<std::pair<K, V>> &res) {
    if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
      return LRUResponse::blocked;
140 141 142 143 144 145 146 147 148 149 150 151 152 153
    // pthread_rwlock_rdlock(&father->rwlock);
    int init_size = node_size - remove_count;
    process_redundant(length * 3);

    for (size_t i = 0; i < length; i++) {
      auto iter = key_map.find(keys[i]);
      if (iter != key_map.end()) {
        res.emplace_back(keys[i], iter->second->data);
        iter->second->ttl--;
        if (iter->second->ttl == 0) {
          remove(iter->second);
          if (remove_count != 0) remove_count--;
        } else {
          move_to_tail(iter->second);
S
seemingwang 已提交
154 155
        }
      }
156 157 158 159 160
    }
    total_diff += node_size - remove_count - init_size;
    if (total_diff >= 500 || total_diff < -500) {
      father->handle_size_diff(total_diff);
      total_diff = 0;
S
seemingwang 已提交
161 162 163 164 165 166 167
    }
    pthread_rwlock_unlock(&father->rwlock);
    return LRUResponse::ok;
  }
  LRUResponse insert(K *keys, V *data, size_t length) {
    if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
      return LRUResponse::blocked;
168 169 170 171 172 173 174 175 176 177 178 179
    // pthread_rwlock_rdlock(&father->rwlock);
    int init_size = node_size - remove_count;
    process_redundant(length * 3);
    for (size_t i = 0; i < length; i++) {
      auto iter = key_map.find(keys[i]);
      if (iter != key_map.end()) {
        move_to_tail(iter->second);
        iter->second->ttl = global_ttl;
        iter->second->data = data[i];
      } else {
        LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
        add_new(temp);
S
seemingwang 已提交
180 181
      }
    }
182 183 184 185 186 187
    total_diff += node_size - remove_count - init_size;
    if (total_diff >= 500 || total_diff < -500) {
      father->handle_size_diff(total_diff);
      total_diff = 0;
    }

S
seemingwang 已提交
188 189 190
    pthread_rwlock_unlock(&father->rwlock);
    return LRUResponse::ok;
  }
191 192
  void remove(LRUNode<K, V> *node) {
    fetch(node);
S
seemingwang 已提交
193
    node_size--;
194 195
    key_map.erase(node->key);
    delete node;
196 197 198 199 200 201 202
  }

  void process_redundant(int process_size) {
    size_t length = std::min(remove_count, process_size);
    while (length--) {
      remove(node_head);
      remove_count--;
S
seemingwang 已提交
203
    }
204
    // std::cerr<<"after remove_count = "<<remove_count<<std::endl;
S
seemingwang 已提交
205 206
  }

207 208 209 210 211 212 213 214 215 216 217 218
  void move_to_tail(LRUNode<K, V> *node) {
    fetch(node);
    place_at_tail(node);
  }

  void add_new(LRUNode<K, V> *node) {
    node->ttl = global_ttl;
    place_at_tail(node);
    node_size++;
    key_map[node->key] = node;
  }
  void place_at_tail(LRUNode<K, V> *node) {
S
seemingwang 已提交
219 220 221 222 223 224 225 226 227 228 229
    if (node_end == NULL) {
      node_head = node_end = node;
      node->next = node->pre = NULL;
    } else {
      node_end->next = node;
      node->pre = node_end;
      node->next = NULL;
      node_end = node;
    }
  }

230 231 232 233 234 235 236 237 238 239 240 241 242
  void fetch(LRUNode<K, V> *node) {
    if (node->pre) {
      node->pre->next = node->next;
    } else {
      node_head = node->next;
    }
    if (node->next) {
      node->next->pre = node->pre;
    } else {
      node_end = node->pre;
    }
  }

S
seemingwang 已提交
243
 private:
244 245
  std::unordered_map<K, LRUNode<K, V> *> key_map;
  ScaledLRU<K, V> *father;
246
  size_t global_ttl, size_limit;
247
  int node_size, total_diff;
S
seemingwang 已提交
248
  LRUNode<K, V> *node_head, *node_end;
249
  friend class ScaledLRU<K, V>;
250
  int remove_count;
S
seemingwang 已提交
251 252
};

253
template <typename K, typename V>
S
seemingwang 已提交
254 255
class ScaledLRU {
 public:
256
  ScaledLRU(size_t _shard_num, size_t size_limit, size_t _ttl)
S
seemingwang 已提交
257
      : size_limit(size_limit), ttl(_ttl) {
258
    shard_num = _shard_num;
S
seemingwang 已提交
259 260 261 262
    pthread_rwlock_init(&rwlock, NULL);
    stop = false;
    thread_pool.reset(new ::ThreadPool(1));
    global_count = 0;
263 264
    lru_pool = std::vector<RandomSampleLRU<K, V>>(shard_num,
                                                  RandomSampleLRU<K, V>(this));
S
seemingwang 已提交
265 266 267 268
    shrink_job = std::thread([this]() -> void {
      while (true) {
        {
          std::unique_lock<std::mutex> lock(mutex_);
269
          cv_.wait_for(lock, std::chrono::milliseconds(20000));
S
seemingwang 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
          if (stop) {
            return;
          }
        }
        auto status =
            thread_pool->enqueue([this]() -> int { return shrink(); });
        status.wait();
      }
    });
    shrink_job.detach();
  }
  ~ScaledLRU() {
    std::unique_lock<std::mutex> lock(mutex_);
    stop = true;
    cv_.notify_one();
  }
  LRUResponse query(size_t index, K *keys, size_t length,
                    std::vector<std::pair<K, V>> &res) {
    return lru_pool[index].query(keys, length, res);
  }
  LRUResponse insert(size_t index, K *keys, V *data, size_t length) {
    return lru_pool[index].insert(keys, data, length);
  }
  int shrink() {
    int node_size = 0;
    for (size_t i = 0; i < lru_pool.size(); i++) {
296
      node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
S
seemingwang 已提交
297 298
    }

299
    if (node_size <= size_t(1.1 * size_limit) + 1) return 0;
S
seemingwang 已提交
300
    if (pthread_rwlock_wrlock(&rwlock) == 0) {
301 302 303 304 305 306 307 308 309 310 311 312 313 314
      // VLOG(0)<"in shrink\n";
      global_count = 0;
      for (size_t i = 0; i < lru_pool.size(); i++) {
        global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
      }
      // VLOG(0)<<"global_count "<<global_count<<"\n";
      if (global_count > size_limit) {
        size_t remove = global_count - size_limit;
        for (int i = 0; i < lru_pool.size(); i++) {
          lru_pool[i].total_diff = 0;
          lru_pool[i].remove_count +=
              1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
              global_count * remove;
          // VLOG(0)<<i<<" "<<lru_pool[i].remove_count<<std::endl;
S
seemingwang 已提交
315 316 317 318 319 320 321
        }
      }
      pthread_rwlock_unlock(&rwlock);
      return 0;
    }
    return 0;
  }
322

S
seemingwang 已提交
323 324 325
  void handle_size_diff(int diff) {
    if (diff != 0) {
      __sync_fetch_and_add(&global_count, diff);
326
      if (global_count > int(1.25 * size_limit)) {
S
seemingwang 已提交
327
        // VLOG(0)<<"global_count too large "<<global_count<<" enter start
S
seemingwang 已提交
328 329 330 331 332 333 334 335 336 337
        // shrink task\n";
        thread_pool->enqueue([this]() -> int { return shrink(); });
      }
    }
  }

  size_t get_ttl() { return ttl; }

 private:
  pthread_rwlock_t rwlock;
338
  size_t shard_num;
S
seemingwang 已提交
339
  int global_count;
340
  size_t size_limit, total, hit;
S
seemingwang 已提交
341 342 343
  size_t ttl;
  bool stop;
  std::thread shrink_job;
344
  std::vector<RandomSampleLRU<K, V>> lru_pool;
S
seemingwang 已提交
345 346 347
  mutable std::mutex mutex_;
  std::condition_variable cv_;
  std::shared_ptr<::ThreadPool> thread_pool;
348
  friend class RandomSampleLRU<K, V>;
S
seemingwang 已提交
349 350
};

S
seemingwang 已提交
351 352
class GraphTable : public SparseTable {
 public:
353
  GraphTable() { use_cache = false; }
S
seemingwang 已提交
354 355 356 357 358 359
  virtual ~GraphTable() {}
  virtual int32_t pull_graph_list(int start, int size,
                                  std::unique_ptr<char[]> &buffer,
                                  int &actual_size, bool need_feature,
                                  int step);

360
  virtual int32_t random_sample_neighbors(
S
seemingwang 已提交
361
      uint64_t *node_ids, int sample_size,
362
      std::vector<std::shared_ptr<char>> &buffers,
S
seemingwang 已提交
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
      std::vector<int> &actual_sizes);

  int32_t random_sample_nodes(int sample_size, std::unique_ptr<char[]> &buffers,
                              int &actual_sizes);

  virtual int32_t get_nodes_ids_by_ranges(
      std::vector<std::pair<int, int>> ranges, std::vector<uint64_t> &res);
  virtual int32_t initialize();

  int32_t load(const std::string &path, const std::string &param);

  int32_t load_edges(const std::string &path, bool reverse);

  int32_t load_nodes(const std::string &path, std::string node_type);

378 379 380 381 382
  int32_t add_graph_node(std::vector<uint64_t> &id_list,
                         std::vector<bool> &is_weight_list);

  int32_t remove_graph_node(std::vector<uint64_t> &id_list);

S
seemingwang 已提交
383
  int32_t get_server_index_by_id(uint64_t id);
S
seemingwang 已提交
384 385
  Node *find_node(uint64_t id);

386 387
  virtual int32_t pull_sparse(float *values,
                              const PullSparseValue &pull_value) {
S
seemingwang 已提交
388 389
    return 0;
  }
390

S
seemingwang 已提交
391 392 393 394
  virtual int32_t push_sparse(const uint64_t *keys, const float *values,
                              size_t num) {
    return 0;
  }
395

396
  virtual int32_t clear_nodes();
S
seemingwang 已提交
397 398 399 400 401 402 403 404
  virtual void clear() {}
  virtual int32_t flush() { return 0; }
  virtual int32_t shrink(const std::string &param) { return 0; }
  //指定保存路径
  virtual int32_t save(const std::string &path, const std::string &converter) {
    return 0;
  }
  virtual int32_t initialize_shard() { return 0; }
405
  virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index);
S
seemingwang 已提交
406 407 408 409 410 411 412
  virtual uint32_t get_thread_pool_index(uint64_t node_id);
  virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str);

  virtual int32_t get_node_feat(const std::vector<uint64_t> &node_ids,
                                const std::vector<std::string> &feature_names,
                                std::vector<std::vector<std::string>> &res);

S
seemingwang 已提交
413 414 415 416 417
  virtual int32_t set_node_feat(
      const std::vector<uint64_t> &node_ids,
      const std::vector<std::string> &feature_names,
      const std::vector<std::vector<std::string>> &res);

S
seemingwang 已提交
418 419
  size_t get_server_num() { return server_num; }

420
  virtual int32_t make_neighbor_sample_cache(size_t size_limit, size_t ttl) {
421 422 423
    {
      std::unique_lock<std::mutex> lock(mutex_);
      if (use_cache == false) {
424
        scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult>(
425
            task_pool_size_, size_limit, ttl));
426 427 428 429 430 431
        use_cache = true;
      }
    }
    return 0;
  }

S
seemingwang 已提交
432 433
 protected:
  std::vector<GraphShard> shards;
S
seemingwang 已提交
434
  size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
S
seemingwang 已提交
435
  const int task_pool_size_ = 24;
S
seemingwang 已提交
436 437 438 439 440 441 442 443 444 445
  const int random_sample_nodes_ranges = 3;

  std::vector<std::string> feat_name;
  std::vector<std::string> feat_dtype;
  std::vector<int32_t> feat_shape;
  std::unordered_map<std::string, int32_t> feat_id_map;
  std::string table_name;
  std::string table_type;

  std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
446
  std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
447
  std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
448 449
  bool use_cache;
  mutable std::mutex mutex_;
S
seemingwang 已提交
450
};
451
}  // namespace distributed
452

453
};  // namespace paddle
454 455 456 457 458 459 460 461 462 463

namespace std {

template <>
struct hash<paddle::distributed::SampleKey> {
  size_t operator()(const paddle::distributed::SampleKey &s) const {
    return s.node_key ^ s.sample_size;
  }
};
}