提交 438b3f4c 编写于 作者: Y Yelrose

add gcnii

上级 b809f488
model_name: APPNP
k_hop: 10
alpha: 0.1
num_layer2: 1
num_layer: 1
learning_rate: 0.01
dropout: 0.5
hidden_size: 64
......
model_name: GCNII
k_hop: 64
alpha: 0.1
num_layer: 1
learning_rate: 0.01
dropout: 0.6
hidden_size: 64
weight_decay: 0.0005
edge_dropout: 0.0
......@@ -154,3 +154,42 @@ class SGC(object):
feature = L.fc(feature, self.num_class, act=None, bias_attr=False, name="output")
return feature
class GCNII(object):
"""Implement of GCNII"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.6)
self.alpha = config.get("alpha", 0.1)
self.lambda_l = config.get("lambda_l", 0.5)
self.k_hop = config.get("k_hop", 64)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
for i in range(self.num_layers):
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = conv.gcnii(graph_wrapper,
feature=feature,
name="gcnii",
activation="relu",
lambda_l=self.lambda_l,
alpha=self.alpha,
dropout=self.dropout,
k_hop=self.k_hop)
feature = L.fc(feature, self.num_class, act=None, name="output")
return feature
......@@ -19,6 +19,7 @@ import paddle.fluid as fluid
import paddle.fluid.layers as L
from pgl.utils import paddle_helper
from pgl import message_passing
import numpy as np
__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv', 'appnp']
......@@ -413,6 +414,7 @@ def get_norm(indegree):
norm = L.pow(float_degree, factor=-0.5)
return norm
def appnp(gw, feature, edge_dropout=0, alpha=0.2, k_hop=10):
"""Implementation of APPNP of "Predict then Propagate: Graph Neural Networks
meet Personalized PageRank" (ICLR 2019).
......@@ -453,3 +455,71 @@ def appnp(gw, feature, edge_dropout=0, alpha=0.2, k_hop=10):
feature = feature * (1 - alpha) + h0 * alpha
return feature
def gcnii(gw,
feature,
name,
activation=None,
alpha=0.5,
lambda_l=0.5,
k_hop=1,
dropout=0.5,
is_test=False):
"""Implementation of GCNII of "Simple and Deep Graph Convolutional Networks"
paper: https://arxiv.org/pdf/2007.02133.pdf
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
activation: The activation for the output.
k_hop: Number of layers for gcnii.
lambda_l: The hyperparameter of lambda in the paper.
alpha: The hyperparameter of alpha in the paper.
dropout: Feature dropout rate.
is_test: train / test phase.
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
feature = src_feat["h"]
return feature
h0 = feature
ngw = gw
norm = get_norm(ngw.indegree())
hidden_size = feature.shape[-1]
for i in range(k_hop):
beta_i = np.log(1.0 * lambda_l / (i + 1) + 1)
feature = L.dropout(
feature,
dropout_prob=dropout,
is_test=is_test,
dropout_implementation='upscale_in_train')
feature = feature * norm
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
feature = gw.recv(msg, "sum")
feature = feature * norm
# appnp
feature = feature * (1 - alpha) + h0 * alpha
feature_transed = L.fc(feature, hidden_size,
act=None, bias_attr=False,
name=name+"_%s_w1" % i)
feature = feature_transed * beta_i + feature * (1 - beta_i)
if activation is not None:
feature = getattr(L, activation)(feature)
return feature
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册