diff --git a/examples/GATNE/model.py b/examples/GATNE/model.py index 18f83c89a31324256f20ae118372828fe8be955d..492aa3d97e07df5b4335adc3069df650ff320870 100644 --- a/examples/GATNE/model.py +++ b/examples/GATNE/model.py @@ -114,29 +114,29 @@ class GATNE(object): node_type_embed = fl.gather(node_type_embed, self.train_inputs) # M_r + tn_initializer = fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)) + trans_weights = fl.create_parameter( shape=[ self.edge_type_count, self.embedding_u_size, self.embedding_size // self.att_head ], - attr=fluid.initializer.TruncatedNormalInitializer( - loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)), + default_initializer=tn_initializer, dtype='float32', name='trans_w') # W_r trans_weights_s1 = fl.create_parameter( shape=[self.edge_type_count, self.embedding_u_size, self.dim_a], - attr=fluid.initializer.TruncatedNormalInitializer( - loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)), + default_initializer=tn_initializer, dtype='float32', name='trans_w_s1') # w_r trans_weights_s2 = fl.create_parameter( shape=[self.edge_type_count, self.dim_a, self.att_head], - attr=fluid.initializer.TruncatedNormalInitializer( - loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)), + default_initializer=tn_initializer, dtype='float32', name='trans_w_s2') diff --git a/pgl/heter_graph.py b/pgl/heter_graph.py index b473003f51c9d62b9b7389fbba5c7cfbac84aa0a..fec1632479c01ff62819f29ca8939ddeb7295682 100644 --- a/pgl/heter_graph.py +++ b/pgl/heter_graph.py @@ -14,12 +14,13 @@ """ This package implement Heterogeneous Graph structure for handling Heterogeneous graph data. """ +import os import time import numpy as np import pickle as pkl import time import pgl.graph_kernel as graph_kernel -from pgl.graph import Graph +from pgl.graph import Graph, MemmapGraph __all__ = ['HeterGraph', 'SubHeterGraph'] @@ -113,6 +114,30 @@ class HeterGraph(object): self._edge_types = self.edge_types_info() + def dump(self, path, indegree=False, outdegree=False): + + if indegree: + for e_type, g in self._multi_graph.items(): + g.indegree() + + if outdegree: + for e_type, g in self._multi_graph.items(): + g.outdegree() + + if not os.path.exists(path): + os.makedirs(path) + + np.save(os.path.join(path, "num_nodes.npy"), self._num_nodes) + np.save(os.path.join(path, "node_types.npy"), self._node_types) + with open(os.path.join(path, "edge_types.pkl"), 'wb') as f: + pkl.dump(self._edge_types, f) + with open(os.path.join(path, "nodes_type_dict.pkl"), 'wb') as f: + pkl.dump(self._nodes_type_dict, f) + + for e_type, g in self._multi_graph.items(): + sub_path = os.path.join(path, e_type) + g.dump(sub_path) + @property def edge_types(self): """Return a list of edge types. @@ -399,7 +424,7 @@ class HeterGraph(object): """ edge_types_info = [] - for key, _ in self._edges_dict.items(): + for key, _ in self._multi_graph.items(): edge_types_info.append(key) return edge_types_info @@ -460,3 +485,21 @@ class SubHeterGraph(HeterGraph): A list of node ids in parent graph. """ return graph_kernel.map_nodes(nodes, self._to_reindex) + + +class MemmapHeterGraph(HeterGraph): + def __init__(self, path): + self._num_nodes = np.load(os.path.join(path, 'num_nodes.npy')) + self._node_types = np.load( + os.path.join(path, 'node_types.npy'), allow_pickle=True) + + with open(os.path.join(path, 'edge_types.pkl'), 'rb') as f: + self._edge_types = pkl.load(f) + + with open(os.path.join(path, "nodes_type_dict.pkl"), 'rb') as f: + self._nodes_type_dict = pkl.load(f) + + self._multi_graph = {} + for e_type in self._edge_types: + sub_path = os.path.join(path, e_type) + self._multi_graph[e_type] = MemmapGraph(sub_path) diff --git a/pgl/tests/test_MmapHeterGraph.py b/pgl/tests/test_MmapHeterGraph.py new file mode 100644 index 0000000000000000000000000000000000000000..189e56071389d494c0f858b87a26da0de6343ca2 --- /dev/null +++ b/pgl/tests/test_MmapHeterGraph.py @@ -0,0 +1,173 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""test_hetergraph""" + +import time +import unittest +import json +import os + +import numpy as np +from pgl.sample import metapath_randomwalk +from pgl.graph import Graph +from pgl import heter_graph +from pgl.heter_graph import MemmapHeterGraph + + +def test_dump(): + np.random.seed(1) + edges = {} + # for test no successor + edges['c2p'] = [(1, 4), (0, 5), (1, 9), (1, 8), (2, 8), (2, 5), (3, 6), + (3, 7), (3, 4), (3, 8)] + edges['p2c'] = [(v, u) for u, v in edges['c2p']] + edges['p2a'] = [(4, 10), (4, 11), (4, 12), (4, 14), (4, 13), (6, 12), + (6, 11), (6, 14), (7, 12), (7, 11), (8, 14), (9, 10)] + edges['a2p'] = [(v, u) for u, v in edges['p2a']] + + node_types = ['c' for _ in range(4)] + ['p' for _ in range(6) + ] + ['a' for _ in range(5)] + node_types = [(i, t) for i, t in enumerate(node_types)] + + graph = heter_graph.HeterGraph( + num_nodes=len(node_types), edges=edges, node_types=node_types) + + graph.dump("./hetergraph_mmap", outdegree=True) + + +def test_load(): + graph = MemmapHeterGraph("./hetergraph_mmap") + + +class MmapHeterGraphTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.graph = MemmapHeterGraph("./hetergraph_mmap") + + def test_num_nodes_by_type(self): + print() + n_types = {'c': 4, 'p': 6, 'a': 5} + for nt in n_types: + num_nodes = self.graph.num_nodes_by_type(nt) + self.assertEqual(num_nodes, n_types[nt]) + + def test_node_batch_iter(self): + print() + batch_size = 2 + ground = [[4, 5], [6, 7], [8, 9]] + for idx, nodes in enumerate( + self.graph.node_batch_iter( + batch_size=batch_size, shuffle=False, n_type='p')): + self.assertEqual(len(nodes), batch_size) + self.assertListEqual(list(nodes), ground[idx]) + + def test_sample_successor(self): + print() + nodes = [4, 5, 8] + md = 2 + succes = self.graph.sample_successor( + edge_type='p2a', nodes=nodes, max_degree=md, return_eids=False) + self.assertIsInstance(succes, list) + ground = [[10, 11, 12, 14, 13], [], [14]] + for succ, g in zip(succes, ground): + self.assertIsInstance(succ, np.ndarray) + for i in succ: + self.assertIn(i, g) + + nodes = [4] + succes = self.graph.sample_successor( + edge_type='p2a', nodes=nodes, max_degree=md, return_eids=False) + self.assertIsInstance(succes, list) + ground = [[10, 11, 12, 14, 13]] + for succ, g in zip(succes, ground): + self.assertIsInstance(succ, np.ndarray) + for i in succ: + self.assertIn(i, g) + + def test_successor(self): + print() + nodes = [4, 5, 8] + e_type = 'p2a' + succes = self.graph.successor( + edge_type=e_type, + nodes=nodes, ) + + self.assertIsInstance(succes, np.ndarray) + ground = [[10, 11, 12, 14, 13], [], [14]] + for succ, g in zip(succes, ground): + self.assertIsInstance(succ, np.ndarray) + self.assertCountEqual(succ, g) + + nodes = [4] + e_type = 'p2a' + succes = self.graph.successor( + edge_type=e_type, + nodes=nodes, ) + + self.assertIsInstance(succes, np.ndarray) + ground = [[10, 11, 12, 14, 13]] + for succ, g in zip(succes, ground): + self.assertIsInstance(succ, np.ndarray) + self.assertCountEqual(succ, g) + + def test_predecessor(self): + print() + nodes = [11, 12, 13] + e_type = 'p2a' + pre = self.graph.predecessor( + edge_type=e_type, + nodes=nodes, ) + + self.assertIsInstance(pre, np.ndarray) + + print(pre) + ground = [[4, 6, 7], [4, 6, 7], [4]] + for succ, g in zip(pre, ground): + self.assertIsInstance(succ, np.ndarray) + self.assertCountEqual(succ, g) + + nodes = [11] + e_type = 'p2a' + pre = self.graph.predecessor( + edge_type=e_type, + nodes=nodes, ) + print(pre) + + self.assertIsInstance(pre, np.ndarray) + ground = [[4, 6, 7]] + for p, g in zip(pre, ground): + self.assertIsInstance(p, np.ndarray) + self.assertCountEqual(p, g) + + def test_sample_nodes(self): + print() + p_ground = [4, 5, 6, 7, 8, 9] + sample_num = 10 + nodes = self.graph.sample_nodes(sample_num=sample_num, n_type='p') + + self.assertEqual(len(nodes), sample_num) + for n in nodes: + self.assertIn(n, p_ground) + + # test n_type == None + ground = [i for i in range(15)] + nodes = self.graph.sample_nodes(sample_num=sample_num, n_type=None) + self.assertEqual(len(nodes), sample_num) + for n in nodes: + self.assertIn(n, ground) + + +if __name__ == "__main__": + unittest.main() + # test_dump()