提交 9cd4d74d 编写于 作者: W Webbley

add gin test

上级 8246974a
......@@ -180,7 +180,13 @@ def gat(gw,
return output
def gin(gw, feature, name, init_eps=0.0, train_eps=False, apply_func=None):
def gin(gw,
feature,
hidden_size,
activation,
name,
init_eps=0.0,
train_eps=False):
"""Implementation of Graph Isomorphism Network (GIN) layer.
This is an implementation of the paper How Powerful are Graph Neural Networks?
......@@ -193,19 +199,18 @@ def gin(gw, feature, name, init_eps=0.0, train_eps=False, apply_func=None):
name: GIN layer names.
hidden_size: The hidden size for gin.
activation: The activation for the output.
init_eps: float, optional
Initial :math:`\epsilon` value, default is 0.
train_eps: bool, optional
if True, :math:`\epsilon` will be a learnable parameter.
apply_func: Callable activation function or None.
Default is None. If not None, apply this function to the updated feature.
Return:
A tensor with shape (num_nodes, output_size) where ``output_size`` is the
output dimensionality of ``apply_func``. If ``apply_func`` is None, ``output_size``
should be the same as ``feature_size``.
A tensor with shape (num_nodes, hidden_size).
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
......@@ -214,8 +219,9 @@ def gin(gw, feature, name, init_eps=0.0, train_eps=False, apply_func=None):
epsilon = fluid.layers.create_parameter(
shape=[1, 1],
dtype="float32",
attr=F.ParamAttr(name="%s_eps" % name),
default_initializer=F.initializer.ConstantInitializer(value=init_eps))
attr=fluid.ParamAttr(name="%s_eps" % name),
default_initializer=fluid.initializer.ConstantInitializer(
value=init_eps))
if not train_eps:
epsilon.stop_gradient = True
......@@ -223,7 +229,17 @@ def gin(gw, feature, name, init_eps=0.0, train_eps=False, apply_func=None):
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
output = gw.recv(msg, "sum") + (1.0 + epsilon) * feature
if apply_func is not None:
output = apply_func(output, name)
output = fluid.layers.fc(output,
size=hidden_size,
bias_attr=False,
param_attr=fluid.ParamAttr(name="%s_w" % name))
bias = fluid.layers.create_parameter(
shape=[hidden_size],
dtype='float32',
is_bias=True,
attr=fluid.ParamAttr(name="%s_b" % name))
output = fluid.layers.elementwise_add(output, bias, act=activation)
return output
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file is for testing gin layer.
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from pgl.layers.conv import gin
from pgl import graph
from pgl import graph_wrapper
class GinTest(unittest.TestCase):
"""GinTest
"""
def test_gin(self):
"""test_gin
"""
np.random.seed(1)
hidden_size = 8
num_nodes = 10
edges = [(1, 4), (0, 5), (1, 9), (1, 8), (2, 8), (2, 5), (3, 6),
(3, 7), (3, 4), (3, 8)]
inver_edges = [(v, u) for u, v in edges]
edges.extend(inver_edges)
node_feat = {"feature": np.random.rand(10, 4).astype("float32")}
g = graph.Graph(num_nodes=num_nodes, edges=edges, node_feat=node_feat)
use_cuda = False
place = F.GPUPlace(0) if use_cuda else F.CPUPlace()
prog = F.Program()
startup_prog = F.Program()
with F.program_guard(prog, startup_prog):
gw = graph_wrapper.GraphWrapper(
name='graph',
place=place,
node_feat=g.node_feat_info(),
edge_feat=g.edge_feat_info())
output = gin(gw,
gw.node_feat['feature'],
hidden_size=hidden_size,
activation='relu',
name='gin',
init_eps=1,
train_eps=True)
exe = F.Executor(place)
exe.run(startup_prog)
ret = exe.run(prog, feed=gw.to_feed(g), fetch_list=[output])
self.assertEqual(ret[0].shape[0], num_nodes)
self.assertEqual(ret[0].shape[1], hidden_size)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册