未验证 提交 d882cc6b 编写于 作者: W wawltor 提交者: GitHub

Function Add: add graph classification model SAGPool

# Self-Attention Graph Pooling
SAGPool is a graph pooling method based on self-attention. Self-attention uses graph convolution, which allows the pooling method to consider both node features and graph topology. Based on PGL, we implement the SAGPool algorithm and train the model on five datasets.
## Datasets
There are five datasets, including D&D, PROTEINS, NCI1, NCI109 and FRANKENSTEIN. You can download the datasets from [here](https://bj.bcebos.com/paddle-pgl/SAGPool/data.zip), and unzip it directly. The pkl format datasets should be in directory ./data.
## Dependencies
- [paddlepaddle >= 1.8](https://github.com/PaddlePaddle/paddle)
- [pgl 1.1](https://github.com/PaddlePaddle/PGL)
## How to run
```
python main.py --dataset_name DD --learning_rate 0.005 --weight_decay 0.00001
python main.py --dataset_name PROTEINS --learning_rate 0.001 --hidden_size 32 --weight_decay 0.00001
python main.py --dataset_name NCI1 --learning_rate 0.001 --weight_decay 0.00001
python main.py --dataset_name NCI109 --learning_rate 0.0005 --hidden_size 64 --weight_decay 0.0001 --patience 200
python main.py --dataset_name FRANKENSTEIN --learning_rate 0.001 --weight_decay 0.0001
```
## Hyperparameters
- seed: random seed
- batch\_size: the number of batch size
- learning\_rate: learning rate of optimizer
- weight\_decay: the weight decay for L2 regularization
- hidden\_size: the hidden size of gcn
- pooling\_ratio: the pooling ratio of SAGPool
- dropout\_ratio: the number of dropout ratio
- dataset\_name: the name of datasets, including DD, PROTEINS, NCI1, NCI109, FRANKENSTEIN
- epochs: maximum number of epochs
- patience: patience for early stopping
- use\_cuda: whether to use cuda
- save\_model: the name for the best model
## Performance
We evaluate the implemented method for 20 random seeds using 10-fold cross validation, following the same training procedures as in the paper.
| dataset | mean accuracy | standard deviation | mean accuracy(paper) | standard deviation(paper) |
| ------------ | ------------- | ------------------ | -------------------- | ------------------------- |
| DD | 74.4181 | 1.0244 | 76.19 | 0.94 |
| PROTEINS | 72.7858 | 0.6617 | 70.04 | 1.47 |
| NCI1 | 75.781 | 1.2125 | 74.18 | 1.2 |
| NCI109 | 74.3156 | 1.3 | 74.06 | 0.78 |
| FRANKENSTEIN | 60.7826 | 0.629 | 62.57 | 0.6 |
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=777,
help='seed')
parser.add_argument('--batch_size', type=int, default=128,
help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.0005,
help='learning rate')
parser.add_argument('--weight_decay', type=float, default=0.0001,
help='weight decay')
parser.add_argument('--hidden_size', type=int, default=128,
help='gcn hidden size')
parser.add_argument('--pooling_ratio', type=float, default=0.5,
help='pooling ratio of SAGPool')
parser.add_argument('--dropout_ratio', type=float, default=0.5,
help='dropout ratio')
parser.add_argument('--dataset_name', type=str, default='DD',
help='DD/PROTEINS/NCI1/NCI109/FRANKENSTEIN')
parser.add_argument('--epochs', type=int, default=100000,
help='maximum number of epochs')
parser.add_argument('--patience', type=int, default=50,
help='patience for early stopping')
parser.add_argument('--use_cuda', type=bool, default=True,
help='use cuda or cpu')
parser.add_argument('--save_model', type=str,
help='save model name')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import random
import pgl
from pgl.utils.logger import log
from pgl.graph import Graph, MultiGraph
import numpy as np
import pickle
class BaseDataset(object):
def __init__(self):
pass
def __getitem__(self, idx):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class Subset(BaseDataset):
"""Subset of a dataset at specified indices.
Args:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
class Dataset(BaseDataset):
def __init__(self, args):
self.args = args
with open('data/%s.pkl' % args.dataset_name, 'rb') as f:
graphs_info_list = pickle.load(f)
self.pgl_graph_list = []
self.graph_label_list = []
for i in range(len(graphs_info_list) - 1):
graph = graphs_info_list[i]
edges_l, edges_r = graph["edge_src"], graph["edge_dst"]
# add self-loops
if self.args.dataset_name != "FRANKENSTEIN":
num_nodes = graph["num_nodes"]
x = np.arange(0, num_nodes)
edges_l = np.append(edges_l, x)
edges_r = np.append(edges_r, x)
edges = list(zip(edges_l, edges_r))
g = pgl.graph.Graph(num_nodes=graph["num_nodes"], edges=edges)
g.node_feat["feat"] = graph["node_feat"]
self.pgl_graph_list.append(g)
self.graph_label_list.append(graph["label"])
self.num_classes = graphs_info_list[-1]["num_classes"]
self.num_features = graphs_info_list[-1]["num_features"]
def __getitem__(self, idx):
return self.pgl_graph_list[idx], self.graph_label_list[idx]
def shuffle(self):
"""shuffle the dataset.
"""
cc = list(zip(self.pgl_graph_list, self.graph_label_list))
random.seed(self.args.seed)
random.shuffle(cc)
a, b = zip(*cc)
self.pgl_graph_list[:], self.graph_label_list[:] = a, b
def __len__(self):
return len(self.pgl_graph_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.layers as L
def norm_gcn(gw, feature, hidden_size, activation, name, norm=None):
"""Implementation of graph convolutional neural networks(GCN), using different
normalization method.
Args:
gw: Graph wrapper object.
feature: A tensor with shape (num_nodes, feature_size).
hidden_size: The hidden size for norm gcn.
activation: The activation for the output.
name: Norm gcn layer names.
norm: If norm is not None, then the feature will be normalized. Norm must
be tensor with shape (num_nodes,) and dtype float32.
Return:
A tensor with shape (num_nodes, hidden_size)
"""
size = feature.shape[-1]
feature = L.fc(feature,
size=hidden_size,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name))
if norm is not None:
src, dst = gw.edges
norm_src = L.gather(norm, src, overwrite=False)
norm_dst = L.gather(norm, dst, overwrite=False)
norm = norm_src * norm_dst
def send_src_copy(src_feat, dst_feat, edge_feat):
return src_feat["h"] * norm
else:
def send_src_copy(src_feat, dst_feat, edge_feat):
return src_feat["h"]
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
output = gw.recv(msg, "sum")
bias = L.create_parameter(
shape=[hidden_size],
dtype='float32',
is_bias=True,
name=name + '_bias')
output = L.elementwise_add(output, bias, act=activation)
return output
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import collections
import paddle
import pgl
from pgl.utils.logger import log
from pgl.graph import Graph, MultiGraph
def batch_iter(data, batch_size):
"""node_batch_iter
"""
size = len(data)
perm = np.arange(size)
np.random.shuffle(perm)
start = 0
while start < size:
index = perm[start:start + batch_size]
start += batch_size
yield data[index]
def scan_batch_iter(data, batch_size):
"""scan_batch_iter
"""
batch = []
for example in data.scan():
batch.append(example)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
def label_to_onehot(labels):
"""Return one-hot representations of labels
"""
onehot_labels = []
for label in labels:
if label == 0:
onehot_labels.append([1, 0])
else:
onehot_labels.append([0, 1])
onehot_labels = np.array(onehot_labels)
return onehot_labels
class GraphDataloader(object):
"""Graph Dataloader
"""
def __init__(self,
dataset,
graph_wrapper,
batch_size,
seed=0,
buf_size=1000,
shuffle=True):
self.shuffle = shuffle
self.seed = seed
self.batch_size = batch_size
self.dataset = dataset
self.buf_size = buf_size
self.graph_wrapper = graph_wrapper
def batch_fn(self, batch_examples):
""" batch_fun batch producer """
graphs = [b[0] for b in batch_examples]
labels = [b[1] for b in batch_examples]
join_graph = MultiGraph(graphs)
# normalize
indegree = join_graph.indegree()
norm = np.zeros_like(indegree, dtype="float32")
norm[indegree > 0] = np.power(indegree[indegree > 0], -0.5)
join_graph.node_feat["norm"] = np.expand_dims(norm, -1)
feed_dict = self.graph_wrapper.to_feed(join_graph)
labels = np.array(labels)
feed_dict["labels_1dim"] = labels
labels = label_to_onehot(labels)
feed_dict["labels"] = labels
graph_lod = join_graph.graph_lod
graph_id = []
for i in range(1, len(graph_lod)):
graph_node_num = graph_lod[i] - graph_lod[i - 1]
graph_id += [i - 1] * graph_node_num
graph_id = np.array(graph_id, dtype="int32")
feed_dict["graph_id"] = graph_id
return feed_dict
def batch_iter(self):
""" batch_iter """
if self.shuffle:
for batch in batch_iter(self, self.batch_size):
yield batch
else:
for batch in scan_batch_iter(self, self.batch_size):
yield batch
def __len__(self):
"""__len__"""
return len(self.dataset)
def __getitem__(self, idx):
"""__getitem__"""
if isinstance(idx, collections.Iterable):
return [self.dataset[bidx] for bidx in idx]
else:
return self.dataset[idx]
def __iter__(self):
"""__iter__"""
def func_run():
for batch_examples in self.batch_iter():
batch_dict = self.batch_fn(batch_examples)
yield batch_dict
r = paddle.reader.buffered(func_run, self.buf_size)
for batch in r():
yield batch
def scan(self):
"""scan"""
for example in self.dataset:
yield example
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as L
import pgl
from pgl.graph_wrapper import GraphWrapper
from pgl.utils.logger import log
from conv import norm_gcn
from pgl.layers.conv import gcn
def topk_pool(gw, score, graph_id, ratio):
"""Implementation of topk pooling, where k means pooling ratio.
Args:
gw: Graph wrapper object.
score: The attention score of all nodes, which is used to select
important nodes.
graph_id: The graphs that the nodes belong to.
ratio: The pooling ratio of nodes we want to select.
Return:
perm: The index of nodes we choose.
ratio_length: The selected node numbers of each graph.
"""
graph_lod = gw.graph_lod
graph_nodes = gw.num_nodes
num_graph = gw.num_graph
num_nodes = L.ones(shape=[graph_nodes], dtype="float32")
num_nodes = L.lod_reset(num_nodes, graph_lod)
num_nodes_per_graph = L.sequence_pool(num_nodes, pool_type='sum')
max_num_nodes = L.reduce_max(num_nodes_per_graph, dim=0)
max_num_nodes = L.cast(max_num_nodes, dtype="int32")
index = L.arange(0, gw.num_nodes, dtype="int64")
offset = L.gather(graph_lod, graph_id, overwrite=False)
index = (index - temp) + (graph_id * max_num_nodes)
index.stop_gradient = True
# padding
dense_score = L.fill_constant(shape=[num_graph * max_num_nodes],
dtype="float32", value=-999999)
index = L.reshape(index, shape=[-1])
dense_score = L.scatter(dense_score, index, updates=score)
num_graph = L.cast(num_graph, dtype="int32")
dense_score = L.reshape(dense_score,
shape=[num_graph, max_num_nodes])
# record the sorted index
_, sort_index = L.argsort(dense_score, axis=-1, descending=True)
# recover the index range
graph_lod = graph_lod[:-1]
graph_lod = L.reshape(graph_lod, shape=[-1, 1])
graph_lod = L.cast(graph_lod, dtype="int64")
sort_index = L.elementwise_add(sort_index, graph_lod, axis=-1)
sort_index = L.reshape(sort_index, shape=[-1, 1])
# use sequence_slice to choose selected node index
pad_lod = L.arange(0, (num_graph + 1) * max_num_nodes, step=max_num_nodes, dtype="int32")
sort_index = L.lod_reset(sort_index, pad_lod)
ratio_length = L.ceil(num_nodes_per_graph * ratio)
ratio_length = L.cast(ratio_length, dtype="int64")
ratio_length = L.reshape(ratio_length, shape=[-1, 1])
offset = L.zeros(shape=[num_graph, 1], dtype="int64")
choose_index = L.sequence_slice(input=sort_index, offset=offset, length=ratio_length)
perm = L.reshape(choose_index, shape=[-1])
return perm, ratio_length
def sag_pool(gw, feature, ratio, graph_id, dataset, name, activation=L.tanh):
"""Implementation of self-attention graph pooling (SAGPool)
This is an implementation of the paper SELF-ATTENTION GRAPH POOLING
(https://arxiv.org/pdf/1904.08082.pdf)
Args:
gw: Graph wrapper object.
feature: A tensor with shape (num_nodes, feature_size).
ratio: The pooling ratio of nodes we want to select.
graph_id: The graphs that the nodes belong to.
dataset: To differentiate FRANKENSTEIN dataset and other datasets.
name: The name of SAGPool layer.
activation: The activation function.
Return:
new_feature: A tensor with shape (num_nodes, feature_size), and the unselected
nodes' feature is masked by zero.
ratio_length: The selected node numbers of each graph.
"""
if dataset == "FRANKENSTEIN":
gcn_ = gcn
else:
gcn_ = norm_gcn
score = gcn_(gw=gw,
feature=feature,
hidden_size=1,
activation=None,
norm=gw.node_feat["norm"],
name=name)
score = L.squeeze(score, axes=[])
perm, ratio_length = topk_pool(gw, score, graph_id, ratio)
mask = L.zeros_like(score)
mask = L.cast(mask, dtype="float32")
updates = L.ones_like(perm)
updates = L.cast(updates, dtype="float32")
mask = L.scatter(mask, perm, updates)
new_feature = L.elementwise_mul(feature, mask, axis=0)
temp_score = activation(score)
new_feature = L.elementwise_mul(new_feature, temp_score, axis=0)
return new_feature, ratio_length
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import argparse
import pgl
from pgl.utils.logger import log
import paddle
import re
import time
import random
import numpy as np
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as L
import pgl
from pgl.utils.logger import log
from model import GlobalModel
from base_dataset import Subset, Dataset
from dataloader import GraphDataloader
from args import parser
import warnings
from sklearn.model_selection import KFold
warnings.filterwarnings("ignore")
def main(args, train_dataset, val_dataset, test_dataset):
"""main function for running one testing results.
"""
log.info("Train Examples: %s" % len(train_dataset))
log.info("Val Examples: %s" % len(val_dataset))
log.info("Test Examples: %s" % len(test_dataset))
train_program = fluid.Program()
train_program.random_seed = args.seed
startup_program = fluid.Program()
startup_program.random_seed = args.seed
if args.use_cuda:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
log.info("building model")
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
graph_model = GlobalModel(args, dataset)
train_loader = GraphDataloader(train_dataset,
graph_model.graph_wrapper,
batch_size=args.batch_size)
optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(args.weight_decay))
optimizer.minimize(graph_model.loss)
exe.run(startup_program)
test_program = fluid.Program()
test_program = train_program.clone(for_test=True)
val_loader = GraphDataloader(val_dataset,
graph_model.graph_wrapper,
batch_size=args.batch_size,
shuffle=False)
test_loader = GraphDataloader(test_dataset,
graph_model.graph_wrapper,
batch_size=args.batch_size,
shuffle=False)
min_loss = 1e10
global_step = 0
for epoch in range(args.epochs):
for feed_dict in train_loader:
loss, pred = exe.run(train_program,
feed=feed_dict,
fetch_list=[graph_model.loss, graph_model.pred])
log.info("Epoch: %d, global_step: %d, Training loss: %f" \
% (epoch, global_step, loss))
global_step += 1
# validation
valid_loss = 0.
correct = 0.
for feed_dict in val_loader:
valid_loss_, correct_ = exe.run(test_program,
feed=feed_dict,
fetch_list=[graph_model.loss, graph_model.correct])
valid_loss += valid_loss_
correct += correct_
if epoch % 50 == 0:
log.info("Epoch:%d, Validation loss: %f, Validation acc: %f" \
% (epoch, valid_loss, correct / len(val_loader)))
if valid_loss < min_loss:
min_loss = valid_loss
patience = 0
path = "./save/%s" % args.dataset_name
if not os.path.exists(path):
os.makedirs(path)
fluid.save(train_program, "%s/%s" \
% (path, args.save_model))
log.info("Model saved at epoch %d" % epoch)
else:
patience += 1
if patience > args.patience:
break
correct = 0.
new_test_program = fluid.Program()
fluid.load(new_test_program, "./save/%s/%s" \
% (args.dataset_name, args.save_model), exe)
for feed_dict in test_loader:
correct_ = exe.run(test_program,
feed=feed_dict,
fetch_list=[graph_model.correct])
correct += correct_[0]
log.info("Test acc: %f" % (correct / len(test_loader)))
return correct / len(test_loader)
def split_10_cv(dataset, args):
"""10 folds cross validation
"""
dataset.shuffle()
X = np.array([0] * len(dataset))
y = X
kf = KFold(n_splits=10, shuffle=False)
i = 1
test_acc = []
for train_index, test_index in kf.split(X, y):
train_val_dataset = Subset(dataset, train_index)
test_dataset = Subset(dataset, test_index)
train_val_index_range = list(range(0, len(train_val_dataset)))
num_val = int(len(train_val_dataset) / 9)
val_dataset = Subset(train_val_dataset, train_val_index_range[:num_val])
train_dataset = Subset(train_val_dataset, train_val_index_range[num_val:])
log.info("######%d fold of 10-fold cross validation######" % i)
i += 1
test_acc_ = main(args, train_dataset, val_dataset, test_dataset)
test_acc.append(test_acc_)
mean_acc = sum(test_acc) / len(test_acc)
return mean_acc, test_acc
def random_seed_20(args, dataset):
"""run for 20 random seeds
"""
alist = random.sample(range(1,1000),20)
test_acc_fold = []
for seed in alist:
log.info('############ Seed %d ############' % seed)
args.seed = seed
test_acc_fold_, _ = split_10_cv(dataset, args)
log.info('Mean test acc at seed %d: %f' % (seed, test_acc_fold_))
test_acc_fold.append(test_acc_fold_)
mean_acc = sum(test_acc_fold) / len(test_acc_fold)
temp = [(acc - mean_acc) * (acc - mean_acc) for acc in test_acc_fold]
standard_std = math.sqrt(sum(temp) / len(test_acc_fold))
log.info('Final mean test acc using 20 random seeds(mean for 10-fold): %f' % (mean_acc))
log.info('Final standard std using 20 random seeds(mean for 10-fold): %f' % (standard_std))
if __name__ == "__main__":
args = parser.parse_args()
log.info('loading data...')
dataset = Dataset(args)
log.info("preprocess finish.")
args.num_classes = dataset.num_classes
args.num_features = dataset.num_features
random_seed_20(args, dataset)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from random import random
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as L
import pgl
from pgl.graph import Graph, MultiGraph
from pgl.graph_wrapper import GraphWrapper
from pgl.utils.logger import log
from pgl.layers.conv import gcn
from layers import sag_pool
from conv import norm_gcn
class GlobalModel(object):
"""Implementation of global pooling architecture with SAGPool.
"""
def __init__(self, args, dataset):
self.args = args
self.dataset = dataset
self.hidden_size = args.hidden_size
self.num_classes = args.num_classes
self.num_features = args.num_features
self.pooling_ratio = args.pooling_ratio
self.dropout_ratio = args.dropout_ratio
self.batch_size = args.batch_size
graph_data = []
g, label = self.dataset[0]
graph_data.append(g)
g, label = self.dataset[1]
graph_data.append(g)
batch_graph = MultiGraph(graph_data)
indegree = batch_graph.indegree()
norm = np.zeros_like(indegree, dtype="float32")
norm[indegree > 0] = np.power(indegree[indegree > 0], -0.5)
batch_graph.node_feat["norm"] = np.expand_dims(norm, -1)
graph_data = batch_graph
self.graph_wrapper = GraphWrapper(
name="graph",
node_feat=graph_data.node_feat_info()
)
self.labels = L.data(
"labels",
shape=[None, self.args.num_classes],
dtype="int32",
append_batch_size=False)
self.labels_1dim = L.data(
"labels_1dim",
shape=[None],
dtype="int32",
append_batch_size=False)
self.graph_id = L.data(
"graph_id",
shape=[None],
dtype="int32",
append_batch_size=False)
if self.args.dataset_name == "FRANKENSTEIN":
self.gcn = gcn
else:
self.gcn = norm_gcn
self.build_model()
def build_model(self):
node_features = self.graph_wrapper.node_feat["feat"]
output = self.gcn(gw=self.graph_wrapper,
feature=node_features,
hidden_size=self.hidden_size,
activation="relu",
norm=self.graph_wrapper.node_feat["norm"],
name="gcn_layer_1")
output1 = output
output = self.gcn(gw=self.graph_wrapper,
feature=output,
hidden_size=self.hidden_size,
activation="relu",
norm=self.graph_wrapper.node_feat["norm"],
name="gcn_layer_2")
output2 = output
output = self.gcn(gw=self.graph_wrapper,
feature=output,
hidden_size=self.hidden_size,
activation="relu",
norm=self.graph_wrapper.node_feat["norm"],
name="gcn_layer_3")
output = L.concat(input=[output1, output2, output], axis=-1)
output, ratio_length = sag_pool(gw=self.graph_wrapper,
feature=output,
ratio=self.pooling_ratio,
graph_id=self.graph_id,
dataset=self.args.dataset_name,
name="sag_pool_1")
output = L.lod_reset(output, self.graph_wrapper.graph_lod)
cat1 = L.sequence_pool(output, "sum")
ratio_length = L.cast(ratio_length, dtype="float32")
cat1 = L.elementwise_div(cat1, ratio_length, axis=-1)
cat2 = L.sequence_pool(output, "max")
output = L.concat(input=[cat2, cat1], axis=-1)
output = L.fc(output, size=self.hidden_size, act="relu")
output = L.dropout(output, dropout_prob=self.dropout_ratio)
output = L.fc(output, size=self.hidden_size // 2, act="relu")
output = L.fc(output, size=self.num_classes, act=None,
param_attr=fluid.ParamAttr(name="final_fc"))
self.labels = L.cast(self.labels, dtype="float32")
loss = L.sigmoid_cross_entropy_with_logits(x=output, label=self.labels)
self.loss = L.mean(loss)
pred = L.sigmoid(output)
self.pred = L.argmax(x=pred, axis=-1)
correct = L.equal(self.pred, self.labels_1dim)
correct = L.cast(correct, dtype="int32")
self.correct = L.reduce_sum(correct)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册