提交 969c1880 编写于 作者: W Webbley

add mmap mode for heter graph

上级 bccff22a
......@@ -14,12 +14,13 @@
"""
This package implement Heterogeneous Graph structure for handling Heterogeneous graph data.
"""
import os
import time
import numpy as np
import pickle as pkl
import time
import pgl.graph_kernel as graph_kernel
from pgl.graph import Graph
from pgl.graph import Graph, MemmapGraph
__all__ = ['HeterGraph', 'SubHeterGraph']
......@@ -113,6 +114,30 @@ class HeterGraph(object):
self._edge_types = self.edge_types_info()
def dump(self, path, indegree=False, outdegree=False):
if indegree:
for e_type, g in self._multi_graph.items():
g.indegree()
if outdegree:
for e_type, g in self._multi_graph.items():
g.outdegree()
if not os.path.exists(path):
os.makedirs(path)
np.save(os.path.join(path, "num_nodes.npy"), self._num_nodes)
np.save(os.path.join(path, "node_types.npy"), self._node_types)
with open(os.path.join(path, "edge_types.pkl"), 'wb') as f:
pkl.dump(self._edge_types, f)
with open(os.path.join(path, "nodes_type_dict.pkl"), 'wb') as f:
pkl.dump(self._nodes_type_dict, f)
for e_type, g in self._multi_graph.items():
sub_path = os.path.join(path, e_type)
g.dump(sub_path)
@property
def edge_types(self):
"""Return a list of edge types.
......@@ -399,7 +424,7 @@ class HeterGraph(object):
"""
edge_types_info = []
for key, _ in self._edges_dict.items():
for key, _ in self._multi_graph.items():
edge_types_info.append(key)
return edge_types_info
......@@ -460,3 +485,21 @@ class SubHeterGraph(HeterGraph):
A list of node ids in parent graph.
"""
return graph_kernel.map_nodes(nodes, self._to_reindex)
class MemmapHeterGraph(HeterGraph):
def __init__(self, path):
self._num_nodes = np.load(os.path.join(path, 'num_nodes.npy'))
self._node_types = np.load(
os.path.join(path, 'node_types.npy'), allow_pickle=True)
with open(os.path.join(path, 'edge_types.pkl'), 'rb') as f:
self._edge_types = pkl.load(f)
with open(os.path.join(path, "nodes_type_dict.pkl"), 'rb') as f:
self._nodes_type_dict = pkl.load(f)
self._multi_graph = {}
for e_type in self._edge_types:
sub_path = os.path.join(path, e_type)
self._multi_graph[e_type] = MemmapGraph(sub_path)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test_hetergraph"""
import time
import unittest
import json
import os
import numpy as np
from pgl.sample import metapath_randomwalk
from pgl.graph import Graph
from pgl import heter_graph
from pgl.heter_graph import MemmapHeterGraph
def test_dump():
np.random.seed(1)
edges = {}
# for test no successor
edges['c2p'] = [(1, 4), (0, 5), (1, 9), (1, 8), (2, 8), (2, 5), (3, 6),
(3, 7), (3, 4), (3, 8)]
edges['p2c'] = [(v, u) for u, v in edges['c2p']]
edges['p2a'] = [(4, 10), (4, 11), (4, 12), (4, 14), (4, 13), (6, 12),
(6, 11), (6, 14), (7, 12), (7, 11), (8, 14), (9, 10)]
edges['a2p'] = [(v, u) for u, v in edges['p2a']]
node_types = ['c' for _ in range(4)] + ['p' for _ in range(6)
] + ['a' for _ in range(5)]
node_types = [(i, t) for i, t in enumerate(node_types)]
graph = heter_graph.HeterGraph(
num_nodes=len(node_types), edges=edges, node_types=node_types)
graph.dump("./hetergraph_mmap", outdegree=True)
def test_load():
graph = MemmapHeterGraph("./hetergraph_mmap")
class MmapHeterGraphTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.graph = MemmapHeterGraph("./hetergraph_mmap")
def test_num_nodes_by_type(self):
print()
n_types = {'c': 4, 'p': 6, 'a': 5}
for nt in n_types:
num_nodes = self.graph.num_nodes_by_type(nt)
self.assertEqual(num_nodes, n_types[nt])
def test_node_batch_iter(self):
print()
batch_size = 2
ground = [[4, 5], [6, 7], [8, 9]]
for idx, nodes in enumerate(
self.graph.node_batch_iter(
batch_size=batch_size, shuffle=False, n_type='p')):
self.assertEqual(len(nodes), batch_size)
self.assertListEqual(list(nodes), ground[idx])
def test_sample_successor(self):
print()
nodes = [4, 5, 8]
md = 2
succes = self.graph.sample_successor(
edge_type='p2a', nodes=nodes, max_degree=md, return_eids=False)
self.assertIsInstance(succes, list)
ground = [[10, 11, 12, 14, 13], [], [14]]
for succ, g in zip(succes, ground):
self.assertIsInstance(succ, np.ndarray)
for i in succ:
self.assertIn(i, g)
nodes = [4]
succes = self.graph.sample_successor(
edge_type='p2a', nodes=nodes, max_degree=md, return_eids=False)
self.assertIsInstance(succes, list)
ground = [[10, 11, 12, 14, 13]]
for succ, g in zip(succes, ground):
self.assertIsInstance(succ, np.ndarray)
for i in succ:
self.assertIn(i, g)
def test_successor(self):
print()
nodes = [4, 5, 8]
e_type = 'p2a'
succes = self.graph.successor(
edge_type=e_type,
nodes=nodes, )
self.assertIsInstance(succes, np.ndarray)
ground = [[10, 11, 12, 14, 13], [], [14]]
for succ, g in zip(succes, ground):
self.assertIsInstance(succ, np.ndarray)
self.assertCountEqual(succ, g)
nodes = [4]
e_type = 'p2a'
succes = self.graph.successor(
edge_type=e_type,
nodes=nodes, )
self.assertIsInstance(succes, np.ndarray)
ground = [[10, 11, 12, 14, 13]]
for succ, g in zip(succes, ground):
self.assertIsInstance(succ, np.ndarray)
self.assertCountEqual(succ, g)
def test_predecessor(self):
print()
nodes = [11, 12, 13]
e_type = 'p2a'
pre = self.graph.predecessor(
edge_type=e_type,
nodes=nodes, )
self.assertIsInstance(pre, np.ndarray)
print(pre)
ground = [[4, 6, 7], [4, 6, 7], [4]]
for succ, g in zip(pre, ground):
self.assertIsInstance(succ, np.ndarray)
self.assertCountEqual(succ, g)
nodes = [11]
e_type = 'p2a'
pre = self.graph.predecessor(
edge_type=e_type,
nodes=nodes, )
print(pre)
self.assertIsInstance(pre, np.ndarray)
ground = [[4, 6, 7]]
for p, g in zip(pre, ground):
self.assertIsInstance(p, np.ndarray)
self.assertCountEqual(p, g)
def test_sample_nodes(self):
print()
p_ground = [4, 5, 6, 7, 8, 9]
sample_num = 10
nodes = self.graph.sample_nodes(sample_num=sample_num, n_type='p')
self.assertEqual(len(nodes), sample_num)
for n in nodes:
self.assertIn(n, p_ground)
# test n_type == None
ground = [i for i in range(15)]
nodes = self.graph.sample_nodes(sample_num=sample_num, n_type=None)
self.assertEqual(len(nodes), sample_num)
for n in nodes:
self.assertIn(n, ground)
if __name__ == "__main__":
unittest.main()
# test_dump()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册