未验证 提交 521a274e 编写于 作者: Webbley's avatar Webbley 提交者: GitHub

optimize the data structure to speed up sampling in graph engine. (#37315)

* optimize the data structure from c++ to python to speed up sampling in graph engine

* update test
上级 c3d3001f
...@@ -289,10 +289,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { ...@@ -289,10 +289,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
status.wait(); status.wait();
} }
} }
std::vector<std::vector<std::pair<uint64_t, float>>>
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>>
GraphPyClient::batch_sample_neighbors(std::string name, GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids, std::vector<uint64_t> node_ids,
int sample_size) { int sample_size, bool return_weight,
bool return_edges) {
std::vector<std::vector<std::pair<uint64_t, float>>> v; std::vector<std::vector<std::pair<uint64_t, float>>> v;
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
...@@ -300,7 +302,31 @@ GraphPyClient::batch_sample_neighbors(std::string name, ...@@ -300,7 +302,31 @@ GraphPyClient::batch_sample_neighbors(std::string name,
worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v); worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v);
status.wait(); status.wait();
} }
return v;
// res.first[0]: neighbors (nodes)
// res.first[1]: slice index
// res.first[2]: src nodes
// res.second: edges weight
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res;
res.first.push_back({});
res.first.push_back({});
if (return_edges) res.first.push_back({});
for (size_t i = 0; i < v.size(); i++) {
for (size_t j = 0; j < v[i].size(); j++) {
res.first[0].push_back(v[i][j].first);
if (return_edges) res.first[2].push_back(node_ids[i]);
if (return_weight) res.second.push_back(v[i][j].second);
}
if (i == v.size() - 1) break;
if (i == 0) {
res.first[1].push_back(v[i].size());
} else {
res.first[1].push_back(v[i].size() + res.first[1].back());
}
}
return res;
} }
void GraphPyClient::use_neighbors_sample_cache(std::string name, void GraphPyClient::use_neighbors_sample_cache(std::string name,
......
...@@ -148,8 +148,10 @@ class GraphPyClient : public GraphPyService { ...@@ -148,8 +148,10 @@ class GraphPyClient : public GraphPyService {
int get_client_id() { return client_id; } int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; } void set_client_id(int client_id) { this->client_id = client_id; }
void start_client(); void start_client();
std::vector<std::vector<std::pair<uint64_t, float>>> batch_sample_neighbors( std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>>
std::string name, std::vector<uint64_t> node_ids, int sample_size); batch_sample_neighbors(std::string name, std::vector<uint64_t> node_ids,
int sample_size, bool return_weight,
bool return_edges);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index, std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size); int sample_size);
std::vector<std::vector<std::string>> get_node_feat( std::vector<std::vector<std::string>> get_node_feat(
......
...@@ -193,7 +193,6 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { ...@@ -193,7 +193,6 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
std::ifstream file(path); std::ifstream file(path);
std::string line; std::string line;
while (std::getline(file, line)) { while (std::getline(file, line)) {
count++;
auto values = paddle::string::split_string<std::string>(line, "\t"); auto values = paddle::string::split_string<std::string>(line, "\t");
if (values.size() < 2) continue; if (values.size() < 2) continue;
auto id = std::stoull(values[1]); auto id = std::stoull(values[1]);
...@@ -207,7 +206,9 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { ...@@ -207,7 +206,9 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
if (count % 1000000 == 0) { if (count % 1000000 == 0) {
VLOG(0) << count << " nodes are loaded from filepath"; VLOG(0) << count << " nodes are loaded from filepath";
VLOG(0) << line;
} }
count++;
std::string nt = values[0]; std::string nt = values[0];
if (nt != node_type) { if (nt != node_type) {
...@@ -273,6 +274,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { ...@@ -273,6 +274,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
} }
if (count % 1000000 == 0) { if (count % 1000000 == 0) {
VLOG(0) << count << " edges are loaded from filepath"; VLOG(0) << count << " edges are loaded from filepath";
VLOG(0) << line;
} }
size_t index = src_shard_id - shard_start; size_t index = src_shard_id - shard_start;
......
...@@ -556,15 +556,17 @@ void RunBrpcPushSparse() { ...@@ -556,15 +556,17 @@ void RunBrpcPushSparse() {
ASSERT_EQ(count_item_nodes.size(), 12); ASSERT_EQ(count_item_nodes.size(), 12);
} }
vs = client1.batch_sample_neighbors(std::string("user2item"), std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res;
std::vector<uint64_t>(1, 96), 4); res = client1.batch_sample_neighbors(
ASSERT_EQ(vs[0].size(), 3); std::string("user2item"), std::vector<uint64_t>(1, 96), 4, true, false);
ASSERT_EQ(res.first[0].size(), 3);
std::vector<uint64_t> node_ids; std::vector<uint64_t> node_ids;
node_ids.push_back(96); node_ids.push_back(96);
node_ids.push_back(37); node_ids.push_back(37);
vs = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4); res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4,
true, false);
ASSERT_EQ(vs.size(), 2); ASSERT_EQ(res.first[1].size(), 1);
std::vector<uint64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6); std::vector<uint64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6);
ASSERT_EQ(nodes_ids.size(), 2); ASSERT_EQ(nodes_ids.size(), 2);
ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) || ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) ||
...@@ -693,4 +695,4 @@ void testGraphToBuffer() { ...@@ -693,4 +695,4 @@ void testGraphToBuffer() {
VLOG(0) << s1.get_feature(0); VLOG(0) << s1.get_feature(0);
} }
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
\ No newline at end of file
...@@ -209,6 +209,8 @@ void BindGraphPyClient(py::module* m) { ...@@ -209,6 +209,8 @@ void BindGraphPyClient(py::module* m) {
.def("start_client", &GraphPyClient::start_client) .def("start_client", &GraphPyClient::start_client)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighbors) .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighbors)
.def("batch_sample_neighbors", &GraphPyClient::batch_sample_neighbors) .def("batch_sample_neighbors", &GraphPyClient::batch_sample_neighbors)
.def("use_neighbors_sample_cache",
&GraphPyClient::use_neighbors_sample_cache)
.def("remove_graph_node", &GraphPyClient::remove_graph_node) .def("remove_graph_node", &GraphPyClient::remove_graph_node)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes) .def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("stop_server", &GraphPyClient::stop_server) .def("stop_server", &GraphPyClient::stop_server)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册