未验证 提交 8c58f962 编写于 作者: S seemingwang 提交者: GitHub

enable graph-engine to return all id (#42319)

* enable graph-engine to return all id

* change vector's dimension

* change vector's dimension

* enlarge returned ids dimensions
上级 32cae24c
......@@ -85,6 +85,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
}
return res;
}
int32_t GraphTable::add_node_to_ssd(int type_id, int idx, int64_t src_id,
char *data, int len) {
if (_db != NULL) {
......@@ -1060,6 +1061,26 @@ std::pair<int32_t, std::string> GraphTable::parse_feature(
return std::make_pair<int32_t, std::string>(-1, "");
}
std::vector<std::vector<int64_t>> GraphTable::get_all_id(int type_id, int idx,
int slice_num) {
std::vector<std::vector<int64_t>> res(slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<std::vector<int64_t>>> tasks;
for (int i = 0; i < search_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&search_shards, i]() -> std::vector<int64_t> {
return search_shards[i]->get_all_id();
}));
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) res[id % slice_num].push_back(id);
}
return res;
}
int32_t GraphTable::pull_graph_list(int type_id, int idx, int start,
int total_size,
std::unique_ptr<char[]> &buffer,
......
......@@ -63,7 +63,13 @@ class GraphShard {
}
return res;
}
std::vector<int64_t> get_all_id() {
std::vector<int64_t> res;
for (int i = 0; i < (int)bucket.size(); i++) {
res.push_back(bucket[i]->get_id());
}
return res;
}
GraphNode *add_graph_node(int64_t id);
GraphNode *add_graph_node(Node *node);
FeatureNode *add_feature_node(int64_t id);
......@@ -465,6 +471,8 @@ class GraphTable : public Table {
int32_t load_edges(const std::string &path, bool reverse,
const std::string &edge_type);
std::vector<std::vector<int64_t>> get_all_id(int type, int idx,
int slice_num);
int32_t load_nodes(const std::string &path, std::string node_type);
int32_t add_graph_node(int idx, std::vector<int64_t> &id_list,
......
......@@ -58,6 +58,11 @@ void GraphGpuWrapper::set_device(std::vector<int> ids) {
device_id_mapping.push_back(device_id);
}
}
std::vector<std::vector<int64_t>> GraphGpuWrapper::get_all_id(int type, int idx,
int slice_num) {
return ((GpuPsGraphTable *)graph_table)
->cpu_graph_table->get_all_id(type, idx, slice_num);
}
void GraphGpuWrapper::set_up_types(std::vector<std::string> &edge_types,
std::vector<std::string> &node_types) {
id_to_edge = edge_types;
......
......@@ -34,6 +34,8 @@ class GraphGpuWrapper {
std::string feat_dtype, int feat_shape);
void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath);
std::vector<std::vector<int64_t>> get_all_id(int type, int idx,
int slice_num);
NodeQueryResult query_node_list(int gpu_id, int start, int query_size);
NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q,
bool cpu_switch);
......
......@@ -342,6 +342,7 @@ void BindGraphGpuWrapper(py::module* m) {
.def("add_table_feat_conf", &GraphGpuWrapper::add_table_feat_conf)
.def("load_edge_file", &GraphGpuWrapper::load_edge_file)
.def("upload_batch", &GraphGpuWrapper::upload_batch)
.def("get_all_id", &GraphGpuWrapper::get_all_id)
.def("load_node_file", &GraphGpuWrapper::load_node_file);
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册