未验证 提交 16752f3e 编写于 作者: H Huang Zhengjie 提交者: GitHub

Merge pull request #108 from Yelrose/master

Add APPNP; GCNII; DropEdge
# Easy Paper Reproduction for Citation Network (Cora/Pubmed/Citeseer)
This page tries to reproduce all the **Graph Neural Network** paper for Citation Network (Cora/Pubmed/Citeseer), which is the **Hello world** dataset (**small** and **fast**) for graph neural networks. But it's very hard to achieve very high performance.
All datasets are runned with public split of **semi-supervised** settings. And we report the averarge accuracy by running 10 times.
# Experiment Results
| Model | Cora | Pubmed | Citeseer | Remarks |
| ------------------------------------------------------------ | ------------ | ------------ | ------------ | --------------------------------------------------------- |
| [Vanilla GCN (Kipf 2017)](https://openreview.net/pdf?id=SJU4ayYgl ) | 0.807(0.010) | 0.794(0.003) | 0.710(0.007) | |
| [GAT (Veličković 2017)](https://arxiv.org/pdf/1710.10903.pdf) | 0.834(0.004) | 0.772(0.004) | 0.700(0.006) | |
| [SGC(Wu 2019)](https://arxiv.org/pdf/1902.07153.pdf) | 0.818(0.000) | 0.782(0.000) | 0.708(0.000) | |
| [APPNP (Johannes 2018)](https://arxiv.org/abs/1810.05997) | 0.846(0.003) | 0.803(0.002) | 0.719(0.003) | Almost the same with the results reported in Appendix E. |
| [GCNII (64 Layers, 1500 Epochs, Chen 2020)](https://arxiv.org/pdf/2007.02133.pdf) | 0.846(0.003) | 0.798(0.003) | 0.724(0.006) | |
How to run the experiments?
```shell
# Device choose
export CUDA_VISIBLE_DEVICES=0
# GCN
python train.py --conf config/gcn.yaml --use_cuda --dataset cora
python train.py --conf config/gcn.yaml --use_cuda --dataset pubmed
python train.py --conf config/gcn.yaml --use_cuda --dataset citeseer
# GAT
python train.py --conf config/gat.yaml --use_cuda --dataset cora
python train.py --conf config/gat.yaml --use_cuda --dataset pubmed
python train.py --conf config/gat.yaml --use_cuda --dataset citeseer
# SGC (Slow version)
python train.py --conf config/sgc.yaml --use_cuda --dataset cora
python train.py --conf config/sgc.yaml --use_cuda --dataset pubmed
python train.py --conf config/sgc.yaml --use_cuda --dataset citeseer
# APPNP
python train.py --conf config/appnp.yaml --use_cuda --dataset cora
python train.py --conf config/appnp.yaml --use_cuda --dataset pubmed
python train.py --conf config/appnp.yaml --use_cuda --dataset citeseer
# GCNII (The original code use 1500 epochs.)
python train.py --conf config/gcnii.yaml --use_cuda --dataset cora --epoch 1500
python train.py --conf config/gcnii.yaml --use_cuda --dataset pubmed --epoch 1500
python train.py --conf config/gcnii.yaml --use_cuda --dataset citeseer --epoch 1500
```
import pgl
import model
from pgl import data_loader
import paddle.fluid as fluid
import numpy as np
import time
def build_model(dataset, config, phase, main_prog):
gw = pgl.graph_wrapper.GraphWrapper(
name="graph",
node_feat=dataset.graph.node_feat_info())
GraphModel = getattr(model, config.model_name)
m = GraphModel(config=config, num_class=dataset.num_classes)
logits = m.forward(gw, gw.node_feat["words"], phase)
# Take the last
node_index = fluid.layers.data(
"node_index",
shape=[None, 1],
dtype="int64",
append_batch_size=False)
node_label = fluid.layers.data(
"node_label",
shape=[None, 1],
dtype="int64",
append_batch_size=False)
pred = fluid.layers.gather(logits, node_index)
loss, pred = fluid.layers.softmax_with_cross_entropy(
logits=pred, label=node_label, return_softmax=True)
acc = fluid.layers.accuracy(input=pred, label=node_label, k=1)
loss = fluid.layers.mean(loss)
if phase == "train":
adam = fluid.optimizer.Adam(
learning_rate=config.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=config.weight_decay))
adam.minimize(loss)
return gw, loss, acc
model_name: APPNP
k_hop: 10
alpha: 0.1
num_layer: 1
learning_rate: 0.01
dropout: 0.5
hidden_size: 64
weight_decay: 0.0005
edge_dropout: 0.0
model_name: GAT
learning_rate: 0.005
weight_decay: 0.0005
num_layers: 1
feat_drop: 0.6
attn_drop: 0.6
num_heads: 8
hidden_size: 8
edge_dropout: 0.0
model_name: GCN
num_layers: 1
dropout: 0.5
hidden_size: 16
learning_rate: 0.01
weight_decay: 0.0005
edge_dropout: 0.0
model_name: GCNII
k_hop: 64
alpha: 0.1
num_layer: 1
learning_rate: 0.01
dropout: 0.6
hidden_size: 64
weight_decay: 0.0005
edge_dropout: 0.0
model_name: SGC
num_layers: 2
learning_rate: 0.2
weight_decay: 0.000005
feature_pre_normalize: False
import pgl
import paddle.fluid.layers as L
import pgl.layers.conv as conv
def get_norm(indegree):
float_degree = L.cast(indegree, dtype="float32")
float_degree = L.clamp(float_degree, min=1.0)
norm = L.pow(float_degree, factor=-0.5)
return norm
class GCN(object):
"""Implement of GCN
"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.5)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
for i in range(self.num_layers):
if phase == "train":
ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout)
norm = get_norm(ngw.indegree())
else:
ngw = graph_wrapper
norm = graph_wrapper.node_feat["norm"]
feature = pgl.layers.gcn(ngw,
feature,
self.hidden_size,
activation="relu",
norm=norm,
name="layer_%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
if phase == "train":
ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout)
norm = get_norm(ngw.indegree())
else:
ngw = graph_wrapper
norm = graph_wrapper.node_feat["norm"]
feature = conv.gcn(ngw,
feature,
self.num_class,
activation=None,
norm=norm,
name="output")
return feature
class GAT(object):
"""Implement of GAT"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.num_heads = config.get("num_heads", 8)
self.hidden_size = config.get("hidden_size", 8)
self.feat_dropout = config.get("feat_drop", 0.6)
self.attn_dropout = config.get("attn_drop", 0.6)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
for i in range(self.num_layers):
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
feature = conv.gat(ngw,
feature,
self.hidden_size,
activation="elu",
name="gat_layer_%s" % i,
num_heads=self.num_heads,
feat_drop=self.feat_dropout,
attn_drop=self.attn_dropout)
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
feature = conv.gat(ngw,
feature,
self.num_class,
num_heads=1,
activation=None,
feat_drop=self.feat_dropout,
attn_drop=self.attn_dropout,
name="output")
return feature
class APPNP(object):
"""Implement of APPNP"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.5)
self.alpha = config.get("alpha", 0.1)
self.k_hop = config.get("k_hop", 10)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
for i in range(self.num_layers):
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = L.fc(feature, self.num_class, act=None, name="output")
feature = conv.appnp(graph_wrapper,
feature=feature,
edge_dropout=edge_dropout,
alpha=self.alpha,
k_hop=self.k_hop)
return feature
class SGC(object):
"""Implement of SGC"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
def forward(self, graph_wrapper, feature, phase):
feature = conv.appnp(graph_wrapper,
feature=feature,
edge_dropout=0,
alpha=0,
k_hop=self.num_layers)
feature.stop_gradient=True
feature = L.fc(feature, self.num_class, act=None, bias_attr=False, name="output")
return feature
class GCNII(object):
"""Implement of GCNII"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.6)
self.alpha = config.get("alpha", 0.1)
self.lambda_l = config.get("lambda_l", 0.5)
self.k_hop = config.get("k_hop", 64)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
for i in range(self.num_layers):
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = conv.gcnii(graph_wrapper,
feature=feature,
name="gcnii",
activation="relu",
lambda_l=self.lambda_l,
alpha=self.alpha,
dropout=self.dropout,
k_hop=self.k_hop)
feature = L.fc(feature, self.num_class, act=None, name="output")
return feature
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pgl
import model# import LabelGraphGCN
from pgl import data_loader
from pgl.utils.logger import log
import paddle.fluid as fluid
import numpy as np
import time
import argparse
from build_model import build_model
import yaml
from easydict import EasyDict as edict
import tqdm
def normalize(feat):
return feat / np.maximum(np.sum(feat, -1, keepdims=True), 1)
def load(name, normalized_feature=True):
if name == 'cora':
dataset = data_loader.CoraDataset()
elif name == "pubmed":
dataset = data_loader.CitationDataset("pubmed", symmetry_edges=True)
elif name == "citeseer":
dataset = data_loader.CitationDataset("citeseer", symmetry_edges=True)
else:
raise ValueError(name + " dataset doesn't exists")
indegree = dataset.graph.indegree()
norm = np.maximum(indegree.astype("float32"), 1)
norm = np.power(norm, -0.5)
dataset.graph.node_feat["norm"] = np.expand_dims(norm, -1)
dataset.graph.node_feat["words"] = normalize(dataset.graph.node_feat["words"])
return dataset
def main(args, config):
dataset = load(args.dataset, args.feature_pre_normalize)
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
train_program = fluid.default_main_program()
startup_program = fluid.default_startup_program()
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
gw, loss, acc = build_model(dataset,
config=config,
phase="train",
main_prog=train_program)
test_program = fluid.Program()
with fluid.program_guard(test_program, startup_program):
with fluid.unique_name.guard():
_gw, v_loss, v_acc = build_model(dataset,
config=config,
phase="test",
main_prog=test_program)
test_program = test_program.clone(for_test=True)
exe = fluid.Executor(place)
train_index = dataset.train_index
train_label = np.expand_dims(dataset.y[train_index], -1)
train_index = np.expand_dims(train_index, -1)
val_index = dataset.val_index
val_label = np.expand_dims(dataset.y[val_index], -1)
val_index = np.expand_dims(val_index, -1)
test_index = dataset.test_index
test_label = np.expand_dims(dataset.y[test_index], -1)
test_index = np.expand_dims(test_index, -1)
dur = []
# Feed data
feed_dict = gw.to_feed(dataset.graph)
best_test = []
for run in range(args.runs):
exe.run(startup_program)
cal_val_acc = []
cal_test_acc = []
cal_val_loss = []
cal_test_loss = []
for epoch in tqdm.tqdm(range(args.epoch)):
feed_dict["node_index"] = np.array(train_index, dtype="int64")
feed_dict["node_label"] = np.array(train_label, dtype="int64")
train_loss, train_acc = exe.run(train_program,
feed=feed_dict,
fetch_list=[loss, acc],
return_numpy=True)
feed_dict["node_index"] = np.array(val_index, dtype="int64")
feed_dict["node_label"] = np.array(val_label, dtype="int64")
val_loss, val_acc = exe.run(test_program,
feed=feed_dict,
fetch_list=[v_loss, v_acc],
return_numpy=True)
cal_val_acc.append(val_acc[0])
cal_val_loss.append(val_loss[0])
feed_dict["node_index"] = np.array(test_index, dtype="int64")
feed_dict["node_label"] = np.array(test_label, dtype="int64")
test_loss, test_acc = exe.run(test_program,
feed=feed_dict,
fetch_list=[v_loss, v_acc],
return_numpy=True)
cal_test_acc.append(test_acc[0])
cal_test_loss.append(test_loss[0])
log.info("Runs %s: Model: %s Best Test Accuracy: %f" % (run, config.model_name,
cal_test_acc[np.argmin(cal_val_loss)]))
best_test.append(cal_test_acc[np.argmin(cal_val_loss)])
log.info("Dataset: %s Best Test Accuracy: %f ( stddev: %f )" % (args.dataset, np.mean(best_test), np.std(best_test)))
print("Dataset: %s Best Test Accuracy: %f ( stddev: %f )" % (args.dataset, np.mean(best_test), np.std(best_test)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Benchmarking Citation Network')
parser.add_argument(
"--dataset", type=str, default="cora", help="dataset (cora, pubmed)")
parser.add_argument("--use_cuda", action='store_true', help="use_cuda")
parser.add_argument("--conf", type=str, help="config file for models")
parser.add_argument("--epoch", type=int, default=200, help="Epoch")
parser.add_argument("--runs", type=int, default=10, help="runs")
parser.add_argument("--feature_pre_normalize", type=bool, default=True, help="pre_normalize feature")
args = parser.parse_args()
config = edict(yaml.load(open(args.conf), Loader=yaml.FullLoader))
log.info(args)
main(args, config)
......@@ -22,3 +22,4 @@ from pgl import heter_graph
from pgl import heter_graph_wrapper
from pgl import contrib
from pgl import message_passing
from pgl import sample
......@@ -19,6 +19,7 @@ for PaddlePaddle.
import warnings
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as L
from pgl.utils import op
from pgl.utils import paddle_helper
......@@ -47,10 +48,10 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
try:
out_dim = msg.shape[-1]
init_output = fluid.layers.fill_constant(
init_output = L.fill_constant(
shape=[num_nodes, out_dim], value=0, dtype=msg.dtype)
init_output.stop_gradient = False
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=msg.dtype)
empty_msg_flag = L.cast(num_edges > 0, dtype=msg.dtype)
msg = msg * empty_msg_flag
output = paddle_helper.scatter_add(init_output, dst, msg)
return output
......@@ -59,7 +60,7 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
"scatter_add is not supported with paddle version <= 1.5")
def sum_func(message):
return fluid.layers.sequence_pool(message, "sum")
return L.sequence_pool(message, "sum")
reduce_function = sum_func
......@@ -67,13 +68,13 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
output = reduce_function(bucketed_msg)
output_dim = output.shape[-1]
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=output.dtype)
empty_msg_flag = L.cast(num_edges > 0, dtype=output.dtype)
output = output * empty_msg_flag
init_output = fluid.layers.fill_constant(
init_output = L.fill_constant(
shape=[num_nodes, output_dim], value=0, dtype=output.dtype)
init_output.stop_gradient = True
final_output = fluid.layers.scatter(init_output, uniq_dst, output)
final_output = L.scatter(init_output, uniq_dst, output)
return final_output
......@@ -104,6 +105,7 @@ class BaseGraphWrapper(object):
self._node_ids = None
self._graph_lod = None
self._num_graph = None
self._num_edges = None
self._data_name_prefix = ""
def __repr__(self):
......@@ -470,7 +472,7 @@ class StaticGraphWrapper(BaseGraphWrapper):
class GraphWrapper(BaseGraphWrapper):
"""Implement a graph wrapper that creates a graph data holders
that attributes and features in the graph are :code:`fluid.layers.data`.
that attributes and features in the graph are :code:`L.data`.
And we provide interface :code:`to_feed` to help converting :code:`Graph`
data into :code:`feed_dict`.
......@@ -546,65 +548,65 @@ class GraphWrapper(BaseGraphWrapper):
def __create_graph_attr_holders(self):
"""Create data holders for graph attributes.
"""
self._num_edges = fluid.layers.data(
self._num_edges = L.data(
self._data_name_prefix + '/num_edges',
shape=[1],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._num_graph = fluid.layers.data(
self._num_graph = L.data(
self._data_name_prefix + '/num_graph',
shape=[1],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._edges_src = fluid.layers.data(
self._edges_src = L.data(
self._data_name_prefix + '/edges_src',
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._edges_dst = fluid.layers.data(
self._edges_dst = L.data(
self._data_name_prefix + '/edges_dst',
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._num_nodes = fluid.layers.data(
self._num_nodes = L.data(
self._data_name_prefix + '/num_nodes',
shape=[1],
append_batch_size=False,
dtype='int64',
stop_gradient=True)
self._edge_uniq_dst = fluid.layers.data(
self._edge_uniq_dst = L.data(
self._data_name_prefix + "/uniq_dst",
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._graph_lod = fluid.layers.data(
self._graph_lod = L.data(
self._data_name_prefix + "/graph_lod",
shape=[None],
append_batch_size=False,
dtype="int32",
stop_gradient=True)
self._edge_uniq_dst_count = fluid.layers.data(
self._edge_uniq_dst_count = L.data(
self._data_name_prefix + "/uniq_dst_count",
shape=[None],
append_batch_size=False,
dtype="int32",
stop_gradient=True)
self._node_ids = fluid.layers.data(
self._node_ids = L.data(
self._data_name_prefix + "/node_ids",
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._indegree = fluid.layers.data(
self._indegree = L.data(
self._data_name_prefix + "/indegree",
shape=[None],
append_batch_size=False,
......@@ -627,7 +629,7 @@ class GraphWrapper(BaseGraphWrapper):
node_feat_dtype):
"""Create data holders for node features.
"""
feat_holder = fluid.layers.data(
feat_holder = L.data(
self._data_name_prefix + '/node_feat/' + node_feat_name,
shape=node_feat_shape,
append_batch_size=False,
......@@ -640,7 +642,7 @@ class GraphWrapper(BaseGraphWrapper):
edge_feat_dtype):
"""Create edge holders for edge features.
"""
feat_holder = fluid.layers.data(
feat_holder = L.data(
self._data_name_prefix + '/edge_feat/' + edge_feat_name,
shape=edge_feat_shape,
append_batch_size=False,
......@@ -719,3 +721,61 @@ class GraphWrapper(BaseGraphWrapper):
"""Return the holder list.
"""
return self._holder_list
def get_degree(edge, num_nodes):
init_output = L.fill_constant(
shape=[num_nodes], value=0, dtype="float32")
init_output.stop_gradient = True
final_output = L.scatter(init_output,
edge,
L.full_like(edge, 1, dtype="float32"),
overwrite=False)
return final_output
class DropEdgeWrapper(BaseGraphWrapper):
"""Implement of Edge Drop """
def __init__(self, graph_wrapper, dropout, keep_self_loop=True):
super(DropEdgeWrapper, self).__init__()
# Copy Node's information
for key, value in graph_wrapper.node_feat.items():
self.node_feat_tensor_dict[key] = value
self._num_nodes = graph_wrapper.num_nodes
self._graph_lod = graph_wrapper.graph_lod
self._num_graph = graph_wrapper.num_graph
self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32")
# Dropout Edges
src, dst = graph_wrapper.edges
u = L.uniform_random(shape=L.cast(L.shape(src), 'int64'), min=0., max=1.)
# Avoid Empty Edges
keeped = L.cast(u > dropout, dtype="float32")
self._num_edges = L.reduce_sum(L.cast(keeped, "int32"))
keeped = keeped + L.cast(self._num_edges == 0, dtype="float32")
if keep_self_loop:
self_loop = L.cast(src == dst, dtype="float32")
keeped = keeped + self_loop
keeped = (keeped > 0.5)
src = paddle_helper.masked_select(src, keeped)
dst = paddle_helper.masked_select(dst, keeped)
src.stop_gradient=True
dst.stop_gradient=True
self._edges_src = src
self._edges_dst = dst
for key, value in graph_wrapper.edge_feat.items():
self.edge_feat_tensor_dict[key] = paddle_helper.masked_select(value, keeped)
self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(dst, dtype="int32")
self._edge_uniq_dst.stop_gradient=True
last = L.reduce_sum(uniq_count, keep_dim=True)
uniq_count = L.cumsum(uniq_count, exclusive=True)
self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True
self._indegree = get_degree(self._edges_dst, self._num_nodes)
......@@ -14,11 +14,14 @@
"""This package implements common layers to help building
graph neural networks.
"""
import pgl
import paddle.fluid as fluid
import paddle.fluid.layers as L
from pgl.utils import paddle_helper
from pgl import message_passing
import numpy as np
__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv']
__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv', 'appnp', 'gcnii']
def gcn(gw, feature, hidden_size, activation, name, norm=None):
......@@ -50,7 +53,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
size = feature.shape[-1]
if size > hidden_size:
feature = fluid.layers.fc(feature,
feature = L.fc(feature,
size=hidden_size,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name))
......@@ -64,7 +67,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
output = gw.recv(msg, "sum")
else:
output = gw.recv(msg, "sum")
output = fluid.layers.fc(output,
output = L.fc(output,
size=hidden_size,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name))
......@@ -72,12 +75,12 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
if norm is not None:
output = output * norm
bias = fluid.layers.create_parameter(
bias = L.create_parameter(
shape=[hidden_size],
dtype='float32',
is_bias=True,
name=name + '_bias')
output = fluid.layers.elementwise_add(output, bias, act=activation)
output = L.elementwise_add(output, bias, act=activation)
return output
......@@ -120,7 +123,7 @@ def gat(gw,
def send_attention(src_feat, dst_feat, edge_feat):
output = src_feat["left_a"] + dst_feat["right_a"]
output = fluid.layers.leaky_relu(
output = L.leaky_relu(
output, alpha=0.2) # (num_edges, num_heads)
return {"alpha": output, "h": src_feat["h"]}
......@@ -129,54 +132,54 @@ def gat(gw,
h = msg["h"]
alpha = paddle_helper.sequence_softmax(alpha)
old_h = h
h = fluid.layers.reshape(h, [-1, num_heads, hidden_size])
alpha = fluid.layers.reshape(alpha, [-1, num_heads, 1])
h = L.reshape(h, [-1, num_heads, hidden_size])
alpha = L.reshape(alpha, [-1, num_heads, 1])
if attn_drop > 1e-15:
alpha = fluid.layers.dropout(
alpha = L.dropout(
alpha,
dropout_prob=attn_drop,
is_test=is_test,
dropout_implementation="upscale_in_train")
h = h * alpha
h = fluid.layers.reshape(h, [-1, num_heads * hidden_size])
h = fluid.layers.lod_reset(h, old_h)
return fluid.layers.sequence_pool(h, "sum")
h = L.reshape(h, [-1, num_heads * hidden_size])
h = L.lod_reset(h, old_h)
return L.sequence_pool(h, "sum")
if feat_drop > 1e-15:
feature = fluid.layers.dropout(
feature = L.dropout(
feature,
dropout_prob=feat_drop,
is_test=is_test,
dropout_implementation='upscale_in_train')
ft = fluid.layers.fc(feature,
ft = L.fc(feature,
hidden_size * num_heads,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_weight'))
left_a = fluid.layers.create_parameter(
left_a = L.create_parameter(
shape=[num_heads, hidden_size],
dtype='float32',
name=name + '_gat_l_A')
right_a = fluid.layers.create_parameter(
right_a = L.create_parameter(
shape=[num_heads, hidden_size],
dtype='float32',
name=name + '_gat_r_A')
reshape_ft = fluid.layers.reshape(ft, [-1, num_heads, hidden_size])
left_a_value = fluid.layers.reduce_sum(reshape_ft * left_a, -1)
right_a_value = fluid.layers.reduce_sum(reshape_ft * right_a, -1)
reshape_ft = L.reshape(ft, [-1, num_heads, hidden_size])
left_a_value = L.reduce_sum(reshape_ft * left_a, -1)
right_a_value = L.reduce_sum(reshape_ft * right_a, -1)
msg = gw.send(
send_attention,
nfeat_list=[("h", ft), ("left_a", left_a_value),
("right_a", right_a_value)])
output = gw.recv(msg, reduce_attention)
bias = fluid.layers.create_parameter(
bias = L.create_parameter(
shape=[hidden_size * num_heads],
dtype='float32',
is_bias=True,
name=name + '_bias')
bias.stop_gradient = True
output = fluid.layers.elementwise_add(output, bias, act=activation)
output = L.elementwise_add(output, bias, act=activation)
return output
......@@ -219,7 +222,7 @@ def gin(gw,
def send_src_copy(src_feat, dst_feat, edge_feat):
return src_feat["h"]
epsilon = fluid.layers.create_parameter(
epsilon = L.create_parameter(
shape=[1, 1],
dtype="float32",
attr=fluid.ParamAttr(name="%s_eps" % name),
......@@ -232,13 +235,13 @@ def gin(gw,
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
output = gw.recv(msg, "sum") + feature * (epsilon + 1.0)
output = fluid.layers.fc(output,
output = L.fc(output,
size=hidden_size,
act=None,
param_attr=fluid.ParamAttr(name="%s_w_0" % name),
bias_attr=fluid.ParamAttr(name="%s_b_0" % name))
output = fluid.layers.layer_norm(
output = L.layer_norm(
output,
begin_norm_axis=1,
param_attr=fluid.ParamAttr(
......@@ -249,9 +252,9 @@ def gin(gw,
initializer=fluid.initializer.Constant(0.0)), )
if activation is not None:
output = getattr(fluid.layers, activation)(output)
output = getattr(L, activation)(output)
output = fluid.layers.fc(output,
output = L.fc(output,
size=hidden_size,
act=activation,
param_attr=fluid.ParamAttr(name="%s_w_1" % name),
......@@ -269,10 +272,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
# E * M * D1
old = feat_query
feat_query = fluid.layers.reshape(feat_query, [-1, heads, hidden_size_a])
feat_key = fluid.layers.reshape(feat_key, [-1, heads, hidden_size_a])
feat_query = L.reshape(feat_query, [-1, heads, hidden_size_a])
feat_key = L.reshape(feat_key, [-1, heads, hidden_size_a])
# E * M
alpha = fluid.layers.reduce_sum(feat_key * feat_query, dim=-1)
alpha = L.reduce_sum(feat_key * feat_query, dim=-1)
return {'dst_node_feat': dst_feat['node_feat'],
'src_node_feat': src_feat['node_feat'],
......@@ -286,15 +289,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 每条边的出发点的特征
src_feat = message['src_node_feat']
# 每个中心点自己的特征
x = fluid.layers.sequence_pool(dst_feat, 'average')
x = L.sequence_pool(dst_feat, 'average')
# 每个中心点的邻居的特征的平均值
z = fluid.layers.sequence_pool(src_feat, 'average')
z = L.sequence_pool(src_feat, 'average')
# 计算 gate
feat_gate = message['feat_gate']
g_max = fluid.layers.sequence_pool(feat_gate, 'max')
g = fluid.layers.concat([x, g_max, z], axis=1)
g = fluid.layers.fc(g, heads, bias_attr=False, act="sigmoid")
g_max = L.sequence_pool(feat_gate, 'max')
g = L.concat([x, g_max, z], axis=1)
g = L.fc(g, heads, bias_attr=False, act="sigmoid")
# softmax
alpha = message['alpha']
......@@ -302,19 +305,19 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_value = message['feat_value'] # E * (M * D2)
old = feat_value
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2
feat_value = fluid.layers.elementwise_mul(feat_value, alpha, axis=0)
feat_value = fluid.layers.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2)
feat_value = fluid.layers.lod_reset(feat_value, old)
feat_value = L.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2
feat_value = L.elementwise_mul(feat_value, alpha, axis=0)
feat_value = L.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2)
feat_value = L.lod_reset(feat_value, old)
feat_value = fluid.layers.sequence_pool(feat_value, 'sum') # N * (M * D2)
feat_value = L.sequence_pool(feat_value, 'sum') # N * (M * D2)
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2
feat_value = L.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2
output = fluid.layers.elementwise_mul(feat_value, g, axis=0)
output = fluid.layers.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2)
output = L.elementwise_mul(feat_value, g, axis=0)
output = L.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2)
output = fluid.layers.concat([x, output], axis=1)
output = L.concat([x, output], axis=1)
return output
......@@ -323,16 +326,16 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M)
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
feat_key = L.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key'))
# N * (D2 * M)
feat_value = fluid.layers.fc(feature, hidden_size_v * heads, bias_attr=False,
feat_value = L.fc(feature, hidden_size_v * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_value'))
# N * (D1 * M)
feat_query = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
feat_query = L.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_query'))
# N * Dm
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
feat_gate = L.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send 阶段
......@@ -346,10 +349,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 聚合邻居特征
output = gw.recv(message, recv_func)
output = fluid.layers.fc(output, hidden_size_o, bias_attr=False,
output = L.fc(output, hidden_size_o, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_output'))
output = fluid.layers.leaky_relu(output, alpha=0.1)
output = fluid.layers.dropout(output, dropout_prob=0.1)
output = L.leaky_relu(output, alpha=0.1)
output = L.dropout(output, dropout_prob=0.1)
return output
......@@ -376,7 +379,7 @@ def gen_conv(gw,
"""
if beta == "dynamic":
beta = fluid.layers.create_parameter(
beta = L.create_parameter(
shape=[1],
dtype='float32',
default_initializer=
......@@ -391,16 +394,132 @@ def gen_conv(gw,
output = message_passing.msg_norm(feature, output, name)
output = feature + output
output = fluid.layers.fc(output,
output = L.fc(output,
feature.shape[-1],
bias_attr=False,
act="relu",
param_attr=fluid.ParamAttr(name=name + '_weight1'))
output = fluid.layers.fc(output,
output = L.fc(output,
feature.shape[-1],
bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_weight2'))
return output
def get_norm(indegree):
"""Get Laplacian Normalization"""
float_degree = L.cast(indegree, dtype="float32")
float_degree = L.clamp(float_degree, min=1.0)
norm = L.pow(float_degree, factor=-0.5)
return norm
def appnp(gw, feature, edge_dropout=0, alpha=0.2, k_hop=10):
"""Implementation of APPNP of "Predict then Propagate: Graph Neural Networks
meet Personalized PageRank" (ICLR 2019).
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
edge_dropout: Edge dropout rate.
k_hop: K Steps for Propagation
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
feature = src_feat["h"]
return feature
h0 = feature
ngw = gw
norm = get_norm(ngw.indegree())
for i in range(k_hop):
if edge_dropout > 1e-5:
ngw = pgl.sample.edge_drop(gw, edge_dropout)
norm = get_norm(ngw.indegree())
feature = feature * norm
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
feature = gw.recv(msg, "sum")
feature = feature * norm
feature = feature * (1 - alpha) + h0 * alpha
return feature
def gcnii(gw,
feature,
name,
activation=None,
alpha=0.5,
lambda_l=0.5,
k_hop=1,
dropout=0.5,
is_test=False):
"""Implementation of GCNII of "Simple and Deep Graph Convolutional Networks"
paper: https://arxiv.org/pdf/2007.02133.pdf
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
activation: The activation for the output.
k_hop: Number of layers for gcnii.
lambda_l: The hyperparameter of lambda in the paper.
alpha: The hyperparameter of alpha in the paper.
dropout: Feature dropout rate.
is_test: train / test phase.
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
feature = src_feat["h"]
return feature
h0 = feature
ngw = gw
norm = get_norm(ngw.indegree())
hidden_size = feature.shape[-1]
for i in range(k_hop):
beta_i = np.log(1.0 * lambda_l / (i + 1) + 1)
feature = L.dropout(
feature,
dropout_prob=dropout,
is_test=is_test,
dropout_implementation='upscale_in_train')
feature = feature * norm
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
feature = gw.recv(msg, "sum")
feature = feature * norm
# appnp
feature = feature * (1 - alpha) + h0 * alpha
feature_transed = L.fc(feature, hidden_size,
act=None, bias_attr=False,
name=name+"_%s_w1" % i)
feature = feature_transed * beta_i + feature * (1 - beta_i)
if activation is not None:
feature = getattr(L, activation)(feature)
return feature
......@@ -516,3 +516,12 @@ def graph_saint_random_walk_sample(graph,
nodes=sample_nodes, eid=eids, with_node_feat=True, with_edge_feat=True)
subgraph.node_feat["index"] = np.array(sample_nodes, dtype="int64")
return subgraph
def edge_drop(graph_wrapper, dropout_rate, keep_self_loop=True):
if dropout_rate < 1e-5:
return graph_wrapper
else:
return pgl.graph_wrapper.DropEdgeWrapper(graph_wrapper,
dropout_rate,
keep_self_loop)
......@@ -250,3 +250,20 @@ def scatter_max(input, index, updates):
output = fluid.layers.scatter(input, index, updates, mode='max')
return output
def masked_select(input, mask):
"""masked_select
Slice the value from given Mask
Args:
input: Input tensor to be selected
mask: A bool tensor for sliced.
Return:
Part of inputs where mask is True.
"""
index = fluid.layers.where(mask)
return fluid.layers.gather(input, index)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册