未验证 提交 5782fb81 编写于 作者: W Weiyue Su 提交者: GitHub

Merge pull request #84 from WenjinW/master

Update example/GaAN
# data and log # data and log
.examples/GaAN/datase/t /examples/GaAN/dataset/
.examples/GaAN/log/ /examples/GaAN/log/
.examples/GaAN/__pycache__/ /examples/GaAN/__pycache__/
/examples/GaAN/params/
/DoorGod
# Virtualenv # Virtualenv
/.venv/ /.venv/
/venv/ /venv/
......
...@@ -6,14 +6,20 @@ ...@@ -6,14 +6,20 @@
The ogbn-proteins dataset will be downloaded in directory ./dataset automatically. The ogbn-proteins dataset will be downloaded in directory ./dataset automatically.
## Dependencies ## Dependencies
- paddlepaddle - [paddlepaddle >= 1.6](https://github.com/paddlepaddle/paddle)
- pgl - [pgl 1.1](https://github.com/PaddlePaddle/PGL)
- ogb - [ogb 1.1.1](https://github.com/snap-stanford/ogb)
## How to run ## How to run
```bash ```bash
python train.py --lr 1e-2 --rc 0 --batch_size 1024 --epochs 100 python train.py --lr 1e-2 --rc 0 --batch_size 1024 --epochs 100
``` ```
or
```bash
source main.sh
```
### Hyperparameters ### Hyperparameters
- use_gpu: whether to use gpu or not - use_gpu: whether to use gpu or not
- mini_data: use a small dataset to test code - mini_data: use a small dataset to test code
...@@ -32,4 +38,4 @@ python train.py --lr 1e-2 --rc 0 --batch_size 1024 --epochs 100 ...@@ -32,4 +38,4 @@ python train.py --lr 1e-2 --rc 0 --batch_size 1024 --epochs 100
We train our models for 100 epochs and report the **rocauc** on the test dataset. We train our models for 100 epochs and report the **rocauc** on the test dataset.
|dataset|mean|std| |dataset|mean|std|
|-|-|-| |-|-|-|
|ogbn-proteins|0.7786|0.0048| |ogbn-proteins|0.7803|0.0073|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This package implements common layers to help building
graph neural networks.
"""
import paddle.fluid as fluid
from pgl import graph_wrapper
from pgl.utils import paddle_helper
__all__ = ['gcn', 'gat', 'gin', 'gaan']
def gcn(gw, feature, hidden_size, activation, name, norm=None):
"""Implementation of graph convolutional neural networks (GCN)
This is an implementation of the paper SEMI-SUPERVISED CLASSIFICATION
WITH GRAPH CONVOLUTIONAL NETWORKS (https://arxiv.org/pdf/1609.02907.pdf).
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
hidden_size: The hidden size for gcn.
activation: The activation for the output.
name: Gcn layer names.
norm: If :code:`norm` is not None, then the feature will be normalized. Norm must
be tensor with shape (num_nodes,) and dtype float32.
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
return src_feat["h"]
size = feature.shape[-1]
if size > hidden_size:
feature = fluid.layers.fc(feature,
size=hidden_size,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name))
if norm is not None:
feature = feature * norm
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
if size > hidden_size:
output = gw.recv(msg, "sum")
else:
output = gw.recv(msg, "sum")
output = fluid.layers.fc(output,
size=hidden_size,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name))
if norm is not None:
output = output * norm
bias = fluid.layers.create_parameter(
shape=[hidden_size],
dtype='float32',
is_bias=True,
name=name + '_bias')
output = fluid.layers.elementwise_add(output, bias, act=activation)
return output
def gat(gw,
feature,
hidden_size,
activation,
name,
num_heads=8,
feat_drop=0.6,
attn_drop=0.6,
is_test=False):
"""Implementation of graph attention networks (GAT)
This is an implementation of the paper GRAPH ATTENTION NETWORKS
(https://arxiv.org/abs/1710.10903).
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
hidden_size: The hidden size for gat.
activation: The activation for the output.
name: Gat layer names.
num_heads: The head number in gat.
feat_drop: Dropout rate for feature.
attn_drop: Dropout rate for attention.
is_test: Whether in test phrase.
Return:
A tensor with shape (num_nodes, hidden_size * num_heads)
"""
def send_attention(src_feat, dst_feat, edge_feat):
output = src_feat["left_a"] + dst_feat["right_a"]
output = fluid.layers.leaky_relu(
output, alpha=0.2) # (num_edges, num_heads)
return {"alpha": output, "h": src_feat["h"]}
def reduce_attention(msg):
alpha = msg["alpha"] # lod-tensor (batch_size, seq_len, num_heads)
h = msg["h"]
alpha = paddle_helper.sequence_softmax(alpha)
old_h = h
h = fluid.layers.reshape(h, [-1, num_heads, hidden_size])
alpha = fluid.layers.reshape(alpha, [-1, num_heads, 1])
if attn_drop > 1e-15:
alpha = fluid.layers.dropout(
alpha,
dropout_prob=attn_drop,
is_test=is_test,
dropout_implementation="upscale_in_train")
h = h * alpha
h = fluid.layers.reshape(h, [-1, num_heads * hidden_size])
h = fluid.layers.lod_reset(h, old_h)
return fluid.layers.sequence_pool(h, "sum")
if feat_drop > 1e-15:
feature = fluid.layers.dropout(
feature,
dropout_prob=feat_drop,
is_test=is_test,
dropout_implementation='upscale_in_train')
ft = fluid.layers.fc(feature,
hidden_size * num_heads,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_weight'))
left_a = fluid.layers.create_parameter(
shape=[num_heads, hidden_size],
dtype='float32',
name=name + '_gat_l_A')
right_a = fluid.layers.create_parameter(
shape=[num_heads, hidden_size],
dtype='float32',
name=name + '_gat_r_A')
reshape_ft = fluid.layers.reshape(ft, [-1, num_heads, hidden_size])
left_a_value = fluid.layers.reduce_sum(reshape_ft * left_a, -1)
right_a_value = fluid.layers.reduce_sum(reshape_ft * right_a, -1)
msg = gw.send(
send_attention,
nfeat_list=[("h", ft), ("left_a", left_a_value),
("right_a", right_a_value)])
output = gw.recv(msg, reduce_attention)
bias = fluid.layers.create_parameter(
shape=[hidden_size * num_heads],
dtype='float32',
is_bias=True,
name=name + '_bias')
bias.stop_gradient = True
output = fluid.layers.elementwise_add(output, bias, act=activation)
return output
def gin(gw,
feature,
hidden_size,
activation,
name,
init_eps=0.0,
train_eps=False):
"""Implementation of Graph Isomorphism Network (GIN) layer.
This is an implementation of the paper How Powerful are Graph Neural Networks?
(https://arxiv.org/pdf/1810.00826.pdf).
In their implementation, all MLPs have 2 layers. Batch normalization is applied
on every hidden layer.
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
name: GIN layer names.
hidden_size: The hidden size for gin.
activation: The activation for the output.
init_eps: float, optional
Initial :math:`\epsilon` value, default is 0.
train_eps: bool, optional
if True, :math:`\epsilon` will be a learnable parameter.
Return:
A tensor with shape (num_nodes, hidden_size).
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
return src_feat["h"]
epsilon = fluid.layers.create_parameter(
shape=[1, 1],
dtype="float32",
attr=fluid.ParamAttr(name="%s_eps" % name),
default_initializer=fluid.initializer.ConstantInitializer(
value=init_eps))
if not train_eps:
epsilon.stop_gradient = True
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
output = gw.recv(msg, "sum") + feature * (epsilon + 1.0)
output = fluid.layers.fc(output,
size=hidden_size,
act=None,
param_attr=fluid.ParamAttr(name="%s_w_0" % name),
bias_attr=fluid.ParamAttr(name="%s_b_0" % name))
output = fluid.layers.layer_norm(
output,
begin_norm_axis=1,
param_attr=fluid.ParamAttr(
name="norm_scale_%s" % (name),
initializer=fluid.initializer.Constant(1.0)),
bias_attr=fluid.ParamAttr(
name="norm_bias_%s" % (name),
initializer=fluid.initializer.Constant(0.0)), )
if activation is not None:
output = getattr(fluid.layers, activation)(output)
output = fluid.layers.fc(output,
size=hidden_size,
act=activation,
param_attr=fluid.ParamAttr(name="%s_w_1" % name),
bias_attr=fluid.ParamAttr(name="%s_b_1" % name))
return output
def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o, heads, name):
"""Implementation of GaAN"""
def send_func(src_feat, dst_feat, edge_feat):
# attention score of each edge
# E * (M * D1)
feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
# E * M * D1
old = feat_query
feat_query = fluid.layers.reshape(feat_query, [-1, heads, hidden_size_a])
feat_key = fluid.layers.reshape(feat_key, [-1, heads, hidden_size_a])
# E * M
alpha = fluid.layers.reduce_sum(feat_key * feat_query, dim=-1)
return {'dst_node_feat': dst_feat['node_feat'],
'src_node_feat': src_feat['node_feat'],
'feat_value': src_feat['feat_value'],
'alpha': alpha,
'feat_gate': src_feat['feat_gate']}
def recv_func(message):
dst_feat = message['dst_node_feat']
src_feat = message['src_node_feat']
x = fluid.layers.sequence_pool(dst_feat, 'average')
z = fluid.layers.sequence_pool(src_feat, 'average')
feat_gate = message['feat_gate']
g_max = fluid.layers.sequence_pool(feat_gate, 'max')
g = fluid.layers.concat([x, g_max, z], axis=1)
g = fluid.layers.fc(g, heads, bias_attr=False, act="sigmoid")
# softmax
alpha = message['alpha']
alpha = paddle_helper.sequence_softmax(alpha) # E * M
feat_value = message['feat_value'] # E * (M * D2)
old = feat_value
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2
feat_value = fluid.layers.elementwise_mul(feat_value, alpha, axis=0)
feat_value = fluid.layers.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2)
feat_value = fluid.layers.lod_reset(feat_value, old)
feat_value = fluid.layers.sequence_pool(feat_value, 'sum') # N * (M * D2)
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2
output = fluid.layers.elementwise_mul(feat_value, g, axis=0)
output = fluid.layers.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2)
output = fluid.layers.concat([x, output], axis=1)
return output
# N * (D1 * M)
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key'))
# N * (D2 * M)
feat_value = fluid.layers.fc(feature, hidden_size_v * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_value'))
# N * (D1 * M)
feat_query = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_query'))
# N * Dm
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send stage
message = gw.send(
send_func,
nfeat_list=[('node_feat', feature), ('feat_key', feat_key), ('feat_value', feat_value),
('feat_query', feat_query), ('feat_gate', feat_gate)],
efeat_list=None,
)
# recv stage
output = gw.recv(message, recv_func)
output = fluid.layers.fc(output, hidden_size_o, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_output'))
output = fluid.layers.leaky_relu(output, alpha=0.1)
output = fluid.layers.dropout(output, dropout_prob=0.1)
return output
python3 train.py --epochs 100 --lr 1e-2 --rc 0 --batch_size 1024 --gpu_id 0 --exp_id 0
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import fluid
from pgl.utils import paddle_helper
# from pgl.layers import gaan
from conv import gaan
class GaANModel(object):
def __init__(self, num_class, num_layers, hidden_size_a=24,
hidden_size_v=32, hidden_size_m=64, hidden_size_o=128,
heads=8, act='relu', name="GaAN"):
self.num_class = num_class
self.num_layers = num_layers
self.hidden_size_a = hidden_size_a
self.hidden_size_v = hidden_size_v
self.hidden_size_m = hidden_size_m
self.hidden_size_o = hidden_size_o
self.act = act
self.name = name
self.heads = heads
def forward(self, gw):
feature = gw.node_feat['node_feat']
for i in range(self.num_layers):
feature = gaan(gw, feature, self.hidden_size_a, self.hidden_size_v,
self.hidden_size_m, self.hidden_size_o, self.heads,
self.name+'_'+str(i))
pred = fluid.layers.fc(
feature, self.num_class, act=None, name=self.name + "_pred_output")
return pred
\ No newline at end of file
""" # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
将 ogb_proteins 的数据处理为 PGL 的 graph 数据,并返回 graph, label, train/valid/test 等信息 #
""" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ssl import ssl
ssl._create_default_https_context = ssl._create_unverified_context ssl._create_default_https_context = ssl._create_unverified_context
from ogb.nodeproppred import NodePropPredDataset, Evaluator from ogb.nodeproppred import NodePropPredDataset, Evaluator
...@@ -17,7 +27,7 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -17,7 +27,7 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
d_name: name of dataset d_name: name of dataset
mini_data: if mini_data==True, only use a small dataset (for test) mini_data: if mini_data==True, only use a small dataset (for test)
""" """
# 导入 ogb 数据 # import ogb data
dataset = NodePropPredDataset(name = d_name) dataset = NodePropPredDataset(name = d_name)
num_tasks = dataset.num_tasks # obtaining the number of prediction tasks in a dataset num_tasks = dataset.num_tasks # obtaining the number of prediction tasks in a dataset
...@@ -25,10 +35,10 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -25,10 +35,10 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
graph, label = dataset[0] graph, label = dataset[0]
# 调整维度,符合 PGL 的 Graph 要求 # reshape
graph["edge_index"] = graph["edge_index"].T graph["edge_index"] = graph["edge_index"].T
# 使用小规模数据,500个节点 # mini dataset
if mini_data: if mini_data:
graph['num_nodes'] = 500 graph['num_nodes'] = 500
mask = (graph['edge_index'][:, 0] < 500)*(graph['edge_index'][:, 1] < 500) mask = (graph['edge_index'][:, 0] < 500)*(graph['edge_index'][:, 1] < 500)
...@@ -39,19 +49,9 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -39,19 +49,9 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
valid_idx = np.arange(400,450) valid_idx = np.arange(400,450)
test_idx = np.arange(450,500) test_idx = np.arange(450,500)
# 输出 dataset 的信息
print(graph.keys())
print("节点个数 ", graph["num_nodes"]) # read/compute node feature
print("节点最小编号", graph['edge_index'][0].min())
print("边个数 ", graph["edge_index"].shape[1])
print("边索引 shape ", graph["edge_index"].shape)
print("边特征 shape ", graph["edge_feat"].shape)
print("节点特征是 ", graph["node_feat"])
print("species shape", graph['species'].shape)
print("label shape ", label.shape)
# 读取/计算 node feature
# 确定读取文件的路径
if mini_data: if mini_data:
node_feat_path = './dataset/ogbn_proteins_node_feat_small.npy' node_feat_path = './dataset/ogbn_proteins_node_feat_small.npy'
else: else:
...@@ -59,14 +59,11 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -59,14 +59,11 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
new_node_feat = None new_node_feat = None
if os.path.exists(node_feat_path): if os.path.exists(node_feat_path):
# 如果文件存在,直接读取 print("Begin: read node feature".center(50, '='))
print("读取 node feature 开始".center(50, '='))
new_node_feat = np.load(node_feat_path) new_node_feat = np.load(node_feat_path)
print("读取 node feature 成功".center(50, '=')) print("End: read node feature".center(50, '='))
else: else:
# 如果文件不存在,则计算 print("Begin: compute node feature".center(50, '='))
# 每个节点 i 的特征为其邻边特征的均值
print("计算 node feature 开始".center(50, '='))
start = time.perf_counter() start = time.perf_counter()
for i in range(graph['num_nodes']): for i in range(graph['num_nodes']):
if i % 100 == 0: if i % 100 == 0:
...@@ -74,8 +71,8 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -74,8 +71,8 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
print("{}/{}({}%), times: {:.2f}s".format( print("{}/{}({}%), times: {:.2f}s".format(
i, graph['num_nodes'], i/graph['num_nodes']*100, dur i, graph['num_nodes'], i/graph['num_nodes']*100, dur
)) ))
mask = (graph['edge_index'][:, 0] == i) # 选择 i 的所有邻边 mask = (graph['edge_index'][:, 0] == i)
# 计算均值
current_node_feat = np.mean(np.compress(mask, graph['edge_feat'], axis=0), current_node_feat = np.mean(np.compress(mask, graph['edge_feat'], axis=0),
axis=0, keepdims=True) axis=0, keepdims=True)
if i == 0: if i == 0:
...@@ -84,23 +81,23 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False): ...@@ -84,23 +81,23 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
new_node_feat.append(current_node_feat) new_node_feat.append(current_node_feat)
new_node_feat = np.concatenate(new_node_feat, axis=0) new_node_feat = np.concatenate(new_node_feat, axis=0)
print("计算 node feature 结束".center(50,'=')) print("End: compute node feature".center(50,'='))
print("存储 node feature 中,在"+node_feat_path.center(50, '=')) print("Saving node feature in "+node_feat_path.center(50, '='))
np.save(node_feat_path, new_node_feat) np.save(node_feat_path, new_node_feat)
print("存储 node feature 结束".center(50,'=')) print("Saving finish".center(50,'='))
print(new_node_feat) print(new_node_feat)
# 构造 Graph 对象 # create graph
g = pgl.graph.Graph( g = pgl.graph.Graph(
num_nodes=graph["num_nodes"], num_nodes=graph["num_nodes"],
edges = graph["edge_index"], edges = graph["edge_index"],
node_feat = {'node_feat': new_node_feat}, node_feat = {'node_feat': new_node_feat},
edge_feat = None edge_feat = None
) )
print("创建 Graph 对象成功") print("Create graph")
print(g) print(g)
return g, label, train_idx, valid_idx, test_idx, Evaluator(d_name) return g, label, train_idx, valid_idx, test_idx, Evaluator(d_name)
\ No newline at end of file
...@@ -11,27 +11,31 @@ ...@@ -11,27 +11,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from preprocess import get_graph_data from preprocess import get_graph_data
import pgl import pgl
import argparse import argparse
import numpy as np import numpy as np
import time import time
from paddle import fluid from paddle import fluid
from visualdl import LogWriter
import reader import reader
from train_tool import train_epoch, valid_epoch from train_tool import train_epoch, valid_epoch
from model import GaANModel
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Training") parser = argparse.ArgumentParser(description="ogb Training")
parser.add_argument("--d_name", type=str, choices=["ogbn-proteins"], default="ogbn-proteins", parser.add_argument("--d_name", type=str, choices=["ogbn-proteins"], default="ogbn-proteins",
help="the name of dataset in ogb") help="the name of dataset in ogb")
parser.add_argument("--model", type=str, choices=["GaAN"], default="GaAN",
help="the name of model")
parser.add_argument("--mini_data", type=str, choices=["True", "False"], default="False", parser.add_argument("--mini_data", type=str, choices=["True", "False"], default="False",
help="use a small dataset to test the code") help="use a small dataset to test the code")
parser.add_argument("--use_gpu", type=bool, choices=[True, False], default=True, parser.add_argument("--use_gpu", type=bool, choices=[True, False], default=True,
help="use gpu") help="use gpu")
parser.add_argument("--gpu_id", type=int, default=0, parser.add_argument("--gpu_id", type=int, default=4,
help="the id of gpu") help="the id of gpu")
parser.add_argument("--exp_id", type=int, default=0, parser.add_argument("--exp_id", type=int, default=0,
help="the id of experiment") help="the id of experiment")
...@@ -58,7 +62,7 @@ if __name__ == "__main__": ...@@ -58,7 +62,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
print("setting".center(50, "=")) print("Parameters Setting".center(50, "="))
print("lr = {}, rc = {}, epochs = {}, batch_size = {}".format(args.lr, args.rc, args.epochs, print("lr = {}, rc = {}, epochs = {}, batch_size = {}".format(args.lr, args.rc, args.epochs,
args.batch_size)) args.batch_size))
print("Experiment ID: {}".format(args.exp_id).center(50, "=")) print("Experiment ID: {}".format(args.exp_id).center(50, "="))
...@@ -66,24 +70,12 @@ if __name__ == "__main__": ...@@ -66,24 +70,12 @@ if __name__ == "__main__":
d_name = args.d_name d_name = args.d_name
# get data # get data
g, label, train_idx, valid_idx, test_idx, evaluator = get_graph_data( g, label, train_idx, valid_idx, test_idx, evaluator = get_graph_data(d_name=d_name,
d_name=d_name,
mini_data=eval(args.mini_data)) mini_data=eval(args.mini_data))
if args.model == "GaAN":
# create log writer graph_model = GaANModel(112, 3, args.hidden_size_a, args.hidden_size_v, args.hidden_size_m,
log_writer = LogWriter(args.log_path, sync_cycle=10) args.hidden_size_o, args.heads)
with log_writer.mode("train") as logger:
log_train_loss_epoch = logger.scalar("loss")
log_train_rocauc_epoch = logger.scalar("rocauc")
with log_writer.mode("valid") as logger:
log_valid_loss_epoch = logger.scalar("loss")
log_valid_rocauc_epoch = logger.scalar("rocauc")
log_text = log_writer.text("text")
log_time = log_writer.scalar("time")
log_test_loss = log_writer.scalar("test_loss")
log_test_rocauc = log_writer.scalar("test_rocauc")
# training # training
samples = [25, 10] # 2-hop sample size samples = [25, 10] # 2-hop sample size
...@@ -102,6 +94,7 @@ if __name__ == "__main__": ...@@ -102,6 +94,7 @@ if __name__ == "__main__":
edge_feat=g.edge_feat_info() edge_feat=g.edge_feat_info()
) )
node_index = fluid.layers.data('node_index', shape=[None, 1], dtype="int64", node_index = fluid.layers.data('node_index', shape=[None, 1], dtype="int64",
append_batch_size=False) append_batch_size=False)
...@@ -109,11 +102,8 @@ if __name__ == "__main__": ...@@ -109,11 +102,8 @@ if __name__ == "__main__":
append_batch_size=False) append_batch_size=False)
parent_node_index = fluid.layers.data('parent_node_index', shape=[None, 1], dtype="int64", parent_node_index = fluid.layers.data('parent_node_index', shape=[None, 1], dtype="int64",
append_batch_size=False) append_batch_size=False)
feature = gw.node_feat['node_feat']
for i in range(3): output = graph_model.forward(gw)
feature = pgl.layers.GaAN(gw, feature, args.hidden_size_a, args.hidden_size_v,
args.hidden_size_m, args.hidden_size_o, args.heads, name='GaAN_'+str(i))
output = fluid.layers.fc(feature, 112, act=None)
output = fluid.layers.gather(output, node_index) output = fluid.layers.gather(output, node_index)
score = fluid.layers.sigmoid(output) score = fluid.layers.sigmoid(output)
...@@ -168,16 +158,13 @@ if __name__ == "__main__": ...@@ -168,16 +158,13 @@ if __name__ == "__main__":
start = time.time() start = time.time()
print("Training Begin".center(50, "=")) print("Training Begin".center(50, "="))
log_text.add_record(0, "Training Begin".center(50, "=")) best_valid = -1.0
for epoch in range(args.epochs): for epoch in range(args.epochs):
start_e = time.time() start_e = time.time()
# print("Train Epoch {}".format(epoch).center(50, "="))
train_loss, train_rocauc = train_epoch( train_loss, train_rocauc = train_epoch(
train_iter, program=train_program, exe=exe, loss=loss, score=score, train_iter, program=train_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch evaluator=evaluator, epoch=epoch
) )
print("Valid Epoch {}".format(epoch).center(50, "="))
valid_loss, valid_rocauc = valid_epoch( valid_loss, valid_rocauc = valid_epoch(
val_iter, program=val_program, exe=exe, loss=loss, score=score, val_iter, program=val_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch) evaluator=evaluator, epoch=epoch)
...@@ -185,32 +172,23 @@ if __name__ == "__main__": ...@@ -185,32 +172,23 @@ if __name__ == "__main__":
print("Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}".format( print("Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}".format(
epoch, train_loss, valid_loss, train_rocauc, valid_rocauc, end_e-start_e epoch, train_loss, valid_loss, train_rocauc, valid_rocauc, end_e-start_e
)) ))
log_text.add_record(epoch+1,
"Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}".format( if valid_rocauc > best_valid:
epoch, train_loss, valid_loss, train_rocauc, valid_rocauc, end_e-start_e print("Update: new {}, old {}".format(valid_rocauc, best_valid))
)) best_valid = valid_rocauc
log_train_loss_epoch.add_record(epoch, train_loss)
log_valid_loss_epoch.add_record(epoch, valid_loss) fluid.io.save_params(executor=exe, dirname='./params/'+str(args.exp_id), main_program=val_program)
log_train_rocauc_epoch.add_record(epoch, train_rocauc)
log_valid_rocauc_epoch.add_record(epoch, valid_rocauc)
log_time.add_record(epoch, end_e-start_e)
print("Test Stage".center(50, "=")) print("Test Stage".center(50, "="))
log_text.add_record(args.epochs+1, "Test Stage".center(50, "="))
fluid.io.load_params(executor=exe, dirname='./params/'+str(args.exp_id), main_program=val_program)
test_loss, test_rocauc = valid_epoch( test_loss, test_rocauc = valid_epoch(
test_iter, program=val_program, exe=exe, loss=loss, score=score, test_iter, program=val_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch) evaluator=evaluator, epoch=epoch)
log_test_loss.add_record(0, test_loss)
log_test_rocauc.add_record(0, test_rocauc)
end = time.time() end = time.time()
print("test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}".format( print("test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}".format(
test_loss, test_rocauc, end-start test_loss, test_rocauc, end-start
)) ))
print("End".center(50, "=")) print("End".center(50, "="))
log_text.add_record(args.epochs+2, "test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}".format(
test_loss, test_rocauc, end-start
))
log_text.add_record(args.epochs+3, "End".center(50, "="))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time import time
from pgl.utils.logger import log from pgl.utils.logger import log
...@@ -15,50 +28,12 @@ def train_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per ...@@ -15,50 +28,12 @@ def train_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per
total_sample += num_samples total_sample += num_samples
input_dict = { input_dict = {
"y_true": batch_feed_dict["node_label"], "y_true": batch_feed_dict["node_label"],
# "y_pred": y_pred[batch_feed_dict["node_index"]]
"y_pred": y_pred "y_pred": y_pred
} }
result += evaluator.eval(input_dict)["rocauc"] result += evaluator.eval(input_dict)["rocauc"]
# if batch % log_per_step == 0:
# print("Batch {}: Loss={}".format(batch, batch_loss))
# log.info("Batch %s %s-Loss %s %s-Acc %s" %
# (batch, prefix, batch_loss, prefix, batch_acc))
# print("Epoch {} Train: Loss={}, rocauc={}, Speed(per batch)={}".format(
# epoch, total_loss/total_sample, result/batch, (end-start)/batch))
return total_loss.item()/total_sample, result/batch return total_loss.item()/total_sample, result/batch
def inference(batch_iter, exe, program, loss, score, evaluator, epoch, log_per_step=1):
batch = 0
total_sample = 0
total_loss = 0
result = 0
start = time.time()
for batch_feed_dict in batch_iter():
batch += 1
y_pred = exe.run(program, fetch_list=[score], feed=batch_feed_dict)[0]
input_dict = {
"y_true": batch_feed_dict["node_label"],
"y_pred": y_pred[batch_feed_dict["node_index"]]
}
result += evaluator.eval(input_dict)["rocauc"]
if batch % log_per_step == 0:
print(batch, result/batch)
num_samples = len(batch_feed_dict["node_index"])
# total_loss += batch_loss * num_samples
# total_acc += batch_acc * num_samples
total_sample += num_samples
end = time.time()
print("Epoch {} Valid: Loss={}, Speed(per batch)={}".format(epoch, total_loss/total_sample,
(end-start)/batch))
return total_loss/total_sample, result/batch
def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per_step=1): def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per_step=1):
batch = 0 batch = 0
total_sample = 0 total_sample = 0
...@@ -69,53 +44,13 @@ def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per ...@@ -69,53 +44,13 @@ def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per
batch_loss, y_pred = exe.run(program, fetch_list=[loss, score], feed=batch_feed_dict) batch_loss, y_pred = exe.run(program, fetch_list=[loss, score], feed=batch_feed_dict)
input_dict = { input_dict = {
"y_true": batch_feed_dict["node_label"], "y_true": batch_feed_dict["node_label"],
# "y_pred": y_pred[batch_feed_dict["node_index"]]
"y_pred": y_pred "y_pred": y_pred
} }
# print(evaluator.eval(input_dict))
result += evaluator.eval(input_dict)["rocauc"] result += evaluator.eval(input_dict)["rocauc"]
# if batch % log_per_step == 0:
# print(batch, result/batch)
num_samples = len(batch_feed_dict["node_index"]) num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples total_loss += batch_loss * num_samples
# total_acc += batch_acc * num_samples
total_sample += num_samples total_sample += num_samples
# print("Epoch {} Valid: Loss={}, Speed(per batch)={}".format(epoch, total_loss/total_sample, (end-start)/batch))
return total_loss.item()/total_sample, result/batch return total_loss.item()/total_sample, result/batch
def run_epoch(batch_iter, exe, program, prefix, model_loss, model_acc, epoch, log_per_step=100):
"""
已废弃
"""
batch = 0
total_loss = 0.
total_acc = 0.
total_sample = 0
start = time.time()
for batch_feed_dict in batch_iter():
batch += 1
batch_loss, batch_acc = exe.run(program,
fetch_list=[model_loss, model_acc],
feed=batch_feed_dict)
if batch % log_per_step == 0:
log.info("Batch %s %s-Loss %s %s-Acc %s" %
(batch, prefix, batch_loss, prefix, batch_acc))
num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples
total_acc += batch_acc * num_samples
total_sample += num_samples
end = time.time()
log.info("%s Epoch %s Loss %.5lf Acc %.5lf Speed(per batch) %.5lf sec" %
(prefix, epoch, total_loss / total_sample,
total_acc / total_sample, (end - start) / batch))
...@@ -18,7 +18,7 @@ import paddle.fluid as fluid ...@@ -18,7 +18,7 @@ import paddle.fluid as fluid
from pgl import graph_wrapper from pgl import graph_wrapper
from pgl.utils import paddle_helper from pgl.utils import paddle_helper
__all__ = ['gcn', 'gat', 'gin', 'GaAN'] __all__ = ['gcn', 'gat', 'gin', 'gaan']
def gcn(gw, feature, hidden_size, activation, name, norm=None): def gcn(gw, feature, hidden_size, activation, name, norm=None):
...@@ -259,27 +259,19 @@ def gin(gw, ...@@ -259,27 +259,19 @@ def gin(gw,
return output return output
def GaAN(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o, heads,
name):
"""
This is an implementation of the paper GaAN: Gated Attention Networks for Learning
on Large and Spatiotemporal Graphs(https://arxiv.org/abs/1803.07294)
"""
# project the feature of nodes into new vector spaces
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key'))
feat_value = fluid.layers.fc(feature, hidden_size_v * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_value'))
feat_query = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_query'))
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send function def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o, heads, name):
"""Implementation of GaAN"""
def send_func(src_feat, dst_feat, edge_feat): def send_func(src_feat, dst_feat, edge_feat):
# 计算每条边上的注意力分数
# E * (M * D1), 每个 dst 点都查询它的全部邻边的 src 点
feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key'] feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
# E * M * D1
old = feat_query
feat_query = fluid.layers.reshape(feat_query, [-1, heads, hidden_size_a]) feat_query = fluid.layers.reshape(feat_query, [-1, heads, hidden_size_a])
feat_key = fluid.layers.reshape(feat_key, [-1, heads, hidden_size_a]) feat_key = fluid.layers.reshape(feat_key, [-1, heads, hidden_size_a])
# E * M
alpha = fluid.layers.reduce_sum(feat_key * feat_query, dim=-1) alpha = fluid.layers.reduce_sum(feat_key * feat_query, dim=-1)
return {'dst_node_feat': dst_feat['node_feat'], return {'dst_node_feat': dst_feat['node_feat'],
...@@ -288,53 +280,75 @@ def GaAN(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -288,53 +280,75 @@ def GaAN(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
'alpha': alpha, 'alpha': alpha,
'feat_gate': src_feat['feat_gate']} 'feat_gate': src_feat['feat_gate']}
# send stage
message = gw.send(send_func, nfeat_list=[('node_feat', feature),
('feat_key', feat_key), ('feat_value', feat_value),
('feat_query', feat_query), ('feat_gate', feat_gate)],
efeat_list=None,
)
# recv function
def recv_func(message): def recv_func(message):
dst_feat = message['dst_node_feat'] # feature of dst nodes on each edge # 每条边的终点的特征
src_feat = message['src_node_feat'] # feature of src nodes on each edge dst_feat = message['dst_node_feat']
x = fluid.layers.sequence_pool(dst_feat, 'average') # feature of center nodes # 每条边的出发点的特征
z = fluid.layers.sequence_pool(src_feat, 'average') # mean feature of neighbors src_feat = message['src_node_feat']
# 每个中心点自己的特征
# compute gate x = fluid.layers.sequence_pool(dst_feat, 'average')
# 每个中心点的邻居的特征的平均值
z = fluid.layers.sequence_pool(src_feat, 'average')
# 计算 gate
feat_gate = message['feat_gate'] feat_gate = message['feat_gate']
g_max = fluid.layers.sequence_pool(feat_gate, 'max') g_max = fluid.layers.sequence_pool(feat_gate, 'max')
g = fluid.layers.concat([x, g_max, z], axis=1) g = fluid.layers.concat([x, g_max, z], axis=1)
g = fluid.layers.fc(g, heads, bias_attr=False, act='sigmoid') g = fluid.layers.fc(g, heads, bias_attr=False, act="sigmoid")
# softmax of attention coefficient # softmax
alpha = message['alpha'] alpha = message['alpha']
alpha = paddle_helper.sequence_softmax(alpha) alpha = paddle_helper.sequence_softmax(alpha) # E * M
feat_value = message['feat_value'] feat_value = message['feat_value'] # E * (M * D2)
old = feat_value old = feat_value
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2
feat_value = fluid.layers.elementwise_mul(feat_value, alpha, axis=0) feat_value = fluid.layers.elementwise_mul(feat_value, alpha, axis=0)
feat_value = fluid.layers.reshape(feat_value, [-1, heads * hidden_size_v]) feat_value = fluid.layers.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2)
feat_value = fluid.layers.lod_reset(feat_value, old) feat_value = fluid.layers.lod_reset(feat_value, old)
feat_value = fluid.layers.sequence_pool(feat_value, 'sum') feat_value = fluid.layers.sequence_pool(feat_value, 'sum') # N * (M * D2)
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v])
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2
output = fluid.layers.elementwise_mul(feat_value, g, axis=0) output = fluid.layers.elementwise_mul(feat_value, g, axis=0)
output = fluid.layers.reshape(output, [-1, heads*hidden_size_v]) output = fluid.layers.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2)
output = fluid.layers.concat([x, output], axis=1) output = fluid.layers.concat([x, output], axis=1)
return output return output
# recv stage # feature N * D
output = gw.recv(message, recv_func)
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M)
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key'))
# N * (D2 * M)
feat_value = fluid.layers.fc(feature, hidden_size_v * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_value'))
# N * (D1 * M)
feat_query = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_query'))
# N * Dm
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send 阶段
# output message = gw.send(
send_func,
nfeat_list=[('node_feat', feature), ('feat_key', feat_key), ('feat_value', feat_value),
('feat_query', feat_query), ('feat_gate', feat_gate)],
efeat_list=None,
)
# 聚合邻居特征
output = gw.recv(message, recv_func)
output = fluid.layers.fc(output, hidden_size_o, bias_attr=False, output = fluid.layers.fc(output, hidden_size_o, bias_attr=False,
param_attr=fluid.ParamAttr(name=name+'_project_output')) param_attr=fluid.ParamAttr(name=name + '_project_output'))
outout = fluid.layers.leaky_relu(output, alpha=0.1) output = fluid.layers.leaky_relu(output, alpha=0.1)
output = fluid.layers.dropout(output, dropout_prob=0.1) output = fluid.layers.dropout(output, dropout_prob=0.1)
return output return output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册