diff --git a/examples/SAGPool/README.md b/examples/SAGPool/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ebed2e7a4b4025a63e2e3f4538cd4a23739c0123 --- /dev/null +++ b/examples/SAGPool/README.md @@ -0,0 +1,53 @@ +# Self-Attention Graph Pooling + +SAGPool is a graph pooling method based on self-attention. Self-attention uses graph convolution, which allows the pooling method to consider both node features and graph topology. Based on PGL, we implement the SAGPool algorithm and train the model on five datasets. + +## Datasets + +There are five datasets, including D&D, PROTEINS, NCI1, NCI109 and FRANKENSTEIN. You can download the datasets from [here](https://bj.bcebos.com/paddle-pgl/SAGPool/data.zip), and unzip it directly. The pkl format datasets should be in directory ./data. + +## Dependencies + +- [paddlepaddle >= 1.8](https://github.com/PaddlePaddle/paddle) +- [pgl 1.1](https://github.com/PaddlePaddle/PGL) + +## How to run + +``` +python main.py --dataset_name DD --learning_rate 0.005 --weight_decay 0.00001 + +python main.py --dataset_name PROTEINS --learning_rate 0.001 --hidden_size 32 --weight_decay 0.00001 + +python main.py --dataset_name NCI1 --learning_rate 0.001 --weight_decay 0.00001 + +python main.py --dataset_name NCI109 --learning_rate 0.0005 --hidden_size 64 --weight_decay 0.0001 --patience 200 + +python main.py --dataset_name FRANKENSTEIN --learning_rate 0.001 --weight_decay 0.0001 +``` + +## Hyperparameters + +- seed: random seed +- batch\_size: the number of batch size +- learning\_rate: learning rate of optimizer +- weight\_decay: the weight decay for L2 regularization +- hidden\_size: the hidden size of gcn +- pooling\_ratio: the pooling ratio of SAGPool +- dropout\_ratio: the number of dropout ratio +- dataset\_name: the name of datasets, including DD, PROTEINS, NCI1, NCI109, FRANKENSTEIN +- epochs: maximum number of epochs +- patience: patience for early stopping +- use\_cuda: whether to use cuda +- save\_model: the name for the best model + +## Performance + +We evaluate the implemented method for 20 random seeds using 10-fold cross validation, following the same training procedures as in the paper. + +| dataset | mean accuracy | standard deviation | mean accuracy(paper) | standard deviation(paper) | +| ------------ | ------------- | ------------------ | -------------------- | ------------------------- | +| DD | 74.4181 | 1.0244 | 76.19 | 0.94 | +| PROTEINS | 72.7858 | 0.6617 | 70.04 | 1.47 | +| NCI1 | 75.781 | 1.2125 | 74.18 | 1.2 | +| NCI109 | 74.3156 | 1.3 | 74.06 | 0.78 | +| FRANKENSTEIN | 60.7826 | 0.629 | 62.57 | 0.6 | diff --git a/examples/SAGPool/args.py b/examples/SAGPool/args.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6c44bc558638a5d6ec0a3c8c957a0730aa22a5 --- /dev/null +++ b/examples/SAGPool/args.py @@ -0,0 +1,43 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument('--seed', type=int, default=777, + help='seed') +parser.add_argument('--batch_size', type=int, default=128, + help='batch size') +parser.add_argument('--learning_rate', type=float, default=0.0005, + help='learning rate') +parser.add_argument('--weight_decay', type=float, default=0.0001, + help='weight decay') +parser.add_argument('--hidden_size', type=int, default=128, + help='gcn hidden size') +parser.add_argument('--pooling_ratio', type=float, default=0.5, + help='pooling ratio of SAGPool') +parser.add_argument('--dropout_ratio', type=float, default=0.5, + help='dropout ratio') +parser.add_argument('--dataset_name', type=str, default='DD', + help='DD/PROTEINS/NCI1/NCI109/FRANKENSTEIN') +parser.add_argument('--epochs', type=int, default=100000, + help='maximum number of epochs') +parser.add_argument('--patience', type=int, default=50, + help='patience for early stopping') +parser.add_argument('--use_cuda', type=bool, default=True, + help='use cuda or cpu') +parser.add_argument('--save_model', type=str, + help='save model name') + diff --git a/examples/SAGPool/base_dataset.py b/examples/SAGPool/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..711e9203e311bb2c9cf5b6e9afc2834122a93a01 --- /dev/null +++ b/examples/SAGPool/base_dataset.py @@ -0,0 +1,96 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import random +import pgl +from pgl.utils.logger import log +from pgl.graph import Graph, MultiGraph +import numpy as np +import pickle + +class BaseDataset(object): + def __init__(self): + pass + + def __getitem__(self, idx): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class Subset(BaseDataset): + """Subset of a dataset at specified indices. + + Args: + dataset (Dataset): The whole Dataset + indices (sequence): Indices in the whole set selected for subset + """ + + def __init__(self, dataset, indices): + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + return self.dataset[self.indices[idx]] + + def __len__(self): + return len(self.indices) + + +class Dataset(BaseDataset): + def __init__(self, args): + self.args = args + + with open('data/%s.pkl' % args.dataset_name, 'rb') as f: + graphs_info_list = pickle.load(f) + + self.pgl_graph_list = [] + self.graph_label_list = [] + for i in range(len(graphs_info_list) - 1): + graph = graphs_info_list[i] + edges_l, edges_r = graph["edge_src"], graph["edge_dst"] + + # add self-loops + if self.args.dataset_name != "FRANKENSTEIN": + num_nodes = graph["num_nodes"] + x = np.arange(0, num_nodes) + edges_l = np.append(edges_l, x) + edges_r = np.append(edges_r, x) + + edges = list(zip(edges_l, edges_r)) + g = pgl.graph.Graph(num_nodes=graph["num_nodes"], edges=edges) + g.node_feat["feat"] = graph["node_feat"] + self.pgl_graph_list.append(g) + self.graph_label_list.append(graph["label"]) + + self.num_classes = graphs_info_list[-1]["num_classes"] + self.num_features = graphs_info_list[-1]["num_features"] + + def __getitem__(self, idx): + return self.pgl_graph_list[idx], self.graph_label_list[idx] + + def shuffle(self): + """shuffle the dataset. + """ + cc = list(zip(self.pgl_graph_list, self.graph_label_list)) + random.seed(self.args.seed) + random.shuffle(cc) + a, b = zip(*cc) + self.pgl_graph_list[:], self.graph_label_list[:] = a, b + + def __len__(self): + return len(self.pgl_graph_list) diff --git a/examples/SAGPool/conv.py b/examples/SAGPool/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..5250f242b5bd5aff4f6fa9324a82a21e46302b6f --- /dev/null +++ b/examples/SAGPool/conv.py @@ -0,0 +1,66 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle.fluid.layers as L + +def norm_gcn(gw, feature, hidden_size, activation, name, norm=None): + """Implementation of graph convolutional neural networks(GCN), using different + normalization method. + Args: + gw: Graph wrapper object. + + feature: A tensor with shape (num_nodes, feature_size). + + hidden_size: The hidden size for norm gcn. + + activation: The activation for the output. + + name: Norm gcn layer names. + + norm: If norm is not None, then the feature will be normalized. Norm must + be tensor with shape (num_nodes,) and dtype float32. + + Return: + A tensor with shape (num_nodes, hidden_size) + """ + + size = feature.shape[-1] + feature = L.fc(feature, + size=hidden_size, + bias_attr=False, + param_attr=fluid.ParamAttr(name=name)) + + if norm is not None: + src, dst = gw.edges + norm_src = L.gather(norm, src, overwrite=False) + norm_dst = L.gather(norm, dst, overwrite=False) + norm = norm_src * norm_dst + + def send_src_copy(src_feat, dst_feat, edge_feat): + return src_feat["h"] * norm + else: + def send_src_copy(src_feat, dst_feat, edge_feat): + return src_feat["h"] + + msg = gw.send(send_src_copy, nfeat_list=[("h", feature)]) + output = gw.recv(msg, "sum") + + bias = L.create_parameter( + shape=[hidden_size], + dtype='float32', + is_bias=True, + name=name + '_bias') + output = L.elementwise_add(output, bias, act=activation) + return output diff --git a/examples/SAGPool/dataloader.py b/examples/SAGPool/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..21ea7d43a64dbb3849afc8d5abf017a7bbac2661 --- /dev/null +++ b/examples/SAGPool/dataloader.py @@ -0,0 +1,143 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import collections +import paddle +import pgl +from pgl.utils.logger import log +from pgl.graph import Graph, MultiGraph + +def batch_iter(data, batch_size): + """node_batch_iter + """ + size = len(data) + perm = np.arange(size) + np.random.shuffle(perm) + start = 0 + while start < size: + index = perm[start:start + batch_size] + start += batch_size + yield data[index] + + +def scan_batch_iter(data, batch_size): + """scan_batch_iter + """ + batch = [] + for example in data.scan(): + batch.append(example) + if len(batch) == batch_size: + yield batch + batch = [] + + if len(batch) > 0: + yield batch + + +def label_to_onehot(labels): + """Return one-hot representations of labels + """ + onehot_labels = [] + for label in labels: + if label == 0: + onehot_labels.append([1, 0]) + else: + onehot_labels.append([0, 1]) + onehot_labels = np.array(onehot_labels) + return onehot_labels + + +class GraphDataloader(object): + """Graph Dataloader + """ + def __init__(self, + dataset, + graph_wrapper, + batch_size, + seed=0, + buf_size=1000, + shuffle=True): + + self.shuffle = shuffle + self.seed = seed + self.batch_size = batch_size + self.dataset = dataset + self.buf_size = buf_size + self.graph_wrapper = graph_wrapper + + def batch_fn(self, batch_examples): + """ batch_fun batch producer """ + graphs = [b[0] for b in batch_examples] + labels = [b[1] for b in batch_examples] + join_graph = MultiGraph(graphs) + + # normalize + indegree = join_graph.indegree() + norm = np.zeros_like(indegree, dtype="float32") + norm[indegree > 0] = np.power(indegree[indegree > 0], -0.5) + join_graph.node_feat["norm"] = np.expand_dims(norm, -1) + + feed_dict = self.graph_wrapper.to_feed(join_graph) + labels = np.array(labels) + feed_dict["labels_1dim"] = labels + labels = label_to_onehot(labels) + feed_dict["labels"] = labels + + graph_lod = join_graph.graph_lod + graph_id = [] + for i in range(1, len(graph_lod)): + graph_node_num = graph_lod[i] - graph_lod[i - 1] + graph_id += [i - 1] * graph_node_num + graph_id = np.array(graph_id, dtype="int32") + feed_dict["graph_id"] = graph_id + + return feed_dict + + def batch_iter(self): + """ batch_iter """ + if self.shuffle: + for batch in batch_iter(self, self.batch_size): + yield batch + else: + for batch in scan_batch_iter(self, self.batch_size): + yield batch + + def __len__(self): + """__len__""" + return len(self.dataset) + + def __getitem__(self, idx): + """__getitem__""" + if isinstance(idx, collections.Iterable): + return [self.dataset[bidx] for bidx in idx] + else: + return self.dataset[idx] + + def __iter__(self): + """__iter__""" + def func_run(): + for batch_examples in self.batch_iter(): + batch_dict = self.batch_fn(batch_examples) + yield batch_dict + + r = paddle.reader.buffered(func_run, self.buf_size) + + for batch in r(): + yield batch + + def scan(self): + """scan""" + for example in self.dataset: + yield example diff --git a/examples/SAGPool/layers.py b/examples/SAGPool/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3dfa0822ece9e564adfbc9b15e0a62e1e1f4b08d --- /dev/null +++ b/examples/SAGPool/layers.py @@ -0,0 +1,141 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as L +import pgl +from pgl.graph_wrapper import GraphWrapper +from pgl.utils.logger import log +from conv import norm_gcn +from pgl.layers.conv import gcn + +def topk_pool(gw, score, graph_id, ratio): + """Implementation of topk pooling, where k means pooling ratio. + + Args: + gw: Graph wrapper object. + + score: The attention score of all nodes, which is used to select + important nodes. + + graph_id: The graphs that the nodes belong to. + + ratio: The pooling ratio of nodes we want to select. + + Return: + perm: The index of nodes we choose. + + ratio_length: The selected node numbers of each graph. + """ + + graph_lod = gw.graph_lod + graph_nodes = gw.num_nodes + num_graph = gw.num_graph + + num_nodes = L.ones(shape=[graph_nodes], dtype="float32") + num_nodes = L.lod_reset(num_nodes, graph_lod) + num_nodes_per_graph = L.sequence_pool(num_nodes, pool_type='sum') + max_num_nodes = L.reduce_max(num_nodes_per_graph, dim=0) + max_num_nodes = L.cast(max_num_nodes, dtype="int32") + + index = L.arange(0, gw.num_nodes, dtype="int64") + offset = L.gather(graph_lod, graph_id, overwrite=False) + index = (index - offset) + (graph_id * max_num_nodes) + index.stop_gradient = True + + # padding + dense_score = L.fill_constant(shape=[num_graph * max_num_nodes], + dtype="float32", value=-999999) + index = L.reshape(index, shape=[-1]) + dense_score = L.scatter(dense_score, index, updates=score) + num_graph = L.cast(num_graph, dtype="int32") + dense_score = L.reshape(dense_score, + shape=[num_graph, max_num_nodes]) + + # record the sorted index + _, sort_index = L.argsort(dense_score, axis=-1, descending=True) + + # recover the index range + graph_lod = graph_lod[:-1] + graph_lod = L.reshape(graph_lod, shape=[-1, 1]) + graph_lod = L.cast(graph_lod, dtype="int64") + sort_index = L.elementwise_add(sort_index, graph_lod, axis=-1) + sort_index = L.reshape(sort_index, shape=[-1, 1]) + + # use sequence_slice to choose selected node index + pad_lod = L.arange(0, (num_graph + 1) * max_num_nodes, step=max_num_nodes, dtype="int32") + sort_index = L.lod_reset(sort_index, pad_lod) + ratio_length = L.ceil(num_nodes_per_graph * ratio) + ratio_length = L.cast(ratio_length, dtype="int64") + ratio_length = L.reshape(ratio_length, shape=[-1, 1]) + offset = L.zeros(shape=[num_graph, 1], dtype="int64") + choose_index = L.sequence_slice(input=sort_index, offset=offset, length=ratio_length) + + perm = L.reshape(choose_index, shape=[-1]) + return perm, ratio_length + + +def sag_pool(gw, feature, ratio, graph_id, dataset, name, activation=L.tanh): + """Implementation of self-attention graph pooling (SAGPool) + + This is an implementation of the paper SELF-ATTENTION GRAPH POOLING + (https://arxiv.org/pdf/1904.08082.pdf) + + Args: + gw: Graph wrapper object. + + feature: A tensor with shape (num_nodes, feature_size). + + ratio: The pooling ratio of nodes we want to select. + + graph_id: The graphs that the nodes belong to. + + dataset: To differentiate FRANKENSTEIN dataset and other datasets. + + name: The name of SAGPool layer. + + activation: The activation function. + + Return: + new_feature: A tensor with shape (num_nodes, feature_size), and the unselected + nodes' feature is masked by zero. + + ratio_length: The selected node numbers of each graph. + + """ + if dataset == "FRANKENSTEIN": + gcn_ = gcn + else: + gcn_ = norm_gcn + + score = gcn_(gw=gw, + feature=feature, + hidden_size=1, + activation=None, + norm=gw.node_feat["norm"], + name=name) + score = L.squeeze(score, axes=[]) + perm, ratio_length = topk_pool(gw, score, graph_id, ratio) + + mask = L.zeros_like(score) + mask = L.cast(mask, dtype="float32") + updates = L.ones_like(perm) + updates = L.cast(updates, dtype="float32") + mask = L.scatter(mask, perm, updates) + new_feature = L.elementwise_mul(feature, mask, axis=0) + temp_score = activation(score) + new_feature = L.elementwise_mul(new_feature, temp_score, axis=0) + return new_feature, ratio_length diff --git a/examples/SAGPool/main.py b/examples/SAGPool/main.py new file mode 100644 index 0000000000000000000000000000000000000000..8895311e0f729bf572919faf07df50c7921cb545 --- /dev/null +++ b/examples/SAGPool/main.py @@ -0,0 +1,194 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import argparse +import pgl +from pgl.utils.logger import log +import paddle + +import re +import time +import random +import numpy as np +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as L +import pgl +from pgl.utils.logger import log + +from model import GlobalModel +from base_dataset import Subset, Dataset +from dataloader import GraphDataloader +from args import parser +import warnings +from sklearn.model_selection import KFold + +warnings.filterwarnings("ignore") + +def main(args, train_dataset, val_dataset, test_dataset): + """main function for running one testing results. + """ + log.info("Train Examples: %s" % len(train_dataset)) + log.info("Val Examples: %s" % len(val_dataset)) + log.info("Test Examples: %s" % len(test_dataset)) + + train_program = fluid.Program() + train_program.random_seed = args.seed + startup_program = fluid.Program() + startup_program.random_seed = args.seed + + if args.use_cuda: + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + log.info("building model") + + with fluid.program_guard(train_program, startup_program): + with fluid.unique_name.guard(): + graph_model = GlobalModel(args, dataset) + train_loader = GraphDataloader(train_dataset, + graph_model.graph_wrapper, + batch_size=args.batch_size) + optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate, + regularization=fluid.regularizer.L2DecayRegularizer(args.weight_decay)) + optimizer.minimize(graph_model.loss) + + exe.run(startup_program) + test_program = fluid.Program() + test_program = train_program.clone(for_test=True) + + val_loader = GraphDataloader(val_dataset, + graph_model.graph_wrapper, + batch_size=args.batch_size, + shuffle=False) + test_loader = GraphDataloader(test_dataset, + graph_model.graph_wrapper, + batch_size=args.batch_size, + shuffle=False) + + min_loss = 1e10 + global_step = 0 + for epoch in range(args.epochs): + for feed_dict in train_loader: + loss, pred = exe.run(train_program, + feed=feed_dict, + fetch_list=[graph_model.loss, graph_model.pred]) + + log.info("Epoch: %d, global_step: %d, Training loss: %f" \ + % (epoch, global_step, loss)) + global_step += 1 + + # validation + valid_loss = 0. + correct = 0. + for feed_dict in val_loader: + valid_loss_, correct_ = exe.run(test_program, + feed=feed_dict, + fetch_list=[graph_model.loss, graph_model.correct]) + valid_loss += valid_loss_ + correct += correct_ + + if epoch % 50 == 0: + log.info("Epoch:%d, Validation loss: %f, Validation acc: %f" \ + % (epoch, valid_loss, correct / len(val_loader))) + + if valid_loss < min_loss: + min_loss = valid_loss + patience = 0 + path = "./save/%s" % args.dataset_name + if not os.path.exists(path): + os.makedirs(path) + fluid.save(train_program, "%s/%s" \ + % (path, args.save_model)) + log.info("Model saved at epoch %d" % epoch) + else: + patience += 1 + if patience > args.patience: + break + + correct = 0. + new_test_program = fluid.Program() + fluid.load(new_test_program, "./save/%s/%s" \ + % (args.dataset_name, args.save_model), exe) + for feed_dict in test_loader: + correct_ = exe.run(test_program, + feed=feed_dict, + fetch_list=[graph_model.correct]) + correct += correct_[0] + log.info("Test acc: %f" % (correct / len(test_loader))) + return correct / len(test_loader) + + +def split_10_cv(dataset, args): + """10 folds cross validation + """ + dataset.shuffle() + X = np.array([0] * len(dataset)) + y = X + kf = KFold(n_splits=10, shuffle=False) + + i = 1 + test_acc = [] + for train_index, test_index in kf.split(X, y): + train_val_dataset = Subset(dataset, train_index) + test_dataset = Subset(dataset, test_index) + train_val_index_range = list(range(0, len(train_val_dataset))) + num_val = int(len(train_val_dataset) / 9) + val_dataset = Subset(train_val_dataset, train_val_index_range[:num_val]) + train_dataset = Subset(train_val_dataset, train_val_index_range[num_val:]) + + log.info("######%d fold of 10-fold cross validation######" % i) + i += 1 + test_acc_ = main(args, train_dataset, val_dataset, test_dataset) + test_acc.append(test_acc_) + + mean_acc = sum(test_acc) / len(test_acc) + return mean_acc, test_acc + + +def random_seed_20(args, dataset): + """run for 20 random seeds + """ + alist = random.sample(range(1,1000),20) + test_acc_fold = [] + for seed in alist: + log.info('############ Seed %d ############' % seed) + args.seed = seed + + test_acc_fold_, _ = split_10_cv(dataset, args) + log.info('Mean test acc at seed %d: %f' % (seed, test_acc_fold_)) + test_acc_fold.append(test_acc_fold_) + + mean_acc = sum(test_acc_fold) / len(test_acc_fold) + temp = [(acc - mean_acc) * (acc - mean_acc) for acc in test_acc_fold] + standard_std = math.sqrt(sum(temp) / len(test_acc_fold)) + + log.info('Final mean test acc using 20 random seeds(mean for 10-fold): %f' % (mean_acc)) + log.info('Final standard std using 20 random seeds(mean for 10-fold): %f' % (standard_std)) + + +if __name__ == "__main__": + args = parser.parse_args() + log.info('loading data...') + dataset = Dataset(args) + log.info("preprocess finish.") + args.num_classes = dataset.num_classes + args.num_features = dataset.num_features + random_seed_20(args, dataset) diff --git a/examples/SAGPool/model.py b/examples/SAGPool/model.py new file mode 100644 index 0000000000000000000000000000000000000000..cdfbe6f4f2d911dff041522d3e75266b4867c5a3 --- /dev/null +++ b/examples/SAGPool/model.py @@ -0,0 +1,136 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from random import random +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as L +import pgl +from pgl.graph import Graph, MultiGraph +from pgl.graph_wrapper import GraphWrapper +from pgl.utils.logger import log +from pgl.layers.conv import gcn +from layers import sag_pool +from conv import norm_gcn + +class GlobalModel(object): + """Implementation of global pooling architecture with SAGPool. + """ + def __init__(self, args, dataset): + self.args = args + self.dataset = dataset + self.hidden_size = args.hidden_size + self.num_classes = args.num_classes + self.num_features = args.num_features + self.pooling_ratio = args.pooling_ratio + self.dropout_ratio = args.dropout_ratio + self.batch_size = args.batch_size + + graph_data = [] + g, label = self.dataset[0] + graph_data.append(g) + g, label = self.dataset[1] + graph_data.append(g) + + batch_graph = MultiGraph(graph_data) + indegree = batch_graph.indegree() + norm = np.zeros_like(indegree, dtype="float32") + norm[indegree > 0] = np.power(indegree[indegree > 0], -0.5) + batch_graph.node_feat["norm"] = np.expand_dims(norm, -1) + graph_data = batch_graph + + self.graph_wrapper = GraphWrapper( + name="graph", + node_feat=graph_data.node_feat_info() + ) + self.labels = L.data( + "labels", + shape=[None, self.args.num_classes], + dtype="int32", + append_batch_size=False) + + self.labels_1dim = L.data( + "labels_1dim", + shape=[None], + dtype="int32", + append_batch_size=False) + + self.graph_id = L.data( + "graph_id", + shape=[None], + dtype="int32", + append_batch_size=False) + + if self.args.dataset_name == "FRANKENSTEIN": + self.gcn = gcn + else: + self.gcn = norm_gcn + + self.build_model() + + def build_model(self): + node_features = self.graph_wrapper.node_feat["feat"] + + output = self.gcn(gw=self.graph_wrapper, + feature=node_features, + hidden_size=self.hidden_size, + activation="relu", + norm=self.graph_wrapper.node_feat["norm"], + name="gcn_layer_1") + output1 = output + output = self.gcn(gw=self.graph_wrapper, + feature=output, + hidden_size=self.hidden_size, + activation="relu", + norm=self.graph_wrapper.node_feat["norm"], + name="gcn_layer_2") + output2 = output + output = self.gcn(gw=self.graph_wrapper, + feature=output, + hidden_size=self.hidden_size, + activation="relu", + norm=self.graph_wrapper.node_feat["norm"], + name="gcn_layer_3") + + output = L.concat(input=[output1, output2, output], axis=-1) + + output, ratio_length = sag_pool(gw=self.graph_wrapper, + feature=output, + ratio=self.pooling_ratio, + graph_id=self.graph_id, + dataset=self.args.dataset_name, + name="sag_pool_1") + output = L.lod_reset(output, self.graph_wrapper.graph_lod) + cat1 = L.sequence_pool(output, "sum") + ratio_length = L.cast(ratio_length, dtype="float32") + cat1 = L.elementwise_div(cat1, ratio_length, axis=-1) + cat2 = L.sequence_pool(output, "max") + output = L.concat(input=[cat2, cat1], axis=-1) + + output = L.fc(output, size=self.hidden_size, act="relu") + output = L.dropout(output, dropout_prob=self.dropout_ratio) + output = L.fc(output, size=self.hidden_size // 2, act="relu") + output = L.fc(output, size=self.num_classes, act=None, + param_attr=fluid.ParamAttr(name="final_fc")) + + self.labels = L.cast(self.labels, dtype="float32") + loss = L.sigmoid_cross_entropy_with_logits(x=output, label=self.labels) + self.loss = L.mean(loss) + pred = L.sigmoid(output) + self.pred = L.argmax(x=pred, axis=-1) + correct = L.equal(self.pred, self.labels_1dim) + correct = L.cast(correct, dtype="int32") + self.correct = L.reduce_sum(correct)