diff --git a/examples/gin/Dataset.py b/examples/gin/Dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3fd8cfc8671817ebc47c9b42b7c9e1e28ce42b --- /dev/null +++ b/examples/gin/Dataset.py @@ -0,0 +1,313 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file implement the dataset for GIN model. +""" + +import os +import sys +import numpy as np + +from sklearn.model_selection import StratifiedKFold + +import pgl +from pgl.utils.logger import log + + +def fold10_split(dataset, fold_idx=0, seed=0, shuffle=True): + """10 fold splitter""" + assert 0 <= fold_idx and fold_idx < 10, print( + "fold_idx must be from 0 to 9.") + + skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed) + labels = [] + for i in range(len(dataset)): + g, c = dataset[i] + labels.append(c) + + idx_list = [] + for idx in skf.split(np.zeros(len(labels)), labels): + idx_list.append(idx) + train_idx, valid_idx = idx_list[fold_idx] + + log.info("train_set : test_set == %d : %d" % + (len(train_idx), len(valid_idx))) + return Subset(dataset, train_idx), Subset(dataset, valid_idx) + + +def random_split(dataset, split_ratio=0.7, seed=0, shuffle=True): + """random splitter""" + np.random.seed(seed) + indices = list(range(len(dataset))) + np.random.shuffle(indices) + split = int(split_ratio * len(dataset)) + train_idx, valid_idx = indices[:split], indices[split:] + + log.info("train_set : test_set == %d : %d" % + (len(train_idx), len(valid_idx))) + return Subset(dataset, train_idx), Subset(dataset, valid_idx) + + +class BaseDataset(object): + """BaseDataset""" + + def __init__(self): + pass + + def __getitem__(self, idx): + """getitem""" + raise NotImplementedError + + def __len__(self): + """len""" + raise NotImplementedError + + +class Subset(BaseDataset): + """ + Subset of a dataset at specified indices. + """ + + def __init__(self, dataset, indices): + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + """getitem""" + return self.dataset[self.indices[idx]] + + def __len__(self): + """len""" + return len(self.indices) + + +class GINDataset(BaseDataset): + """Dataset for Graph Isomorphism Network (GIN) + Adapted from https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip. + """ + + def __init__(self, + data_path, + dataset_name, + self_loop, + degree_as_nlabel=False): + self.data_path = data_path + self.dataset_name = dataset_name + self.self_loop = self_loop + self.degree_as_nlabel = degree_as_nlabel + + self.graph_list = [] + self.glabel_list = [] + + # relabel + self.glabel_dict = {} + self.nlabel_dict = {} + self.elabel_dict = {} + self.ndegree_dict = {} + + # global num + self.num_graph = 0 # total graphs number + self.n = 0 # total nodes number + self.m = 0 # total edges number + + # global num of classes + self.gclasses = 0 + self.nclasses = 0 + self.eclasses = 0 + self.dim_nfeats = 0 + + # flags + self.degree_as_nlabel = degree_as_nlabel + self.nattrs_flag = False + self.nlabels_flag = False + + self._load_data() + + def __len__(self): + """return the number of graphs""" + return len(self.graph_list) + + def __getitem__(self, idx): + """getitem""" + return self.graph_list[idx], self.glabel_list[idx] + + def _load_data(self): + """Loads dataset + """ + filename = os.path.join(self.data_path, self.dataset_name, + "%s.txt" % self.dataset_name) + log.info("loading data from %s" % filename) + + with open(filename, 'r') as reader: + # first line --> N, means total number of graphs + self.num_graph = int(reader.readline().strip()) + + for i in range(self.num_graph): + if (i + 1) % int(self.num_graph / 10) == 0: + log.info("processing graph %s" % (i + 1)) + graph = dict() + # second line --> [num_node, label] + # means [node number of a graph, class label of a graph] + grow = reader.readline().strip().split() + n_nodes, glabel = [int(w) for w in grow] + + # relabel graphs + if glabel not in self.glabel_dict: + mapped = len(self.glabel_dict) + self.glabel_dict[glabel] = mapped + + graph['num_nodes'] = n_nodes + self.glabel_list.append(self.glabel_dict[glabel]) + + nlabels = [] + node_features = [] + num_edges = 0 + edges = [] + + for j in range(graph['num_nodes']): + slots = reader.readline().strip().split() + + # handle edges and node feature(if has) + tmp = int(slots[ + 1]) + 2 # tmp == 2 + num_edges of current node + if tmp == len(slots): + # no node feature + nrow = [int(w) for w in slots] + nfeat = None + elif tmp < len(slots): + nrow = [int(w) for w in slots[:tmp]] + nfeat = [float(w) for w in slots[tmp:]] + node_features.append(nfeat) + else: + raise Exception('edge number is not correct!') + + # relabel nodes if is has labels + # if it doesn't have node labels, then every nrow[0] == 0 + if not nrow[0] in self.nlabel_dict: + mapped = len(self.nlabel_dict) + self.nlabel_dict[nrow[0]] = mapped + + nlabels.append(self.nlabel_dict[nrow[0]]) + num_edges += nrow[1] + edges.extend([(j, u) for u in nrow[2:]]) + + if self.self_loop: + num_edges += 1 + edges.append((j, j)) + + if node_features != []: + node_features = np.stack(node_features) + graph['attr'] = node_features + self.nattrs_flag = True + else: + node_features = None + graph['attr'] = node_features + + graph['nlabel'] = np.array( + nlabels, dtype="int64").reshape(-1, 1) + if len(self.nlabel_dict) > 1: + self.nlabels_flag = True + + graph['edges'] = edges + assert num_edges == len(edges) + + g = pgl.graph.Graph( + num_nodes=graph['num_nodes'], + edges=graph['edges'], + node_feat={ + 'nlabel': graph['nlabel'], + 'attr': graph['attr'] + }) + + self.graph_list.append(g) + + # update statistics of graphs + self.n += graph['num_nodes'] + self.m += num_edges + + # if no attr + if not self.nattrs_flag: + log.info('there are no node features in this dataset!') + label2idx = {} + # generate node attr by node degree + if self.degree_as_nlabel: + log.info('generate node features by node degree...') + nlabel_set = set([]) + for g in self.graph_list: + + g.node_feat['nlabel'] = g.indegree() + # extracting unique node labels + nlabel_set = nlabel_set.union(set(g.node_feat['nlabel'])) + g.node_feat['nlabel'] = g.node_feat['nlabel'].reshape(-1, + 1) + + nlabel_set = list(nlabel_set) + # in case the labels/degrees are not continuous number + self.ndegree_dict = { + nlabel_set[i]: i + for i in range(len(nlabel_set)) + } + label2idx = self.ndegree_dict + # generate node attr by node label + else: + log.info('generate node features by node label...') + label2idx = self.nlabel_dict + + for g in self.graph_list: + attr = np.zeros((g.num_nodes, len(label2idx))) + idx = [ + label2idx[tag] + for tag in g.node_feat['nlabel'].reshape(-1, ) + ] + attr[:, idx] = 1 + g.node_feat['attr'] = attr.astype("float32") + + # after load, get the #classes and #dim + self.gclasses = len(self.glabel_dict) + self.nclasses = len(self.nlabel_dict) + self.eclasses = len(self.elabel_dict) + self.dim_nfeats = len(self.graph_list[0].node_feat['attr'][0]) + + message = "finished loading data\n" + message += """ + num_graph: %d + num_graph_class: %d + total_num_nodes: %d + node Classes: %d + node_features_dim: %d + num_edges: %d + edge_classes: %d + Avg. of #Nodes: %.2f + Avg. of #Edges: %.2f + Graph Relabeled: %s + Node Relabeled: %s + Degree Relabeled(If degree_as_nlabel=True): %s""" % ( + self.num_graph, + self.gclasses, + self.n, + self.nclasses, + self.dim_nfeats, + self.m, + self.eclasses, + self.n / self.num_graph, + self.m / self.num_graph, + self.glabel_dict, + self.nlabel_dict, + self.ndegree_dict, ) + log.info(message) + + +if __name__ == "__main__": + gindataset = GINDataset( + "./dataset/", "MUTAG", self_loop=True, degree_as_nlabel=False) diff --git a/examples/gin/README.md b/examples/gin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..22145cb04236c2522e7e1d6be8e5b8b6e1bd913a --- /dev/null +++ b/examples/gin/README.md @@ -0,0 +1,32 @@ +# Graph Isomorphism Network (GIN) + +[Graph Isomorphism Network \(GIN\)](https://arxiv.org/pdf/1810.00826.pdf) is a simple graph neural network that expects to achieve the ability as the Weisfeiler-Lehman graph isomorphism test. Based on PGL, we reproduce the GIN model. + +### Datasets + +The dataset can be downloaded from [here](https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip) + +### Dependencies + +- paddlepaddle 1.6 +- pgl 1.0.2 + +### How to run + +For examples, use GPU to train GIN model on MUTAG dataset. +``` +python main.py --use_cuda --dataset_name MUTAG +``` + +### Hyperparameters + +- data\_path: the root path of your dataset +- dataset\_name: the name of the dataset +- fold\_idx: The $fold\_idx^{th}$ fold of dataset splited. Here we use 10 fold cross-validation +- train\_eps: whether the $\epsilon$ parameter is learnable. + +### Experiment results (Accuracy) +| |MUTAG | COLLAB | IMDBBINARY | IMDBMULTI | +|--|-------------|----------|------------|-----------------| +|PGL result | 90.8 | 78.6 | 76.8 | 50.8 | +|paper reuslt |90.0 | 80.0 | 75.1 | 52.3 | diff --git a/examples/gin/dataloader.py b/examples/gin/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..d27f48be68e969e4374c69d3c5b5aa187c0512e4 --- /dev/null +++ b/examples/gin/dataloader.py @@ -0,0 +1,152 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file implement the graph dataloader. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import absolute_import + +import os +import sys +import time +import argparse +import numpy as np +import collections + +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as fl +import pgl +from pgl.utils import mp_reader +from pgl.utils.logger import log + + +def batch_iter(data, batch_size, fid, num_workers): + """node_batch_iter + """ + size = len(data) + perm = np.arange(size) + np.random.shuffle(perm) + start = 0 + cc = 0 + while start < size: + index = perm[start:start + batch_size] + start += batch_size + cc += 1 + if cc % num_workers != fid: + continue + yield data[index] + + +def scan_batch_iter(data, batch_size, fid, num_workers): + """scan_batch_iter + """ + batch = [] + cc = 0 + for line_example in data.scan(): + cc += 1 + if cc % num_workers != fid: + continue + batch.append(line_example) + if len(batch) == batch_size: + yield batch + batch = [] + + if len(batch) > 0: + yield batch + + +class GraphDataloader(object): + """Graph Dataloader + """ + + def __init__( + self, + dataset, + batch_size, + seed=0, + num_workers=1, + buf_size=1000, + shuffle=True, ): + + self.shuffle = shuffle + self.seed = seed + self.num_workers = num_workers + self.buf_size = buf_size + self.batch_size = batch_size + self.dataset = dataset + + def batch_fn(self, batch_examples): + """ batch_fn batch producer""" + graphs = [b[0] for b in batch_examples] + labels = [b[1] for b in batch_examples] + join_graph = pgl.graph.MultiGraph(graphs) + labels = np.array(labels, dtype="int64").reshape(-1, 1) + return join_graph, labels + # feed_dict = self.graph_wrapper.to_feed(join_graph) + + # raise NotImplementedError("No defined Batch Fn") + + def batch_iter(self, fid): + """batch_iter""" + if self.shuffle: + for batch in batch_iter(self, self.batch_size, fid, + self.num_workers): + yield batch + else: + for batch in scan_batch_iter(self, self.batch_size, fid, + self.num_workers): + yield batch + + def __len__(self): + """__len__""" + return len(self.dataset) + + def __getitem__(self, idx): + """__getitem__""" + if isinstance(idx, collections.Iterable): + return [self[bidx] for bidx in idx] + else: + return self.dataset[idx] + + def __iter__(self): + """__iter__""" + + def worker(filter_id): + def func_run(): + for batch_examples in self.batch_iter(filter_id): + batch_dict = self.batch_fn(batch_examples) + yield batch_dict + + return func_run + + if self.num_workers == 1: + r = paddle.reader.buffered(worker(0), self.buf_size) + else: + worker_pool = [worker(wid) for wid in range(self.num_workers)] + worker = mp_reader.multiprocess_reader( + worker_pool, use_pipe=True, queue_size=1000) + r = paddle.reader.buffered(worker, self.buf_size) + + for batch in r(): + yield batch + + def scan(self): + """scan""" + for example in self.dataset: + yield example diff --git a/examples/gin/main.py b/examples/gin/main.py new file mode 100644 index 0000000000000000000000000000000000000000..51fc61ee88ff8ef8dd39c7f988fbe9364800e197 --- /dev/null +++ b/examples/gin/main.py @@ -0,0 +1,149 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file implement the training process of GIN model. +""" +import os +import sys +import time +import argparse +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.layers as fl +import pgl +from pgl.utils.logger import log + +from Dataset import GINDataset, fold10_split, random_split +from dataloader import GraphDataloader +from model import GINModel + + +def main(args): + """main function""" + dataset = GINDataset( + args.data_path, + args.dataset_name, + self_loop=not args.train_eps, + degree_as_nlabel=True) + train_dataset, test_dataset = fold10_split( + dataset, fold_idx=args.fold_idx, seed=args.seed) + + train_loader = GraphDataloader(train_dataset, batch_size=args.batch_size) + test_loader = GraphDataloader( + test_dataset, batch_size=args.batch_size, shuffle=False) + + place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() + train_program = fluid.Program() + startup_program = fluid.Program() + + with fluid.program_guard(train_program, startup_program): + gw = pgl.graph_wrapper.GraphWrapper( + "gw", place=place, node_feat=dataset[0][0].node_feat_info()) + + model = GINModel(args, gw, dataset.gclasses) + model.forward() + + infer_program = train_program.clone(for_test=True) + + with fluid.program_guard(train_program, startup_program): + epoch_step = int(len(train_dataset) / args.batch_size) + 1 + boundaries = [ + i + for i in range(50 * epoch_step, args.epochs * epoch_step, + epoch_step * 50) + ] + values = [args.lr * 0.5**i for i in range(0, len(boundaries) + 1)] + lr = fl.piecewise_decay(boundaries=boundaries, values=values) + train_op = fluid.optimizer.Adam(lr).minimize(model.loss) + + exe = fluid.Executor(place) + exe.run(startup_program) + + # train and evaluate + global_step = 0 + for epoch in range(1, args.epochs + 1): + for idx, batch_data in enumerate(train_loader): + g, labels = batch_data + feed_dict = gw.to_feed(g) + feed_dict['labels'] = labels + ret_loss, ret_lr, ret_acc = exe.run( + train_program, + feed=feed_dict, + fetch_list=[model.loss, lr, model.acc]) + + global_step += 1 + if global_step % 10 == 0: + message = "epoch %d | step %d | " % (epoch, global_step) + message += "lr %.6f | loss %.6f | acc %.4f" % ( + ret_lr, ret_loss, ret_acc) + log.info(message) + + # evaluate + result = evaluate(exe, infer_program, model, gw, test_loader) + + message = "evaluating result" + for key, value in result.items(): + message += " | %s %.6f" % (key, value) + log.info(message) + + +def evaluate(exe, prog, model, gw, loader): + """evaluate""" + total_loss = [] + total_acc = [] + for idx, batch_data in enumerate(loader): + g, labels = batch_data + feed_dict = gw.to_feed(g) + feed_dict['labels'] = labels + ret_loss, ret_acc = exe.run(prog, + feed=feed_dict, + fetch_list=[model.loss, model.acc]) + total_loss.append(ret_loss) + total_acc.append(ret_acc) + + total_loss = np.mean(total_loss) + total_acc = np.mean(total_acc) + + return {"loss": total_loss, "acc": total_acc} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--data_path', type=str, default='./dataset') + parser.add_argument('--dataset_name', type=str, default='MUTAG') + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--fold_idx', type=int, default=0) + parser.add_argument('--output_path', type=str, default='./outputs/') + parser.add_argument('--use_cuda', action='store_true') + parser.add_argument('--num_layers', type=int, default=5) + parser.add_argument('--num_mlp_layers', type=int, default=2) + parser.add_argument('--hidden_size', type=int, default=64) + parser.add_argument( + '--pool_type', + type=str, + default="sum", + choices=["sum", "average", "max"]) + parser.add_argument('--train_eps', action='store_true') + parser.add_argument('--epochs', type=int, default=350) + parser.add_argument('--lr', type=float, default=0.01) + parser.add_argument('--dropout_prob', type=float, default=0.5) + parser.add_argument('--seed', type=int, default=0) + args = parser.parse_args() + + log.info(args) + if not os.path.exists(args.output_path): + os.makedirs(args.output_path) + + main(args) diff --git a/examples/gin/model.py b/examples/gin/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2380ddde8990fa68987bc41e4774a87247c8e3cc --- /dev/null +++ b/examples/gin/model.py @@ -0,0 +1,74 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file implement the GIN model. +""" + +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.layers as fl +import pgl +from pgl.layers.conv import gin + + +class GINModel(object): + """GINModel""" + + def __init__(self, args, gw, num_class): + self.args = args + self.num_layers = self.args.num_layers + self.hidden_size = self.args.hidden_size + self.train_eps = self.args.train_eps + self.pool_type = self.args.pool_type + self.dropout_prob = self.args.dropout_prob + self.num_class = num_class + + self.gw = gw + self.labels = fl.data(name="labels", shape=[None, 1], dtype="int64") + + def forward(self): + """forward""" + features_list = [self.gw.node_feat["attr"]] + + for i in range(self.num_layers): + h = gin(self.gw, + features_list[i], + hidden_size=self.hidden_size, + activation="relu", + name="gin_%s" % (i), + init_eps=0.0, + train_eps=self.train_eps) + + h = fl.batch_norm(h) + h = fl.relu(h) + + features_list.append(h) + + output = 0 + for i, h in enumerate(features_list): + pooled_h = pgl.layers.graph_pooling(self.gw, h, self.pool_type) + drop_h = fl.dropout( + pooled_h, + self.dropout_prob, + dropout_implementation="upscale_in_train") + output += fl.fc(drop_h, + size=self.num_class, + act=None, + param_attr=fluid.ParamAttr(name="final_fc_%s" % + (i))) + + # calculate loss + self.loss = fl.softmax_with_cross_entropy(output, self.labels) + self.loss = fl.reduce_mean(self.loss) + self.acc = fl.accuracy(fl.softmax(output), self.labels)