提交 969c1880 编写于 作者: W Webbley

add mmap mode for heter graph

上级 bccff22a
...@@ -14,12 +14,13 @@ ...@@ -14,12 +14,13 @@
""" """
This package implement Heterogeneous Graph structure for handling Heterogeneous graph data. This package implement Heterogeneous Graph structure for handling Heterogeneous graph data.
""" """
import os
import time import time
import numpy as np import numpy as np
import pickle as pkl import pickle as pkl
import time import time
import pgl.graph_kernel as graph_kernel import pgl.graph_kernel as graph_kernel
from pgl.graph import Graph from pgl.graph import Graph, MemmapGraph
__all__ = ['HeterGraph', 'SubHeterGraph'] __all__ = ['HeterGraph', 'SubHeterGraph']
...@@ -113,6 +114,30 @@ class HeterGraph(object): ...@@ -113,6 +114,30 @@ class HeterGraph(object):
self._edge_types = self.edge_types_info() self._edge_types = self.edge_types_info()
def dump(self, path, indegree=False, outdegree=False):
if indegree:
for e_type, g in self._multi_graph.items():
g.indegree()
if outdegree:
for e_type, g in self._multi_graph.items():
g.outdegree()
if not os.path.exists(path):
os.makedirs(path)
np.save(os.path.join(path, "num_nodes.npy"), self._num_nodes)
np.save(os.path.join(path, "node_types.npy"), self._node_types)
with open(os.path.join(path, "edge_types.pkl"), 'wb') as f:
pkl.dump(self._edge_types, f)
with open(os.path.join(path, "nodes_type_dict.pkl"), 'wb') as f:
pkl.dump(self._nodes_type_dict, f)
for e_type, g in self._multi_graph.items():
sub_path = os.path.join(path, e_type)
g.dump(sub_path)
@property @property
def edge_types(self): def edge_types(self):
"""Return a list of edge types. """Return a list of edge types.
...@@ -399,7 +424,7 @@ class HeterGraph(object): ...@@ -399,7 +424,7 @@ class HeterGraph(object):
""" """
edge_types_info = [] edge_types_info = []
for key, _ in self._edges_dict.items(): for key, _ in self._multi_graph.items():
edge_types_info.append(key) edge_types_info.append(key)
return edge_types_info return edge_types_info
...@@ -460,3 +485,21 @@ class SubHeterGraph(HeterGraph): ...@@ -460,3 +485,21 @@ class SubHeterGraph(HeterGraph):
A list of node ids in parent graph. A list of node ids in parent graph.
""" """
return graph_kernel.map_nodes(nodes, self._to_reindex) return graph_kernel.map_nodes(nodes, self._to_reindex)
class MemmapHeterGraph(HeterGraph):
def __init__(self, path):
self._num_nodes = np.load(os.path.join(path, 'num_nodes.npy'))
self._node_types = np.load(
os.path.join(path, 'node_types.npy'), allow_pickle=True)
with open(os.path.join(path, 'edge_types.pkl'), 'rb') as f:
self._edge_types = pkl.load(f)
with open(os.path.join(path, "nodes_type_dict.pkl"), 'rb') as f:
self._nodes_type_dict = pkl.load(f)
self._multi_graph = {}
for e_type in self._edge_types:
sub_path = os.path.join(path, e_type)
self._multi_graph[e_type] = MemmapGraph(sub_path)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test_hetergraph"""
import time
import unittest
import json
import os
import numpy as np
from pgl.sample import metapath_randomwalk
from pgl.graph import Graph
from pgl import heter_graph
from pgl.heter_graph import MemmapHeterGraph
def test_dump():
np.random.seed(1)
edges = {}
# for test no successor
edges['c2p'] = [(1, 4), (0, 5), (1, 9), (1, 8), (2, 8), (2, 5), (3, 6),
(3, 7), (3, 4), (3, 8)]
edges['p2c'] = [(v, u) for u, v in edges['c2p']]
edges['p2a'] = [(4, 10), (4, 11), (4, 12), (4, 14), (4, 13), (6, 12),
(6, 11), (6, 14), (7, 12), (7, 11), (8, 14), (9, 10)]
edges['a2p'] = [(v, u) for u, v in edges['p2a']]
node_types = ['c' for _ in range(4)] + ['p' for _ in range(6)
] + ['a' for _ in range(5)]
node_types = [(i, t) for i, t in enumerate(node_types)]
graph = heter_graph.HeterGraph(
num_nodes=len(node_types), edges=edges, node_types=node_types)
graph.dump("./hetergraph_mmap", outdegree=True)
def test_load():
graph = MemmapHeterGraph("./hetergraph_mmap")
class MmapHeterGraphTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.graph = MemmapHeterGraph("./hetergraph_mmap")
def test_num_nodes_by_type(self):
print()
n_types = {'c': 4, 'p': 6, 'a': 5}
for nt in n_types:
num_nodes = self.graph.num_nodes_by_type(nt)
self.assertEqual(num_nodes, n_types[nt])
def test_node_batch_iter(self):
print()
batch_size = 2
ground = [[4, 5], [6, 7], [8, 9]]
for idx, nodes in enumerate(
self.graph.node_batch_iter(
batch_size=batch_size, shuffle=False, n_type='p')):
self.assertEqual(len(nodes), batch_size)
self.assertListEqual(list(nodes), ground[idx])
def test_sample_successor(self):
print()
nodes = [4, 5, 8]
md = 2
succes = self.graph.sample_successor(
edge_type='p2a', nodes=nodes, max_degree=md, return_eids=False)
self.assertIsInstance(succes, list)
ground = [[10, 11, 12, 14, 13], [], [14]]
for succ, g in zip(succes, ground):
self.assertIsInstance(succ, np.ndarray)
for i in succ:
self.assertIn(i, g)
nodes = [4]
succes = self.graph.sample_successor(
edge_type='p2a', nodes=nodes, max_degree=md, return_eids=False)
self.assertIsInstance(succes, list)
ground = [[10, 11, 12, 14, 13]]
for succ, g in zip(succes, ground):
self.assertIsInstance(succ, np.ndarray)
for i in succ:
self.assertIn(i, g)
def test_successor(self):
print()
nodes = [4, 5, 8]
e_type = 'p2a'
succes = self.graph.successor(
edge_type=e_type,
nodes=nodes, )
self.assertIsInstance(succes, np.ndarray)
ground = [[10, 11, 12, 14, 13], [], [14]]
for succ, g in zip(succes, ground):
self.assertIsInstance(succ, np.ndarray)
self.assertCountEqual(succ, g)
nodes = [4]
e_type = 'p2a'
succes = self.graph.successor(
edge_type=e_type,
nodes=nodes, )
self.assertIsInstance(succes, np.ndarray)
ground = [[10, 11, 12, 14, 13]]
for succ, g in zip(succes, ground):
self.assertIsInstance(succ, np.ndarray)
self.assertCountEqual(succ, g)
def test_predecessor(self):
print()
nodes = [11, 12, 13]
e_type = 'p2a'
pre = self.graph.predecessor(
edge_type=e_type,
nodes=nodes, )
self.assertIsInstance(pre, np.ndarray)
print(pre)
ground = [[4, 6, 7], [4, 6, 7], [4]]
for succ, g in zip(pre, ground):
self.assertIsInstance(succ, np.ndarray)
self.assertCountEqual(succ, g)
nodes = [11]
e_type = 'p2a'
pre = self.graph.predecessor(
edge_type=e_type,
nodes=nodes, )
print(pre)
self.assertIsInstance(pre, np.ndarray)
ground = [[4, 6, 7]]
for p, g in zip(pre, ground):
self.assertIsInstance(p, np.ndarray)
self.assertCountEqual(p, g)
def test_sample_nodes(self):
print()
p_ground = [4, 5, 6, 7, 8, 9]
sample_num = 10
nodes = self.graph.sample_nodes(sample_num=sample_num, n_type='p')
self.assertEqual(len(nodes), sample_num)
for n in nodes:
self.assertIn(n, p_ground)
# test n_type == None
ground = [i for i in range(15)]
nodes = self.graph.sample_nodes(sample_num=sample_num, n_type=None)
self.assertEqual(len(nodes), sample_num)
for n in nodes:
self.assertIn(n, ground)
if __name__ == "__main__":
unittest.main()
# test_dump()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册