未验证 提交 521a274e 编写于 作者: Webbley's avatar Webbley 提交者: GitHub

optimize the data structure to speed up sampling in graph engine. (#37315)

* optimize the data structure from c++ to python to speed up sampling in graph engine

* update test
上级 c3d3001f
......@@ -289,10 +289,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
status.wait();
}
}
std::vector<std::vector<std::pair<uint64_t, float>>>
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>>
GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids,
int sample_size) {
int sample_size, bool return_weight,
bool return_edges) {
std::vector<std::vector<std::pair<uint64_t, float>>> v;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
......@@ -300,7 +302,31 @@ GraphPyClient::batch_sample_neighbors(std::string name,
worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v);
status.wait();
}
return v;
// res.first[0]: neighbors (nodes)
// res.first[1]: slice index
// res.first[2]: src nodes
// res.second: edges weight
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res;
res.first.push_back({});
res.first.push_back({});
if (return_edges) res.first.push_back({});
for (size_t i = 0; i < v.size(); i++) {
for (size_t j = 0; j < v[i].size(); j++) {
res.first[0].push_back(v[i][j].first);
if (return_edges) res.first[2].push_back(node_ids[i]);
if (return_weight) res.second.push_back(v[i][j].second);
}
if (i == v.size() - 1) break;
if (i == 0) {
res.first[1].push_back(v[i].size());
} else {
res.first[1].push_back(v[i].size() + res.first[1].back());
}
}
return res;
}
void GraphPyClient::use_neighbors_sample_cache(std::string name,
......
......@@ -148,8 +148,10 @@ class GraphPyClient : public GraphPyService {
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
std::vector<std::vector<std::pair<uint64_t, float>>> batch_sample_neighbors(
std::string name, std::vector<uint64_t> node_ids, int sample_size);
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>>
batch_sample_neighbors(std::string name, std::vector<uint64_t> node_ids,
int sample_size, bool return_weight,
bool return_edges);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size);
std::vector<std::vector<std::string>> get_node_feat(
......
......@@ -193,7 +193,6 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
std::ifstream file(path);
std::string line;
while (std::getline(file, line)) {
count++;
auto values = paddle::string::split_string<std::string>(line, "\t");
if (values.size() < 2) continue;
auto id = std::stoull(values[1]);
......@@ -207,7 +206,9 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
if (count % 1000000 == 0) {
VLOG(0) << count << " nodes are loaded from filepath";
VLOG(0) << line;
}
count++;
std::string nt = values[0];
if (nt != node_type) {
......@@ -273,6 +274,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
}
if (count % 1000000 == 0) {
VLOG(0) << count << " edges are loaded from filepath";
VLOG(0) << line;
}
size_t index = src_shard_id - shard_start;
......
......@@ -556,15 +556,17 @@ void RunBrpcPushSparse() {
ASSERT_EQ(count_item_nodes.size(), 12);
}
vs = client1.batch_sample_neighbors(std::string("user2item"),
std::vector<uint64_t>(1, 96), 4);
ASSERT_EQ(vs[0].size(), 3);
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res;
res = client1.batch_sample_neighbors(
std::string("user2item"), std::vector<uint64_t>(1, 96), 4, true, false);
ASSERT_EQ(res.first[0].size(), 3);
std::vector<uint64_t> node_ids;
node_ids.push_back(96);
node_ids.push_back(37);
vs = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4);
res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4,
true, false);
ASSERT_EQ(vs.size(), 2);
ASSERT_EQ(res.first[1].size(), 1);
std::vector<uint64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6);
ASSERT_EQ(nodes_ids.size(), 2);
ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) ||
......@@ -693,4 +695,4 @@ void testGraphToBuffer() {
VLOG(0) << s1.get_feature(0);
}
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
\ No newline at end of file
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
......@@ -209,6 +209,8 @@ void BindGraphPyClient(py::module* m) {
.def("start_client", &GraphPyClient::start_client)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighbors)
.def("batch_sample_neighbors", &GraphPyClient::batch_sample_neighbors)
.def("use_neighbors_sample_cache",
&GraphPyClient::use_neighbors_sample_cache)
.def("remove_graph_node", &GraphPyClient::remove_graph_node)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("stop_server", &GraphPyClient::stop_server)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册