diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc index 78f239f80d44599513066ba8bc985137a42097fc..130a76a683e64e804ba40798aa7cdb0be179a6a8 100644 --- a/paddle/fluid/distributed/service/graph_py_service.cc +++ b/paddle/fluid/distributed/service/graph_py_service.cc @@ -289,10 +289,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { status.wait(); } } -std::vector>> + +std::pair>, std::vector> GraphPyClient::batch_sample_neighbors(std::string name, std::vector node_ids, - int sample_size) { + int sample_size, bool return_weight, + bool return_edges) { std::vector>> v; if (this->table_id_map.count(name)) { uint32_t table_id = this->table_id_map[name]; @@ -300,7 +302,31 @@ GraphPyClient::batch_sample_neighbors(std::string name, worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v); status.wait(); } - return v; + + // res.first[0]: neighbors (nodes) + // res.first[1]: slice index + // res.first[2]: src nodes + // res.second: edges weight + std::pair>, std::vector> res; + res.first.push_back({}); + res.first.push_back({}); + if (return_edges) res.first.push_back({}); + for (size_t i = 0; i < v.size(); i++) { + for (size_t j = 0; j < v[i].size(); j++) { + res.first[0].push_back(v[i][j].first); + if (return_edges) res.first[2].push_back(node_ids[i]); + if (return_weight) res.second.push_back(v[i][j].second); + } + if (i == v.size() - 1) break; + + if (i == 0) { + res.first[1].push_back(v[i].size()); + } else { + res.first[1].push_back(v[i].size() + res.first[1].back()); + } + } + + return res; } void GraphPyClient::use_neighbors_sample_cache(std::string name, diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h index 2d36edbf9c17d91bac742272c464fa2d2a39efe3..a860d1f58d3a23e79ca3d3a380b6067c13e76371 100644 --- a/paddle/fluid/distributed/service/graph_py_service.h +++ b/paddle/fluid/distributed/service/graph_py_service.h @@ -148,8 +148,10 @@ class GraphPyClient : public GraphPyService { int get_client_id() { return client_id; } void set_client_id(int client_id) { this->client_id = client_id; } void start_client(); - std::vector>> batch_sample_neighbors( - std::string name, std::vector node_ids, int sample_size); + std::pair>, std::vector> + batch_sample_neighbors(std::string name, std::vector node_ids, + int sample_size, bool return_weight, + bool return_edges); std::vector random_sample_nodes(std::string name, int server_index, int sample_size); std::vector> get_node_feat( diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 6c31856ed704db64a557acebd598051a5c2345c8..0c4a473570e3d7cb2b1fb8bb945a484815e708e1 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -193,7 +193,6 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { std::ifstream file(path); std::string line; while (std::getline(file, line)) { - count++; auto values = paddle::string::split_string(line, "\t"); if (values.size() < 2) continue; auto id = std::stoull(values[1]); @@ -207,7 +206,9 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { if (count % 1000000 == 0) { VLOG(0) << count << " nodes are loaded from filepath"; + VLOG(0) << line; } + count++; std::string nt = values[0]; if (nt != node_type) { @@ -273,6 +274,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { } if (count % 1000000 == 0) { VLOG(0) << count << " edges are loaded from filepath"; + VLOG(0) << line; } size_t index = src_shard_id - shard_start; diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index 93547e10d49399f6a7bf09ff8b90fbd8d3996e2d..9b55daa210c10e48548991103941345b962e5a65 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -556,15 +556,17 @@ void RunBrpcPushSparse() { ASSERT_EQ(count_item_nodes.size(), 12); } - vs = client1.batch_sample_neighbors(std::string("user2item"), - std::vector(1, 96), 4); - ASSERT_EQ(vs[0].size(), 3); + std::pair>, std::vector> res; + res = client1.batch_sample_neighbors( + std::string("user2item"), std::vector(1, 96), 4, true, false); + ASSERT_EQ(res.first[0].size(), 3); std::vector node_ids; node_ids.push_back(96); node_ids.push_back(37); - vs = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4); + res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4, + true, false); - ASSERT_EQ(vs.size(), 2); + ASSERT_EQ(res.first[1].size(), 1); std::vector nodes_ids = client2.random_sample_nodes("user", 0, 6); ASSERT_EQ(nodes_ids.size(), 2); ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) || @@ -693,4 +695,4 @@ void testGraphToBuffer() { VLOG(0) << s1.get_feature(0); } -TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } \ No newline at end of file +TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 7c57d2b6fb851b46fe2928c6a00c14eec10e91cb..3f3eec345cb616c37f84cdc0ddf628d9350e5b87 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -209,6 +209,8 @@ void BindGraphPyClient(py::module* m) { .def("start_client", &GraphPyClient::start_client) .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighbors) .def("batch_sample_neighbors", &GraphPyClient::batch_sample_neighbors) + .def("use_neighbors_sample_cache", + &GraphPyClient::use_neighbors_sample_cache) .def("remove_graph_node", &GraphPyClient::remove_graph_node) .def("random_sample_nodes", &GraphPyClient::random_sample_nodes) .def("stop_server", &GraphPyClient::stop_server)