未验证 提交 b44db69f 编写于 作者: S seemingwang 提交者: GitHub

graph-engine cache optimization (#37168)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table

* optimize get_feat function of graph engine

* fix long long count problem

* remove redandunt graph files

* remove unused shell

* recover dropout_op_pass.h

* fix potential stack overflow when request number is too large & node add & node clear & node remove

* when sample k is larger than neigbor num, return directly

* using random seed generator of paddle to speed up

* fix bug of random sample k

* fix code style

* fix code style

* add remove graph to fleet_py.cc

* fix blocking_queue problem

* fix style

* fix

* recover capacity check

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* fix distributed op combining problems

* optimize

* remove logs

* fix MultiSlotDataGenerator error

* cache for graph engine

* fix type compare error

* more test&fix thread terminating problem

* remove header

* change time interval of shrink

* use cache when sample nodes

* remove unused function

* change unique_ptr to shared_ptr

* simplify cache template

* cache api on client

* fix

* reduce sample threads when cache is not used

* reduce cache memory

* cache optimization

* remove test function

* remove extra fetch function
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 f2a56c6a
......@@ -105,8 +105,6 @@ class LRUNode {
LRUNode(K _key, V _data, size_t _ttl) : key(_key), data(_data), ttl(_ttl) {
next = pre = NULL;
}
std::chrono::milliseconds ms;
// the last hit time
K key;
V data;
size_t ttl;
......@@ -119,12 +117,13 @@ class ScaledLRU;
template <typename K, typename V>
class RandomSampleLRU {
public:
RandomSampleLRU(ScaledLRU<K, V> *_father) : father(_father) {
RandomSampleLRU(ScaledLRU<K, V> *_father) {
father = _father;
remove_count = 0;
node_size = 0;
node_head = node_end = NULL;
global_ttl = father->ttl;
extra_penalty = 0;
size_limit = (father->size_limit / father->shard_num + 1);
total_diff = 0;
}
~RandomSampleLRU() {
......@@ -138,53 +137,55 @@ class RandomSampleLRU {
LRUResponse query(K *keys, size_t length, std::vector<std::pair<K, V>> &res) {
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
return LRUResponse::blocked;
int init_node_size = node_size;
try {
// pthread_rwlock_rdlock(&father->rwlock);
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
res.emplace_back(keys[i], iter->second->data);
iter->second->ttl--;
if (iter->second->ttl == 0) {
remove(iter->second);
} else {
move_to_tail(iter->second);
}
// pthread_rwlock_rdlock(&father->rwlock);
int init_size = node_size - remove_count;
process_redundant(length * 3);
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
res.emplace_back(keys[i], iter->second->data);
iter->second->ttl--;
if (iter->second->ttl == 0) {
remove(iter->second);
if (remove_count != 0) remove_count--;
} else {
move_to_tail(iter->second);
}
}
} catch (...) {
pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::err;
}
total_diff += node_size - remove_count - init_size;
if (total_diff >= 500 || total_diff < -500) {
father->handle_size_diff(total_diff);
total_diff = 0;
}
pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::ok;
}
LRUResponse insert(K *keys, V *data, size_t length) {
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
return LRUResponse::blocked;
int init_node_size = node_size;
try {
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
move_to_tail(iter->second);
iter->second->ttl = global_ttl;
iter->second->data = data[i];
} else {
LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
add_new(temp);
}
// pthread_rwlock_rdlock(&father->rwlock);
int init_size = node_size - remove_count;
process_redundant(length * 3);
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
move_to_tail(iter->second);
iter->second->ttl = global_ttl;
iter->second->data = data[i];
} else {
LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
add_new(temp);
}
} catch (...) {
pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::err;
}
total_diff += node_size - remove_count - init_size;
if (total_diff >= 500 || total_diff < -500) {
father->handle_size_diff(total_diff);
total_diff = 0;
}
pthread_rwlock_unlock(&father->rwlock);
father->handle_size_diff(node_size - init_node_size);
return LRUResponse::ok;
}
void remove(LRUNode<K, V> *node) {
......@@ -192,9 +193,15 @@ class RandomSampleLRU {
node_size--;
key_map.erase(node->key);
delete node;
if (node_size >= size_limit) {
extra_penalty -= 1.0;
}
void process_redundant(int process_size) {
size_t length = std::min(remove_count, process_size);
while (length--) {
remove(node_head);
remove_count--;
}
// std::cerr<<"after remove_count = "<<remove_count<<std::endl;
}
void move_to_tail(LRUNode<K, V> *node) {
......@@ -207,12 +214,6 @@ class RandomSampleLRU {
place_at_tail(node);
node_size++;
key_map[node->key] = node;
if (node_size > size_limit) {
extra_penalty += penalty_inc;
if (extra_penalty >= 1.0) {
remove(node_head);
}
}
}
void place_at_tail(LRUNode<K, V> *node) {
if (node_end == NULL) {
......@@ -224,8 +225,6 @@ class RandomSampleLRU {
node->next = NULL;
node_end = node;
}
node->ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch());
}
void fetch(LRUNode<K, V> *node) {
......@@ -245,11 +244,10 @@ class RandomSampleLRU {
std::unordered_map<K, LRUNode<K, V> *> key_map;
ScaledLRU<K, V> *father;
size_t global_ttl, size_limit;
int node_size;
int node_size, total_diff;
LRUNode<K, V> *node_head, *node_end;
friend class ScaledLRU<K, V>;
float extra_penalty;
const float penalty_inc = 0.75;
int remove_count;
};
template <typename K, typename V>
......@@ -295,52 +293,33 @@ class ScaledLRU {
int shrink() {
int node_size = 0;
for (size_t i = 0; i < lru_pool.size(); i++) {
node_size += lru_pool[i].node_size;
node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
}
if (node_size <= 1.2 * size_limit) return 0;
if (node_size <= size_t(1.1 * size_limit) + 1) return 0;
if (pthread_rwlock_wrlock(&rwlock) == 0) {
try {
global_count = 0;
std::priority_queue<RemovedNode, std::vector<RemovedNode>,
std::greater<RemovedNode>>
q;
for (size_t i = 0; i < lru_pool.size(); i++) {
if (lru_pool[i].node_size > 0) {
global_count += lru_pool[i].node_size;
q.push({lru_pool[i].node_head, &lru_pool[i]});
}
}
if (global_count > size_limit) {
// VLOG(0)<<"before shrinking cache, cached nodes count =
// "<<global_count<<std::endl;
size_t remove = global_count - size_limit;
while (remove--) {
RemovedNode remove_node = q.top();
q.pop();
auto next = remove_node.node->next;
if (next) {
q.push({next, remove_node.lru_pointer});
}
global_count--;
remove_node.lru_pointer->remove(remove_node.node);
}
for (size_t i = 0; i < lru_pool.size(); i++) {
lru_pool[i].size_limit = lru_pool[i].node_size;
lru_pool[i].extra_penalty = 0;
}
// VLOG(0)<<"after shrinking cache, cached nodes count =
// // "<<global_count<<std::endl;
// VLOG(0)<"in shrink\n";
global_count = 0;
for (size_t i = 0; i < lru_pool.size(); i++) {
global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
}
// VLOG(0)<<"global_count "<<global_count<<"\n";
if (global_count > size_limit) {
size_t remove = global_count - size_limit;
for (int i = 0; i < lru_pool.size(); i++) {
lru_pool[i].total_diff = 0;
lru_pool[i].remove_count +=
1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
global_count * remove;
// VLOG(0)<<i<<" "<<lru_pool[i].remove_count<<std::endl;
}
} catch (...) {
pthread_rwlock_unlock(&rwlock);
return -1;
}
pthread_rwlock_unlock(&rwlock);
return 0;
}
return 0;
}
void handle_size_diff(int diff) {
if (diff != 0) {
__sync_fetch_and_add(&global_count, diff);
......@@ -358,18 +337,13 @@ class ScaledLRU {
pthread_rwlock_t rwlock;
size_t shard_num;
int global_count;
size_t size_limit;
size_t size_limit, total, hit;
size_t ttl;
bool stop;
std::thread shrink_job;
std::vector<RandomSampleLRU<K, V>> lru_pool;
mutable std::mutex mutex_;
std::condition_variable cv_;
struct RemovedNode {
LRUNode<K, V> *node;
RandomSampleLRU<K, V> *lru_pointer;
bool operator>(const RemovedNode &a) const { return node->ms > a.node->ms; }
};
std::shared_ptr<::ThreadPool> thread_pool;
friend class RandomSampleLRU<K, V>;
};
......@@ -448,7 +422,7 @@ class GraphTable : public SparseTable {
std::unique_lock<std::mutex> lock(mutex_);
if (use_cache == false) {
scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult>(
shard_end - shard_start, size_limit, ttl));
task_pool_size_, size_limit, ttl));
use_cache = true;
}
}
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -681,28 +678,6 @@ void testCache() {
}
st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 0);
::paddle::distributed::ScaledLRU<::paddle::distributed::SampleKey,
::paddle::distributed::SampleResult>
cache1(2, 1, 4);
str = new char[18];
strcpy(str, "3433776521");
result = new ::paddle::distributed::SampleResult(strlen(str), str);
cache1.insert(1, &skey, result, 1);
::paddle::distributed::SampleKey skey1 = {8, 1};
char* str1 = new char[18];
strcpy(str1, "3xcf2eersfd");
usleep(3000); // sleep 3ms to guaruntee that skey1's time stamp is larger
// than skey;
auto result1 = new ::paddle::distributed::SampleResult(strlen(str1), str1);
cache1.insert(0, &skey1, result1, 1);
sleep(1); // sleep 1 s to guarantee that shrinking work is done
cache1.query(1, &skey, 1, r);
ASSERT_EQ((int)r.size(), 0);
cache1.query(0, &skey1, 1, r);
ASSERT_EQ((int)r.size(), 1);
char* p1 = (char*)r[0].second.buffer.get();
for (int j = 0; j < r[0].second.actual_size; j++) ASSERT_EQ(p1[j], str1[j]);
r.clear();
}
void testGraphToBuffer() {
::paddle::distributed::GraphNode s, s1;
......@@ -718,4 +693,4 @@ void testGraphToBuffer() {
VLOG(0) << s1.get_feature(0);
}
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册