diff --git a/examples/citation_benchmark/README.md b/examples/citation_benchmark/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f2abd2a1fc920bc5a32ab4b45b28473566f72481 --- /dev/null +++ b/examples/citation_benchmark/README.md @@ -0,0 +1,62 @@ +# Easy Paper Reproduction for Citation Network (Cora/Pubmed/Citeseer) + + + +This page tries to reproduce all the **Graph Neural Network** paper for Citation Network (Cora/Pubmed/Citeseer), which is the **Hello world** dataset (**small** and **fast**) for graph neural networks. But it's very hard to achieve very high performance. + + + +All datasets are runned with public split of **semi-supervised** settings. And we report the averarge accuracy by running 10 times. + + + +# Experiment Results + +| Model | Cora | Pubmed | Citeseer | Remarks | +| ------------------------------------------------------------ | ------------ | ------------ | ------------ | --------------------------------------------------------- | +| [Vanilla GCN (Kipf 2017)](https://openreview.net/pdf?id=SJU4ayYgl ) | 0.807(0.010) | 0.794(0.003) | 0.710(0.007) | | +| [GAT (Veličković 2017)](https://arxiv.org/pdf/1710.10903.pdf) | 0.834(0.004) | 0.772(0.004) | 0.700(0.006) | | +| [SGC(Wu 2019)](https://arxiv.org/pdf/1902.07153.pdf) | 0.818(0.000) | 0.782(0.000) | 0.708(0.000) | | +| [APPNP (Johannes 2018)](https://arxiv.org/abs/1810.05997) | 0.846(0.003) | 0.803(0.002) | 0.719(0.003) | Almost the same with the results reported in Appendix E. | +| [GCNII (64 Layers, 1500 Epochs, Chen 2020)](https://arxiv.org/pdf/2007.02133.pdf) | 0.846(0.003) | 0.798(0.003) | 0.724(0.006) | | + + + + + +How to run the experiments? + + + +```shell +# Device choose +export CUDA_VISIBLE_DEVICES=0 +# GCN +python train.py --conf config/gcn.yaml --use_cuda --dataset cora +python train.py --conf config/gcn.yaml --use_cuda --dataset pubmed +python train.py --conf config/gcn.yaml --use_cuda --dataset citeseer + + +# GAT +python train.py --conf config/gat.yaml --use_cuda --dataset cora +python train.py --conf config/gat.yaml --use_cuda --dataset pubmed +python train.py --conf config/gat.yaml --use_cuda --dataset citeseer + + +# SGC (Slow version) +python train.py --conf config/sgc.yaml --use_cuda --dataset cora +python train.py --conf config/sgc.yaml --use_cuda --dataset pubmed +python train.py --conf config/sgc.yaml --use_cuda --dataset citeseer + +# APPNP +python train.py --conf config/appnp.yaml --use_cuda --dataset cora +python train.py --conf config/appnp.yaml --use_cuda --dataset pubmed +python train.py --conf config/appnp.yaml --use_cuda --dataset citeseer + +# GCNII (The original code use 1500 epochs.) +python train.py --conf config/gcnii.yaml --use_cuda --dataset cora --epoch 1500 +python train.py --conf config/gcnii.yaml --use_cuda --dataset pubmed --epoch 1500 +python train.py --conf config/gcnii.yaml --use_cuda --dataset citeseer --epoch 1500 +``` + + diff --git a/examples/citation_benchmark/build_model.py b/examples/citation_benchmark/build_model.py new file mode 100644 index 0000000000000000000000000000000000000000..179483d5284a7de18f59455ddfdaf5c742d8e810 --- /dev/null +++ b/examples/citation_benchmark/build_model.py @@ -0,0 +1,43 @@ +import pgl +import model +from pgl import data_loader +import paddle.fluid as fluid +import numpy as np +import time + +def build_model(dataset, config, phase, main_prog): + gw = pgl.graph_wrapper.GraphWrapper( + name="graph", + node_feat=dataset.graph.node_feat_info()) + + GraphModel = getattr(model, config.model_name) + m = GraphModel(config=config, num_class=dataset.num_classes) + logits = m.forward(gw, gw.node_feat["words"], phase) + + # Take the last + node_index = fluid.layers.data( + "node_index", + shape=[None, 1], + dtype="int64", + append_batch_size=False) + node_label = fluid.layers.data( + "node_label", + shape=[None, 1], + dtype="int64", + append_batch_size=False) + + pred = fluid.layers.gather(logits, node_index) + loss, pred = fluid.layers.softmax_with_cross_entropy( + logits=pred, label=node_label, return_softmax=True) + acc = fluid.layers.accuracy(input=pred, label=node_label, k=1) + loss = fluid.layers.mean(loss) + + if phase == "train": + adam = fluid.optimizer.Adam( + learning_rate=config.learning_rate, + regularization=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=config.weight_decay)) + adam.minimize(loss) + return gw, loss, acc + + diff --git a/examples/citation_benchmark/config/appnp.yaml b/examples/citation_benchmark/config/appnp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5b1b986a46d856d199503caa85dc6d05b0866e5b --- /dev/null +++ b/examples/citation_benchmark/config/appnp.yaml @@ -0,0 +1,9 @@ +model_name: APPNP +k_hop: 10 +alpha: 0.1 +num_layer: 1 +learning_rate: 0.01 +dropout: 0.5 +hidden_size: 64 +weight_decay: 0.0005 +edge_dropout: 0.0 diff --git a/examples/citation_benchmark/config/gat.yaml b/examples/citation_benchmark/config/gat.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9f62aedb1a3891388e16b2667bc556d236c59c9a --- /dev/null +++ b/examples/citation_benchmark/config/gat.yaml @@ -0,0 +1,9 @@ +model_name: GAT +learning_rate: 0.005 +weight_decay: 0.0005 +num_layers: 1 +feat_drop: 0.6 +attn_drop: 0.6 +num_heads: 8 +hidden_size: 8 +edge_dropout: 0.0 diff --git a/examples/citation_benchmark/config/gcn.yaml b/examples/citation_benchmark/config/gcn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..beb9129068e12ad077156d4f030f1c06f1cdbb01 --- /dev/null +++ b/examples/citation_benchmark/config/gcn.yaml @@ -0,0 +1,7 @@ +model_name: GCN +num_layers: 1 +dropout: 0.5 +hidden_size: 16 +learning_rate: 0.01 +weight_decay: 0.0005 +edge_dropout: 0.0 diff --git a/examples/citation_benchmark/config/gcnii.yaml b/examples/citation_benchmark/config/gcnii.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8fc5595f1eeb4965cf8cf5bf3f42cb8a94089e05 --- /dev/null +++ b/examples/citation_benchmark/config/gcnii.yaml @@ -0,0 +1,9 @@ +model_name: GCNII +k_hop: 64 +alpha: 0.1 +num_layer: 1 +learning_rate: 0.01 +dropout: 0.6 +hidden_size: 64 +weight_decay: 0.0005 +edge_dropout: 0.0 diff --git a/examples/citation_benchmark/config/sgc.yaml b/examples/citation_benchmark/config/sgc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2124608fa5533d2b29886f3b63dad973e60e0b5f --- /dev/null +++ b/examples/citation_benchmark/config/sgc.yaml @@ -0,0 +1,5 @@ +model_name: SGC +num_layers: 2 +learning_rate: 0.2 +weight_decay: 0.000005 +feature_pre_normalize: False diff --git a/examples/citation_benchmark/model.py b/examples/citation_benchmark/model.py new file mode 100644 index 0000000000000000000000000000000000000000..dcb1f78cb0627c140bd5a2039e84ba2a3029cfac --- /dev/null +++ b/examples/citation_benchmark/model.py @@ -0,0 +1,195 @@ +import pgl +import paddle.fluid.layers as L +import pgl.layers.conv as conv + +def get_norm(indegree): + float_degree = L.cast(indegree, dtype="float32") + float_degree = L.clamp(float_degree, min=1.0) + norm = L.pow(float_degree, factor=-0.5) + return norm + + +class GCN(object): + """Implement of GCN + """ + def __init__(self, config, num_class): + self.num_class = num_class + self.num_layers = config.get("num_layers", 1) + self.hidden_size = config.get("hidden_size", 64) + self.dropout = config.get("dropout", 0.5) + self.edge_dropout = config.get("edge_dropout", 0.0) + + def forward(self, graph_wrapper, feature, phase): + + for i in range(self.num_layers): + + if phase == "train": + ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout) + norm = get_norm(ngw.indegree()) + else: + ngw = graph_wrapper + norm = graph_wrapper.node_feat["norm"] + + + feature = pgl.layers.gcn(ngw, + feature, + self.hidden_size, + activation="relu", + norm=norm, + name="layer_%s" % i) + + feature = L.dropout( + feature, + self.dropout, + dropout_implementation='upscale_in_train') + + if phase == "train": + ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout) + norm = get_norm(ngw.indegree()) + else: + ngw = graph_wrapper + norm = graph_wrapper.node_feat["norm"] + + feature = conv.gcn(ngw, + feature, + self.num_class, + activation=None, + norm=norm, + name="output") + + return feature + +class GAT(object): + """Implement of GAT""" + def __init__(self, config, num_class): + self.num_class = num_class + self.num_layers = config.get("num_layers", 1) + self.num_heads = config.get("num_heads", 8) + self.hidden_size = config.get("hidden_size", 8) + self.feat_dropout = config.get("feat_drop", 0.6) + self.attn_dropout = config.get("attn_drop", 0.6) + self.edge_dropout = config.get("edge_dropout", 0.0) + + def forward(self, graph_wrapper, feature, phase): + if phase == "train": + edge_dropout = 0 + else: + edge_dropout = self.edge_dropout + + for i in range(self.num_layers): + ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout) + + feature = conv.gat(ngw, + feature, + self.hidden_size, + activation="elu", + name="gat_layer_%s" % i, + num_heads=self.num_heads, + feat_drop=self.feat_dropout, + attn_drop=self.attn_dropout) + + ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout) + feature = conv.gat(ngw, + feature, + self.num_class, + num_heads=1, + activation=None, + feat_drop=self.feat_dropout, + attn_drop=self.attn_dropout, + name="output") + return feature + + +class APPNP(object): + """Implement of APPNP""" + def __init__(self, config, num_class): + self.num_class = num_class + self.num_layers = config.get("num_layers", 1) + self.hidden_size = config.get("hidden_size", 64) + self.dropout = config.get("dropout", 0.5) + self.alpha = config.get("alpha", 0.1) + self.k_hop = config.get("k_hop", 10) + self.edge_dropout = config.get("edge_dropout", 0.0) + + def forward(self, graph_wrapper, feature, phase): + if phase == "train": + edge_dropout = 0 + else: + edge_dropout = self.edge_dropout + + for i in range(self.num_layers): + feature = L.dropout( + feature, + self.dropout, + dropout_implementation='upscale_in_train') + feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i) + + feature = L.dropout( + feature, + self.dropout, + dropout_implementation='upscale_in_train') + + feature = L.fc(feature, self.num_class, act=None, name="output") + + feature = conv.appnp(graph_wrapper, + feature=feature, + edge_dropout=edge_dropout, + alpha=self.alpha, + k_hop=self.k_hop) + return feature + +class SGC(object): + """Implement of SGC""" + def __init__(self, config, num_class): + self.num_class = num_class + self.num_layers = config.get("num_layers", 1) + + def forward(self, graph_wrapper, feature, phase): + feature = conv.appnp(graph_wrapper, + feature=feature, + edge_dropout=0, + alpha=0, + k_hop=self.num_layers) + feature.stop_gradient=True + feature = L.fc(feature, self.num_class, act=None, bias_attr=False, name="output") + return feature + + +class GCNII(object): + """Implement of GCNII""" + def __init__(self, config, num_class): + self.num_class = num_class + self.num_layers = config.get("num_layers", 1) + self.hidden_size = config.get("hidden_size", 64) + self.dropout = config.get("dropout", 0.6) + self.alpha = config.get("alpha", 0.1) + self.lambda_l = config.get("lambda_l", 0.5) + self.k_hop = config.get("k_hop", 64) + self.edge_dropout = config.get("edge_dropout", 0.0) + + def forward(self, graph_wrapper, feature, phase): + if phase == "train": + edge_dropout = 0 + else: + edge_dropout = self.edge_dropout + + for i in range(self.num_layers): + feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i) + feature = L.dropout( + feature, + self.dropout, + dropout_implementation='upscale_in_train') + + feature = conv.gcnii(graph_wrapper, + feature=feature, + name="gcnii", + activation="relu", + lambda_l=self.lambda_l, + alpha=self.alpha, + dropout=self.dropout, + k_hop=self.k_hop) + + feature = L.fc(feature, self.num_class, act=None, name="output") + return feature + + diff --git a/examples/citation_benchmark/train.py b/examples/citation_benchmark/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9e96fdadd8e47185a53a8424aa0f1abddd31a72e --- /dev/null +++ b/examples/citation_benchmark/train.py @@ -0,0 +1,152 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pgl +import model# import LabelGraphGCN +from pgl import data_loader +from pgl.utils.logger import log +import paddle.fluid as fluid +import numpy as np +import time +import argparse +from build_model import build_model +import yaml +from easydict import EasyDict as edict +import tqdm + +def normalize(feat): + return feat / np.maximum(np.sum(feat, -1, keepdims=True), 1) + + +def load(name, normalized_feature=True): + if name == 'cora': + dataset = data_loader.CoraDataset() + elif name == "pubmed": + dataset = data_loader.CitationDataset("pubmed", symmetry_edges=True) + elif name == "citeseer": + dataset = data_loader.CitationDataset("citeseer", symmetry_edges=True) + else: + raise ValueError(name + " dataset doesn't exists") + + indegree = dataset.graph.indegree() + norm = np.maximum(indegree.astype("float32"), 1) + norm = np.power(norm, -0.5) + dataset.graph.node_feat["norm"] = np.expand_dims(norm, -1) + dataset.graph.node_feat["words"] = normalize(dataset.graph.node_feat["words"]) + return dataset + + +def main(args, config): + dataset = load(args.dataset, args.feature_pre_normalize) + place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() + train_program = fluid.default_main_program() + startup_program = fluid.default_startup_program() + with fluid.program_guard(train_program, startup_program): + with fluid.unique_name.guard(): + gw, loss, acc = build_model(dataset, + config=config, + phase="train", + main_prog=train_program) + + test_program = fluid.Program() + with fluid.program_guard(test_program, startup_program): + with fluid.unique_name.guard(): + _gw, v_loss, v_acc = build_model(dataset, + config=config, + phase="test", + main_prog=test_program) + + test_program = test_program.clone(for_test=True) + + exe = fluid.Executor(place) + + + train_index = dataset.train_index + train_label = np.expand_dims(dataset.y[train_index], -1) + train_index = np.expand_dims(train_index, -1) + + val_index = dataset.val_index + val_label = np.expand_dims(dataset.y[val_index], -1) + val_index = np.expand_dims(val_index, -1) + + test_index = dataset.test_index + test_label = np.expand_dims(dataset.y[test_index], -1) + test_index = np.expand_dims(test_index, -1) + + dur = [] + + # Feed data + feed_dict = gw.to_feed(dataset.graph) + + + best_test = [] + + for run in range(args.runs): + exe.run(startup_program) + cal_val_acc = [] + cal_test_acc = [] + cal_val_loss = [] + cal_test_loss = [] + for epoch in tqdm.tqdm(range(args.epoch)): + feed_dict["node_index"] = np.array(train_index, dtype="int64") + feed_dict["node_label"] = np.array(train_label, dtype="int64") + train_loss, train_acc = exe.run(train_program, + feed=feed_dict, + fetch_list=[loss, acc], + return_numpy=True) + + + feed_dict["node_index"] = np.array(val_index, dtype="int64") + feed_dict["node_label"] = np.array(val_label, dtype="int64") + val_loss, val_acc = exe.run(test_program, + feed=feed_dict, + fetch_list=[v_loss, v_acc], + return_numpy=True) + + cal_val_acc.append(val_acc[0]) + cal_val_loss.append(val_loss[0]) + + feed_dict["node_index"] = np.array(test_index, dtype="int64") + feed_dict["node_label"] = np.array(test_label, dtype="int64") + test_loss, test_acc = exe.run(test_program, + feed=feed_dict, + fetch_list=[v_loss, v_acc], + return_numpy=True) + + cal_test_acc.append(test_acc[0]) + cal_test_loss.append(test_loss[0]) + + + log.info("Runs %s: Model: %s Best Test Accuracy: %f" % (run, config.model_name, + cal_test_acc[np.argmin(cal_val_loss)])) + + best_test.append(cal_test_acc[np.argmin(cal_val_loss)]) + log.info("Dataset: %s Best Test Accuracy: %f ( stddev: %f )" % (args.dataset, np.mean(best_test), np.std(best_test))) + print("Dataset: %s Best Test Accuracy: %f ( stddev: %f )" % (args.dataset, np.mean(best_test), np.std(best_test))) + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Benchmarking Citation Network') + parser.add_argument( + "--dataset", type=str, default="cora", help="dataset (cora, pubmed)") + parser.add_argument("--use_cuda", action='store_true', help="use_cuda") + parser.add_argument("--conf", type=str, help="config file for models") + parser.add_argument("--epoch", type=int, default=200, help="Epoch") + parser.add_argument("--runs", type=int, default=10, help="runs") + parser.add_argument("--feature_pre_normalize", type=bool, default=True, help="pre_normalize feature") + args = parser.parse_args() + config = edict(yaml.load(open(args.conf), Loader=yaml.FullLoader)) + log.info(args) + main(args, config) diff --git a/examples/erniesage/README.en.md b/examples/erniesage/README.en.md index 5933b9576bcc882de5014a6d09e5ef09107fce19..c21b2dc144f05882caa7749785c215dd917d844e 100644 --- a/examples/erniesage/README.en.md +++ b/examples/erniesage/README.en.md @@ -49,6 +49,8 @@ sh local_run.sh config/enriesage_v1_gpu.yaml sh local_run.sh config/enriesage_v1_cpu.yaml ``` +**NOTE**: To help users better understand the ERNIESage Model, we provide a running example in Baidu AIStudio. Please visit here: https://aistudio.baidu.com/aistudio/projectdetail/667443. + ## Hyperparamters - learner_type: `gpu` or `cpu`; gpu use fleet Collective mode, cpu use fleet Transpiler mode. diff --git a/examples/erniesage/README.md b/examples/erniesage/README.md index 78004852b9aa1f95b372f06d4104e7d65118f6e7..7ae48ba7d2ce763a327632c2a494c10f82151a90 100644 --- a/examples/erniesage/README.md +++ b/examples/erniesage/README.md @@ -50,6 +50,8 @@ sh local_run.sh config/erniesage_v2_gpu.yaml sh local_run.sh config/erniesage_v2_cpu.yaml ``` +**NOTE**:为了方便用户们学习使用ERNIESage,我们在百度AIStudio中提供了可以直接运行的ERNIESage实例,详情可见:https://aistudio.baidu.com/aistudio/projectdetail/667443. + ## Hyperparamters - learner_type: `gpu` or `cpu`; gpu 使用fleet Collective 模式, cpu 使用fleet Transpiler 模式. diff --git a/examples/xformer/README.md b/examples/xformer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..616d0ad53e1dfe054732d052c4a97cfce5cb88c1 --- /dev/null +++ b/examples/xformer/README.md @@ -0,0 +1,31 @@ +# X-Transformer + +Models based on Transformers are wildly successful for a wide variety of Natural Language Processing (NLP) tasks and consequently are a mainstay of modern NLP research. Transformer is constituted of a self-attention and a feed-forward module. The self-attention mechanism allows each token in the input sequence to attend independently to every other token in the sequence. From the view of graph representation, the generalized attention mechanism can be described by a Undirected Complete Graph whose vertex is the token. So, the attention module can be implemented by a graph library, especially recently the efficient attention implementation, e.g. [BigBird](https://arxiv.org/abs/2007.14062) \ [LongFormer](https://arxiv.org/abs/2004.05150) \ [Sparse Transformer](https://arxiv.org/abs/1904.10509). + +We have showcased the [BigBird](https://arxiv.org/abs/2007.14062) implementation and tested the performence as show below, and the [LongFormer](https://arxiv.org/abs/2004.05150) \ [Sparse Transformer](https://arxiv.org/abs/1904.10509) can be easily implemented by revised the correspoding code. + + + +## Dependencies + +- [paddlepaddle >= 1.7](https://github.com/PaddlePaddle/paddle) +- [pgl 1.1](https://github.com/PaddlePaddle/PGL) + + +## Performance + +We have evaluate the implemented method on a summarization dataset CNN/DM. The experiment was conducted on two P40 GPU cards. + +| CNN/DM | BatchSize | R1 | R2 | R3 | speed(steps/s) | +| ------------------ | --------- | ----------------- | ----------------- | ----------------- | ------ | +| LEAD | - | 40.42 | 17.62 | 36.67 | - | +| Oracle | - | 52.59 | 31.24 | 48.87 | - | +| non-sparse, L=512 | 32 | 42.175 | 19.392 | 38.613 | 0.6359 | +| L=2048 | 10 | 41.334 | 18.369 | 37.752 | 0.8246 | +| L=1024 | 20 | 41.453 | 18.529 | 37.872 | 0.6432 | +| L=768 | 26 | 41.611 | 18.735 | 38.051 | 0.6517 | +| L=512 | 40 | 41.742 | 18.733 | 38.127 | 0.6213 | + +**\**** For this task, we warm up from ERNIE 2.0 en directly rather than pretrain the model for the additional position embedding, so the embedding for the position which is larger than 512 is used repeatedly from ERNIE 2.0. +This may cause score degradation. But in the future, we will test the pre-trained model. + diff --git a/examples/xformer/sparse_scaled_dot_product_attention.py b/examples/xformer/sparse_scaled_dot_product_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1df85ae4e0b2437365a9f49df3676c609f61b37c --- /dev/null +++ b/examples/xformer/sparse_scaled_dot_product_attention.py @@ -0,0 +1,183 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle.fluid as fluid +import paddle.fluid.layers as L +import paddle.fluid.layers as layers +from pgl.utils import paddle_helper +import pgl + + +def masked_select(input, mask): + """masked_select + + Slice the value from given Mask + + Args: + input: Input tensor to be selected + + mask: A bool tensor for sliced. + + Return: + Part of inputs where mask is True. + """ + index = L.where(mask) + return L.gather(input, index, overwrite=False) + + +class BigBirdWrapper(pgl.graph_wrapper.BaseGraphWrapper): + """Implement of Big Bird by PGL graph wrapper """ + def __init__(self, input_mask): + super(BigBirdWrapper, self).__init__() + max_seqlen = L.shape(input_mask)[1] + input_mask = L.reshape(input_mask, [-1]) + num_nodes = L.shape(input_mask)[0] + src, dst = build_edges(num_nodes, input_mask, max_seqlen) + self._edges_src = src + self._edges_dst = dst + self._edges_src.stop_gradient=True + self._edges_dst.stop_gradient=True + self._num_nodes = num_nodes + self._num_edges = L.shape(self._edges_src)[0] + self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32") + self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(self._edges_dst, dtype="int32") + self._edge_uniq_dst.stop_gradient=True + last = L.reduce_sum(uniq_count, keep_dim=True) + uniq_count = L.cumsum(uniq_count, exclusive=True) + self._edge_uniq_dst_count = L.concat([uniq_count, last]) + self._edge_uniq_dst_count.stop_gradient=True + + +def select_edges(src, dst, input_mask, num_nodes, max_seqlen): + src = fluid.layers.elementwise_max(src, num_nodes * 0) + dst = fluid.layers.elementwise_max(dst, num_nodes * 0) + src = fluid.layers.elementwise_min(src, num_nodes - 1) + dst = fluid.layers.elementwise_min(dst, num_nodes - 1) + + conditions = [] + conditions.append(L.gather(input_mask, src) > 0.5) + conditions.append(L.gather(input_mask, dst) > 0.5) + block_src = src / max_seqlen + block_dst = dst / max_seqlen + conditions.append(block_src == block_dst) + mask = None + for cond in conditions: + if mask is None: + mask = cond + else: + mask = L.logical_and(mask, cond) + + dst = masked_select(dst, mask) + src = masked_select(src, mask) + return src, dst + + +def uniq_edges(src, dst, num_nodes): + sorted_dst = L.cast(dst, dtype="int64") + sorted_src = L.cast(src, dtype="int64") + num_nodes = L.cast(num_nodes, dtype="int64") + edge_hash = sorted_dst * num_nodes + sorted_src + edge_hash, _ = L.argsort(edge_hash) + edge_hash, _ = L.unique(edge_hash, dtype="int64") + sorted_src = L.elementwise_mod(edge_hash, num_nodes) + sorted_dst = L.elementwise_div(edge_hash, num_nodes) + sorted_src = L.cast(sorted_src, dtype="int32") + sorted_dst = L.cast(sorted_dst, dtype="int32") + return sorted_src, sorted_dst + + +def build_edges(num_nodes, input_mask, max_seqlen): + edges = L.range(start=0, end=num_nodes, step=1, dtype="int32") + all_edges = [] + # Window + filter_func = lambda x, y: select_edges(x, y, input_mask, num_nodes, max_seqlen) + + all_edges.append(filter_func(edges - 1, edges)) # win-1 + all_edges.append(filter_func(edges + 1, edges)) # win-2 + all_edges.append(filter_func(edges, edges)) #self-loop + + # Global Assume [CLS] is the first token. + + # vertical cls-window attention + cls_position = edges / max_seqlen * max_seqlen + all_edges.append(filter_func(cls_position, edges)) + + # horizontal cls attention + all_edges.append(filter_func(edges, cls_position)) + + # Random + for i in range(2): + rand_edge = L.floor(L.uniform_random(min=0, max=1, shape=[num_nodes]) * L.cast(max_seqlen, dtype="float32")) + rand_edge = L.cast(rand_edge, dtype="int32") + cls_position + all_edges.append(filter_func(rand_edge, edges)) + + if len(all_edges) > 1: + src = L.concat([ s for s, d in all_edges], 0) + dst = L.concat([ d for s, d in all_edges], 0) + else: + src = all_edges[0][0] + dst = all_edges[0][1] + + # sort edges + sorted_src, sorted_dst = uniq_edges(src, dst, num_nodes) + return sorted_src, sorted_dst + + +def sparse_scaled_dot_product_attention(q, k, v, input_mask, dropout_rate, n_head, d_key, d_value): + def send_q_k_spmm(src_feat, dst_feat, edge_feat): + # q [ num_edges, n_head * dim] + # k [ num_edges, n_head * dim] + # v [ num_edges, n_head * dim] + _q = dst_feat["q"] + _k = src_feat["k"] + _v = src_feat["v"] + _q = L.reshape(_q, [-1, n_head, _q.shape[-1] // n_head]) + _k = L.reshape(_k, [-1, n_head, _k.shape[-1] // n_head]) + score = L.reduce_sum(_q * _k, -1) # [num_edge, n_head] + return { "score": score, "value": _v} + + def recv_score_v_spmm(msg): + score = msg["score"] + score = paddle_helper.sequence_softmax(score) + score = layers.dropout( + score, + dropout_prob=dropout_rate, + dropout_implementation="upscale_in_train", + is_test=False) + + score = L.reshape(score, [-1, n_head, 1]) + _v = msg["value"] + _new_v = L.reshape(_v, [-1, n_head, _v.shape[-1] // n_head]) + + _new_v = _new_v * score + + _new_v = L.reshape(_new_v, [-1, _v.shape[-1]]) + _new_v = L.lod_reset(_new_v, _v) + return L.sequence_pool(_new_v, "sum") + + graph_wrapper = BigBirdWrapper(input_mask) + old_v = v + + q = L.reshape(q, [-1, d_key * n_head]) + k = L.reshape(k, [-1, d_key * n_head]) + v = L.reshape(v, [-1, d_value * n_head]) + + q = L.scale(q, scale=d_key ** -0.5) + msg = graph_wrapper.send(send_q_k_spmm, nfeat_list=[("k", k), ("v", v), ("q", q)]) + out = graph_wrapper.recv(msg, recv_score_v_spmm) + out = L.reshape(out, [-1, L.shape(old_v)[1], d_value * n_head]) + return out, out + + diff --git a/examples/xformer/transformer_encoder_sparse.py b/examples/xformer/transformer_encoder_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..43345d8a050ef00acf6ce57c3eb7a7a9ed01ee0d --- /dev/null +++ b/examples/xformer/transformer_encoder_sparse.py @@ -0,0 +1,361 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import paddle.fluid as fluid +import paddle.fluid.layers as L +import paddle.fluid.layers as layers + + +from .sparse_scaled_dot_product_attention import sparse_scaled_dot_product_attention + +to_3d = lambda a: a # will change later +to_2d = lambda a: a + + +def multi_head_attention(queries, + keys, + values, + attn_bias, + d_key, + d_value, + d_model, + input_mask, + n_head=1, + dropout_rate=0., + cache=None, + param_initializer=None, + name='multi_head_att'): + """ + Multi-Head Attention. Note that attn_bias is added to the logit before + computing softmax activiation to mask certain selected positions so that + they will not considered in attention weights. + """ + keys = queries if keys is None else keys + values = keys if values is None else values + + def __compute_qkv(queries, keys, values, n_head, d_key, d_value): + """ + Add linear projection to queries, keys, and values. + """ + q = layers.fc(input=queries, + size=d_key * n_head, + num_flatten_dims=len(queries.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_query_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_query_fc.b_0') + k = layers.fc(input=keys, + size=d_key * n_head, + num_flatten_dims=len(keys.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_key_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_key_fc.b_0') + v = layers.fc(input=values, + size=d_value * n_head, + num_flatten_dims=len(values.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_value_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_value_fc.b_0') + return q, k, v + + def __split_heads(x, n_head): + """ + Reshape the last dimension of inpunt tensor x so that it becomes two + dimensions and then transpose. Specifically, input a tensor with shape + [bs, max_sequence_length, n_head * hidden_dim] then output a tensor + with shape [bs, n_head, max_sequence_length, hidden_dim]. + """ + hidden_size = x.shape[-1] + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + reshaped = layers.reshape( + x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True) + + # permuate the dimensions into: + # [batch_size, n_head, max_sequence_len, hidden_size_per_head] + return layers.transpose(x=reshaped, perm=[0, 2, 1, 3]) + + def __combine_heads(x): + """ + Transpose and then reshape the last two dimensions of inpunt tensor x + so that it becomes one dimension, which is reverse to __split_heads. + """ + if len(x.shape) == 3: return x + if len(x.shape) != 4: + raise ValueError("Input(x) should be a 4-D Tensor.") + trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + #trans_x.desc.set_shape((-1, 1, n_head, d_value)) + return layers.reshape(x=trans_x, shape=[0, 0, d_model], inplace=True) + + q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) + q = to_3d(q) + k = to_3d(k) + v = to_3d(v) + + if cache is not None: # use cache and concat time steps + # Since the inplace reshape in __split_heads changes the shape of k and + # v, which is the cache input for next time step, reshape the cache + # input from the previous time step first. + k = cache["k"] = layers.concat( + [layers.reshape( + cache["k"], shape=[0, 0, d_model]), k], axis=1) + v = cache["v"] = layers.concat( + [layers.reshape( + cache["v"], shape=[0, 0, d_model]), v], axis=1) + + out, _ = sparse_scaled_dot_product_attention(q, k, v, + input_mask, dropout_rate, n_head, d_key, d_value) + + out = to_2d(out) + + # Project back to the model size. + proj_out = layers.fc(input=out, + size=d_model, + num_flatten_dims=len(out.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_output_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_output_fc.b_0') + return proj_out, _ + + +def positionwise_feed_forward(x, + d_inner_hid, + d_hid, + dropout_rate, + hidden_act, + param_initializer=None, + name='ffn'): + """ + Position-wise Feed-Forward Networks. + This module consists of two linear transformations with a ReLU activation + in between, which is applied to each position separately and identically. + """ + hidden = layers.fc(input=x, + size=d_inner_hid, + num_flatten_dims=len(x.shape) - 1, + act=hidden_act, + param_attr=fluid.ParamAttr( + name=name + '_fc_0.w_0', + initializer=param_initializer), + bias_attr=name + '_fc_0.b_0') + if dropout_rate: + hidden = layers.dropout( + hidden, + dropout_prob=dropout_rate, + dropout_implementation="upscale_in_train", + is_test=False) + out = layers.fc(input=hidden, + size=d_hid, + num_flatten_dims=len(hidden.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_fc_1.w_0', + initializer=param_initializer), + bias_attr=name + '_fc_1.b_0') + return out + + +def pre_post_process_layer(prev_out, + out, + process_cmd, + dropout_rate=0., + name=''): + """ + Add residual connection, layer normalization and droput to the out tensor + optionally according to the value of process_cmd. + This will be used before or after multi-head attention and position-wise + feed-forward networks. + """ + for cmd in process_cmd: + if cmd == "a": # add residual connection + out = out + prev_out if prev_out else out + elif cmd == "n": # add layer normalization + out_dtype = out.dtype + if out_dtype == fluid.core.VarDesc.VarType.FP16: + out = layers.cast(x=out, dtype="float32") + out = layers.layer_norm( + out, + begin_norm_axis=len(out.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_layer_norm_scale', + initializer=fluid.initializer.Constant(1.)), + bias_attr=fluid.ParamAttr( + name=name + '_layer_norm_bias', + initializer=fluid.initializer.Constant(0.))) + if out_dtype == fluid.core.VarDesc.VarType.FP16: + out = layers.cast(x=out, dtype="float16") + elif cmd == "d": # add dropout + if dropout_rate: + out = layers.dropout( + out, + dropout_prob=dropout_rate, + dropout_implementation="upscale_in_train", + is_test=False) + return out + + +pre_process_layer = partial(pre_post_process_layer, None) +post_process_layer = pre_post_process_layer + + +def encoder_layer(enc_input, + input_mask, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + hidden_act, + preprocess_cmd="n", + postprocess_cmd="da", + param_initializer=None, + name=''): + """The encoder layers that can be stacked to form a deep encoder. + This module consits of a multi-head (self) attention followed by + position-wise feed-forward networks and both the two components companied + with the post_process_layer to add residual connection, layer normalization + and droput. + """ + attn_output, ctx_multiheads_attn = multi_head_attention( + pre_process_layer( + enc_input, + preprocess_cmd, + prepostprocess_dropout, + name=name + '_pre_att'), + None, + None, + attn_bias, + d_key, + d_value, + d_model, + input_mask, + n_head, + attention_dropout, + param_initializer=param_initializer, + name=name + '_multi_head_att') + attn_output = post_process_layer( + enc_input, + attn_output, + postprocess_cmd, + prepostprocess_dropout, + name=name + '_post_att') + + ffd_output = positionwise_feed_forward( + pre_process_layer( + attn_output, + preprocess_cmd, + prepostprocess_dropout, + name=name + '_pre_ffn'), + d_inner_hid, + d_model, + relu_dropout, + hidden_act, + param_initializer=param_initializer, + name=name + '_ffn') + + ret = post_process_layer( + attn_output, + ffd_output, + postprocess_cmd, + prepostprocess_dropout, + name=name + '_post_ffn') + + return ret, ctx_multiheads_attn, ffd_output + + +def build_pad_idx(input_mask): + pad_idx = L.where(L.cast(L.squeeze(input_mask, [2]), 'bool')) + return pad_idx + + +def build_attn_bias(input_mask, n_head, dtype): + attn_bias = L.matmul( + input_mask, input_mask, transpose_y=True) # [batch, seq, seq] + attn_bias = (1. - attn_bias) * -10000. + attn_bias = L.stack([attn_bias] * n_head, 1) + if attn_bias.dtype != dtype: + attn_bias = L.cast(attn_bias, dtype) + return attn_bias + + +def encoder(enc_input, + input_mask, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + hidden_act, + preprocess_cmd="n", + postprocess_cmd="da", + param_initializer=None, + name=''): + """ + The encoder is composed of a stack of identical layers returned by calling + encoder_layer. + """ + + d_shape = L.shape(input_mask) + pad_idx = build_pad_idx(input_mask) + attn_bias = build_attn_bias(input_mask, n_head, enc_input.dtype) + + enc_input = to_2d(enc_input) + all_hidden = [] + all_attn = [] + all_ffn = [] + for i in range(n_layer): + enc_output, ctx_multiheads_attn, ffn_output = encoder_layer( + enc_input, + input_mask, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + hidden_act, + preprocess_cmd, + postprocess_cmd, + param_initializer=param_initializer, + name=name + '_layer_' + str(i)) + all_hidden.append(enc_output) + all_attn.append(ctx_multiheads_attn) + all_ffn.append(ffn_output) + enc_input = enc_output + enc_output = pre_process_layer( + enc_output, + preprocess_cmd, + prepostprocess_dropout, + name="post_encoder") + enc_output = to_3d(enc_output) + return enc_output, all_hidden, all_attn, all_ffn + + diff --git a/pgl/__init__.py b/pgl/__init__.py index 7543265d30492f8d1fe7a898f948166ae89001ea..e364540e8192d6dbebab72e1af552f00a4919c72 100644 --- a/pgl/__init__.py +++ b/pgl/__init__.py @@ -22,3 +22,4 @@ from pgl import heter_graph from pgl import heter_graph_wrapper from pgl import contrib from pgl import message_passing +from pgl import sample diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index 91dda8f78796aedd493b37e85a92ad9ecb1c6664..e91feddc69805d5c50ac4cfbf2e54df0238487cc 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -19,6 +19,7 @@ for PaddlePaddle. import warnings import numpy as np import paddle.fluid as fluid +import paddle.fluid.layers as L from pgl.utils import op from pgl.utils import paddle_helper @@ -26,12 +27,11 @@ from pgl.utils.logger import log __all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"] - def send(src, dst, nfeat, efeat, message_func): """Send message from src to dst. """ - src_feat = op.read_rows(nfeat, src) - dst_feat = op.read_rows(nfeat, dst) + src_feat = op.RowReader(nfeat, src) + dst_feat = op.RowReader(nfeat, dst) msg = message_func(src_feat, dst_feat, efeat) return msg @@ -47,10 +47,10 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, try: out_dim = msg.shape[-1] - init_output = fluid.layers.fill_constant( + init_output = L.fill_constant( shape=[num_nodes, out_dim], value=0, dtype=msg.dtype) init_output.stop_gradient = False - empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=msg.dtype) + empty_msg_flag = L.cast(num_edges > 0, dtype=msg.dtype) msg = msg * empty_msg_flag output = paddle_helper.scatter_add(init_output, dst, msg) return output @@ -59,7 +59,7 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, "scatter_add is not supported with paddle version <= 1.5") def sum_func(message): - return fluid.layers.sequence_pool(message, "sum") + return L.sequence_pool(message, "sum") reduce_function = sum_func @@ -67,13 +67,13 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, output = reduce_function(bucketed_msg) output_dim = output.shape[-1] - empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=output.dtype) + empty_msg_flag = L.cast(num_edges > 0, dtype=output.dtype) output = output * empty_msg_flag - init_output = fluid.layers.fill_constant( + init_output = L.fill_constant( shape=[num_nodes, output_dim], value=0, dtype=output.dtype) init_output.stop_gradient = True - final_output = fluid.layers.scatter(init_output, uniq_dst, output) + final_output = L.scatter(init_output, uniq_dst, output) return final_output @@ -104,6 +104,7 @@ class BaseGraphWrapper(object): self._node_ids = None self._graph_lod = None self._num_graph = None + self._num_edges = None self._data_name_prefix = "" def __repr__(self): @@ -142,11 +143,13 @@ class BaseGraphWrapper(object): """ if efeat_list is None: efeat_list = {} + if nfeat_list is None: nfeat_list = {} src, dst = self.edges nfeat = {} + for feat in nfeat_list: if isinstance(feat, str): nfeat[feat] = self.node_feat[feat] @@ -470,7 +473,7 @@ class StaticGraphWrapper(BaseGraphWrapper): class GraphWrapper(BaseGraphWrapper): """Implement a graph wrapper that creates a graph data holders - that attributes and features in the graph are :code:`fluid.layers.data`. + that attributes and features in the graph are :code:`L.data`. And we provide interface :code:`to_feed` to help converting :code:`Graph` data into :code:`feed_dict`. @@ -546,65 +549,65 @@ class GraphWrapper(BaseGraphWrapper): def __create_graph_attr_holders(self): """Create data holders for graph attributes. """ - self._num_edges = fluid.layers.data( + self._num_edges = L.data( self._data_name_prefix + '/num_edges', shape=[1], append_batch_size=False, dtype="int64", stop_gradient=True) - self._num_graph = fluid.layers.data( + self._num_graph = L.data( self._data_name_prefix + '/num_graph', shape=[1], append_batch_size=False, dtype="int64", stop_gradient=True) - self._edges_src = fluid.layers.data( + self._edges_src = L.data( self._data_name_prefix + '/edges_src', shape=[None], append_batch_size=False, dtype="int64", stop_gradient=True) - self._edges_dst = fluid.layers.data( + self._edges_dst = L.data( self._data_name_prefix + '/edges_dst', shape=[None], append_batch_size=False, dtype="int64", stop_gradient=True) - self._num_nodes = fluid.layers.data( + self._num_nodes = L.data( self._data_name_prefix + '/num_nodes', shape=[1], append_batch_size=False, dtype='int64', stop_gradient=True) - self._edge_uniq_dst = fluid.layers.data( + self._edge_uniq_dst = L.data( self._data_name_prefix + "/uniq_dst", shape=[None], append_batch_size=False, dtype="int64", stop_gradient=True) - self._graph_lod = fluid.layers.data( + self._graph_lod = L.data( self._data_name_prefix + "/graph_lod", shape=[None], append_batch_size=False, dtype="int32", stop_gradient=True) - self._edge_uniq_dst_count = fluid.layers.data( + self._edge_uniq_dst_count = L.data( self._data_name_prefix + "/uniq_dst_count", shape=[None], append_batch_size=False, dtype="int32", stop_gradient=True) - self._node_ids = fluid.layers.data( + self._node_ids = L.data( self._data_name_prefix + "/node_ids", shape=[None], append_batch_size=False, dtype="int64", stop_gradient=True) - self._indegree = fluid.layers.data( + self._indegree = L.data( self._data_name_prefix + "/indegree", shape=[None], append_batch_size=False, @@ -627,7 +630,7 @@ class GraphWrapper(BaseGraphWrapper): node_feat_dtype): """Create data holders for node features. """ - feat_holder = fluid.layers.data( + feat_holder = L.data( self._data_name_prefix + '/node_feat/' + node_feat_name, shape=node_feat_shape, append_batch_size=False, @@ -640,7 +643,7 @@ class GraphWrapper(BaseGraphWrapper): edge_feat_dtype): """Create edge holders for edge features. """ - feat_holder = fluid.layers.data( + feat_holder = L.data( self._data_name_prefix + '/edge_feat/' + edge_feat_name, shape=edge_feat_shape, append_batch_size=False, @@ -719,3 +722,61 @@ class GraphWrapper(BaseGraphWrapper): """Return the holder list. """ return self._holder_list + + +def get_degree(edge, num_nodes): + init_output = L.fill_constant( + shape=[num_nodes], value=0, dtype="float32") + init_output.stop_gradient = True + final_output = L.scatter(init_output, + edge, + L.full_like(edge, 1, dtype="float32"), + overwrite=False) + return final_output + +class DropEdgeWrapper(BaseGraphWrapper): + """Implement of Edge Drop """ + def __init__(self, graph_wrapper, dropout, keep_self_loop=True): + super(DropEdgeWrapper, self).__init__() + + # Copy Node's information + for key, value in graph_wrapper.node_feat.items(): + self.node_feat_tensor_dict[key] = value + + self._num_nodes = graph_wrapper.num_nodes + self._graph_lod = graph_wrapper.graph_lod + self._num_graph = graph_wrapper.num_graph + self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32") + + # Dropout Edges + src, dst = graph_wrapper.edges + u = L.uniform_random(shape=L.cast(L.shape(src), 'int64'), min=0., max=1.) + + + # Avoid Empty Edges + keeped = L.cast(u > dropout, dtype="float32") + self._num_edges = L.reduce_sum(L.cast(keeped, "int32")) + keeped = keeped + L.cast(self._num_edges == 0, dtype="float32") + + if keep_self_loop: + self_loop = L.cast(src == dst, dtype="float32") + keeped = keeped + self_loop + + keeped = (keeped > 0.5) + src = paddle_helper.masked_select(src, keeped) + dst = paddle_helper.masked_select(dst, keeped) + src.stop_gradient=True + dst.stop_gradient=True + self._edges_src = src + self._edges_dst = dst + + for key, value in graph_wrapper.edge_feat.items(): + self.edge_feat_tensor_dict[key] = paddle_helper.masked_select(value, keeped) + + self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(dst, dtype="int32") + self._edge_uniq_dst.stop_gradient=True + last = L.reduce_sum(uniq_count, keep_dim=True) + uniq_count = L.cumsum(uniq_count, exclusive=True) + self._edge_uniq_dst_count = L.concat([uniq_count, last]) + self._edge_uniq_dst_count.stop_gradient=True + self._indegree = get_degree(self._edges_dst, self._num_nodes) diff --git a/pgl/layers/conv.py b/pgl/layers/conv.py index 4d17c323b1f338f185bd3c90d60ac36741664886..6973b3734b17aa9623c11422893da52568851f61 100644 --- a/pgl/layers/conv.py +++ b/pgl/layers/conv.py @@ -14,11 +14,14 @@ """This package implements common layers to help building graph neural networks. """ +import pgl import paddle.fluid as fluid +import paddle.fluid.layers as L from pgl.utils import paddle_helper from pgl import message_passing +import numpy as np -__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv'] +__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv', 'appnp', 'gcnii'] def gcn(gw, feature, hidden_size, activation, name, norm=None): @@ -50,7 +53,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None): size = feature.shape[-1] if size > hidden_size: - feature = fluid.layers.fc(feature, + feature = L.fc(feature, size=hidden_size, bias_attr=False, param_attr=fluid.ParamAttr(name=name)) @@ -64,7 +67,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None): output = gw.recv(msg, "sum") else: output = gw.recv(msg, "sum") - output = fluid.layers.fc(output, + output = L.fc(output, size=hidden_size, bias_attr=False, param_attr=fluid.ParamAttr(name=name)) @@ -72,12 +75,12 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None): if norm is not None: output = output * norm - bias = fluid.layers.create_parameter( + bias = L.create_parameter( shape=[hidden_size], dtype='float32', is_bias=True, name=name + '_bias') - output = fluid.layers.elementwise_add(output, bias, act=activation) + output = L.elementwise_add(output, bias, act=activation) return output @@ -120,7 +123,7 @@ def gat(gw, def send_attention(src_feat, dst_feat, edge_feat): output = src_feat["left_a"] + dst_feat["right_a"] - output = fluid.layers.leaky_relu( + output = L.leaky_relu( output, alpha=0.2) # (num_edges, num_heads) return {"alpha": output, "h": src_feat["h"]} @@ -129,54 +132,54 @@ def gat(gw, h = msg["h"] alpha = paddle_helper.sequence_softmax(alpha) old_h = h - h = fluid.layers.reshape(h, [-1, num_heads, hidden_size]) - alpha = fluid.layers.reshape(alpha, [-1, num_heads, 1]) + h = L.reshape(h, [-1, num_heads, hidden_size]) + alpha = L.reshape(alpha, [-1, num_heads, 1]) if attn_drop > 1e-15: - alpha = fluid.layers.dropout( + alpha = L.dropout( alpha, dropout_prob=attn_drop, is_test=is_test, dropout_implementation="upscale_in_train") h = h * alpha - h = fluid.layers.reshape(h, [-1, num_heads * hidden_size]) - h = fluid.layers.lod_reset(h, old_h) - return fluid.layers.sequence_pool(h, "sum") + h = L.reshape(h, [-1, num_heads * hidden_size]) + h = L.lod_reset(h, old_h) + return L.sequence_pool(h, "sum") if feat_drop > 1e-15: - feature = fluid.layers.dropout( + feature = L.dropout( feature, dropout_prob=feat_drop, is_test=is_test, dropout_implementation='upscale_in_train') - ft = fluid.layers.fc(feature, + ft = L.fc(feature, hidden_size * num_heads, bias_attr=False, param_attr=fluid.ParamAttr(name=name + '_weight')) - left_a = fluid.layers.create_parameter( + left_a = L.create_parameter( shape=[num_heads, hidden_size], dtype='float32', name=name + '_gat_l_A') - right_a = fluid.layers.create_parameter( + right_a = L.create_parameter( shape=[num_heads, hidden_size], dtype='float32', name=name + '_gat_r_A') - reshape_ft = fluid.layers.reshape(ft, [-1, num_heads, hidden_size]) - left_a_value = fluid.layers.reduce_sum(reshape_ft * left_a, -1) - right_a_value = fluid.layers.reduce_sum(reshape_ft * right_a, -1) + reshape_ft = L.reshape(ft, [-1, num_heads, hidden_size]) + left_a_value = L.reduce_sum(reshape_ft * left_a, -1) + right_a_value = L.reduce_sum(reshape_ft * right_a, -1) msg = gw.send( send_attention, nfeat_list=[("h", ft), ("left_a", left_a_value), ("right_a", right_a_value)]) output = gw.recv(msg, reduce_attention) - bias = fluid.layers.create_parameter( + bias = L.create_parameter( shape=[hidden_size * num_heads], dtype='float32', is_bias=True, name=name + '_bias') bias.stop_gradient = True - output = fluid.layers.elementwise_add(output, bias, act=activation) + output = L.elementwise_add(output, bias, act=activation) return output @@ -219,7 +222,7 @@ def gin(gw, def send_src_copy(src_feat, dst_feat, edge_feat): return src_feat["h"] - epsilon = fluid.layers.create_parameter( + epsilon = L.create_parameter( shape=[1, 1], dtype="float32", attr=fluid.ParamAttr(name="%s_eps" % name), @@ -232,13 +235,13 @@ def gin(gw, msg = gw.send(send_src_copy, nfeat_list=[("h", feature)]) output = gw.recv(msg, "sum") + feature * (epsilon + 1.0) - output = fluid.layers.fc(output, + output = L.fc(output, size=hidden_size, act=None, param_attr=fluid.ParamAttr(name="%s_w_0" % name), bias_attr=fluid.ParamAttr(name="%s_b_0" % name)) - output = fluid.layers.layer_norm( + output = L.layer_norm( output, begin_norm_axis=1, param_attr=fluid.ParamAttr( @@ -249,9 +252,9 @@ def gin(gw, initializer=fluid.initializer.Constant(0.0)), ) if activation is not None: - output = getattr(fluid.layers, activation)(output) + output = getattr(L, activation)(output) - output = fluid.layers.fc(output, + output = L.fc(output, size=hidden_size, act=activation, param_attr=fluid.ParamAttr(name="%s_w_1" % name), @@ -269,10 +272,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key'] # E * M * D1 old = feat_query - feat_query = fluid.layers.reshape(feat_query, [-1, heads, hidden_size_a]) - feat_key = fluid.layers.reshape(feat_key, [-1, heads, hidden_size_a]) + feat_query = L.reshape(feat_query, [-1, heads, hidden_size_a]) + feat_key = L.reshape(feat_key, [-1, heads, hidden_size_a]) # E * M - alpha = fluid.layers.reduce_sum(feat_key * feat_query, dim=-1) + alpha = L.reduce_sum(feat_key * feat_query, dim=-1) return {'dst_node_feat': dst_feat['node_feat'], 'src_node_feat': src_feat['node_feat'], @@ -286,15 +289,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o # 每条边的出发点的特征 src_feat = message['src_node_feat'] # 每个中心点自己的特征 - x = fluid.layers.sequence_pool(dst_feat, 'average') + x = L.sequence_pool(dst_feat, 'average') # 每个中心点的邻居的特征的平均值 - z = fluid.layers.sequence_pool(src_feat, 'average') + z = L.sequence_pool(src_feat, 'average') # 计算 gate feat_gate = message['feat_gate'] - g_max = fluid.layers.sequence_pool(feat_gate, 'max') - g = fluid.layers.concat([x, g_max, z], axis=1) - g = fluid.layers.fc(g, heads, bias_attr=False, act="sigmoid") + g_max = L.sequence_pool(feat_gate, 'max') + g = L.concat([x, g_max, z], axis=1) + g = L.fc(g, heads, bias_attr=False, act="sigmoid") # softmax alpha = message['alpha'] @@ -302,19 +305,19 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o feat_value = message['feat_value'] # E * (M * D2) old = feat_value - feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2 - feat_value = fluid.layers.elementwise_mul(feat_value, alpha, axis=0) - feat_value = fluid.layers.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2) - feat_value = fluid.layers.lod_reset(feat_value, old) + feat_value = L.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2 + feat_value = L.elementwise_mul(feat_value, alpha, axis=0) + feat_value = L.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2) + feat_value = L.lod_reset(feat_value, old) - feat_value = fluid.layers.sequence_pool(feat_value, 'sum') # N * (M * D2) + feat_value = L.sequence_pool(feat_value, 'sum') # N * (M * D2) - feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2 + feat_value = L.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2 - output = fluid.layers.elementwise_mul(feat_value, g, axis=0) - output = fluid.layers.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2) + output = L.elementwise_mul(feat_value, g, axis=0) + output = L.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2) - output = fluid.layers.concat([x, output], axis=1) + output = L.concat([x, output], axis=1) return output @@ -323,16 +326,16 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o # 计算每个点自己需要发送出去的内容 # 投影后的特征向量 # N * (D1 * M) - feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False, + feat_key = L.fc(feature, hidden_size_a * heads, bias_attr=False, param_attr=fluid.ParamAttr(name=name + '_project_key')) # N * (D2 * M) - feat_value = fluid.layers.fc(feature, hidden_size_v * heads, bias_attr=False, + feat_value = L.fc(feature, hidden_size_v * heads, bias_attr=False, param_attr=fluid.ParamAttr(name=name + '_project_value')) # N * (D1 * M) - feat_query = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False, + feat_query = L.fc(feature, hidden_size_a * heads, bias_attr=False, param_attr=fluid.ParamAttr(name=name + '_project_query')) # N * Dm - feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False, + feat_gate = L.fc(feature, hidden_size_m, bias_attr=False, param_attr=fluid.ParamAttr(name=name + '_project_gate')) # send 阶段 @@ -346,10 +349,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o # 聚合邻居特征 output = gw.recv(message, recv_func) - output = fluid.layers.fc(output, hidden_size_o, bias_attr=False, + output = L.fc(output, hidden_size_o, bias_attr=False, param_attr=fluid.ParamAttr(name=name + '_project_output')) - output = fluid.layers.leaky_relu(output, alpha=0.1) - output = fluid.layers.dropout(output, dropout_prob=0.1) + output = L.leaky_relu(output, alpha=0.1) + output = L.dropout(output, dropout_prob=0.1) return output @@ -376,7 +379,7 @@ def gen_conv(gw, """ if beta == "dynamic": - beta = fluid.layers.create_parameter( + beta = L.create_parameter( shape=[1], dtype='float32', default_initializer= @@ -391,16 +394,132 @@ def gen_conv(gw, output = message_passing.msg_norm(feature, output, name) output = feature + output - output = fluid.layers.fc(output, + output = L.fc(output, feature.shape[-1], bias_attr=False, act="relu", param_attr=fluid.ParamAttr(name=name + '_weight1')) - output = fluid.layers.fc(output, + output = L.fc(output, feature.shape[-1], bias_attr=False, param_attr=fluid.ParamAttr(name=name + '_weight2')) return output +def get_norm(indegree): + """Get Laplacian Normalization""" + float_degree = L.cast(indegree, dtype="float32") + float_degree = L.clamp(float_degree, min=1.0) + norm = L.pow(float_degree, factor=-0.5) + return norm + + +def appnp(gw, feature, edge_dropout=0, alpha=0.2, k_hop=10): + """Implementation of APPNP of "Predict then Propagate: Graph Neural Networks + meet Personalized PageRank" (ICLR 2019). + + Args: + gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) + + feature: A tensor with shape (num_nodes, feature_size). + + edge_dropout: Edge dropout rate. + + k_hop: K Steps for Propagation + + Return: + A tensor with shape (num_nodes, hidden_size) + """ + + def send_src_copy(src_feat, dst_feat, edge_feat): + feature = src_feat["h"] + return feature + + h0 = feature + ngw = gw + norm = get_norm(ngw.indegree()) + + for i in range(k_hop): + if edge_dropout > 1e-5: + ngw = pgl.sample.edge_drop(gw, edge_dropout) + norm = get_norm(ngw.indegree()) + + feature = feature * norm + + msg = gw.send(send_src_copy, nfeat_list=[("h", feature)]) + + feature = gw.recv(msg, "sum") + + feature = feature * norm + + feature = feature * (1 - alpha) + h0 * alpha + return feature + + +def gcnii(gw, + feature, + name, + activation=None, + alpha=0.5, + lambda_l=0.5, + k_hop=1, + dropout=0.5, + is_test=False): + """Implementation of GCNII of "Simple and Deep Graph Convolutional Networks" + + paper: https://arxiv.org/pdf/2007.02133.pdf + + Args: + gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) + + feature: A tensor with shape (num_nodes, feature_size). + + activation: The activation for the output. + + k_hop: Number of layers for gcnii. + + lambda_l: The hyperparameter of lambda in the paper. + + alpha: The hyperparameter of alpha in the paper. + + dropout: Feature dropout rate. + + is_test: train / test phase. + + Return: + A tensor with shape (num_nodes, hidden_size) + """ + + def send_src_copy(src_feat, dst_feat, edge_feat): + feature = src_feat["h"] + return feature + + h0 = feature + ngw = gw + norm = get_norm(ngw.indegree()) + hidden_size = feature.shape[-1] + + for i in range(k_hop): + beta_i = np.log(1.0 * lambda_l / (i + 1) + 1) + feature = L.dropout( + feature, + dropout_prob=dropout, + is_test=is_test, + dropout_implementation='upscale_in_train') + + feature = feature * norm + msg = gw.send(send_src_copy, nfeat_list=[("h", feature)]) + feature = gw.recv(msg, "sum") + feature = feature * norm + + # appnp + feature = feature * (1 - alpha) + h0 * alpha + + feature_transed = L.fc(feature, hidden_size, + act=None, bias_attr=False, + name=name+"_%s_w1" % i) + feature = feature_transed * beta_i + feature * (1 - beta_i) + if activation is not None: + feature = getattr(L, activation)(feature) + return feature diff --git a/pgl/sample.py b/pgl/sample.py index 81241d5dc6f8224283abebeaa35da69644e9d1a1..e890c3783779d1863d5c848d0bb97a37febe2928 100644 --- a/pgl/sample.py +++ b/pgl/sample.py @@ -516,3 +516,12 @@ def graph_saint_random_walk_sample(graph, nodes=sample_nodes, eid=eids, with_node_feat=True, with_edge_feat=True) subgraph.node_feat["index"] = np.array(sample_nodes, dtype="int64") return subgraph + + +def edge_drop(graph_wrapper, dropout_rate, keep_self_loop=True): + if dropout_rate < 1e-5: + return graph_wrapper + else: + return pgl.graph_wrapper.DropEdgeWrapper(graph_wrapper, + dropout_rate, + keep_self_loop) diff --git a/pgl/utils/op.py b/pgl/utils/op.py index fe3945381aad1bb9bf59bfde8e78de6db0491ccc..2052adaf8d0bc7a5639c20fbfb1d107d9e61b9e5 100644 --- a/pgl/utils/op.py +++ b/pgl/utils/op.py @@ -68,3 +68,18 @@ def read_rows(data, index): return new_data else: return paddle_helper.gather(data, index) + + +class RowReader(object): + """Memory Efficient RowReader + """ + def __init__(self, nfeat, index): + self.nfeat = nfeat + self.loaded_nfeat = {} + self.index = index + + def __getitem__(self, key): + if key not in self.loaded_nfeat: + self.loaded_nfeat[key] = read_rows(self.nfeat[key], self.index) + return self.loaded_nfeat[key] + diff --git a/pgl/utils/paddle_helper.py b/pgl/utils/paddle_helper.py index 3570fac2c9da6b668108d4216cac9d415ce68dcd..2dd6ea248966a734ed2cbfefd342cd90655407f4 100644 --- a/pgl/utils/paddle_helper.py +++ b/pgl/utils/paddle_helper.py @@ -250,3 +250,20 @@ def scatter_max(input, index, updates): output = fluid.layers.scatter(input, index, updates, mode='max') return output + +def masked_select(input, mask): + """masked_select + + Slice the value from given Mask + + Args: + input: Input tensor to be selected + + mask: A bool tensor for sliced. + + Return: + Part of inputs where mask is True. + """ + index = fluid.layers.where(mask) + return fluid.layers.gather(input, index) +