From cea0041b37e51a80056786d9bec490d0e30ecaf6 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Tue, 14 Jul 2020 10:32:52 +0800 Subject: [PATCH] add load support for pgl graph --- pgl/contrib/ogb/nodeproppred/dataset_pgl.py | 25 ++++++- pgl/graph.py | 80 ++++++++++++++++++--- 2 files changed, 94 insertions(+), 11 deletions(-) diff --git a/pgl/contrib/ogb/nodeproppred/dataset_pgl.py b/pgl/contrib/ogb/nodeproppred/dataset_pgl.py index a6cf521..e6405df 100644 --- a/pgl/contrib/ogb/nodeproppred/dataset_pgl.py +++ b/pgl/contrib/ogb/nodeproppred/dataset_pgl.py @@ -13,6 +13,7 @@ # limitations under the License. """NodePropPredDataset for pgl """ +import pgl import pandas as pd import shutil, os import os.path as osp @@ -69,8 +70,19 @@ class PglNodePropPredDataset(object): pre_processed_file_path = osp.join(processed_dir, 'pgl_data_processed') if osp.exists(pre_processed_file_path): - # TODO: Reload Preprocess files - pass + # TODO: Reload Preprocess files. DONE @ZHUI + # TODO: add support for heterogenous graph. + self.graph = [] + if os.path.isdir(pre_processed_file_path): + for i in range(len(os.listdir(pre_processed_file_path))): + graph_path = os.path.join(pre_processed_file_path, + "graph_{}".format(i)) + if os.path.exists(graph_path): + self.graph.append(pgl.graph.Graph().load(graph_path)) + node_label = np.load( + os.path.join(pre_processed_file_path, "node_label.npy")) + label_dict = {"labels": node_label} + self.labels = label_dict['labels'] else: ### check download if not osp.exists(osp.join(self.root, "raw")): @@ -152,8 +164,15 @@ class PglNodePropPredDataset(object): label_dict = {"labels": node_label} - # TODO: SAVE preprocess graph self.labels = label_dict['labels'] + # TODO: SAVE preprocess graph, DONE @ZHUI + for i in range(len(self.graph)): + self.graph[i].dump( + os.path.join(pre_processed_file_path, + "graph_{}".format(i))) + np.save( + os.path.join(pre_processed_file_path, "node_label.npy"), + node_label) def get_idx_split(self): """Train/Validation/Test split diff --git a/pgl/graph.py b/pgl/graph.py index 85ec406..493f8d5 100644 --- a/pgl/graph.py +++ b/pgl/graph.py @@ -19,6 +19,7 @@ import os import numpy as np import pickle as pkl import time +import warnings import pgl.graph_kernel as graph_kernel from collections import defaultdict @@ -44,9 +45,13 @@ class EdgeIndex(object): num_nodes: The exactive number of nodes. """ - def __init__(self, u, v, num_nodes): - self._degree, self._sorted_v, self._sorted_u, \ - self._sorted_eid, self._indptr = graph_kernel.build_index(u, v, num_nodes) + def __init__(self, u=None, v=None, num_nodes=None): + if num_nodes is None: + warnings.warn( + "Creat empty edge index, please load index before use it!") + else: + self._degree, self._sorted_v, self._sorted_u, \ + self._sorted_eid, self._indptr = graph_kernel.build_index(u, v, num_nodes) @property def degree(self): @@ -88,6 +93,18 @@ class EdgeIndex(object): np.save(os.path.join(path, 'sorted_eid.npy'), self._sorted_eid) np.save(os.path.join(path, 'indptr.npy'), self._indptr) + def load(self, path, mmap_mode=None): + self._degree = np.load( + os.path.join(path, 'degree.npy'), mmap_mode=mmap_mode) + self._sorted_u = np.load( + os.path.join(path, 'sorted_u.npy'), mmap_mode=mmap_mode) + self._sorted_v = np.load( + os.path.join(path, 'sorted_v.npy'), mmap_mode=mmap_mode) + self._sorted_eid = np.load( + os.path.join(path, 'sorted_eid.npy'), mmap_mode=mmap_mode) + self._indptr = np.load( + os.path.join(path, 'indptr.npy'), mmap_mode=mmap_mode) + class Graph(object): """Implementation of graph structure in pgl. @@ -121,7 +138,15 @@ class Graph(object): """ - def __init__(self, num_nodes, edges=None, node_feat=None, edge_feat=None): + def __init__(self, + num_nodes=None, + edges=None, + node_feat=None, + edge_feat=None): + if num_nodes is None: + warnings.warn( + "Creat empty Graph, please load graph data before use it!") + return if node_feat is not None: self._node_feat = node_feat else: @@ -172,6 +197,45 @@ class Graph(object): dump_feat(os.path.join(path, "node_feat"), self.node_feat) dump_feat(os.path.join(path, "edge_feat"), self.edge_feat) + def load(self, path, mmap_mode=None): + """ load graph from dumped files. + """ + if not os.path.exists(path): + raise ValueError("Not find path {}, can't load graph".format(path)) + + self._num_nodes = np.load(os.path.join(path, 'num_nodes.npy')) + self._edges = np.load( + os.path.join(path, 'edges.npy'), mmap_mode=mmap_mode) + if os.path.isdir(os.path.join(path, 'adj_src')): + edge_index = EdgeIndex() + edge_index.load(os.path.join(path, 'adj_src'), mmap_mode=mmap_mode) + self._adj_src_index = edge_index + else: + self._adj_src_index = None + + if os.path.isdir(os.path.join(path, 'adj_dst')): + edge_index = EdgeIndex() + edge_index.load(os.path.join(path, 'adj_dst'), mmap_mode=mmap_mode) + self._adj_dst_index = edge_index + else: + self._adj_dst_index = None + + def load_feat(feat_path): + """Load features from .npy file. + """ + feat = {} + if os.path.isdir(feat_path): + for feat_name in os.listdir(feat_path): + feat[os.path.splitext(feat_name)[0]] = np.load( + os.path.join(feat_path, feat_name), + mmap_mode=mmap_mode) + return feat + + self._node_feat = load_feat(os.path.join(path, 'node_feat')) + self._edge_feat = load_feat(os.path.join(path, 'edge_feat')) + + return self + @property def adj_src_index(self): """Return an EdgeIndex object for src. @@ -573,7 +637,7 @@ class Graph(object): eid (optional): Edge ids which will be included in the subgraph. edges (optional): Edge(src, dst) list which will be included in the subgraph. - + with_node_feat: Whether to inherit node features from parent graph. with_edge_feat: Whether to inherit edge features from parent graph. @@ -593,7 +657,7 @@ class Graph(object): edges = self._edges[eid] else: edges = np.array(edges, dtype="int64") - + sub_edges = graph_kernel.map_edges( np.arange( len(edges), dtype="int64"), edges, reindex) @@ -667,7 +731,7 @@ class Graph(object): replace: boolean, Whether the sample is with or without replacement. Return: - (u, v), eid + (u, v), eid each is a numy.array with the same shape. """ @@ -864,7 +928,7 @@ class MultiGraph(Graph): Examples: .. code-block:: python - + batch_graph = MultiGraph([graph1, graph2, graph3]) """ -- GitLab