From 2dfbbe8236daca5d0790412dce37e90f713c1260 Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Wed, 12 Aug 2020 12:05:51 +0800 Subject: [PATCH] add graph_gather op; rename graph_pool.py as graph_op.py --- pgl/layers/__init__.py | 6 +- pgl/layers/{graph_pool.py => graph_op.py} | 43 ++++++++++-- pgl/tests/test_graph_gather.py | 83 +++++++++++++++++++++++ pgl/utils/paddle_helper.py | 6 ++ 4 files changed, 129 insertions(+), 9 deletions(-) rename pgl/layers/{graph_pool.py => graph_op.py} (61%) create mode 100644 pgl/tests/test_graph_gather.py diff --git a/pgl/layers/__init__.py b/pgl/layers/__init__.py index efc27aa..f545c0e 100644 --- a/pgl/layers/__init__.py +++ b/pgl/layers/__init__.py @@ -18,10 +18,10 @@ from pgl.layers import conv from pgl.layers.conv import * from pgl.layers import set2set from pgl.layers.set2set import * -from pgl.layers import graph_pool -from pgl.layers.graph_pool import * +from pgl.layers import graph_op +from pgl.layers.graph_op import * __all__ = [] __all__ += conv.__all__ __all__ += set2set.__all__ -__all__ += graph_pool.__all__ +__all__ += graph_op.__all__ diff --git a/pgl/layers/graph_pool.py b/pgl/layers/graph_op.py similarity index 61% rename from pgl/layers/graph_pool.py rename to pgl/layers/graph_op.py index fbe6c50..1528bfe 100644 --- a/pgl/layers/graph_pool.py +++ b/pgl/layers/graph_op.py @@ -14,12 +14,13 @@ """This package implements common layers to help building graph neural networks. """ -import paddle.fluid as fluid +import paddle.fluid as F +import paddle.fluid.layers as L from pgl import graph_wrapper from pgl.utils import paddle_helper from pgl.utils import op -__all__ = ['graph_pooling', 'graph_norm'] +__all__ = ['graph_pooling', 'graph_norm', 'graph_gather'] def graph_pooling(gw, node_feat, pool_type): @@ -38,7 +39,7 @@ def graph_pooling(gw, node_feat, pool_type): A tensor with shape (num_graph, hidden_size) """ graph_feat = op.nested_lod_reset(node_feat, gw.graph_lod) - graph_feat = fluid.layers.sequence_pool(graph_feat, pool_type) + graph_feat = L.sequence_pool(graph_feat, pool_type) return graph_feat @@ -57,11 +58,41 @@ def graph_norm(gw, feature): Return: A tensor with shape (num_nodes, hidden_size) """ - nodes = fluid.layers.fill_constant( + nodes = L.fill_constant( [gw.num_nodes, 1], dtype="float32", value=1.0) norm = graph_pooling(gw, nodes, pool_type="sum") - norm = fluid.layers.sqrt(norm) + norm = L.sqrt(norm) feature_lod = op.nested_lod_reset(feature, gw.graph_lod) - norm = fluid.layers.sequence_expand_as(norm, feature_lod) + norm = L.sequence_expand_as(norm, feature_lod) norm.stop_gradient = True return feature_lod / norm + + +def graph_gather(gw, feature, index): + """Implementation of graph gather + + Gather the corresponding index for each graph. + + Args: + gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) + + feature: A tensor with shape (num_nodes, ). + + index (int32): A tensor with K-rank where the first dim denotes the graph. + Shape (num_graph, ) or (num_graph, k1, k2, k3, ..., kn). + WARNING: We dont support negative index. + + Return: + A tensor with shape (num_graph, k1, k2, k3, ..., kn, hidden_size) + """ + shape = L.shape(index) + index = index + gw.graph_lod[:-1] + index = L.reshape(index, [-1]) + feature = L.gather(feature, index, overwrite=False) + new_shape = [] + for i in range(shape.shape[0]): + new_shape.append(shape[i]) + new_shape.append(-1) + feature = L.reshape(feature, new_shape) + return feature + diff --git a/pgl/tests/test_graph_gather.py b/pgl/tests/test_graph_gather.py new file mode 100644 index 0000000..22c311f --- /dev/null +++ b/pgl/tests/test_graph_gather.py @@ -0,0 +1,83 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + This file is for testing gin layer. +""" +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals +import unittest +import numpy as np + +import paddle.fluid as F +import paddle.fluid.layers as L + +import pgl +from pgl import graph +from pgl import graph_wrapper + + +class GraphGatherTest(unittest.TestCase): + """GraphGatherTest + """ + + def test_graph_gather(self): + """test_graph_gather + """ + np.random.seed(1) + + graph_list = [] + + num_graph = 10 + for _ in range(num_graph): + num_nodes = np.random.randint(5, 20) + edges = np.random.randint(low=0, high=num_nodes, size=(10, 2)) + node_feat = {"feature": np.random.rand(num_nodes, 4).astype("float32")} + g = graph.Graph(num_nodes=num_nodes, edges=edges, node_feat=node_feat) + graph_list.append(g) + + gg = graph.MultiGraph(graph_list) + + + use_cuda = False + place = F.CUDAPlace(0) if use_cuda else F.CPUPlace() + + prog = F.Program() + startup_prog = F.Program() + with F.program_guard(prog, startup_prog): + gw = graph_wrapper.GraphWrapper( + name='graph', + place=place, + node_feat=g.node_feat_info(), + edge_feat=g.edge_feat_info()) + + index = L.data(name="index", dtype="int32", shape=[-1]) + feats = pgl.layers.graph_gather(gw, gw.node_feat["feature"], index) + + + exe = F.Executor(place) + exe.run(startup_prog) + feed_dict = gw.to_feed(gg) + feed_dict["index"] = np.zeros(num_graph, dtype="int32") + ret = exe.run(prog, feed=feed_dict, fetch_list=[feats]) + self.assertEqual(list(ret[0].shape), [num_graph, 4]) + for i in range(num_graph): + dist = (ret[0][i] - graph_list[i].node_feat["feature"][0]) + dist = np.sum(dist ** 2) + self.assertLess(dist, 1e-15) + + +if __name__ == "__main__": + unittest.main() diff --git a/pgl/utils/paddle_helper.py b/pgl/utils/paddle_helper.py index 2dd6ea2..883e9e6 100644 --- a/pgl/utils/paddle_helper.py +++ b/pgl/utils/paddle_helper.py @@ -267,3 +267,9 @@ def masked_select(input, mask): index = fluid.layers.where(mask) return fluid.layers.gather(input, index) + +def ensure_dtype(input, dtype): + if input.dtype == dtype: + return input + else: + return fluid.layers.cast(input, dtype=dtype) -- GitLab