提交 a84c739b 编写于 作者: D DesmonDay

add FRANKENSTEIN

上级 9cbdc6ed
# Self-Attention Graph Pooling
SAGPool is a graph pooling method based on self-attention. Self-attention uses graph convolution, which allows the pooling method to consider both node features and graph topology. Based on PGL, we implement the SAGPool algorithm and train the model on four datasets.
SAGPool is a graph pooling method based on self-attention. Self-attention uses graph convolution, which allows the pooling method to consider both node features and graph topology. Based on PGL, we implement the SAGPool algorithm and train the model on five datasets.
## Datasets
There are four datasets, including D&D, PROTEINS, NCI1, NCI109. You can download the datasets from [here](https://bj.bcebos.com/paddle-pgl/SAGPool/data.zip), and unzip it directly. The pkl format datasets should be in directory ./data.
There are five datasets, including D&D, PROTEINS, NCI1, NCI109 and FRANKENSTEIN. You can download the datasets from [here](https://bj.bcebos.com/paddle-pgl/SAGPool/data.zip), and unzip it directly. The pkl format datasets should be in directory ./data.
## Dependencies
......@@ -21,18 +21,20 @@ python main.py --dataset_name PROTEINS --learning_rate 0.001 --hidden_size 32 --
python main.py --dataset_name NCI1 --learning_rate 0.001 --weight_decay 0.00001
python main.py --dataset_name NCI109 --learning_rate 0.0005 --hidden_size 64 --weight_decay 0.0001 --patience 200
python main.py --dataset_name FRANKENSTEIN --learning_rate 0.001 --weight_decay 0.0001
```
## Hyperparameters
- seed: random seed
- batch\_size: the number of batch size
- learning\_rate: the number of learning rate
- learning\_rate: learning rate of optimizer
- weight\_decay: the weight decay for L2 regularization
- hidden\_size: the hidden size of gcn
- pooling\_ratio: the pooling ratio of SAGPool
- dropout\_ratio: the number of dropout ratio
- dataset\_name: the name of datasets, including DD, PROTEINS, NCI1, NCI109
- dataset\_name: the name of datasets, including DD, PROTEINS, NCI1, NCI109, FRANKENSTEIN
- epochs: maximum number of epochs
- patience: patience for early stopping
- use\_cuda: whether to use cuda
......@@ -48,3 +50,4 @@ We evaluate the implemented method for 20 random seeds using 10-fold cross valid
| PROTEINS | 72.7858 | 0.6617 |
| NCI1 | 75.781 | 1.2125 |
| NCI109 | 74.3156 | 1.3 |
| FRANKENSTEIN | 60.7826 | 0.629 |
......@@ -57,6 +57,7 @@ class Dataset(BaseDataset):
with open('data/%s.pkl' % args.dataset_name, 'rb') as f:
graphs_info_list = pickle.load(f)
self.pgl_graph_list = []
self.graph_label_list = []
for i in range(len(graphs_info_list) - 1):
......@@ -64,8 +65,9 @@ class Dataset(BaseDataset):
edges_l, edges_r = graph["edge_src"], graph["edge_dst"]
# add self-loops
node_nums = graph["num_nodes"]
x = np.arange(0, node_nums)
if self.args.dataset_name != "FRANKENSTEIN":
num_nodes = graph["num_nodes"]
x = np.arange(0, num_nodes)
edges_l = np.append(edges_l, x)
edges_r = np.append(edges_r, x)
......
......@@ -20,6 +20,7 @@ import pgl
from pgl.graph_wrapper import GraphWrapper
from pgl.utils.logger import log
from conv import norm_gcn
from pgl.layers.conv import gcn
def topk_pool(gw, score, graph_id, ratio):
"""Implementation of topk pooling, where k means pooling ratio.
......@@ -53,6 +54,7 @@ def topk_pool(gw, score, graph_id, ratio):
index = L.arange(0, gw.num_nodes, dtype="int64")
temp = L.gather(graph_lod, graph_id, overwrite=False)
index = (index - temp) + (graph_id * max_num_nodes)
index.stop_gradient = True
# padding
dense_score = L.fill_constant(shape=[num_graph * max_num_nodes],
......@@ -86,7 +88,7 @@ def topk_pool(gw, score, graph_id, ratio):
return perm, ratio_length
def sag_pool(gw, feature, ratio, graph_id, name, activation=L.tanh):
def sag_pool(gw, feature, ratio, graph_id, dataset, name, activation=L.tanh):
"""Implementation of self-attention graph pooling (SAGPool)
This is an implementation of the paper SELF-ATTENTION GRAPH POOLING
......@@ -101,7 +103,9 @@ def sag_pool(gw, feature, ratio, graph_id, name, activation=L.tanh):
graph_id: The graphs that the nodes belong to.
gcn: To use the official gcn or norm gcn.
dataset: To differentiate FRANKENSTEIN dataset and other datasets.
name: The name of SAGPool layer.
activation: The activation function.
......@@ -112,7 +116,12 @@ def sag_pool(gw, feature, ratio, graph_id, name, activation=L.tanh):
ratio_length: The selected node numbers of each graph.
"""
score = norm_gcn(gw=gw,
if dataset == "FRANKENSTEIN":
gcn_ = gcn
else:
gcn_ = norm_gcn
score = gcn_(gw=gw,
feature=feature,
hidden_size=1,
activation=None,
......
......@@ -15,8 +15,6 @@
import sys
import os
import argparse
import torch
from torch_geometric.datasets import TUDataset
import pgl
from pgl.utils.logger import log
import paddle
......
......@@ -22,6 +22,7 @@ import pgl
from pgl.graph import Graph, MultiGraph
from pgl.graph_wrapper import GraphWrapper
from pgl.utils.logger import log
from pgl.layers.conv import gcn
from layers import sag_pool
from conv import norm_gcn
......@@ -73,26 +74,31 @@ class GlobalModel(object):
dtype="int32",
append_batch_size=False)
if self.args.dataset_name == "FRANKENSTEIN":
self.gcn = gcn
else:
self.gcn = norm_gcn
self.build_model()
def build_model(self):
node_features = self.graph_wrapper.node_feat["feat"]
output = gcn(gw=self.graph_wrapper,
output = self.gcn(gw=self.graph_wrapper,
feature=node_features,
hidden_size=self.hidden_size,
activation="relu",
norm=self.graph_wrapper.node_feat["norm"],
name="gcn_layer_1")
output1 = output
output = gcn(gw=self.graph_wrapper,
output = self.gcn(gw=self.graph_wrapper,
feature=output,
hidden_size=self.hidden_size,
activation="relu",
norm=self.graph_wrapper.node_feat["norm"],
name="gcn_layer_2")
output2 = output
output = gcn(gw=self.graph_wrapper,
output = self.gcn(gw=self.graph_wrapper,
feature=output,
hidden_size=self.hidden_size,
activation="relu",
......@@ -105,6 +111,7 @@ class GlobalModel(object):
feature=output,
ratio=self.pooling_ratio,
graph_id=self.graph_id,
dataset=self.args.dataset_name,
name="sag_pool_1")
output = L.lod_reset(output, self.graph_wrapper.graph_lod)
cat1 = L.sequence_pool(output, "sum")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册