未验证 提交 fff91869 编写于 作者: Webbley's avatar Webbley 提交者: GitHub

Merge pull request #38 from Liwb5/develop

add gin layer
......@@ -96,7 +96,7 @@ class GNNModel(object):
loss = fluid.layers.sigmoid_cross_entropy_with_logits(pred,
self.edge_label)
loss = fluid.layers.reduce_mean(loss)
loss = fluid.layers.reduce_sum(loss)
return pred, prob, loss
......@@ -223,8 +223,10 @@ def test(exe, val_program, prob, evaluator, feed, splitted_edge):
"float32").reshape(-1, 1)
y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0]
input_dict = {
"y_true": splitted_edge["valid_edge_label"],
"y_pred": y_pred.reshape(-1, ),
"y_pred_pos":
y_pred[splitted_edge["valid_edge_label"] == 1].reshape(-1, ),
"y_pred_neg":
y_pred[splitted_edge["valid_edge_label"] == 0].reshape(-1, )
}
result["valid"] = evaluator.eval(input_dict)
......@@ -234,8 +236,10 @@ def test(exe, val_program, prob, evaluator, feed, splitted_edge):
"float32").reshape(-1, 1)
y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0]
input_dict = {
"y_true": splitted_edge["test_edge_label"],
"y_pred": y_pred.reshape(-1, ),
"y_pred_pos":
y_pred[splitted_edge["test_edge_label"] == 1].reshape(-1, ),
"y_pred_neg":
y_pred[splitted_edge["test_edge_label"] == 0].reshape(-1, )
}
result["test"] = evaluator.eval(input_dict)
return result
......
......@@ -18,7 +18,7 @@ import paddle.fluid as fluid
from pgl import graph_wrapper
from pgl.utils import paddle_helper
__all__ = ['gcn', 'gat']
__all__ = ['gcn', 'gat', 'gin']
def gcn(gw, feature, hidden_size, activation, name, norm=None):
......@@ -178,3 +178,73 @@ def gat(gw,
bias.stop_gradient = True
output = fluid.layers.elementwise_add(output, bias, act=activation)
return output
def gin(gw,
feature,
hidden_size,
activation,
name,
init_eps=0.0,
train_eps=False):
"""Implementation of Graph Isomorphism Network (GIN) layer.
This is an implementation of the paper How Powerful are Graph Neural Networks?
(https://arxiv.org/pdf/1810.00826.pdf).
In their implementation, all MLPs have 2 layers. Batch normalization is applied
on every hidden layer.
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
name: GIN layer names.
hidden_size: The hidden size for gin.
activation: The activation for the output.
init_eps: float, optional
Initial :math:`\epsilon` value, default is 0.
train_eps: bool, optional
if True, :math:`\epsilon` will be a learnable parameter.
Return:
A tensor with shape (num_nodes, hidden_size).
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
return src_feat["h"]
epsilon = fluid.layers.create_parameter(
shape=[1, 1],
dtype="float32",
attr=fluid.ParamAttr(name="%s_eps" % name),
default_initializer=fluid.initializer.ConstantInitializer(
value=init_eps))
if not train_eps:
epsilon.stop_gradient = True
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
output = gw.recv(msg, "sum") + (1.0 + epsilon) * feature
output = fluid.layers.fc(output,
size=hidden_size,
act=None,
param_attr=fluid.ParamAttr(name="%s_w_0" % name),
bias_attr=fluid.ParamAttr(name="%s_b_0" % name))
output = fluid.layers.batch_norm(output)
output = getattr(fluid.layers, activation)(output)
output = fluid.layers.fc(output,
size=hidden_size,
act=activation,
param_attr=fluid.ParamAttr(name="%s_w_1" % name),
bias_attr=fluid.ParamAttr(name="%s_b_1" % name))
return output
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file is for testing gin layer.
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from pgl.layers.conv import gin
from pgl import graph
from pgl import graph_wrapper
class GinTest(unittest.TestCase):
"""GinTest
"""
def test_gin(self):
"""test_gin
"""
np.random.seed(1)
hidden_size = 8
num_nodes = 10
edges = [(1, 4), (0, 5), (1, 9), (1, 8), (2, 8), (2, 5), (3, 6),
(3, 7), (3, 4), (3, 8)]
inver_edges = [(v, u) for u, v in edges]
edges.extend(inver_edges)
node_feat = {"feature": np.random.rand(10, 4).astype("float32")}
g = graph.Graph(num_nodes=num_nodes, edges=edges, node_feat=node_feat)
use_cuda = False
place = F.GPUPlace(0) if use_cuda else F.CPUPlace()
prog = F.Program()
startup_prog = F.Program()
with F.program_guard(prog, startup_prog):
gw = graph_wrapper.GraphWrapper(
name='graph',
place=place,
node_feat=g.node_feat_info(),
edge_feat=g.edge_feat_info())
output = gin(gw,
gw.node_feat['feature'],
hidden_size=hidden_size,
activation='relu',
name='gin',
init_eps=1,
train_eps=True)
exe = F.Executor(place)
exe.run(startup_prog)
ret = exe.run(prog, feed=gw.to_feed(g), fetch_list=[output])
self.assertEqual(ret[0].shape[0], num_nodes)
self.assertEqual(ret[0].shape[1], hidden_size)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册