提交 141fe25b 编写于 作者: Y Yelrose

add appnp

上级 d96c9759
...@@ -18,7 +18,7 @@ import paddle.fluid as fluid ...@@ -18,7 +18,7 @@ import paddle.fluid as fluid
from pgl.utils import paddle_helper from pgl.utils import paddle_helper
from pgl import message_passing from pgl import message_passing
__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv'] __all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv', 'appnp']
def gcn(gw, feature, hidden_size, activation, name, norm=None): def gcn(gw, feature, hidden_size, activation, name, norm=None):
...@@ -404,3 +404,40 @@ def gen_conv(gw, ...@@ -404,3 +404,40 @@ def gen_conv(gw,
return output return output
def appnp(gw, feature, norm=None, alpha=0.2, k_hop=10):
"""Implementation of APPNP of "Predict then Propagate: Graph Neural Networks
meet Personalized PageRank" (ICLR 2019).
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
norm: If :code:`norm` is not None, then the feature will be normalized. Norm must
be tensor with shape (num_nodes,) and dtype float32.
k_hop: K Steps for Propagation
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
feature = src_feat["h"]
return feature
h0 = feature
for i in range(k_hop):
if norm is not None:
feature = feature * norm
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
feature = gw.recv(msg, "sum")
if norm is not None:
feature = feature * norm
feature = feature * (1 - alpha) + h0 * alpha
return feature
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册