提交 cea0041b 编写于 作者: Z Zhong Hui

add load support for pgl graph

上级 bcbed7f4
......@@ -13,6 +13,7 @@
# limitations under the License.
"""NodePropPredDataset for pgl
"""
import pgl
import pandas as pd
import shutil, os
import os.path as osp
......@@ -69,8 +70,19 @@ class PglNodePropPredDataset(object):
pre_processed_file_path = osp.join(processed_dir, 'pgl_data_processed')
if osp.exists(pre_processed_file_path):
# TODO: Reload Preprocess files
pass
# TODO: Reload Preprocess files. DONE @ZHUI
# TODO: add support for heterogenous graph.
self.graph = []
if os.path.isdir(pre_processed_file_path):
for i in range(len(os.listdir(pre_processed_file_path))):
graph_path = os.path.join(pre_processed_file_path,
"graph_{}".format(i))
if os.path.exists(graph_path):
self.graph.append(pgl.graph.Graph().load(graph_path))
node_label = np.load(
os.path.join(pre_processed_file_path, "node_label.npy"))
label_dict = {"labels": node_label}
self.labels = label_dict['labels']
else:
### check download
if not osp.exists(osp.join(self.root, "raw")):
......@@ -152,8 +164,15 @@ class PglNodePropPredDataset(object):
label_dict = {"labels": node_label}
# TODO: SAVE preprocess graph
self.labels = label_dict['labels']
# TODO: SAVE preprocess graph, DONE @ZHUI
for i in range(len(self.graph)):
self.graph[i].dump(
os.path.join(pre_processed_file_path,
"graph_{}".format(i)))
np.save(
os.path.join(pre_processed_file_path, "node_label.npy"),
node_label)
def get_idx_split(self):
"""Train/Validation/Test split
......
......@@ -19,6 +19,7 @@ import os
import numpy as np
import pickle as pkl
import time
import warnings
import pgl.graph_kernel as graph_kernel
from collections import defaultdict
......@@ -44,9 +45,13 @@ class EdgeIndex(object):
num_nodes: The exactive number of nodes.
"""
def __init__(self, u, v, num_nodes):
self._degree, self._sorted_v, self._sorted_u, \
self._sorted_eid, self._indptr = graph_kernel.build_index(u, v, num_nodes)
def __init__(self, u=None, v=None, num_nodes=None):
if num_nodes is None:
warnings.warn(
"Creat empty edge index, please load index before use it!")
else:
self._degree, self._sorted_v, self._sorted_u, \
self._sorted_eid, self._indptr = graph_kernel.build_index(u, v, num_nodes)
@property
def degree(self):
......@@ -88,6 +93,18 @@ class EdgeIndex(object):
np.save(os.path.join(path, 'sorted_eid.npy'), self._sorted_eid)
np.save(os.path.join(path, 'indptr.npy'), self._indptr)
def load(self, path, mmap_mode=None):
self._degree = np.load(
os.path.join(path, 'degree.npy'), mmap_mode=mmap_mode)
self._sorted_u = np.load(
os.path.join(path, 'sorted_u.npy'), mmap_mode=mmap_mode)
self._sorted_v = np.load(
os.path.join(path, 'sorted_v.npy'), mmap_mode=mmap_mode)
self._sorted_eid = np.load(
os.path.join(path, 'sorted_eid.npy'), mmap_mode=mmap_mode)
self._indptr = np.load(
os.path.join(path, 'indptr.npy'), mmap_mode=mmap_mode)
class Graph(object):
"""Implementation of graph structure in pgl.
......@@ -121,7 +138,15 @@ class Graph(object):
"""
def __init__(self, num_nodes, edges=None, node_feat=None, edge_feat=None):
def __init__(self,
num_nodes=None,
edges=None,
node_feat=None,
edge_feat=None):
if num_nodes is None:
warnings.warn(
"Creat empty Graph, please load graph data before use it!")
return
if node_feat is not None:
self._node_feat = node_feat
else:
......@@ -172,6 +197,45 @@ class Graph(object):
dump_feat(os.path.join(path, "node_feat"), self.node_feat)
dump_feat(os.path.join(path, "edge_feat"), self.edge_feat)
def load(self, path, mmap_mode=None):
""" load graph from dumped files.
"""
if not os.path.exists(path):
raise ValueError("Not find path {}, can't load graph".format(path))
self._num_nodes = np.load(os.path.join(path, 'num_nodes.npy'))
self._edges = np.load(
os.path.join(path, 'edges.npy'), mmap_mode=mmap_mode)
if os.path.isdir(os.path.join(path, 'adj_src')):
edge_index = EdgeIndex()
edge_index.load(os.path.join(path, 'adj_src'), mmap_mode=mmap_mode)
self._adj_src_index = edge_index
else:
self._adj_src_index = None
if os.path.isdir(os.path.join(path, 'adj_dst')):
edge_index = EdgeIndex()
edge_index.load(os.path.join(path, 'adj_dst'), mmap_mode=mmap_mode)
self._adj_dst_index = edge_index
else:
self._adj_dst_index = None
def load_feat(feat_path):
"""Load features from .npy file.
"""
feat = {}
if os.path.isdir(feat_path):
for feat_name in os.listdir(feat_path):
feat[os.path.splitext(feat_name)[0]] = np.load(
os.path.join(feat_path, feat_name),
mmap_mode=mmap_mode)
return feat
self._node_feat = load_feat(os.path.join(path, 'node_feat'))
self._edge_feat = load_feat(os.path.join(path, 'edge_feat'))
return self
@property
def adj_src_index(self):
"""Return an EdgeIndex object for src.
......@@ -573,7 +637,7 @@ class Graph(object):
eid (optional): Edge ids which will be included in the subgraph.
edges (optional): Edge(src, dst) list which will be included in the subgraph.
with_node_feat: Whether to inherit node features from parent graph.
with_edge_feat: Whether to inherit edge features from parent graph.
......@@ -593,7 +657,7 @@ class Graph(object):
edges = self._edges[eid]
else:
edges = np.array(edges, dtype="int64")
sub_edges = graph_kernel.map_edges(
np.arange(
len(edges), dtype="int64"), edges, reindex)
......@@ -667,7 +731,7 @@ class Graph(object):
replace: boolean, Whether the sample is with or without replacement.
Return:
(u, v), eid
(u, v), eid
each is a numy.array with the same shape.
"""
......@@ -864,7 +928,7 @@ class MultiGraph(Graph):
Examples:
.. code-block:: python
batch_graph = MultiGraph([graph1, graph2, graph3])
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册