未验证 提交 85c8c170 编写于 作者: S seemingwang 提交者: GitHub

simplify graph-engine's templates (#36990)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table

* optimize get_feat function of graph engine

* fix long long count problem

* remove redandunt graph files

* remove unused shell

* recover dropout_op_pass.h

* fix potential stack overflow when request number is too large & node add & node clear & node remove

* when sample k is larger than neigbor num, return directly

* using random seed generator of paddle to speed up

* fix bug of random sample k

* fix code style

* fix code style

* add remove graph to fleet_py.cc

* fix blocking_queue problem

* fix style

* fix

* recover capacity check

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* fix distributed op combining problems

* optimize

* remove logs

* fix MultiSlotDataGenerator error

* cache for graph engine

* fix type compare error

* more test&fix thread terminating problem

* remove header

* change time interval of shrink

* use cache when sample nodes

* remove unused function

* change unique_ptr to shared_ptr

* simplify cache template

* cache api on client

* fix
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 9d2dd727
......@@ -302,7 +302,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
return fut;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res,
int server_index) {
......@@ -390,8 +390,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
0) {
++fail_num;
} else {
auto &res_io_buffer =
......@@ -435,7 +435,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS);
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
......@@ -494,6 +494,47 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
closure);
return fut;
}
std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache(
uint32_t table_id, size_t total_size_limit, size_t ttl) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(
request_idx, PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
size_t size_limit = total_size_limit / server_size +
(total_size_limit % server_size != 0 ? 1 : 0);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)
->set_cmd_id(PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)
->add_params((char *)&size_limit, sizeof(size_t));
closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id, int server_index, int start, int size, int step,
std::vector<FeatureNode> &res) {
......@@ -515,7 +556,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
index += node.get_size(false);
res.push_back(node);
}
delete buffer;
delete[] buffer;
}
closure->set_promise_value(ret);
});
......
......@@ -61,8 +61,8 @@ class GraphBrpcClient : public BrpcPsClient {
public:
GraphBrpcClient() {}
virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighboors for each of them
virtual std::future<int32_t> batch_sample_neighboors(
// given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>>& res,
int server_index = -1);
......@@ -89,6 +89,9 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list,
std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id,
size_t size_limit,
size_t ttl);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list);
virtual int32_t initialize();
......
......@@ -187,8 +187,8 @@ int32_t GraphBrpcService::initialize() {
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler;
_service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBOORS] =
&GraphBrpcService::graph_random_sample_neighboors;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
&GraphBrpcService::graph_random_sample_neighbors;
_service_handler_map[PS_GRAPH_SAMPLE_NODES] =
&GraphBrpcService::graph_random_sample_nodes;
_service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
......@@ -201,8 +201,9 @@ int32_t GraphBrpcService::initialize() {
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
&GraphBrpcService::graph_set_node_feat;
_service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
&GraphBrpcService::sample_neighboors_across_multi_servers;
&GraphBrpcService::sample_neighbors_across_multi_servers;
_service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] =
&GraphBrpcService::use_neighbors_sample_cache;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
......@@ -373,7 +374,7 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
cntl->response_attachment().append(buffer.get(), actual_size);
return 0;
}
int32_t GraphBrpcService::graph_random_sample_neighboors(
int32_t GraphBrpcService::graph_random_sample_neighbors(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
......@@ -389,7 +390,7 @@ int32_t GraphBrpcService::graph_random_sample_neighboors(
std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table)
->random_sample_neighboors(node_data, sample_size, buffers, actual_sizes);
->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes);
cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(),
......@@ -448,7 +449,7 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
return 0;
}
int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
// sleep(5);
......@@ -456,7 +457,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
if (request.params_size() < 2) {
set_response_code(
response, -1,
"graph_random_sample request requires at least 2 arguments");
"graph_random_neighbors_sample request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t),
......@@ -519,7 +520,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
remote_call_num);
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBOORS) !=
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
0) {
++fail_num;
failed[request2server[request_idx]] = true;
......@@ -570,7 +571,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS);
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(request.table_id());
closure->request(request_idx)->set_client_id(rank);
size_t node_num = node_id_buckets[request_idx].size();
......@@ -590,8 +591,8 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
}
if (server2request[rank] != -1) {
((GraphTable *)table)
->random_sample_neighboors(node_id_buckets.back().data(), sample_size,
local_buffers, local_actual_sizes);
->random_sample_neighbors(node_id_buckets.back().data(), sample_size,
local_buffers, local_actual_sizes);
}
local_promise.get()->set_value(0);
if (remote_call_num == 0) func(closure);
......@@ -636,5 +637,20 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
return 0;
}
int32_t GraphBrpcService::use_neighbors_sample_cache(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(response, -1,
"use_neighbors_sample_cache request requires at least 2 "
"arguments[cache_size, ttl]");
return 0;
}
size_t size_limit = *(size_t *)(request.params(0).c_str());
size_t ttl = *(size_t *)(request.params(1).c_str());
((GraphTable *)table)->make_neighbor_sample_cache(size_limit, ttl);
return 0;
}
} // namespace distributed
} // namespace paddle
......@@ -78,10 +78,10 @@ class GraphBrpcService : public PsBaseService {
int32_t initialize_shard_info();
int32_t pull_graph_list(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t graph_random_sample_neighboors(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_random_sample_neighbors(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_random_sample_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
......@@ -116,9 +116,15 @@ class GraphBrpcService : public PsBaseService {
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t sample_neighboors_across_multi_servers(
Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t sample_neighbors_across_multi_servers(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t use_neighbors_sample_cache(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
private:
bool _is_initialize_shard_info;
......
......@@ -290,19 +290,29 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
}
}
std::vector<std::vector<std::pair<uint64_t, float>>>
GraphPyClient::batch_sample_neighboors(std::string name,
std::vector<uint64_t> node_ids,
int sample_size) {
GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids,
int sample_size) {
std::vector<std::vector<std::pair<uint64_t, float>>> v;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->batch_sample_neighboors(table_id, node_ids, sample_size, v);
worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v);
status.wait();
}
return v;
}
void GraphPyClient::use_neighbors_sample_cache(std::string name,
size_t total_size_limit,
size_t ttl) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->use_neighbors_sample_cache(table_id, total_size_limit, ttl);
status.wait();
}
}
std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
int server_index,
int sample_size) {
......
......@@ -148,13 +148,15 @@ class GraphPyClient : public GraphPyService {
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
std::vector<std::vector<std::pair<uint64_t, float>>> batch_sample_neighboors(
std::vector<std::vector<std::pair<uint64_t, float>>> batch_sample_neighbors(
std::string name, std::vector<uint64_t> node_ids, int sample_size);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size);
std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names);
void use_neighbors_sample_cache(std::string name, size_t total_size_limit,
size_t ttl);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features);
......
......@@ -49,7 +49,7 @@ enum PsCmdID {
PS_STOP_PROFILER = 28;
PS_PUSH_GLOBAL_STEP = 29;
PS_PULL_GRAPH_LIST = 30;
PS_GRAPH_SAMPLE_NEIGHBOORS = 31;
PS_GRAPH_SAMPLE_NEIGHBORS = 31;
PS_GRAPH_SAMPLE_NODES = 32;
PS_GRAPH_GET_NODE_FEAT = 33;
PS_GRAPH_CLEAR = 34;
......@@ -57,6 +57,7 @@ enum PsCmdID {
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
PS_GRAPH_SET_NODE_FEAT = 37;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38;
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39;
}
message PsRequestMessage {
......
......@@ -392,7 +392,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
memcpy(pointer, res.data(), actual_size);
return 0;
}
int32_t GraphTable::random_sample_neighboors(
int32_t GraphTable::random_sample_neighbors(
uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes) {
......
......@@ -89,12 +89,6 @@ struct SampleKey {
}
};
struct SampleKeyHash {
size_t operator()(const SampleKey &s) const {
return s.node_key ^ s.sample_size;
}
};
class SampleResult {
public:
size_t actual_size;
......@@ -121,13 +115,13 @@ class LRUNode {
// time to live
LRUNode<K, V> *pre, *next;
};
template <typename K, typename V, typename Hash = std::hash<K>>
template <typename K, typename V>
class ScaledLRU;
template <typename K, typename V, typename Hash = std::hash<K>>
template <typename K, typename V>
class RandomSampleLRU {
public:
RandomSampleLRU(ScaledLRU<K, V, Hash> *_father) : father(_father) {
RandomSampleLRU(ScaledLRU<K, V> *_father) : father(_father) {
node_size = 0;
node_head = node_end = NULL;
global_ttl = father->ttl;
......@@ -229,15 +223,15 @@ class RandomSampleLRU {
}
private:
std::unordered_map<K, LRUNode<K, V> *, Hash> key_map;
ScaledLRU<K, V, Hash> *father;
std::unordered_map<K, LRUNode<K, V> *> key_map;
ScaledLRU<K, V> *father;
size_t global_ttl;
int node_size;
LRUNode<K, V> *node_head, *node_end;
friend class ScaledLRU<K, V, Hash>;
friend class ScaledLRU<K, V>;
};
template <typename K, typename V, typename Hash>
template <typename K, typename V>
class ScaledLRU {
public:
ScaledLRU(size_t shard_num, size_t size_limit, size_t _ttl)
......@@ -246,8 +240,8 @@ class ScaledLRU {
stop = false;
thread_pool.reset(new ::ThreadPool(1));
global_count = 0;
lru_pool = std::vector<RandomSampleLRU<K, V, Hash>>(
shard_num, RandomSampleLRU<K, V, Hash>(this));
lru_pool = std::vector<RandomSampleLRU<K, V>>(shard_num,
RandomSampleLRU<K, V>(this));
shrink_job = std::thread([this]() -> void {
while (true) {
{
......@@ -352,16 +346,16 @@ class ScaledLRU {
size_t ttl;
bool stop;
std::thread shrink_job;
std::vector<RandomSampleLRU<K, V, Hash>> lru_pool;
std::vector<RandomSampleLRU<K, V>> lru_pool;
mutable std::mutex mutex_;
std::condition_variable cv_;
struct RemovedNode {
LRUNode<K, V> *node;
RandomSampleLRU<K, V, Hash> *lru_pointer;
RandomSampleLRU<K, V> *lru_pointer;
bool operator>(const RemovedNode &a) const { return node->ms > a.node->ms; }
};
std::shared_ptr<::ThreadPool> thread_pool;
friend class RandomSampleLRU<K, V, Hash>;
friend class RandomSampleLRU<K, V>;
};
class GraphTable : public SparseTable {
......@@ -373,7 +367,7 @@ class GraphTable : public SparseTable {
int &actual_size, bool need_feature,
int step);
virtual int32_t random_sample_neighboors(
virtual int32_t random_sample_neighbors(
uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes);
......@@ -433,11 +427,11 @@ class GraphTable : public SparseTable {
size_t get_server_num() { return server_num; }
virtual int32_t make_neigh_sample_cache(size_t size_limit, size_t ttl) {
virtual int32_t make_neighbor_sample_cache(size_t size_limit, size_t ttl) {
{
std::unique_lock<std::mutex> lock(mutex_);
if (use_cache == false) {
scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult, SampleKeyHash>(
scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult>(
shard_end - shard_start, size_limit, ttl));
use_cache = true;
}
......@@ -460,10 +454,20 @@ class GraphTable : public SparseTable {
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult, SampleKeyHash>> scaled_lru;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
bool use_cache;
mutable std::mutex mutex_;
};
} // namespace distributed
}; // namespace paddle
namespace std {
template <>
struct hash<paddle::distributed::SampleKey> {
size_t operator()(const paddle::distributed::SampleKey &s) const {
return s.node_key ^ s.sample_size;
}
};
}
......@@ -111,7 +111,7 @@ void testFeatureNodeSerializeFloat64() {
void testSingleSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs;
auto pull_status = worker_ptr_->batch_sample_neighboors(
auto pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 4, vs);
pull_status.wait();
......@@ -127,7 +127,7 @@ void testSingleSampleNeighboor(
s.clear();
s1.clear();
vs.clear();
pull_status = worker_ptr_->batch_sample_neighboors(
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 96), 4, vs);
pull_status.wait();
s1 = {111, 48, 247};
......@@ -139,7 +139,7 @@ void testSingleSampleNeighboor(
ASSERT_EQ(true, s1.find(g) != s1.end());
}
vs.clear();
pull_status = worker_ptr_->batch_sample_neighboors(0, {96, 37}, 4, vs, 0);
pull_status = worker_ptr_->batch_sample_neighbors(0, {96, 37}, 4, vs, 0);
pull_status.wait();
ASSERT_EQ(vs.size(), 2);
}
......@@ -199,7 +199,7 @@ void testBatchSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs;
std::vector<std::uint64_t> v = {37, 96};
auto pull_status = worker_ptr_->batch_sample_neighboors(0, v, 4, vs);
auto pull_status = worker_ptr_->batch_sample_neighbors(0, v, 4, vs);
pull_status.wait();
std::unordered_set<uint64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145};
......@@ -401,7 +401,6 @@ void RunClient(
}
void RunBrpcPushSparse() {
std::cout << "in test cache";
testCache();
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
......@@ -436,24 +435,24 @@ void RunBrpcPushSparse() {
sleep(5);
testSingleSampleNeighboor(worker_ptr_);
testBatchSampleNeighboor(worker_ptr_);
pull_status = worker_ptr_->batch_sample_neighboors(
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, vs);
pull_status.wait();
ASSERT_EQ(0, vs[0].size());
paddle::distributed::GraphTable* g =
(paddle::distributed::GraphTable*)pserver_ptr_->table(0);
size_t ttl = 6;
g->make_neigh_sample_cache(4, ttl);
g->make_neighbor_sample_cache(4, ttl);
int round = 5;
while (round--) {
vs.clear();
pull_status = worker_ptr_->batch_sample_neighboors(
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs);
pull_status.wait();
for (int i = 0; i < ttl; i++) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs1;
pull_status = worker_ptr_->batch_sample_neighboors(
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs1);
pull_status.wait();
ASSERT_EQ(vs[0].size(), vs1[0].size());
......@@ -560,13 +559,13 @@ void RunBrpcPushSparse() {
ASSERT_EQ(count_item_nodes.size(), 12);
}
vs = client1.batch_sample_neighboors(std::string("user2item"),
std::vector<uint64_t>(1, 96), 4);
vs = client1.batch_sample_neighbors(std::string("user2item"),
std::vector<uint64_t>(1, 96), 4);
ASSERT_EQ(vs[0].size(), 3);
std::vector<uint64_t> node_ids;
node_ids.push_back(96);
node_ids.push_back(37);
vs = client1.batch_sample_neighboors(std::string("user2item"), node_ids, 4);
vs = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4);
ASSERT_EQ(vs.size(), 2);
std::vector<uint64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6);
......@@ -635,8 +634,7 @@ void RunBrpcPushSparse() {
void testCache() {
::paddle::distributed::ScaledLRU<::paddle::distributed::SampleKey,
::paddle::distributed::SampleResult,
::paddle::distributed::SampleKeyHash>
::paddle::distributed::SampleResult>
st(1, 2, 4);
char* str = new char[7];
strcpy(str, "54321");
......
......@@ -205,7 +205,8 @@ void BindGraphPyClient(py::module* m) {
.def("add_table_feat_conf", &GraphPyClient::add_table_feat_conf)
.def("pull_graph_list", &GraphPyClient::pull_graph_list)
.def("start_client", &GraphPyClient::start_client)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighbors)
.def("batch_sample_neighbors", &GraphPyClient::batch_sample_neighbors)
.def("remove_graph_node", &GraphPyClient::remove_graph_node)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("stop_server", &GraphPyClient::stop_server)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册