未验证 提交 249081b6 编写于 作者: S seemingwang 提交者: GitHub

cache for graph_engine (#36880)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table

* optimize get_feat function of graph engine

* fix long long count problem

* remove redandunt graph files

* remove unused shell

* recover dropout_op_pass.h

* fix potential stack overflow when request number is too large & node add & node clear & node remove

* when sample k is larger than neigbor num, return directly

* using random seed generator of paddle to speed up

* fix bug of random sample k

* fix code style

* fix code style

* add remove graph to fleet_py.cc

* fix blocking_queue problem

* fix style

* fix

* recover capacity check

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* fix distributed op combining problems

* optimize

* remove logs

* fix MultiSlotDataGenerator error

* cache for graph engine

* fix type compare error

* more test&fix thread terminating problem

* remove header

* change time interval of shrink
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 792d3d76
...@@ -17,11 +17,23 @@ ...@@ -17,11 +17,23 @@
#include <ThreadPool.h> #include <ThreadPool.h>
#include <assert.h> #include <assert.h>
#include <pthread.h> #include <pthread.h>
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <ctime>
#include <functional>
#include <iostream>
#include <list> #include <list>
#include <map>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <numeric>
#include <queue>
#include <set>
#include <string> #include <string>
#include <thread>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/table/accessor.h" #include "paddle/fluid/distributed/table/accessor.h"
...@@ -62,6 +74,294 @@ class GraphShard { ...@@ -62,6 +74,294 @@ class GraphShard {
int shard_num; int shard_num;
std::vector<Node *> bucket; std::vector<Node *> bucket;
}; };
enum LRUResponse { ok = 0, blocked = 1, err = 2 };
struct SampleKey {
uint64_t node_key;
size_t sample_size;
bool operator==(const SampleKey &s) const {
return node_key == s.node_key && sample_size == s.sample_size;
}
};
struct SampleKeyHash {
size_t operator()(const SampleKey &s) const {
return s.node_key ^ s.sample_size;
}
};
class SampleResult {
public:
size_t actual_size;
char *buffer;
SampleResult(size_t _actual_size, char *_buffer) : actual_size(_actual_size) {
buffer = new char[actual_size];
memcpy(buffer, _buffer, actual_size);
}
~SampleResult() {
// std::cout<<"in SampleResult deconstructor\n";
delete[] buffer;
}
};
template <typename K, typename V>
class LRUNode {
public:
LRUNode(K _key, V _data, size_t _ttl) : key(_key), data(_data), ttl(_ttl) {
next = pre = NULL;
}
std::chrono::milliseconds ms;
// the last hit time
K key;
V data;
size_t ttl;
// time to live
LRUNode<K, V> *pre, *next;
};
template <typename K, typename V, typename Hash = std::hash<K>>
class ScaledLRU;
template <typename K, typename V, typename Hash = std::hash<K>>
class RandomSampleLRU {
public:
RandomSampleLRU(ScaledLRU<K, V, Hash> *_father) : father(_father) {
node_size = 0;
node_head = node_end = NULL;
global_ttl = father->ttl;
}
~RandomSampleLRU() {
LRUNode<K, V> *p;
while (node_head != NULL) {
p = node_head->next;
delete node_head;
node_head = p;
}
}
LRUResponse query(K *keys, size_t length, std::vector<std::pair<K, V>> &res) {
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
return LRUResponse::blocked;
int init_node_size = node_size;
try {
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
res.push_back({keys[i], iter->second->data});
iter->second->ttl--;
if (iter->second->ttl == 0) {
remove(iter->second, true);
} else {
remove(iter->second);
add_to_tail(iter->second);
}
}
}
} catch (...) {
pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::err;
}
pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::ok;
}
LRUResponse insert(K *keys, V *data, size_t length) {
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
return LRUResponse::blocked;
int init_node_size = node_size;
try {
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
iter->second->ttl = global_ttl;
remove(iter->second);
add_to_tail(iter->second);
iter->second->data = data[i];
} else {
LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
add_to_tail(temp);
key_map[keys[i]] = temp;
}
}
} catch (...) {
pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::err;
}
pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::ok;
}
void remove(LRUNode<K, V> *node, bool del = false) {
if (node->pre) {
node->pre->next = node->next;
} else {
node_head = node->next;
}
if (node->next) {
node->next->pre = node->pre;
} else {
node_end = node->pre;
}
node_size--;
if (del) {
delete node;
key_map.erase(node->key);
}
}
void add_to_tail(LRUNode<K, V> *node) {
if (node_end == NULL) {
node_head = node_end = node;
node->next = node->pre = NULL;
} else {
node_end->next = node;
node->pre = node_end;
node->next = NULL;
node_end = node;
}
node_size++;
node->ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch());
}
private:
std::unordered_map<K, LRUNode<K, V> *, Hash> key_map;
ScaledLRU<K, V, Hash> *father;
size_t global_ttl;
int node_size;
LRUNode<K, V> *node_head, *node_end;
friend class ScaledLRU<K, V, Hash>;
};
template <typename K, typename V, typename Hash>
class ScaledLRU {
public:
ScaledLRU(size_t shard_num, size_t size_limit, size_t _ttl)
: size_limit(size_limit), ttl(_ttl) {
pthread_rwlock_init(&rwlock, NULL);
stop = false;
thread_pool.reset(new ::ThreadPool(1));
global_count = 0;
lru_pool = std::vector<RandomSampleLRU<K, V, Hash>>(
shard_num, RandomSampleLRU<K, V, Hash>(this));
shrink_job = std::thread([this]() -> void {
while (true) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait_for(lock, std::chrono::milliseconds(3000));
if (stop) {
return;
}
}
// shrink();
// std::cerr<<"shrink job in queue\n";
auto status =
thread_pool->enqueue([this]() -> int { return shrink(); });
status.wait();
}
});
shrink_job.detach();
}
~ScaledLRU() {
std::unique_lock<std::mutex> lock(mutex_);
// std::cerr<<"cancel shrink job\n";
stop = true;
cv_.notify_one();
// pthread_cancel(shrink_job.native_handle());
}
LRUResponse query(size_t index, K *keys, size_t length,
std::vector<std::pair<K, V>> &res) {
return lru_pool[index].query(keys, length, res);
}
LRUResponse insert(size_t index, K *keys, V *data, size_t length) {
return lru_pool[index].insert(keys, data, length);
}
int shrink() {
int node_size = 0;
std::string t = "";
for (size_t i = 0; i < lru_pool.size(); i++) {
node_size += lru_pool[i].node_size;
// t += std::to_string(i) + "->" + std::to_string(lru_pool[i].node_size) +
// " ";
}
// std::cout<<t<<std::endl;
if (node_size <= size_limit) return 0;
if (pthread_rwlock_wrlock(&rwlock) == 0) {
try {
global_count = 0;
std::priority_queue<RemovedNode, std::vector<RemovedNode>,
std::greater<RemovedNode>>
q;
for (size_t i = 0; i < lru_pool.size(); i++) {
if (lru_pool[i].node_size > 0) {
global_count += lru_pool[i].node_size;
q.push({lru_pool[i].node_head, &lru_pool[i]});
}
}
if (global_count > size_limit) {
// std::cout<<"before shrinking cache, cached nodes count =
// "<<global_count<<std::endl;
size_t remove = global_count - size_limit;
while (remove--) {
RemovedNode remove_node = q.top();
q.pop();
auto next = remove_node.node->next;
if (next) {
q.push({next, remove_node.lru_pointer});
}
global_count--;
remove_node.lru_pointer->key_map.erase(remove_node.node->key);
remove_node.lru_pointer->remove(remove_node.node, true);
}
// std::cout<<"after shrinking cache, cached nodes count =
// "<<global_count<<std::endl;
}
} catch (...) {
// std::cout << "shrink cache failed"<<std::endl;
pthread_rwlock_unlock(&rwlock);
return -1;
}
pthread_rwlock_unlock(&rwlock);
return 0;
}
return 0;
}
void handle_size_diff(int diff) {
if (diff != 0) {
__sync_fetch_and_add(&global_count, diff);
if (global_count > int(1.5 * size_limit)) {
// std::cout<<"global_count too large "<<global_count<<" enter start
// shrink task\n";
thread_pool->enqueue([this]() -> int { return shrink(); });
}
}
}
size_t get_ttl() { return ttl; }
private:
pthread_rwlock_t rwlock;
int global_count;
size_t size_limit;
size_t ttl;
bool stop;
std::thread shrink_job;
std::vector<RandomSampleLRU<K, V, Hash>> lru_pool;
mutable std::mutex mutex_;
std::condition_variable cv_;
struct RemovedNode {
LRUNode<K, V> *node;
RandomSampleLRU<K, V, Hash> *lru_pointer;
bool operator>(const RemovedNode &a) const { return node->ms > a.node->ms; }
};
std::shared_ptr<::ThreadPool> thread_pool;
friend class RandomSampleLRU<K, V, Hash>;
};
class GraphTable : public SparseTable { class GraphTable : public SparseTable {
public: public:
GraphTable() {} GraphTable() {}
......
...@@ -222,6 +222,7 @@ void testBatchSampleNeighboor( ...@@ -222,6 +222,7 @@ void testBatchSampleNeighboor(
} }
} }
void testCache();
void testGraphToBuffer(); void testGraphToBuffer();
// std::string nodes[] = {std::string("37\taa\t45;0.34\t145;0.31\t112;0.21"), // std::string nodes[] = {std::string("37\taa\t45;0.34\t145;0.31\t112;0.21"),
// std::string("96\tfeature\t48;1.4\t247;0.31\t111;1.21"), // std::string("96\tfeature\t48;1.4\t247;0.31\t111;1.21"),
...@@ -400,6 +401,8 @@ void RunClient( ...@@ -400,6 +401,8 @@ void RunClient(
} }
void RunBrpcPushSparse() { void RunBrpcPushSparse() {
std::cout << "in test cache";
testCache();
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
prepare_file(edge_file_name, 1); prepare_file(edge_file_name, 1);
...@@ -607,6 +610,64 @@ void RunBrpcPushSparse() { ...@@ -607,6 +610,64 @@ void RunBrpcPushSparse() {
client1.stop_server(); client1.stop_server();
} }
void testCache() {
::paddle::distributed::ScaledLRU<
::paddle::distributed::SampleKey,
std::shared_ptr<::paddle::distributed::SampleResult>,
::paddle::distributed::SampleKeyHash>
st(1, 2, 4);
std::shared_ptr<::paddle::distributed::SampleResult> sp;
char* str = (char*)"54321";
::paddle::distributed::SampleResult* result =
new ::paddle::distributed::SampleResult(5, str);
::paddle::distributed::SampleKey skey = {6, 1};
sp.reset(result);
std::vector<std::pair<::paddle::distributed::SampleKey,
std::shared_ptr<::paddle::distributed::SampleResult>>>
r;
st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 0);
st.insert(0, &skey, &sp, 1);
for (int i = 0; i < st.get_ttl(); i++) {
st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 1);
char* p = (char*)r[0].second.get()->buffer;
for (int j = 0; j < r[0].second.get()->actual_size; j++)
ASSERT_EQ(p[j], str[j]);
r.clear();
}
st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 0);
str = (char*)"342cd4321";
result = new ::paddle::distributed::SampleResult(strlen(str), str);
std::shared_ptr<::paddle::distributed::SampleResult> sp1;
sp1.reset(result);
st.insert(0, &skey, &sp1, 1);
for (int i = 0; i < st.get_ttl() / 2; i++) {
st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 1);
char* p = (char*)r[0].second.get()->buffer;
for (int j = 0; j < r[0].second.get()->actual_size; j++)
ASSERT_EQ(p[j], str[j]);
r.clear();
}
str = (char*)"343332d4321";
result = new ::paddle::distributed::SampleResult(strlen(str), str);
std::shared_ptr<::paddle::distributed::SampleResult> sp2;
sp2.reset(result);
st.insert(0, &skey, &sp2, 1);
for (int i = 0; i < st.get_ttl(); i++) {
st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 1);
char* p = (char*)r[0].second.get()->buffer;
for (int j = 0; j < r[0].second.get()->actual_size; j++)
ASSERT_EQ(p[j], str[j]);
r.clear();
}
st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 0);
}
void testGraphToBuffer() { void testGraphToBuffer() {
::paddle::distributed::GraphNode s, s1; ::paddle::distributed::GraphNode s, s1;
s.set_feature_size(1); s.set_feature_size(1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册