From 5b76e78f84af70e34e817005e5c805b3db9270c4 Mon Sep 17 00:00:00 2001 From: fengshikun01 Date: Tue, 23 Jun 2020 20:19:03 +0800 Subject: [PATCH] deepergcn --- examples/deeper_gcn/README.md | 59 ++++++++++ examples/deeper_gcn/model.py | 89 +++++++++++++++ examples/deeper_gcn/train.py | 155 ++++++++++++++++++++++++++ pgl/__init__.py | 1 + pgl/layers/conv.py | 56 +++++++++- pgl/message_passing.py | 203 ++++++++++++++++++++++++++++++++++ pgl/utils/paddle_helper.py | 7 +- 7 files changed, 567 insertions(+), 3 deletions(-) create mode 100644 examples/deeper_gcn/README.md create mode 100644 examples/deeper_gcn/model.py create mode 100644 examples/deeper_gcn/train.py create mode 100644 pgl/message_passing.py diff --git a/examples/deeper_gcn/README.md b/examples/deeper_gcn/README.md new file mode 100644 index 0000000..4baefbb --- /dev/null +++ b/examples/deeper_gcn/README.md @@ -0,0 +1,59 @@ +# GAT: Graph Attention Networks + +[Graph Attention Networks \(GAT\)](https://arxiv.org/abs/1710.10903) is a novel architectures that operate on graph-structured data, which leverages masked self-attentional layers to address the shortcomings of prior methods based on graph convolutions or their approximations. Based on PGL, we reproduce GAT algorithms and reach the same level of indicators as the paper in citation network benchmarks. +### Simple example to build single head GAT + +To build a gat layer, one can use our pre-defined ```pgl.layers.gat``` or just write a gat layer with message passing interface. +```python +import paddle.fluid as fluid +def gat_layer(graph_wrapper, node_feature, hidden_size): + def send_func(src_feat, dst_feat, edge_feat): + logits = src_feat["a1"] + dst_feat["a2"] + logits = fluid.layers.leaky_relu(logits, alpha=0.2) + return {"logits": logits, "h": src_feat } + + def recv_func(msg): + norm = fluid.layers.sequence_softmax(msg["logits"]) + output = msg["h"] * norm + return output + + h = fluid.layers.fc(node_feature, hidden_size, bias_attr=False, name="hidden") + a1 = fluid.layers.fc(node_feature, 1, name="a1_weight") + a2 = fluid.layers.fc(node_feature, 1, name="a2_weight") + message = graph_wrapper.send(send_func, + nfeat_list=[("h", h), ("a1", a1), ("a2", a2)]) + output = graph_wrapper.recv(recv_func, message) + return output +``` + + +### Datasets + +The datasets contain three citation networks: CORA, PUBMED, CITESEER. The details for these three datasets can be found in the [paper](https://arxiv.org/abs/1609.02907). + +### Dependencies + +- paddlepaddle>=1.6 +- pgl + +### Performance + +We train our models for 200 epochs and report the accuracy on the test dataset. + +| Dataset | Accuracy | +| --- | --- | +| Cora | ~83% | +| Pubmed | ~78% | +| Citeseer | ~70% | + +### How to run + +For examples, use gpu to train gat on cora dataset. +``` +python train.py --dataset cora --use_cuda +``` + +#### Hyperparameters + +- dataset: The citation dataset "cora", "citeseer", "pubmed". +- use_cuda: Use gpu if assign use_cuda. diff --git a/examples/deeper_gcn/model.py b/examples/deeper_gcn/model.py new file mode 100644 index 0000000..f5f4126 --- /dev/null +++ b/examples/deeper_gcn/model.py @@ -0,0 +1,89 @@ +import pgl +import paddle.fluid as fluid + +def DeeperGCN(gw, feature, num_layers, + hidden_size, num_tasks, name, dropout_prob): + """Implementation of DeeperGCN, see the paper + "DeeperGCN: All You Need to Train Deeper GCNs" in + https://arxiv.org/pdf/2006.07739.pdf + + Args: + gw: Graph wrapper object + + feature: A tensor with shape (num_nodes, feature_size) + + num_layers: num of layers in DeeperGCN + + hidden_size: hidden_size in DeeperGCN + + num_tasks: final prediction + + name: deeper gcn layer names + + dropout_prob: dropout prob in DeeperGCN + + Return: + A tensor with shape (num_nodes, hidden_size) + """ + + beta = "dynamic" + feature = fluid.layers.fc(feature, + hidden_size, + bias_attr=False, + param_attr=fluid.ParamAttr(name=name + '_weight')) + + output = pgl.layers.gen_conv(gw, feature, name=name+"_gen_conv_0", beta=beta) + + for layer in range(num_layers): + # LN/BN->ReLU->GraphConv->Res + old_output = output + # 1. Layer Norm + output = fluid.layers.layer_norm( + output, + begin_norm_axis=1, + param_attr=fluid.ParamAttr( + name="norm_scale_%s_%d" % (name, layer), + initializer=fluid.initializer.Constant(1.0)), + bias_attr=fluid.ParamAttr( + name="norm_bias_%s_%d" % (name, layer), + initializer=fluid.initializer.Constant(0.0))) + + # 2. ReLU + output = fluid.layers.relu(output) + + #3. dropout + output = fluid.layers.dropout(output, + dropout_prob=dropout_prob, + dropout_implementation="upscale_in_train") + + #4 gen_conv + output = pgl.layers.gen_conv(gw, output, + name=name+"_gen_conv_%d"%layer, beta=beta) + + #5 res + output = output + old_output + + # final layer: LN + relu + droput + output = fluid.layers.layer_norm( + output, + begin_norm_axis=1, + param_attr=fluid.ParamAttr( + name="norm_scale_%s_%d" % (name, num_layers), + initializer=fluid.initializer.Constant(1.0)), + bias_attr=fluid.ParamAttr( + name="norm_bias_%s_%d" % (name, num_layers), + initializer=fluid.initializer.Constant(0.0))) + output = fluid.layers.relu(output) + output = fluid.layers.dropout(output, + dropout_prob=dropout_prob, + dropout_implementation="upscale_in_train") + + # final prediction + output = fluid.layers.fc(output, + num_tasks, + bias_attr=False, + param_attr=fluid.ParamAttr(name=name + '_final_weight')) + + return output + + diff --git a/examples/deeper_gcn/train.py b/examples/deeper_gcn/train.py new file mode 100644 index 0000000..83b2c69 --- /dev/null +++ b/examples/deeper_gcn/train.py @@ -0,0 +1,155 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#-*- coding: utf-8 -*- +import pgl +from pgl import data_loader +from pgl.utils.logger import log +import paddle.fluid as fluid +import numpy as np +import time +import argparse +from pgl.utils.log_writer import LogWriter # vdl +from model import DeeperGCN + +def load(name): + if name == 'cora': + dataset = data_loader.CoraDataset() + elif name == "pubmed": + dataset = data_loader.CitationDataset("pubmed", symmetry_edges=False) + elif name == "citeseer": + dataset = data_loader.CitationDataset("citeseer", symmetry_edges=False) + else: + raise ValueError(name + " dataset doesn't exists") + return dataset + + +def main(args): + # vdl + writer = LogWriter("checkpoints/train_history") + + dataset = load(args.dataset) + place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() + train_program = fluid.Program() + startup_program = fluid.Program() + test_program = fluid.Program() + hidden_size = 64 + num_layers = 50 + + with fluid.program_guard(train_program, startup_program): + gw = pgl.graph_wrapper.GraphWrapper( + name="graph", + node_feat=dataset.graph.node_feat_info()) + + output = DeeperGCN(gw, + gw.node_feat["words"], + num_layers, + hidden_size, + dataset.num_classes, + "deepercnn", + 0.1) + + node_index = fluid.layers.data( + "node_index", + shape=[None, 1], + dtype="int64", + append_batch_size=False) + node_label = fluid.layers.data( + "node_label", + shape=[None, 1], + dtype="int64", + append_batch_size=False) + + pred = fluid.layers.gather(output, node_index) + loss, pred = fluid.layers.softmax_with_cross_entropy( + logits=pred, label=node_label, return_softmax=True) + acc = fluid.layers.accuracy(input=pred, label=node_label, k=1) + loss = fluid.layers.mean(loss) + + test_program = train_program.clone(for_test=True) + with fluid.program_guard(train_program, startup_program): + adam = fluid.optimizer.Adam( + regularization=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=0.0005), + learning_rate=0.005) + adam.minimize(loss) + + exe = fluid.Executor(place) + exe.run(startup_program) + + feed_dict = gw.to_feed(dataset.graph) + + train_index = dataset.train_index + train_label = np.expand_dims(dataset.y[train_index], -1) + train_index = np.expand_dims(train_index, -1) + + val_index = dataset.val_index + val_label = np.expand_dims(dataset.y[val_index], -1) + val_index = np.expand_dims(val_index, -1) + + test_index = dataset.test_index + test_label = np.expand_dims(dataset.y[test_index], -1) + test_index = np.expand_dims(test_index, -1) + + # get beta param + beta_param_list = [] + for param in train_program.global_block().all_parameters(): + if param.name.endswith("_beta"): + beta_param_list.append(param) + + dur = [] + for epoch in range(200): + if epoch >= 3: + t0 = time.time() + feed_dict["node_index"] = np.array(train_index, dtype="int64") + feed_dict["node_label"] = np.array(train_label, dtype="int64") + train_loss, train_acc = exe.run(train_program, + feed=feed_dict, + fetch_list=[loss, acc], + return_numpy=True) + for param in beta_param_list: + beta = np.array(fluid.global_scope().find_var(param.name).get_tensor()) + writer.add_scalar(param.name, beta, epoch) + + if epoch >= 3: + time_per_epoch = 1.0 * (time.time() - t0) + dur.append(time_per_epoch) + + feed_dict["node_index"] = np.array(val_index, dtype="int64") + feed_dict["node_label"] = np.array(val_label, dtype="int64") + val_loss, val_acc = exe.run(test_program, + feed=feed_dict, + fetch_list=[loss, acc], + return_numpy=True) + + log.info("Epoch %d " % epoch + "(%.5lf sec) " % np.mean(dur) + + "Train Loss: %f " % train_loss + "Train Acc: %f " % train_acc + + "Val Loss: %f " % val_loss + "Val Acc: %f " % val_acc) + + feed_dict["node_index"] = np.array(test_index, dtype="int64") + feed_dict["node_label"] = np.array(test_label, dtype="int64") + test_loss, test_acc = exe.run(test_program, + feed=feed_dict, + fetch_list=[loss, acc], + return_numpy=True) + log.info("Accuracy: %f" % test_acc) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='DeeperGCN') + parser.add_argument( + "--dataset", type=str, default="cora", help="dataset (cora, pubmed)") + parser.add_argument("--use_cuda", action='store_true', help="use_cuda") + args = parser.parse_args() + log.info(args) + main(args) diff --git a/pgl/__init__.py b/pgl/__init__.py index 93375e9..7543265 100644 --- a/pgl/__init__.py +++ b/pgl/__init__.py @@ -21,3 +21,4 @@ from pgl import data_loader from pgl import heter_graph from pgl import heter_graph_wrapper from pgl import contrib +from pgl import message_passing diff --git a/pgl/layers/conv.py b/pgl/layers/conv.py index 68a1d73..4d17c32 100644 --- a/pgl/layers/conv.py +++ b/pgl/layers/conv.py @@ -15,10 +15,10 @@ graph neural networks. """ import paddle.fluid as fluid -from pgl import graph_wrapper from pgl.utils import paddle_helper +from pgl import message_passing -__all__ = ['gcn', 'gat', 'gin', 'gaan'] +__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv'] def gcn(gw, feature, hidden_size, activation, name, norm=None): @@ -352,3 +352,55 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o output = fluid.layers.dropout(output, dropout_prob=0.1) return output + + +def gen_conv(gw, + feature, + name, + beta=None): + """Implementation of GENeralized Graph Convolution (GENConv), see the paper + "DeeperGCN: All You Need to Train Deeper GCNs" in + https://arxiv.org/pdf/2006.07739.pdf + + Args: + gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) + + feature: A tensor with shape (num_nodes, feature_size). + + beta: [0, +infinity] or "dynamic" or None + + name: deeper gcn layer names. + + Return: + A tensor with shape (num_nodes, feature_size) + """ + + if beta == "dynamic": + beta = fluid.layers.create_parameter( + shape=[1], + dtype='float32', + default_initializer= + fluid.initializer.ConstantInitializer(value=1.0), + name=name + '_beta') + + # message passing + msg = gw.send(message_passing.copy_send, nfeat_list=[("h", feature)]) + output = gw.recv(msg, message_passing.softmax_agg(beta)) + + # msg norm + output = message_passing.msg_norm(feature, output, name) + output = feature + output + + output = fluid.layers.fc(output, + feature.shape[-1], + bias_attr=False, + act="relu", + param_attr=fluid.ParamAttr(name=name + '_weight1')) + + output = fluid.layers.fc(output, + feature.shape[-1], + bias_attr=False, + param_attr=fluid.ParamAttr(name=name + '_weight2')) + + return output + diff --git a/pgl/message_passing.py b/pgl/message_passing.py new file mode 100644 index 0000000..858ea97 --- /dev/null +++ b/pgl/message_passing.py @@ -0,0 +1,203 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This package implements some common message passing +functions to help building graph neural networks. +""" + +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as L +from pgl.utils import paddle_helper + +__all__ = ['copy_send', 'weighted_copy_send', 'mean_recv', + 'sum_recv', 'max_recv', 'lstm_recv', 'graphsage_sum', + 'graphsage_mean', 'pinsage_mean', 'pinsage_sum', + 'softmax_agg', 'msg_norm'] + + +def copy_send(src_feat, dst_feat, edge_feat): + """doc""" + return src_feat["h"] + +def weighted_copy_send(src_feat, dst_feat, edge_feat): + """doc""" + return src_feat["h"] * edge_feat["weight"] + +def mean_recv(feat): + """doc""" + return fluid.layers.sequence_pool(feat, pool_type="average") + + +def sum_recv(feat): + """doc""" + return fluid.layers.sequence_pool(feat, pool_type="sum") + + +def max_recv(feat): + """doc""" + return fluid.layers.sequence_pool(feat, pool_type="max") + + +def lstm_recv(feat): + """doc""" + hidden_dim = 128 + forward, _ = fluid.layers.dynamic_lstm( + input=feat, size=hidden_dim * 4, use_peepholes=False) + output = fluid.layers.sequence_last_step(forward) + return output + + +def graphsage_sum(gw, feature, hidden_size, act, initializer, learning_rate, name): + """doc""" + msg = gw.send(copy_send, nfeat_list=[("h", feature)]) + neigh_feature = gw.recv(msg, sum_recv) + self_feature = feature + self_feature = fluid.layers.fc(self_feature, + hidden_size, + act=act, + param_attr=fluid.ParamAttr(name=name + "_l.w_0", initializer=initializer, + learning_rate=learning_rate), + bias_attr=name+"_l.b_0" + ) + neigh_feature = fluid.layers.fc(neigh_feature, + hidden_size, + act=act, + param_attr=fluid.ParamAttr(name=name + "_r.w_0", initializer=initializer, + learning_rate=learning_rate), + bias_attr=name+"_r.b_0" + ) + output = fluid.layers.concat([self_feature, neigh_feature], axis=1) + output = fluid.layers.l2_normalize(output, axis=1) + return output + + +def graphsage_mean(gw, feature, hidden_size, act, initializer, learning_rate, name): + """doc""" + msg = gw.send(copy_send, nfeat_list=[("h", feature)]) + neigh_feature = gw.recv(msg, mean_recv) + self_feature = feature + self_feature = fluid.layers.fc(self_feature, + hidden_size, + act=act, + param_attr=fluid.ParamAttr(name=name + "_l.w_0", initializer=initializer, + learning_rate=learning_rate), + bias_attr=name+"_l.b_0" + ) + neigh_feature = fluid.layers.fc(neigh_feature, + hidden_size, + act=act, + param_attr=fluid.ParamAttr(name=name + "_r.w_0", initializer=initializer, + learning_rate=learning_rate), + bias_attr=name+"_r.b_0" + ) + output = fluid.layers.concat([self_feature, neigh_feature], axis=1) + output = fluid.layers.l2_normalize(output, axis=1) + return output + + +def pinsage_mean(gw, feature, hidden_size, act, initializer, learning_rate, name): + """doc""" + msg = gw.send(weighted_copy_send, nfeat_list=[("h", feature)], efeat_list=["weight"]) + neigh_feature = gw.recv(msg, mean_recv) + self_feature = feature + self_feature = fluid.layers.fc(self_feature, + hidden_size, + act=act, + param_attr=fluid.ParamAttr(name=name + "_l.w_0", initializer=initializer, + learning_rate=learning_rate), + bias_attr=name+"_l.b_0" + ) + neigh_feature = fluid.layers.fc(neigh_feature, + hidden_size, + act=act, + param_attr=fluid.ParamAttr(name=name + "_r.w_0", initializer=initializer, + learning_rate=learning_rate), + bias_attr=name+"_r.b_0" + ) + output = fluid.layers.concat([self_feature, neigh_feature], axis=1) + output = fluid.layers.l2_normalize(output, axis=1) + return output + + +def pinsage_sum(gw, feature, hidden_size, act, initializer, learning_rate, name): + """doc""" + msg = gw.send(weighted_copy_send, nfeat_list=[("h", feature)], efeat_list=["weight"]) + neigh_feature = gw.recv(msg, sum_recv) + self_feature = feature + self_feature = fluid.layers.fc(self_feature, + hidden_size, + act=act, + param_attr=fluid.ParamAttr(name=name + "_l.w_0", initializer=initializer, + learning_rate=learning_rate), + bias_attr=name+"_l.b_0" + ) + neigh_feature = fluid.layers.fc(neigh_feature, + hidden_size, + act=act, + param_attr=fluid.ParamAttr(name=name + "_r.w_0", initializer=initializer, + learning_rate=learning_rate), + bias_attr=name+"_r.b_0" + ) + output = fluid.layers.concat([self_feature, neigh_feature], axis=1) + output = fluid.layers.l2_normalize(output, axis=1) + return output + + +def softmax_agg(beta): + """Implementation of softmax_agg aggregator, see more information in the paper + "DeeperGCN: All You Need to Train Deeper GCNs" + (https://arxiv.org/pdf/2006.07739.pdf) + + Args: + msg: the received message, lod-tensor, (batch_size, seq_len, hidden_size) + beta: Inverse Temperature + + Return: + An output tensor with shape (num_nodes, hidden_size) + """ + + def softmax_agg_inside(msg): + alpha = paddle_helper.sequence_softmax(msg, beta) + msg = msg * alpha + return fluid.layers.sequence_pool(msg, "sum") + + return softmax_agg_inside + + +def msg_norm(x, msg, name): + """Implementation of message normalization, see more information in the paper + "DeeperGCN: All You Need to Train Deeper GCNs" + (https://arxiv.org/pdf/2006.07739.pdf) + + Args: + x: centre node feature (num_nodes, feature_size) + msg: neighbor node feature (num_nodes, feature_size) + name: name for s + + Return: + An output tensor with shape (num_nodes, feature_size) + """ + s = fluid.layers.create_parameter( + shape=[1], + dtype='float32', + default_initializer= + fluid.initializer.ConstantInitializer(value=1.0), + name=name + '_s_msg_norm') + + msg = fluid.layers.l2_normalize(msg, axis=1) + x_norm = fluid.layers.reduce_sum(x * x, dim=1, keep_dim=True) + msg = msg * x_norm * s + return msg + diff --git a/pgl/utils/paddle_helper.py b/pgl/utils/paddle_helper.py index adbece5..3570fac 100644 --- a/pgl/utils/paddle_helper.py +++ b/pgl/utils/paddle_helper.py @@ -185,7 +185,7 @@ def lod_constant(name, value, lod, dtype): return output, data_initializer -def sequence_softmax(x): +def sequence_softmax(x, beta=None): """Compute sequence softmax over paddle LodTensor This function compute softmax normalization along with the length of sequence. @@ -194,10 +194,15 @@ def sequence_softmax(x): Args: x: The input variable which is a LodTensor. + beta: Inverse Temperature Return: Output of sequence_softmax """ + + if beta is not None: + x = x * beta + x_max = fluid.layers.sequence_pool(x, "max") x_max = fluid.layers.sequence_expand_as(x_max, x) x = x - x_max -- GitLab