提交 cea0041b 编写于 作者: Z Zhong Hui

add load support for pgl graph

上级 bcbed7f4
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""NodePropPredDataset for pgl """NodePropPredDataset for pgl
""" """
import pgl
import pandas as pd import pandas as pd
import shutil, os import shutil, os
import os.path as osp import os.path as osp
...@@ -69,8 +70,19 @@ class PglNodePropPredDataset(object): ...@@ -69,8 +70,19 @@ class PglNodePropPredDataset(object):
pre_processed_file_path = osp.join(processed_dir, 'pgl_data_processed') pre_processed_file_path = osp.join(processed_dir, 'pgl_data_processed')
if osp.exists(pre_processed_file_path): if osp.exists(pre_processed_file_path):
# TODO: Reload Preprocess files # TODO: Reload Preprocess files. DONE @ZHUI
pass # TODO: add support for heterogenous graph.
self.graph = []
if os.path.isdir(pre_processed_file_path):
for i in range(len(os.listdir(pre_processed_file_path))):
graph_path = os.path.join(pre_processed_file_path,
"graph_{}".format(i))
if os.path.exists(graph_path):
self.graph.append(pgl.graph.Graph().load(graph_path))
node_label = np.load(
os.path.join(pre_processed_file_path, "node_label.npy"))
label_dict = {"labels": node_label}
self.labels = label_dict['labels']
else: else:
### check download ### check download
if not osp.exists(osp.join(self.root, "raw")): if not osp.exists(osp.join(self.root, "raw")):
...@@ -152,8 +164,15 @@ class PglNodePropPredDataset(object): ...@@ -152,8 +164,15 @@ class PglNodePropPredDataset(object):
label_dict = {"labels": node_label} label_dict = {"labels": node_label}
# TODO: SAVE preprocess graph
self.labels = label_dict['labels'] self.labels = label_dict['labels']
# TODO: SAVE preprocess graph, DONE @ZHUI
for i in range(len(self.graph)):
self.graph[i].dump(
os.path.join(pre_processed_file_path,
"graph_{}".format(i)))
np.save(
os.path.join(pre_processed_file_path, "node_label.npy"),
node_label)
def get_idx_split(self): def get_idx_split(self):
"""Train/Validation/Test split """Train/Validation/Test split
......
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
import numpy as np import numpy as np
import pickle as pkl import pickle as pkl
import time import time
import warnings
import pgl.graph_kernel as graph_kernel import pgl.graph_kernel as graph_kernel
from collections import defaultdict from collections import defaultdict
...@@ -44,9 +45,13 @@ class EdgeIndex(object): ...@@ -44,9 +45,13 @@ class EdgeIndex(object):
num_nodes: The exactive number of nodes. num_nodes: The exactive number of nodes.
""" """
def __init__(self, u, v, num_nodes): def __init__(self, u=None, v=None, num_nodes=None):
self._degree, self._sorted_v, self._sorted_u, \ if num_nodes is None:
self._sorted_eid, self._indptr = graph_kernel.build_index(u, v, num_nodes) warnings.warn(
"Creat empty edge index, please load index before use it!")
else:
self._degree, self._sorted_v, self._sorted_u, \
self._sorted_eid, self._indptr = graph_kernel.build_index(u, v, num_nodes)
@property @property
def degree(self): def degree(self):
...@@ -88,6 +93,18 @@ class EdgeIndex(object): ...@@ -88,6 +93,18 @@ class EdgeIndex(object):
np.save(os.path.join(path, 'sorted_eid.npy'), self._sorted_eid) np.save(os.path.join(path, 'sorted_eid.npy'), self._sorted_eid)
np.save(os.path.join(path, 'indptr.npy'), self._indptr) np.save(os.path.join(path, 'indptr.npy'), self._indptr)
def load(self, path, mmap_mode=None):
self._degree = np.load(
os.path.join(path, 'degree.npy'), mmap_mode=mmap_mode)
self._sorted_u = np.load(
os.path.join(path, 'sorted_u.npy'), mmap_mode=mmap_mode)
self._sorted_v = np.load(
os.path.join(path, 'sorted_v.npy'), mmap_mode=mmap_mode)
self._sorted_eid = np.load(
os.path.join(path, 'sorted_eid.npy'), mmap_mode=mmap_mode)
self._indptr = np.load(
os.path.join(path, 'indptr.npy'), mmap_mode=mmap_mode)
class Graph(object): class Graph(object):
"""Implementation of graph structure in pgl. """Implementation of graph structure in pgl.
...@@ -121,7 +138,15 @@ class Graph(object): ...@@ -121,7 +138,15 @@ class Graph(object):
""" """
def __init__(self, num_nodes, edges=None, node_feat=None, edge_feat=None): def __init__(self,
num_nodes=None,
edges=None,
node_feat=None,
edge_feat=None):
if num_nodes is None:
warnings.warn(
"Creat empty Graph, please load graph data before use it!")
return
if node_feat is not None: if node_feat is not None:
self._node_feat = node_feat self._node_feat = node_feat
else: else:
...@@ -172,6 +197,45 @@ class Graph(object): ...@@ -172,6 +197,45 @@ class Graph(object):
dump_feat(os.path.join(path, "node_feat"), self.node_feat) dump_feat(os.path.join(path, "node_feat"), self.node_feat)
dump_feat(os.path.join(path, "edge_feat"), self.edge_feat) dump_feat(os.path.join(path, "edge_feat"), self.edge_feat)
def load(self, path, mmap_mode=None):
""" load graph from dumped files.
"""
if not os.path.exists(path):
raise ValueError("Not find path {}, can't load graph".format(path))
self._num_nodes = np.load(os.path.join(path, 'num_nodes.npy'))
self._edges = np.load(
os.path.join(path, 'edges.npy'), mmap_mode=mmap_mode)
if os.path.isdir(os.path.join(path, 'adj_src')):
edge_index = EdgeIndex()
edge_index.load(os.path.join(path, 'adj_src'), mmap_mode=mmap_mode)
self._adj_src_index = edge_index
else:
self._adj_src_index = None
if os.path.isdir(os.path.join(path, 'adj_dst')):
edge_index = EdgeIndex()
edge_index.load(os.path.join(path, 'adj_dst'), mmap_mode=mmap_mode)
self._adj_dst_index = edge_index
else:
self._adj_dst_index = None
def load_feat(feat_path):
"""Load features from .npy file.
"""
feat = {}
if os.path.isdir(feat_path):
for feat_name in os.listdir(feat_path):
feat[os.path.splitext(feat_name)[0]] = np.load(
os.path.join(feat_path, feat_name),
mmap_mode=mmap_mode)
return feat
self._node_feat = load_feat(os.path.join(path, 'node_feat'))
self._edge_feat = load_feat(os.path.join(path, 'edge_feat'))
return self
@property @property
def adj_src_index(self): def adj_src_index(self):
"""Return an EdgeIndex object for src. """Return an EdgeIndex object for src.
...@@ -573,7 +637,7 @@ class Graph(object): ...@@ -573,7 +637,7 @@ class Graph(object):
eid (optional): Edge ids which will be included in the subgraph. eid (optional): Edge ids which will be included in the subgraph.
edges (optional): Edge(src, dst) list which will be included in the subgraph. edges (optional): Edge(src, dst) list which will be included in the subgraph.
with_node_feat: Whether to inherit node features from parent graph. with_node_feat: Whether to inherit node features from parent graph.
with_edge_feat: Whether to inherit edge features from parent graph. with_edge_feat: Whether to inherit edge features from parent graph.
...@@ -593,7 +657,7 @@ class Graph(object): ...@@ -593,7 +657,7 @@ class Graph(object):
edges = self._edges[eid] edges = self._edges[eid]
else: else:
edges = np.array(edges, dtype="int64") edges = np.array(edges, dtype="int64")
sub_edges = graph_kernel.map_edges( sub_edges = graph_kernel.map_edges(
np.arange( np.arange(
len(edges), dtype="int64"), edges, reindex) len(edges), dtype="int64"), edges, reindex)
...@@ -667,7 +731,7 @@ class Graph(object): ...@@ -667,7 +731,7 @@ class Graph(object):
replace: boolean, Whether the sample is with or without replacement. replace: boolean, Whether the sample is with or without replacement.
Return: Return:
(u, v), eid (u, v), eid
each is a numy.array with the same shape. each is a numy.array with the same shape.
""" """
...@@ -864,7 +928,7 @@ class MultiGraph(Graph): ...@@ -864,7 +928,7 @@ class MultiGraph(Graph):
Examples: Examples:
.. code-block:: python .. code-block:: python
batch_graph = MultiGraph([graph1, graph2, graph3]) batch_graph = MultiGraph([graph1, graph2, graph3])
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册