diff --git a/examples/citation_benchmark/model.py b/examples/citation_benchmark/model.py index dcb1f78cb0627c140bd5a2039e84ba2a3029cfac..6028e0f63a82432d588da5543996d57c90d1ffe1 100644 --- a/examples/citation_benchmark/model.py +++ b/examples/citation_benchmark/model.py @@ -72,9 +72,9 @@ class GAT(object): def forward(self, graph_wrapper, feature, phase): if phase == "train": - edge_dropout = 0 - else: edge_dropout = self.edge_dropout + else: + edge_dropout = 0 for i in range(self.num_layers): ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout) @@ -113,9 +113,9 @@ class APPNP(object): def forward(self, graph_wrapper, feature, phase): if phase == "train": - edge_dropout = 0 - else: edge_dropout = self.edge_dropout + else: + edge_dropout = 0 for i in range(self.num_layers): feature = L.dropout( @@ -169,9 +169,9 @@ class GCNII(object): def forward(self, graph_wrapper, feature, phase): if phase == "train": - edge_dropout = 0 - else: edge_dropout = self.edge_dropout + else: + edge_dropout = 0 for i in range(self.num_layers): feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i) @@ -191,5 +191,3 @@ class GCNII(object): feature = L.fc(feature, self.num_class, act=None, name="output") return feature - - diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index e91feddc69805d5c50ac4cfbf2e54df0238487cc..d9f8eafa616298487c114e4de51c682351f4a91b 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -25,7 +25,7 @@ from pgl.utils import op from pgl.utils import paddle_helper from pgl.utils.logger import log -__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"] +__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper", "BatchGraphWrapper"] def send(src, dst, nfeat, efeat, message_func): """Send message from src to dst. @@ -101,7 +101,6 @@ class BaseGraphWrapper(object): self._indegree = None self._edge_uniq_dst = None self._edge_uniq_dst_count = None - self._node_ids = None self._graph_lod = None self._num_graph = None self._num_edges = None @@ -416,13 +415,6 @@ class StaticGraphWrapper(BaseGraphWrapper): value=graph_lod) self._initializers.append(init) - node_ids_value = np.arange(0, graph.num_nodes, dtype="int64") - self._node_ids, init = paddle_helper.constant( - name=self._data_name_prefix + "/node_ids", - dtype="int64", - value=node_ids_value) - self._initializers.append(init) - self._indegree, init = paddle_helper.constant( name=self._data_name_prefix + "/indegree", dtype="int64", @@ -601,12 +593,6 @@ class GraphWrapper(BaseGraphWrapper): dtype="int32", stop_gradient=True) - self._node_ids = L.data( - self._data_name_prefix + "/node_ids", - shape=[None], - append_batch_size=False, - dtype="int64", - stop_gradient=True) self._indegree = L.data( self._data_name_prefix + "/indegree", shape=[None], @@ -619,7 +605,6 @@ class GraphWrapper(BaseGraphWrapper): self._num_nodes, self._edge_uniq_dst, self._edge_uniq_dst_count, - self._node_ids, self._indegree, self._graph_lod, self._num_graph, @@ -700,7 +685,6 @@ class GraphWrapper(BaseGraphWrapper): [graph.num_nodes], dtype="int64") feed_dict[self._data_name_prefix + '/uniq_dst'] = uniq_dst feed_dict[self._data_name_prefix + '/uniq_dst_count'] = uniq_dst_count - feed_dict[self._data_name_prefix + '/node_ids'] = graph.nodes feed_dict[self._data_name_prefix + '/indegree'] = indegree feed_dict[self._data_name_prefix + '/graph_lod'] = graph_lod feed_dict[self._data_name_prefix + '/num_graph'] = np.array( @@ -746,7 +730,6 @@ class DropEdgeWrapper(BaseGraphWrapper): self._num_nodes = graph_wrapper.num_nodes self._graph_lod = graph_wrapper.graph_lod self._num_graph = graph_wrapper.num_graph - self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32") # Dropout Edges src, dst = graph_wrapper.edges @@ -780,3 +763,96 @@ class DropEdgeWrapper(BaseGraphWrapper): self._edge_uniq_dst_count = L.concat([uniq_count, last]) self._edge_uniq_dst_count.stop_gradient=True self._indegree = get_degree(self._edges_dst, self._num_nodes) + + +class BatchGraphWrapper(BaseGraphWrapper): + """Implement a graph wrapper that user can use their own data holder. + And this graph wrapper support multiple graphs which is benefit for data parallel algorithms. + + Args: + num_nodes (int32 or int64): Shape [ num_graph ]. + + num_edges (int32 or int64): Shape [ num_graph ]. + + edges (int32 or int64): Shape [ total_num_edges_in_the_graphs, 2 ] + or Tuple with (src, dst). + + node_feats: A dictionary for node features. Each value should be tensor + with shape [ total_num_nodes_in_the_graphs, feature_size] + + edge_feats: A dictionary for edge features. Each value should be tensor + with shape [ total_num_edges_in_the_graphs, feature_size] + + """ + def __init__(self, num_nodes, num_edges, edges, node_feats=None, edge_feats=None): + super(BatchGraphWrapper, self).__init__() + + node_shift, edge_lod = self.__build_meta_data(num_nodes, num_edges) + self.__build_edges(edges, node_shift, edge_lod) + + # assign node features + if node_feats is not None: + for key, value in node_feats.items(): + self.node_feat_tensor_dict[key] = value + + # assign edge features + if edge_feats is not None: + for key, value in edge_feats.items(): + self.edge_feat_tensor_dict[key] = value + + # other meta-data + self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(self._edges_dst, dtype="int32") + self._edge_uniq_dst.stop_gradient=True + last = L.reduce_sum(uniq_count, keep_dim=True) + uniq_count = L.cumsum(uniq_count, exclusive=True) + self._edge_uniq_dst_count = L.concat([uniq_count, last]) + self._edge_uniq_dst_count.stop_gradient=True + self._indegree = get_degree(self._edges_dst, self._num_nodes) + + def __build_meta_data(self, num_nodes, num_edges): + """ Merge information for nodes and edges. + """ + num_nodes = L.reshape(num_nodes, [-1]) + num_edges = L.reshape(num_edges, [-1]) + num_nodes = paddle_helper.ensure_dtype(num_nodes, dtype="int32") + num_edges = paddle_helper.ensure_dtype(num_edges, dtype="int32") + + num_graph = L.shape(num_nodes)[0] + sum_num_nodes = L.reduce_sum(num_nodes) + sum_num_edges = L.reduce_sum(num_edges) + edge_lod = L.concat([L.cumsum(num_edges, exclusive=True), sum_num_edges]) + edge_lod = paddle_helper.lod_remove(edge_lod) + + node_shift = L.cumsum(num_nodes, exclusive=True) + graph_lod = L.concat([node_shift, sum_num_nodes]) + graph_lod = paddle_helper.lod_remove(graph_lod) + self._num_nodes = sum_num_nodes + self._num_edges = sum_num_edges + self._num_graph = num_graph + self._graph_lod = graph_lod + return node_shift, edge_lod + + + def __build_edges(self, edges, node_shift, edge_lod): + """ Merge subgraph edges. + """ + if isinstance(edges, tuple): + src, dst = edges + else: + src = edges[:, 0] + dst = edges[:, 1] + + src = L.reshape(src, [-1]) + dst = L.reshape(dst, [-1]) + src = paddle_helper.ensure_dtype(src, dtype="int32") + dst = paddle_helper.ensure_dtype(dst, dtype="int32") + # preprocess edges + lod_dst = L.lod_reset(dst, edge_lod) + node_shift = L.reshape(node_shift, [-1, 1]) + node_shift = L.sequence_expand_as(node_shift, lod_dst) + node_shift = L.reshape(node_shift, [-1]) + src = src + node_shift + dst = dst + node_shift + # sort edges + self._edges_dst, index = L.argsort(dst) + self._edges_src = L.gather(src, index, overwrite=False) diff --git a/pgl/layers/__init__.py b/pgl/layers/__init__.py index efc27aa5bda6316348c7c65d6d714de70584b1dc..f545c0e033359f75930b72a1cf70bd0574679235 100644 --- a/pgl/layers/__init__.py +++ b/pgl/layers/__init__.py @@ -18,10 +18,10 @@ from pgl.layers import conv from pgl.layers.conv import * from pgl.layers import set2set from pgl.layers.set2set import * -from pgl.layers import graph_pool -from pgl.layers.graph_pool import * +from pgl.layers import graph_op +from pgl.layers.graph_op import * __all__ = [] __all__ += conv.__all__ __all__ += set2set.__all__ -__all__ += graph_pool.__all__ +__all__ += graph_op.__all__ diff --git a/pgl/layers/graph_pool.py b/pgl/layers/graph_op.py similarity index 60% rename from pgl/layers/graph_pool.py rename to pgl/layers/graph_op.py index fbe6c500b7d3efc5a836b07982458293b4207b5d..042860bffd7bfd76288869c803b24eccf40daae0 100644 --- a/pgl/layers/graph_pool.py +++ b/pgl/layers/graph_op.py @@ -14,12 +14,13 @@ """This package implements common layers to help building graph neural networks. """ -import paddle.fluid as fluid +import paddle.fluid as F +import paddle.fluid.layers as L from pgl import graph_wrapper from pgl.utils import paddle_helper from pgl.utils import op -__all__ = ['graph_pooling', 'graph_norm'] +__all__ = ['graph_pooling', 'graph_norm', 'graph_gather'] def graph_pooling(gw, node_feat, pool_type): @@ -38,7 +39,7 @@ def graph_pooling(gw, node_feat, pool_type): A tensor with shape (num_graph, hidden_size) """ graph_feat = op.nested_lod_reset(node_feat, gw.graph_lod) - graph_feat = fluid.layers.sequence_pool(graph_feat, pool_type) + graph_feat = L.sequence_pool(graph_feat, pool_type) return graph_feat @@ -57,11 +58,42 @@ def graph_norm(gw, feature): Return: A tensor with shape (num_nodes, hidden_size) """ - nodes = fluid.layers.fill_constant( + nodes = L.fill_constant( [gw.num_nodes, 1], dtype="float32", value=1.0) norm = graph_pooling(gw, nodes, pool_type="sum") - norm = fluid.layers.sqrt(norm) + norm = L.sqrt(norm) feature_lod = op.nested_lod_reset(feature, gw.graph_lod) - norm = fluid.layers.sequence_expand_as(norm, feature_lod) + norm = L.sequence_expand_as(norm, feature_lod) norm.stop_gradient = True return feature_lod / norm + + +def graph_gather(gw, feature, index): + """Implementation of graph gather + + Gather the corresponding index for each graph. + + Args: + gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) + + feature: A tensor with shape (num_nodes, ). + + index (int32): A tensor with K-rank where the first dim denotes the graph. + Shape (num_graph, ) or (num_graph, k1, k2, k3, ..., kn). + WARNING: We dont support negative index. + + Return: + A tensor with shape (num_graph, k1, k2, k3, ..., kn, hidden_size) + """ + shape = L.shape(index) + output_dim = int(feature.shape[-1]) + index = index + gw.graph_lod[:-1] + index = L.reshape(index, [-1]) + feature = L.gather(feature, index, overwrite=False) + new_shape = [] + for i in range(shape.shape[0]): + new_shape.append(shape[i]) + new_shape.append(output_dim) + feature = L.reshape(feature, new_shape) + return feature + diff --git a/pgl/tests/test_batch_graph_wrapper.py b/pgl/tests/test_batch_graph_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f5014ded40b4a3171f6e78a6e7e96a3ed1fa2ca4 --- /dev/null +++ b/pgl/tests/test_batch_graph_wrapper.py @@ -0,0 +1,118 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + This file is for testing gin layer. +""" +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals +import unittest +import numpy as np + +import paddle.fluid as F +import paddle.fluid.layers as L + +from pgl.layers.conv import gcn +from pgl import graph +from pgl import graph_wrapper + + +class BatchedGraphWrapper(unittest.TestCase): + """BatchedGraphWrapper + """ + def test_batched_graph_wrapper(self): + """test_batch_graph_wrapper + """ + np.random.seed(1) + + graph_list = [] + + num_graph = 5 + feed_num_nodes = [] + feed_num_edges = [] + feed_edges = [] + feed_node_feats = [] + + for _ in range(num_graph): + num_nodes = np.random.randint(5, 20) + edges = np.random.randint(low=0, high=num_nodes, size=(10, 2)) + node_feat = {"feature": np.random.rand(num_nodes, 4).astype("float32")} + single_graph = graph.Graph(num_nodes=num_nodes, edges=edges, node_feat=node_feat) + feed_num_nodes.append(num_nodes) + feed_num_edges.append(len(edges)) + feed_edges.append(edges) + feed_node_feats.append(node_feat["feature"]) + graph_list.append(single_graph) + + multi_graph = graph.MultiGraph(graph_list) + + np.random.seed(1) + hidden_size = 8 + num_nodes = 10 + + place = F.CUDAPlace(0)# if use_cuda else F.CPUPlace() + prog = F.Program() + startup_prog = F.Program() + + with F.program_guard(prog, startup_prog): + with F.unique_name.guard(): + # Standard Graph Wrapper + gw = graph_wrapper.GraphWrapper( + name='graph', + place=place, + node_feat=[("feature", [-1, 4], "float32")]) + + output = gcn(gw, + gw.node_feat['feature'], + hidden_size=hidden_size, + activation='relu', + name='gcn') + + # BatchGraphWrapper + num_nodes = L.data(name="num_nodes", shape=[-1], dtype="int32") + num_edges= L.data(name="num_edges", shape=[-1], dtype="int32") + edges = L.data(name="edges", shape=[-1, 2], dtype="int32") + node_feat = L.data(name="node_feats", shape=[-1, 4], dtype="float32") + batch_gw = graph_wrapper.BatchGraphWrapper(num_nodes=num_nodes, + num_edges=num_edges, + edges=edges, + node_feats={"feature": node_feat}) + + output2 = gcn(batch_gw, + batch_gw.node_feat['feature'], + hidden_size=hidden_size, + activation='relu', + name='gcn') + + + exe = F.Executor(place) + exe.run(startup_prog) + feed_dict = gw.to_feed(multi_graph) + feed_dict["num_nodes"] = np.array(feed_num_nodes, dtype="int32") + feed_dict["num_edges"] = np.array(feed_num_edges, dtype="int32") + feed_dict["edges"] = np.array(np.concatenate(feed_edges, 0), dtype="int32").reshape([-1, 2]) + feed_dict["node_feats"] = np.array(np.concatenate(feed_node_feats, 0), dtype="float32").reshape([-1, 4]) + + # Run + O1, O2 = exe.run(prog, feed=feed_dict, fetch_list=[output, output2]) + + # The output from two kind of models should be same. + for o1, o2 in zip(O1, O2): + dist = np.sum((o1 - o2) ** 2) + self.assertLess(dist, 1e-15) + + +if __name__ == "__main__": + unittest.main() diff --git a/pgl/tests/test_graph_gather.py b/pgl/tests/test_graph_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..22c311f368ffc8649247c122decde20132d3cc75 --- /dev/null +++ b/pgl/tests/test_graph_gather.py @@ -0,0 +1,83 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + This file is for testing gin layer. +""" +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals +import unittest +import numpy as np + +import paddle.fluid as F +import paddle.fluid.layers as L + +import pgl +from pgl import graph +from pgl import graph_wrapper + + +class GraphGatherTest(unittest.TestCase): + """GraphGatherTest + """ + + def test_graph_gather(self): + """test_graph_gather + """ + np.random.seed(1) + + graph_list = [] + + num_graph = 10 + for _ in range(num_graph): + num_nodes = np.random.randint(5, 20) + edges = np.random.randint(low=0, high=num_nodes, size=(10, 2)) + node_feat = {"feature": np.random.rand(num_nodes, 4).astype("float32")} + g = graph.Graph(num_nodes=num_nodes, edges=edges, node_feat=node_feat) + graph_list.append(g) + + gg = graph.MultiGraph(graph_list) + + + use_cuda = False + place = F.CUDAPlace(0) if use_cuda else F.CPUPlace() + + prog = F.Program() + startup_prog = F.Program() + with F.program_guard(prog, startup_prog): + gw = graph_wrapper.GraphWrapper( + name='graph', + place=place, + node_feat=g.node_feat_info(), + edge_feat=g.edge_feat_info()) + + index = L.data(name="index", dtype="int32", shape=[-1]) + feats = pgl.layers.graph_gather(gw, gw.node_feat["feature"], index) + + + exe = F.Executor(place) + exe.run(startup_prog) + feed_dict = gw.to_feed(gg) + feed_dict["index"] = np.zeros(num_graph, dtype="int32") + ret = exe.run(prog, feed=feed_dict, fetch_list=[feats]) + self.assertEqual(list(ret[0].shape), [num_graph, 4]) + for i in range(num_graph): + dist = (ret[0][i] - graph_list[i].node_feat["feature"][0]) + dist = np.sum(dist ** 2) + self.assertLess(dist, 1e-15) + + +if __name__ == "__main__": + unittest.main() diff --git a/pgl/utils/paddle_helper.py b/pgl/utils/paddle_helper.py index 2dd6ea248966a734ed2cbfefd342cd90655407f4..66b5ddc63e42f5c46500ba7101875ae2cfa6756e 100644 --- a/pgl/utils/paddle_helper.py +++ b/pgl/utils/paddle_helper.py @@ -22,13 +22,14 @@ import paddle from paddle.fluid import core import paddle.fluid as fluid import paddle.fluid.layer_helper as layer_helper +import paddle.fluid.layers as L from pgl.utils.logger import log def gather(input, index): """Gather input from given index. - Slicing input data with given index. This function rewrite paddle.fluid.layers.gather + Slicing input data with given index. This function rewrite paddle.L.gather to fix issue: https://github.com/PaddlePaddle/Paddle/issues/17509 when paddlepaddle's version is less than 1.5. @@ -42,16 +43,16 @@ def gather(input, index): """ try: # PaddlePaddle 1.5 - output = fluid.layers.gather(input, index, overwrite=False) + output = L.gather(input, index, overwrite=False) return output except TypeError as e: warnings.warn("Your paddle version is less than 1.5" " gather may be slower.") if index.dtype == core.VarDesc.VarType.INT32: - index = fluid.layers.cast(index, "int64") + index = L.cast(index, "int64") if index.shape[-1] != 1: - index = fluid.layers.reshape(index, shape=[-1, 1]) + index = L.reshape(index, shape=[-1, 1]) index.stop_gradient = True helper = layer_helper.LayerHelper("gather", **locals()) #**locals()) @@ -112,7 +113,7 @@ def constant(name, value, dtype, hide_batch_size=True): raise TypeError("value should be Numpy array.") value = value.astype(dtype) - data = fluid.layers.create_global_var( + data = L.create_global_var( shape=value.shape, value=0, dtype=value.dtype, @@ -181,7 +182,7 @@ def lod_constant(name, value, lod, dtype): _lod = [0] for l in lod: _lod.append(_lod[-1] + l) - output = fluid.layers.lod_reset(data, target_lod=_lod) + output = L.lod_reset(data, target_lod=_lod) return output, data_initializer @@ -189,7 +190,7 @@ def sequence_softmax(x, beta=None): """Compute sequence softmax over paddle LodTensor This function compute softmax normalization along with the length of sequence. - This function is an extention of :code:`fluid.layers.sequence_softmax` which can only + This function is an extention of :code:`L.sequence_softmax` which can only deal with LodTensor whose last dimension is 1. Args: @@ -203,12 +204,12 @@ def sequence_softmax(x, beta=None): if beta is not None: x = x * beta - x_max = fluid.layers.sequence_pool(x, "max") - x_max = fluid.layers.sequence_expand_as(x_max, x) + x_max = L.sequence_pool(x, "max") + x_max = L.sequence_expand_as(x_max, x) x = x - x_max - exp_x = fluid.layers.exp(x) - sum_exp_x = fluid.layers.sequence_pool(exp_x, "sum") - sum_exp_x = fluid.layers.sequence_expand_as(sum_exp_x, exp_x) + exp_x = L.exp(x) + sum_exp_x = L.sequence_pool(exp_x, "sum") + sum_exp_x = L.sequence_expand_as(sum_exp_x, exp_x) return exp_x / sum_exp_x @@ -228,7 +229,7 @@ def scatter_add(input, index, updates): Same type and shape as input. """ - output = fluid.layers.scatter(input, index, updates, overwrite=False) + output = L.scatter(input, index, updates, overwrite=False) return output @@ -248,7 +249,7 @@ def scatter_max(input, index, updates): Same type and shape as input. """ - output = fluid.layers.scatter(input, index, updates, mode='max') + output = L.scatter(input, index, updates, mode='max') return output def masked_select(input, mask): @@ -264,6 +265,41 @@ def masked_select(input, mask): Return: Part of inputs where mask is True. """ - index = fluid.layers.where(mask) - return fluid.layers.gather(input, index) + index = L.where(mask) + return L.gather(input, index) + +def ensure_dtype(input, dtype): + """ensure_dtype + + If input is dtype, return input + + else cast input into dtype + + Args: + input: Input tensor + + dtype: a string of type + + Return: + If input is dtype, return input, else cast input into dtype + """ + if str(input.dtype) == dtype: + return input + else: + return L.cast(input, dtype=dtype) + +def lod_remove(input): + """Lod Remove + + Remove the lod for LodTensor and Flatten the data into 1D-Tensor. + + Args: + input: A tensor to be flattend + + Return: + A 1D input + """ + return L.reshape(L.reshape(input, [1, -1]), [-1]) + +