提交 83b10999 编写于 作者: W Webbley

add gin example

上级 11bfcf8a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement the dataset for GIN model.
"""
import os
import sys
import numpy as np
from sklearn.model_selection import StratifiedKFold
import pgl
from pgl.utils.logger import log
def fold10_split(dataset, fold_idx=0, seed=0, shuffle=True):
"""10 fold splitter"""
assert 0 <= fold_idx and fold_idx < 10, print(
"fold_idx must be from 0 to 9.")
skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed)
labels = []
for i in range(len(dataset)):
g, c = dataset[i]
labels.append(c)
idx_list = []
for idx in skf.split(np.zeros(len(labels)), labels):
idx_list.append(idx)
train_idx, valid_idx = idx_list[fold_idx]
log.info("train_set : test_set == %d : %d" %
(len(train_idx), len(valid_idx)))
return Subset(dataset, train_idx), Subset(dataset, valid_idx)
def random_split(dataset, split_ratio=0.7, seed=0, shuffle=True):
"""random splitter"""
np.random.seed(seed)
indices = list(range(len(dataset)))
np.random.shuffle(indices)
split = int(split_ratio * len(dataset))
train_idx, valid_idx = indices[:split], indices[split:]
log.info("train_set : test_set == %d : %d" %
(len(train_idx), len(valid_idx)))
return Subset(dataset, train_idx), Subset(dataset, valid_idx)
class BaseDataset(object):
"""BaseDataset"""
def __init__(self):
pass
def __getitem__(self, idx):
"""getitem"""
raise NotImplementedError
def __len__(self):
"""len"""
raise NotImplementedError
class Subset(BaseDataset):
"""
Subset of a dataset at specified indices.
"""
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
"""getitem"""
return self.dataset[self.indices[idx]]
def __len__(self):
"""len"""
return len(self.indices)
class GINDataset(BaseDataset):
"""Dataset for Graph Isomorphism Network (GIN)
Adapted from https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip.
"""
def __init__(self,
data_path,
dataset_name,
self_loop,
degree_as_nlabel=False):
self.data_path = data_path
self.dataset_name = dataset_name
self.self_loop = self_loop
self.degree_as_nlabel = degree_as_nlabel
self.graph_list = []
self.glabel_list = []
# relabel
self.glabel_dict = {}
self.nlabel_dict = {}
self.elabel_dict = {}
self.ndegree_dict = {}
# global num
self.num_graph = 0 # total graphs number
self.n = 0 # total nodes number
self.m = 0 # total edges number
# global num of classes
self.gclasses = 0
self.nclasses = 0
self.eclasses = 0
self.dim_nfeats = 0
# flags
self.degree_as_nlabel = degree_as_nlabel
self.nattrs_flag = False
self.nlabels_flag = False
self._load_data()
def __len__(self):
"""return the number of graphs"""
return len(self.graph_list)
def __getitem__(self, idx):
"""getitem"""
return self.graph_list[idx], self.glabel_list[idx]
def _load_data(self):
"""Loads dataset
"""
filename = os.path.join(self.data_path, self.dataset_name,
"%s.txt" % self.dataset_name)
log.info("loading data from %s" % filename)
with open(filename, 'r') as reader:
# first line --> N, means total number of graphs
self.num_graph = int(reader.readline().strip())
for i in range(self.num_graph):
if (i + 1) % int(self.num_graph / 10) == 0:
log.info("processing graph %s" % (i + 1))
graph = dict()
# second line --> [num_node, label]
# means [node number of a graph, class label of a graph]
grow = reader.readline().strip().split()
n_nodes, glabel = [int(w) for w in grow]
# relabel graphs
if glabel not in self.glabel_dict:
mapped = len(self.glabel_dict)
self.glabel_dict[glabel] = mapped
graph['num_nodes'] = n_nodes
self.glabel_list.append(self.glabel_dict[glabel])
nlabels = []
node_features = []
num_edges = 0
edges = []
for j in range(graph['num_nodes']):
slots = reader.readline().strip().split()
# handle edges and node feature(if has)
tmp = int(slots[
1]) + 2 # tmp == 2 + num_edges of current node
if tmp == len(slots):
# no node feature
nrow = [int(w) for w in slots]
nfeat = None
elif tmp < len(slots):
nrow = [int(w) for w in slots[:tmp]]
nfeat = [float(w) for w in slots[tmp:]]
node_features.append(nfeat)
else:
raise Exception('edge number is not correct!')
# relabel nodes if is has labels
# if it doesn't have node labels, then every nrow[0] == 0
if not nrow[0] in self.nlabel_dict:
mapped = len(self.nlabel_dict)
self.nlabel_dict[nrow[0]] = mapped
nlabels.append(self.nlabel_dict[nrow[0]])
num_edges += nrow[1]
edges.extend([(j, u) for u in nrow[2:]])
if self.self_loop:
num_edges += 1
edges.append((j, j))
if node_features != []:
node_features = np.stack(node_features)
graph['attr'] = node_features
self.nattrs_flag = True
else:
node_features = None
graph['attr'] = node_features
graph['nlabel'] = np.array(
nlabels, dtype="int64").reshape(-1, 1)
if len(self.nlabel_dict) > 1:
self.nlabels_flag = True
graph['edges'] = edges
assert num_edges == len(edges)
g = pgl.graph.Graph(
num_nodes=graph['num_nodes'],
edges=graph['edges'],
node_feat={
'nlabel': graph['nlabel'],
'attr': graph['attr']
})
self.graph_list.append(g)
# update statistics of graphs
self.n += graph['num_nodes']
self.m += num_edges
# if no attr
if not self.nattrs_flag:
log.info('there are no node features in this dataset!')
label2idx = {}
# generate node attr by node degree
if self.degree_as_nlabel:
log.info('generate node features by node degree...')
nlabel_set = set([])
for g in self.graph_list:
g.node_feat['nlabel'] = g.indegree()
# extracting unique node labels
nlabel_set = nlabel_set.union(set(g.node_feat['nlabel']))
g.node_feat['nlabel'] = g.node_feat['nlabel'].reshape(-1,
1)
nlabel_set = list(nlabel_set)
# in case the labels/degrees are not continuous number
self.ndegree_dict = {
nlabel_set[i]: i
for i in range(len(nlabel_set))
}
label2idx = self.ndegree_dict
# generate node attr by node label
else:
log.info('generate node features by node label...')
label2idx = self.nlabel_dict
for g in self.graph_list:
attr = np.zeros((g.num_nodes, len(label2idx)))
idx = [
label2idx[tag]
for tag in g.node_feat['nlabel'].reshape(-1, )
]
attr[:, idx] = 1
g.node_feat['attr'] = attr.astype("float32")
# after load, get the #classes and #dim
self.gclasses = len(self.glabel_dict)
self.nclasses = len(self.nlabel_dict)
self.eclasses = len(self.elabel_dict)
self.dim_nfeats = len(self.graph_list[0].node_feat['attr'][0])
message = "finished loading data\n"
message += """
num_graph: %d
num_graph_class: %d
total_num_nodes: %d
node Classes: %d
node_features_dim: %d
num_edges: %d
edge_classes: %d
Avg. of #Nodes: %.2f
Avg. of #Edges: %.2f
Graph Relabeled: %s
Node Relabeled: %s
Degree Relabeled(If degree_as_nlabel=True): %s""" % (
self.num_graph,
self.gclasses,
self.n,
self.nclasses,
self.dim_nfeats,
self.m,
self.eclasses,
self.n / self.num_graph,
self.m / self.num_graph,
self.glabel_dict,
self.nlabel_dict,
self.ndegree_dict, )
log.info(message)
if __name__ == "__main__":
gindataset = GINDataset(
"./dataset/", "MUTAG", self_loop=True, degree_as_nlabel=False)
# Graph Isomorphism Network (GIN)
[Graph Isomorphism Network \(GIN\)](https://arxiv.org/pdf/1810.00826.pdf) is a simple graph neural network that expects to achieve the ability as the Weisfeiler-Lehman graph isomorphism test. Based on PGL, we reproduce the GIN model.
### Datasets
The dataset can be downloaded from [here](https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip)
### Dependencies
- paddlepaddle 1.6
- pgl 1.0.2
### How to run
For examples, use GPU to train GIN model on MUTAG dataset.
```
python main.py --use_cuda --dataset_name MUTAG
```
### Hyperparameters
- data\_path: the root path of your dataset
- dataset\_name: the name of the dataset
- fold\_idx: The $fold\_idx^{th}$ fold of dataset splited. Here we use 10 fold cross-validation
- train\_eps: whether the $\epsilon$ parameter is learnable.
### Experiment results (Accuracy)
| |MUTAG | COLLAB | IMDBBINARY | IMDBMULTI |
|--|-------------|----------|------------|-----------------|
|PGL result | 90.8 | 78.6 | 76.8 | 50.8 |
|paper reuslt |90.0 | 80.0 | 75.1 | 52.3 |
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement the graph dataloader.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os
import sys
import time
import argparse
import numpy as np
import collections
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as fl
import pgl
from pgl.utils import mp_reader
from pgl.utils.logger import log
def batch_iter(data, batch_size, fid, num_workers):
"""node_batch_iter
"""
size = len(data)
perm = np.arange(size)
np.random.shuffle(perm)
start = 0
cc = 0
while start < size:
index = perm[start:start + batch_size]
start += batch_size
cc += 1
if cc % num_workers != fid:
continue
yield data[index]
def scan_batch_iter(data, batch_size, fid, num_workers):
"""scan_batch_iter
"""
batch = []
cc = 0
for line_example in data.scan():
cc += 1
if cc % num_workers != fid:
continue
batch.append(line_example)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
class GraphDataloader(object):
"""Graph Dataloader
"""
def __init__(
self,
dataset,
batch_size,
seed=0,
num_workers=1,
buf_size=1000,
shuffle=True, ):
self.shuffle = shuffle
self.seed = seed
self.num_workers = num_workers
self.buf_size = buf_size
self.batch_size = batch_size
self.dataset = dataset
def batch_fn(self, batch_examples):
""" batch_fn batch producer"""
graphs = [b[0] for b in batch_examples]
labels = [b[1] for b in batch_examples]
join_graph = pgl.graph.MultiGraph(graphs)
labels = np.array(labels, dtype="int64").reshape(-1, 1)
return join_graph, labels
# feed_dict = self.graph_wrapper.to_feed(join_graph)
# raise NotImplementedError("No defined Batch Fn")
def batch_iter(self, fid):
"""batch_iter"""
if self.shuffle:
for batch in batch_iter(self, self.batch_size, fid,
self.num_workers):
yield batch
else:
for batch in scan_batch_iter(self, self.batch_size, fid,
self.num_workers):
yield batch
def __len__(self):
"""__len__"""
return len(self.dataset)
def __getitem__(self, idx):
"""__getitem__"""
if isinstance(idx, collections.Iterable):
return [self[bidx] for bidx in idx]
else:
return self.dataset[idx]
def __iter__(self):
"""__iter__"""
def worker(filter_id):
def func_run():
for batch_examples in self.batch_iter(filter_id):
batch_dict = self.batch_fn(batch_examples)
yield batch_dict
return func_run
if self.num_workers == 1:
r = paddle.reader.buffered(worker(0), self.buf_size)
else:
worker_pool = [worker(wid) for wid in range(self.num_workers)]
worker = mp_reader.multiprocess_reader(
worker_pool, use_pipe=True, queue_size=1000)
r = paddle.reader.buffered(worker, self.buf_size)
for batch in r():
yield batch
def scan(self):
"""scan"""
for example in self.dataset:
yield example
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement the training process of GIN model.
"""
import os
import sys
import time
import argparse
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as fl
import pgl
from pgl.utils.logger import log
from Dataset import GINDataset, fold10_split, random_split
from dataloader import GraphDataloader
from model import GINModel
def main(args):
"""main function"""
dataset = GINDataset(
args.data_path,
args.dataset_name,
self_loop=not args.train_eps,
degree_as_nlabel=True)
train_dataset, test_dataset = fold10_split(
dataset, fold_idx=args.fold_idx, seed=args.seed)
train_loader = GraphDataloader(train_dataset, batch_size=args.batch_size)
test_loader = GraphDataloader(
test_dataset, batch_size=args.batch_size, shuffle=False)
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper(
"gw", place=place, node_feat=dataset[0][0].node_feat_info())
model = GINModel(args, gw, dataset.gclasses)
model.forward()
infer_program = train_program.clone(for_test=True)
with fluid.program_guard(train_program, startup_program):
epoch_step = int(len(train_dataset) / args.batch_size) + 1
boundaries = [
i
for i in range(50 * epoch_step, args.epochs * epoch_step,
epoch_step * 50)
]
values = [args.lr * 0.5**i for i in range(0, len(boundaries) + 1)]
lr = fl.piecewise_decay(boundaries=boundaries, values=values)
train_op = fluid.optimizer.Adam(lr).minimize(model.loss)
exe = fluid.Executor(place)
exe.run(startup_program)
# train and evaluate
global_step = 0
for epoch in range(1, args.epochs + 1):
for idx, batch_data in enumerate(train_loader):
g, labels = batch_data
feed_dict = gw.to_feed(g)
feed_dict['labels'] = labels
ret_loss, ret_lr, ret_acc = exe.run(
train_program,
feed=feed_dict,
fetch_list=[model.loss, lr, model.acc])
global_step += 1
if global_step % 10 == 0:
message = "epoch %d | step %d | " % (epoch, global_step)
message += "lr %.6f | loss %.6f | acc %.4f" % (
ret_lr, ret_loss, ret_acc)
log.info(message)
# evaluate
result = evaluate(exe, infer_program, model, gw, test_loader)
message = "evaluating result"
for key, value in result.items():
message += " | %s %.6f" % (key, value)
log.info(message)
def evaluate(exe, prog, model, gw, loader):
"""evaluate"""
total_loss = []
total_acc = []
for idx, batch_data in enumerate(loader):
g, labels = batch_data
feed_dict = gw.to_feed(g)
feed_dict['labels'] = labels
ret_loss, ret_acc = exe.run(prog,
feed=feed_dict,
fetch_list=[model.loss, model.acc])
total_loss.append(ret_loss)
total_acc.append(ret_acc)
total_loss = np.mean(total_loss)
total_acc = np.mean(total_acc)
return {"loss": total_loss, "acc": total_acc}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default='./dataset')
parser.add_argument('--dataset_name', type=str, default='MUTAG')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--fold_idx', type=int, default=0)
parser.add_argument('--output_path', type=str, default='./outputs/')
parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--num_layers', type=int, default=5)
parser.add_argument('--num_mlp_layers', type=int, default=2)
parser.add_argument('--hidden_size', type=int, default=64)
parser.add_argument(
'--pool_type',
type=str,
default="sum",
choices=["sum", "average", "max"])
parser.add_argument('--train_eps', action='store_true')
parser.add_argument('--epochs', type=int, default=350)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--dropout_prob', type=float, default=0.5)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
log.info(args)
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file implement the GIN model.
"""
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as fl
import pgl
from pgl.layers.conv import gin
class GINModel(object):
"""GINModel"""
def __init__(self, args, gw, num_class):
self.args = args
self.num_layers = self.args.num_layers
self.hidden_size = self.args.hidden_size
self.train_eps = self.args.train_eps
self.pool_type = self.args.pool_type
self.dropout_prob = self.args.dropout_prob
self.num_class = num_class
self.gw = gw
self.labels = fl.data(name="labels", shape=[None, 1], dtype="int64")
def forward(self):
"""forward"""
features_list = [self.gw.node_feat["attr"]]
for i in range(self.num_layers):
h = gin(self.gw,
features_list[i],
hidden_size=self.hidden_size,
activation="relu",
name="gin_%s" % (i),
init_eps=0.0,
train_eps=self.train_eps)
h = fl.batch_norm(h)
h = fl.relu(h)
features_list.append(h)
pooled_h = pgl.layers.graph_pooling(self.gw, features_list[0],
self.pool_type)
output = fl.dropout(
pooled_h,
self.dropout_prob,
dropout_implementation="upscale_in_train")
output = fl.fc(output,
size=self.num_class,
act=None,
param_attr=fluid.ParamAttr(name="final_fc_0"))
for i, h in enumerate(features_list):
if i == 0:
continue
pooled_h = pgl.layers.graph_pooling(self.gw, h, self.pool_type)
drop_h = fl.dropout(
pooled_h,
self.dropout_prob,
dropout_implementation="upscale_in_train")
output += fl.fc(drop_h,
size=self.num_class,
act=None,
param_attr=fluid.ParamAttr(name="final_fc_%s" %
(i)))
# calculate loss
self.loss = fl.softmax_with_cross_entropy(output, self.labels)
self.loss = fl.reduce_mean(self.loss)
self.acc = fl.accuracy(fl.softmax(output), self.labels)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册