diff --git a/examples/SAGPool/README.md b/examples/SAGPool/README.md index 24c672299cb6f63c98073688c21039d12f147f7f..6db58ae8c57ad0583b141ffea7a6c14d386a3006 100644 --- a/examples/SAGPool/README.md +++ b/examples/SAGPool/README.md @@ -1,10 +1,10 @@ # Self-Attention Graph Pooling -SAGPool is a graph pooling method based on self-attention. Self-attention uses graph convolution, which allows the pooling method to consider both node features and graph topology. Based on PGL, we implement the SAGPool algorithm and train the model on four datasets. +SAGPool is a graph pooling method based on self-attention. Self-attention uses graph convolution, which allows the pooling method to consider both node features and graph topology. Based on PGL, we implement the SAGPool algorithm and train the model on five datasets. ## Datasets -There are four datasets, including D&D, PROTEINS, NCI1, NCI109. You can download the datasets from [here](https://bj.bcebos.com/paddle-pgl/SAGPool/data.zip), and unzip it directly. The pkl format datasets should be in directory ./data. +There are five datasets, including D&D, PROTEINS, NCI1, NCI109 and FRANKENSTEIN. You can download the datasets from [here](https://bj.bcebos.com/paddle-pgl/SAGPool/data.zip), and unzip it directly. The pkl format datasets should be in directory ./data. ## Dependencies @@ -21,18 +21,20 @@ python main.py --dataset_name PROTEINS --learning_rate 0.001 --hidden_size 32 -- python main.py --dataset_name NCI1 --learning_rate 0.001 --weight_decay 0.00001 python main.py --dataset_name NCI109 --learning_rate 0.0005 --hidden_size 64 --weight_decay 0.0001 --patience 200 + +python main.py --dataset_name FRANKENSTEIN --learning_rate 0.001 --weight_decay 0.0001 ``` ## Hyperparameters - seed: random seed - batch\_size: the number of batch size -- learning\_rate: the number of learning rate +- learning\_rate: learning rate of optimizer - weight\_decay: the weight decay for L2 regularization - hidden\_size: the hidden size of gcn - pooling\_ratio: the pooling ratio of SAGPool - dropout\_ratio: the number of dropout ratio -- dataset\_name: the name of datasets, including DD, PROTEINS, NCI1, NCI109 +- dataset\_name: the name of datasets, including DD, PROTEINS, NCI1, NCI109, FRANKENSTEIN - epochs: maximum number of epochs - patience: patience for early stopping - use\_cuda: whether to use cuda @@ -48,3 +50,4 @@ We evaluate the implemented method for 20 random seeds using 10-fold cross valid | PROTEINS | 72.7858 | 0.6617 | | NCI1 | 75.781 | 1.2125 | | NCI109 | 74.3156 | 1.3 | +| FRANKENSTEIN | 60.7826 | 0.629 | diff --git a/examples/SAGPool/base_dataset.py b/examples/SAGPool/base_dataset.py index 4241a6dc81a7d25ba5b4373d6043a44ab90107c6..711e9203e311bb2c9cf5b6e9afc2834122a93a01 100644 --- a/examples/SAGPool/base_dataset.py +++ b/examples/SAGPool/base_dataset.py @@ -57,6 +57,7 @@ class Dataset(BaseDataset): with open('data/%s.pkl' % args.dataset_name, 'rb') as f: graphs_info_list = pickle.load(f) + self.pgl_graph_list = [] self.graph_label_list = [] for i in range(len(graphs_info_list) - 1): @@ -64,10 +65,11 @@ class Dataset(BaseDataset): edges_l, edges_r = graph["edge_src"], graph["edge_dst"] # add self-loops - node_nums = graph["num_nodes"] - x = np.arange(0, node_nums) - edges_l = np.append(edges_l, x) - edges_r = np.append(edges_r, x) + if self.args.dataset_name != "FRANKENSTEIN": + num_nodes = graph["num_nodes"] + x = np.arange(0, num_nodes) + edges_l = np.append(edges_l, x) + edges_r = np.append(edges_r, x) edges = list(zip(edges_l, edges_r)) g = pgl.graph.Graph(num_nodes=graph["num_nodes"], edges=edges) diff --git a/examples/SAGPool/layers.py b/examples/SAGPool/layers.py index fa800ecc48e822821806ca0ce22e41152501c811..650beefa8f89bcbfef708d0587dffe2243e99275 100644 --- a/examples/SAGPool/layers.py +++ b/examples/SAGPool/layers.py @@ -20,6 +20,7 @@ import pgl from pgl.graph_wrapper import GraphWrapper from pgl.utils.logger import log from conv import norm_gcn +from pgl.layers.conv import gcn def topk_pool(gw, score, graph_id, ratio): """Implementation of topk pooling, where k means pooling ratio. @@ -53,6 +54,7 @@ def topk_pool(gw, score, graph_id, ratio): index = L.arange(0, gw.num_nodes, dtype="int64") temp = L.gather(graph_lod, graph_id, overwrite=False) index = (index - temp) + (graph_id * max_num_nodes) + index.stop_gradient = True # padding dense_score = L.fill_constant(shape=[num_graph * max_num_nodes], @@ -86,7 +88,7 @@ def topk_pool(gw, score, graph_id, ratio): return perm, ratio_length -def sag_pool(gw, feature, ratio, graph_id, name, activation=L.tanh): +def sag_pool(gw, feature, ratio, graph_id, dataset, name, activation=L.tanh): """Implementation of self-attention graph pooling (SAGPool) This is an implementation of the paper SELF-ATTENTION GRAPH POOLING @@ -100,9 +102,11 @@ def sag_pool(gw, feature, ratio, graph_id, name, activation=L.tanh): ratio: The pooling ratio of nodes we want to select. graph_id: The graphs that the nodes belong to. - - gcn: To use the official gcn or norm gcn. + dataset: To differentiate FRANKENSTEIN dataset and other datasets. + + name: The name of SAGPool layer. + activation: The activation function. Return: @@ -112,7 +116,12 @@ def sag_pool(gw, feature, ratio, graph_id, name, activation=L.tanh): ratio_length: The selected node numbers of each graph. """ - score = norm_gcn(gw=gw, + if dataset == "FRANKENSTEIN": + gcn_ = gcn + else: + gcn_ = norm_gcn + + score = gcn_(gw=gw, feature=feature, hidden_size=1, activation=None, diff --git a/examples/SAGPool/main.py b/examples/SAGPool/main.py index ddf0400cf5c6b1fe78bd83b0905a7200ce1b01e3..8895311e0f729bf572919faf07df50c7921cb545 100644 --- a/examples/SAGPool/main.py +++ b/examples/SAGPool/main.py @@ -15,8 +15,6 @@ import sys import os import argparse -import torch -from torch_geometric.datasets import TUDataset import pgl from pgl.utils.logger import log import paddle diff --git a/examples/SAGPool/model.py b/examples/SAGPool/model.py index 7ebe9af2d253bffaa65dbda47cfc2d4bb1745a28..cdfbe6f4f2d911dff041522d3e75266b4867c5a3 100644 --- a/examples/SAGPool/model.py +++ b/examples/SAGPool/model.py @@ -22,6 +22,7 @@ import pgl from pgl.graph import Graph, MultiGraph from pgl.graph_wrapper import GraphWrapper from pgl.utils.logger import log +from pgl.layers.conv import gcn from layers import sag_pool from conv import norm_gcn @@ -72,27 +73,32 @@ class GlobalModel(object): shape=[None], dtype="int32", append_batch_size=False) + + if self.args.dataset_name == "FRANKENSTEIN": + self.gcn = gcn + else: + self.gcn = norm_gcn self.build_model() def build_model(self): node_features = self.graph_wrapper.node_feat["feat"] - output = gcn(gw=self.graph_wrapper, + output = self.gcn(gw=self.graph_wrapper, feature=node_features, hidden_size=self.hidden_size, activation="relu", norm=self.graph_wrapper.node_feat["norm"], name="gcn_layer_1") output1 = output - output = gcn(gw=self.graph_wrapper, + output = self.gcn(gw=self.graph_wrapper, feature=output, hidden_size=self.hidden_size, activation="relu", norm=self.graph_wrapper.node_feat["norm"], name="gcn_layer_2") output2 = output - output = gcn(gw=self.graph_wrapper, + output = self.gcn(gw=self.graph_wrapper, feature=output, hidden_size=self.hidden_size, activation="relu", @@ -105,6 +111,7 @@ class GlobalModel(object): feature=output, ratio=self.pooling_ratio, graph_id=self.graph_id, + dataset=self.args.dataset_name, name="sag_pool_1") output = L.lod_reset(output, self.graph_wrapper.graph_lod) cat1 = L.sequence_pool(output, "sum")