未验证 提交 5782fb81 编写于 作者: W Weiyue Su 提交者: GitHub

Merge pull request #84 from WenjinW/master

Update example/GaAN
# data and log
.examples/GaAN/datase/t
.examples/GaAN/log/
.examples/GaAN/__pycache__/
/examples/GaAN/dataset/
/examples/GaAN/log/
/examples/GaAN/__pycache__/
/examples/GaAN/params/
/DoorGod
# Virtualenv
/.venv/
/venv/
......
......@@ -6,14 +6,20 @@
The ogbn-proteins dataset will be downloaded in directory ./dataset automatically.
## Dependencies
- paddlepaddle
- pgl
- ogb
- [paddlepaddle >= 1.6](https://github.com/paddlepaddle/paddle)
- [pgl 1.1](https://github.com/PaddlePaddle/PGL)
- [ogb 1.1.1](https://github.com/snap-stanford/ogb)
## How to run
```bash
python train.py --lr 1e-2 --rc 0 --batch_size 1024 --epochs 100
```
```
or
```bash
source main.sh
```
### Hyperparameters
- use_gpu: whether to use gpu or not
- mini_data: use a small dataset to test code
......@@ -32,4 +38,4 @@ python train.py --lr 1e-2 --rc 0 --batch_size 1024 --epochs 100
We train our models for 100 epochs and report the **rocauc** on the test dataset.
|dataset|mean|std|
|-|-|-|
|ogbn-proteins|0.7786|0.0048|
|ogbn-proteins|0.7803|0.0073|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This package implements common layers to help building
graph neural networks.
"""
import paddle.fluid as fluid
from pgl import graph_wrapper
from pgl.utils import paddle_helper
__all__ = ['gcn', 'gat', 'gin', 'gaan']
def gcn(gw, feature, hidden_size, activation, name, norm=None):
"""Implementation of graph convolutional neural networks (GCN)
This is an implementation of the paper SEMI-SUPERVISED CLASSIFICATION
WITH GRAPH CONVOLUTIONAL NETWORKS (https://arxiv.org/pdf/1609.02907.pdf).
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
hidden_size: The hidden size for gcn.
activation: The activation for the output.
name: Gcn layer names.
norm: If :code:`norm` is not None, then the feature will be normalized. Norm must
be tensor with shape (num_nodes,) and dtype float32.
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
return src_feat["h"]
size = feature.shape[-1]
if size > hidden_size:
feature = fluid.layers.fc(feature,
size=hidden_size,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name))
if norm is not None:
feature = feature * norm
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
if size > hidden_size:
output = gw.recv(msg, "sum")
else:
output = gw.recv(msg, "sum")
output = fluid.layers.fc(output,
size=hidden_size,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name))
if norm is not None:
output = output * norm
bias = fluid.layers.create_parameter(
shape=[hidden_size],
dtype='float32',
is_bias=True,
name=name + '_bias')
output = fluid.layers.elementwise_add(output, bias, act=activation)
return output
def gat(gw,
feature,
hidden_size,
activation,
name,
num_heads=8,
feat_drop=0.6,
attn_drop=0.6,
is_test=False):
"""Implementation of graph attention networks (GAT)
This is an implementation of the paper GRAPH ATTENTION NETWORKS
(https://arxiv.org/abs/1710.10903).
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
hidden_size: The hidden size for gat.
activation: The activation for the output.
name: Gat layer names.
num_heads: The head number in gat.
feat_drop: Dropout rate for feature.
attn_drop: Dropout rate for attention.
is_test: Whether in test phrase.
Return:
A tensor with shape (num_nodes, hidden_size * num_heads)
"""
def send_attention(src_feat, dst_feat, edge_feat):
output = src_feat["left_a"] + dst_feat["right_a"]
output = fluid.layers.leaky_relu(
output, alpha=0.2) # (num_edges, num_heads)
return {"alpha": output, "h": src_feat["h"]}
def reduce_attention(msg):
alpha = msg["alpha"] # lod-tensor (batch_size, seq_len, num_heads)
h = msg["h"]
alpha = paddle_helper.sequence_softmax(alpha)
old_h = h
h = fluid.layers.reshape(h, [-1, num_heads, hidden_size])
alpha = fluid.layers.reshape(alpha, [-1, num_heads, 1])
if attn_drop > 1e-15:
alpha = fluid.layers.dropout(
alpha,
dropout_prob=attn_drop,
is_test=is_test,
dropout_implementation="upscale_in_train")
h = h * alpha
h = fluid.layers.reshape(h, [-1, num_heads * hidden_size])
h = fluid.layers.lod_reset(h, old_h)
return fluid.layers.sequence_pool(h, "sum")
if feat_drop > 1e-15:
feature = fluid.layers.dropout(
feature,
dropout_prob=feat_drop,
is_test=is_test,
dropout_implementation='upscale_in_train')
ft = fluid.layers.fc(feature,
hidden_size * num_heads,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_weight'))
left_a = fluid.layers.create_parameter(
shape=[num_heads, hidden_size],
dtype='float32',
name=name + '_gat_l_A')
right_a = fluid.layers.create_parameter(
shape=[num_heads, hidden_size],
dtype='float32',
name=name + '_gat_r_A')
reshape_ft = fluid.layers.reshape(ft, [-1, num_heads, hidden_size])
left_a_value = fluid.layers.reduce_sum(reshape_ft * left_a, -1)
right_a_value = fluid.layers.reduce_sum(reshape_ft * right_a, -1)
msg = gw.send(
send_attention,
nfeat_list=[("h", ft), ("left_a", left_a_value),
("right_a", right_a_value)])
output = gw.recv(msg, reduce_attention)
bias = fluid.layers.create_parameter(
shape=[hidden_size * num_heads],
dtype='float32',
is_bias=True,
name=name + '_bias')
bias.stop_gradient = True
output = fluid.layers.elementwise_add(output, bias, act=activation)
return output
def gin(gw,
feature,
hidden_size,
activation,
name,
init_eps=0.0,
train_eps=False):
"""Implementation of Graph Isomorphism Network (GIN) layer.
This is an implementation of the paper How Powerful are Graph Neural Networks?
(https://arxiv.org/pdf/1810.00826.pdf).
In their implementation, all MLPs have 2 layers. Batch normalization is applied
on every hidden layer.
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
name: GIN layer names.
hidden_size: The hidden size for gin.
activation: The activation for the output.
init_eps: float, optional
Initial :math:`\epsilon` value, default is 0.
train_eps: bool, optional
if True, :math:`\epsilon` will be a learnable parameter.
Return:
A tensor with shape (num_nodes, hidden_size).
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
return src_feat["h"]
epsilon = fluid.layers.create_parameter(
shape=[1, 1],
dtype="float32",
attr=fluid.ParamAttr(name="%s_eps" % name),
default_initializer=fluid.initializer.ConstantInitializer(
value=init_eps))
if not train_eps:
epsilon.stop_gradient = True
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
output = gw.recv(msg, "sum") + feature * (epsilon + 1.0)
output = fluid.layers.fc(output,
size=hidden_size,
act=None,
param_attr=fluid.ParamAttr(name="%s_w_0" % name),
bias_attr=fluid.ParamAttr(name="%s_b_0" % name))
output = fluid.layers.layer_norm(
output,
begin_norm_axis=1,
param_attr=fluid.ParamAttr(
name="norm_scale_%s" % (name),
initializer=fluid.initializer.Constant(1.0)),
bias_attr=fluid.ParamAttr(
name="norm_bias_%s" % (name),
initializer=fluid.initializer.Constant(0.0)), )
if activation is not None:
output = getattr(fluid.layers, activation)(output)
output = fluid.layers.fc(output,
size=hidden_size,
act=activation,
param_attr=fluid.ParamAttr(name="%s_w_1" % name),
bias_attr=fluid.ParamAttr(name="%s_b_1" % name))
return output
def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o, heads, name):
"""Implementation of GaAN"""
def send_func(src_feat, dst_feat, edge_feat):
# attention score of each edge
# E * (M * D1)
feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
# E * M * D1
old = feat_query
feat_query = fluid.layers.reshape(feat_query, [-1, heads, hidden_size_a])
feat_key = fluid.layers.reshape(feat_key, [-1, heads, hidden_size_a])
# E * M
alpha = fluid.layers.reduce_sum(feat_key * feat_query, dim=-1)
return {'dst_node_feat': dst_feat['node_feat'],
'src_node_feat': src_feat['node_feat'],
'feat_value': src_feat['feat_value'],
'alpha': alpha,
'feat_gate': src_feat['feat_gate']}
def recv_func(message):
dst_feat = message['dst_node_feat']
src_feat = message['src_node_feat']
x = fluid.layers.sequence_pool(dst_feat, 'average')
z = fluid.layers.sequence_pool(src_feat, 'average')
feat_gate = message['feat_gate']
g_max = fluid.layers.sequence_pool(feat_gate, 'max')
g = fluid.layers.concat([x, g_max, z], axis=1)
g = fluid.layers.fc(g, heads, bias_attr=False, act="sigmoid")
# softmax
alpha = message['alpha']
alpha = paddle_helper.sequence_softmax(alpha) # E * M
feat_value = message['feat_value'] # E * (M * D2)
old = feat_value
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2
feat_value = fluid.layers.elementwise_mul(feat_value, alpha, axis=0)
feat_value = fluid.layers.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2)
feat_value = fluid.layers.lod_reset(feat_value, old)
feat_value = fluid.layers.sequence_pool(feat_value, 'sum') # N * (M * D2)
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2
output = fluid.layers.elementwise_mul(feat_value, g, axis=0)
output = fluid.layers.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2)
output = fluid.layers.concat([x, output], axis=1)
return output
# N * (D1 * M)
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key'))
# N * (D2 * M)
feat_value = fluid.layers.fc(feature, hidden_size_v * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_value'))
# N * (D1 * M)
feat_query = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_query'))
# N * Dm
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send stage
message = gw.send(
send_func,
nfeat_list=[('node_feat', feature), ('feat_key', feat_key), ('feat_value', feat_value),
('feat_query', feat_query), ('feat_gate', feat_gate)],
efeat_list=None,
)
# recv stage
output = gw.recv(message, recv_func)
output = fluid.layers.fc(output, hidden_size_o, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_output'))
output = fluid.layers.leaky_relu(output, alpha=0.1)
output = fluid.layers.dropout(output, dropout_prob=0.1)
return output
python3 train.py --epochs 100 --lr 1e-2 --rc 0 --batch_size 1024 --gpu_id 0 --exp_id 0
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import fluid
from pgl.utils import paddle_helper
# from pgl.layers import gaan
from conv import gaan
class GaANModel(object):
def __init__(self, num_class, num_layers, hidden_size_a=24,
hidden_size_v=32, hidden_size_m=64, hidden_size_o=128,
heads=8, act='relu', name="GaAN"):
self.num_class = num_class
self.num_layers = num_layers
self.hidden_size_a = hidden_size_a
self.hidden_size_v = hidden_size_v
self.hidden_size_m = hidden_size_m
self.hidden_size_o = hidden_size_o
self.act = act
self.name = name
self.heads = heads
def forward(self, gw):
feature = gw.node_feat['node_feat']
for i in range(self.num_layers):
feature = gaan(gw, feature, self.hidden_size_a, self.hidden_size_v,
self.hidden_size_m, self.hidden_size_o, self.heads,
self.name+'_'+str(i))
pred = fluid.layers.fc(
feature, self.num_class, act=None, name=self.name + "_pred_output")
return pred
\ No newline at end of file
"""
将 ogb_proteins 的数据处理为 PGL 的 graph 数据,并返回 graph, label, train/valid/test 等信息
"""
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
from ogb.nodeproppred import NodePropPredDataset, Evaluator
......@@ -17,7 +27,7 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
d_name: name of dataset
mini_data: if mini_data==True, only use a small dataset (for test)
"""
# 导入 ogb 数据
# import ogb data
dataset = NodePropPredDataset(name = d_name)
num_tasks = dataset.num_tasks # obtaining the number of prediction tasks in a dataset
......@@ -25,10 +35,10 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
graph, label = dataset[0]
# 调整维度,符合 PGL 的 Graph 要求
# reshape
graph["edge_index"] = graph["edge_index"].T
# 使用小规模数据,500个节点
# mini dataset
if mini_data:
graph['num_nodes'] = 500
mask = (graph['edge_index'][:, 0] < 500)*(graph['edge_index'][:, 1] < 500)
......@@ -39,19 +49,9 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
valid_idx = np.arange(400,450)
test_idx = np.arange(450,500)
# 输出 dataset 的信息
print(graph.keys())
print("节点个数 ", graph["num_nodes"])
print("节点最小编号", graph['edge_index'][0].min())
print("边个数 ", graph["edge_index"].shape[1])
print("边索引 shape ", graph["edge_index"].shape)
print("边特征 shape ", graph["edge_feat"].shape)
print("节点特征是 ", graph["node_feat"])
print("species shape", graph['species'].shape)
print("label shape ", label.shape)
# 读取/计算 node feature
# 确定读取文件的路径
# read/compute node feature
if mini_data:
node_feat_path = './dataset/ogbn_proteins_node_feat_small.npy'
else:
......@@ -59,14 +59,11 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
new_node_feat = None
if os.path.exists(node_feat_path):
# 如果文件存在,直接读取
print("读取 node feature 开始".center(50, '='))
print("Begin: read node feature".center(50, '='))
new_node_feat = np.load(node_feat_path)
print("读取 node feature 成功".center(50, '='))
print("End: read node feature".center(50, '='))
else:
# 如果文件不存在,则计算
# 每个节点 i 的特征为其邻边特征的均值
print("计算 node feature 开始".center(50, '='))
print("Begin: compute node feature".center(50, '='))
start = time.perf_counter()
for i in range(graph['num_nodes']):
if i % 100 == 0:
......@@ -74,8 +71,8 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
print("{}/{}({}%), times: {:.2f}s".format(
i, graph['num_nodes'], i/graph['num_nodes']*100, dur
))
mask = (graph['edge_index'][:, 0] == i) # 选择 i 的所有邻边
# 计算均值
mask = (graph['edge_index'][:, 0] == i)
current_node_feat = np.mean(np.compress(mask, graph['edge_feat'], axis=0),
axis=0, keepdims=True)
if i == 0:
......@@ -84,23 +81,23 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
new_node_feat.append(current_node_feat)
new_node_feat = np.concatenate(new_node_feat, axis=0)
print("计算 node feature 结束".center(50,'='))
print("End: compute node feature".center(50,'='))
print("存储 node feature 中,在"+node_feat_path.center(50, '='))
print("Saving node feature in "+node_feat_path.center(50, '='))
np.save(node_feat_path, new_node_feat)
print("存储 node feature 结束".center(50,'='))
print("Saving finish".center(50,'='))
print(new_node_feat)
# 构造 Graph 对象
# create graph
g = pgl.graph.Graph(
num_nodes=graph["num_nodes"],
edges = graph["edge_index"],
node_feat = {'node_feat': new_node_feat},
edge_feat = None
)
print("创建 Graph 对象成功")
print("Create graph")
print(g)
return g, label, train_idx, valid_idx, test_idx, Evaluator(d_name)
\ No newline at end of file
......@@ -11,27 +11,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from preprocess import get_graph_data
import pgl
import argparse
import numpy as np
import time
from paddle import fluid
from visualdl import LogWriter
import reader
from train_tool import train_epoch, valid_epoch
from train_tool import train_epoch, valid_epoch
from model import GaANModel
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Training")
parser = argparse.ArgumentParser(description="ogb Training")
parser.add_argument("--d_name", type=str, choices=["ogbn-proteins"], default="ogbn-proteins",
help="the name of dataset in ogb")
parser.add_argument("--model", type=str, choices=["GaAN"], default="GaAN",
help="the name of model")
parser.add_argument("--mini_data", type=str, choices=["True", "False"], default="False",
help="use a small dataset to test the code")
parser.add_argument("--use_gpu", type=bool, choices=[True, False], default=True,
help="use gpu")
parser.add_argument("--gpu_id", type=int, default=0,
parser.add_argument("--gpu_id", type=int, default=4,
help="the id of gpu")
parser.add_argument("--exp_id", type=int, default=0,
help="the id of experiment")
......@@ -58,7 +62,7 @@ if __name__ == "__main__":
args = parser.parse_args()
print("setting".center(50, "="))
print("Parameters Setting".center(50, "="))
print("lr = {}, rc = {}, epochs = {}, batch_size = {}".format(args.lr, args.rc, args.epochs,
args.batch_size))
print("Experiment ID: {}".format(args.exp_id).center(50, "="))
......@@ -66,24 +70,12 @@ if __name__ == "__main__":
d_name = args.d_name
# get data
g, label, train_idx, valid_idx, test_idx, evaluator = get_graph_data(
d_name=d_name,
mini_data=eval(args.mini_data))
g, label, train_idx, valid_idx, test_idx, evaluator = get_graph_data(d_name=d_name,
mini_data=eval(args.mini_data))
# create log writer
log_writer = LogWriter(args.log_path, sync_cycle=10)
with log_writer.mode("train") as logger:
log_train_loss_epoch = logger.scalar("loss")
log_train_rocauc_epoch = logger.scalar("rocauc")
with log_writer.mode("valid") as logger:
log_valid_loss_epoch = logger.scalar("loss")
log_valid_rocauc_epoch = logger.scalar("rocauc")
log_text = log_writer.text("text")
log_time = log_writer.scalar("time")
log_test_loss = log_writer.scalar("test_loss")
log_test_rocauc = log_writer.scalar("test_rocauc")
if args.model == "GaAN":
graph_model = GaANModel(112, 3, args.hidden_size_a, args.hidden_size_v, args.hidden_size_m,
args.hidden_size_o, args.heads)
# training
samples = [25, 10] # 2-hop sample size
......@@ -102,6 +94,7 @@ if __name__ == "__main__":
edge_feat=g.edge_feat_info()
)
node_index = fluid.layers.data('node_index', shape=[None, 1], dtype="int64",
append_batch_size=False)
......@@ -109,11 +102,8 @@ if __name__ == "__main__":
append_batch_size=False)
parent_node_index = fluid.layers.data('parent_node_index', shape=[None, 1], dtype="int64",
append_batch_size=False)
feature = gw.node_feat['node_feat']
for i in range(3):
feature = pgl.layers.GaAN(gw, feature, args.hidden_size_a, args.hidden_size_v,
args.hidden_size_m, args.hidden_size_o, args.heads, name='GaAN_'+str(i))
output = fluid.layers.fc(feature, 112, act=None)
output = graph_model.forward(gw)
output = fluid.layers.gather(output, node_index)
score = fluid.layers.sigmoid(output)
......@@ -168,16 +158,13 @@ if __name__ == "__main__":
start = time.time()
print("Training Begin".center(50, "="))
log_text.add_record(0, "Training Begin".center(50, "="))
best_valid = -1.0
for epoch in range(args.epochs):
start_e = time.time()
# print("Train Epoch {}".format(epoch).center(50, "="))
train_loss, train_rocauc = train_epoch(
train_iter, program=train_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch
)
print("Valid Epoch {}".format(epoch).center(50, "="))
valid_loss, valid_rocauc = valid_epoch(
val_iter, program=val_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch)
......@@ -185,32 +172,23 @@ if __name__ == "__main__":
print("Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}".format(
epoch, train_loss, valid_loss, train_rocauc, valid_rocauc, end_e-start_e
))
log_text.add_record(epoch+1,
"Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}".format(
epoch, train_loss, valid_loss, train_rocauc, valid_rocauc, end_e-start_e
))
log_train_loss_epoch.add_record(epoch, train_loss)
log_valid_loss_epoch.add_record(epoch, valid_loss)
log_train_rocauc_epoch.add_record(epoch, train_rocauc)
log_valid_rocauc_epoch.add_record(epoch, valid_rocauc)
log_time.add_record(epoch, end_e-start_e)
if valid_rocauc > best_valid:
print("Update: new {}, old {}".format(valid_rocauc, best_valid))
best_valid = valid_rocauc
fluid.io.save_params(executor=exe, dirname='./params/'+str(args.exp_id), main_program=val_program)
print("Test Stage".center(50, "="))
log_text.add_record(args.epochs+1, "Test Stage".center(50, "="))
fluid.io.load_params(executor=exe, dirname='./params/'+str(args.exp_id), main_program=val_program)
test_loss, test_rocauc = valid_epoch(
test_iter, program=val_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch)
log_test_loss.add_record(0, test_loss)
log_test_rocauc.add_record(0, test_rocauc)
end = time.time()
print("test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}".format(
test_loss, test_rocauc, end-start
))
print("End".center(50, "="))
log_text.add_record(args.epochs+2, "test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}".format(
test_loss, test_rocauc, end-start
))
log_text.add_record(args.epochs+3, "End".center(50, "="))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from pgl.utils.logger import log
......@@ -15,50 +28,12 @@ def train_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per
total_sample += num_samples
input_dict = {
"y_true": batch_feed_dict["node_label"],
# "y_pred": y_pred[batch_feed_dict["node_index"]]
"y_pred": y_pred
}
result += evaluator.eval(input_dict)["rocauc"]
# if batch % log_per_step == 0:
# print("Batch {}: Loss={}".format(batch, batch_loss))
# log.info("Batch %s %s-Loss %s %s-Acc %s" %
# (batch, prefix, batch_loss, prefix, batch_acc))
# print("Epoch {} Train: Loss={}, rocauc={}, Speed(per batch)={}".format(
# epoch, total_loss/total_sample, result/batch, (end-start)/batch))
return total_loss.item()/total_sample, result/batch
def inference(batch_iter, exe, program, loss, score, evaluator, epoch, log_per_step=1):
batch = 0
total_sample = 0
total_loss = 0
result = 0
start = time.time()
for batch_feed_dict in batch_iter():
batch += 1
y_pred = exe.run(program, fetch_list=[score], feed=batch_feed_dict)[0]
input_dict = {
"y_true": batch_feed_dict["node_label"],
"y_pred": y_pred[batch_feed_dict["node_index"]]
}
result += evaluator.eval(input_dict)["rocauc"]
if batch % log_per_step == 0:
print(batch, result/batch)
num_samples = len(batch_feed_dict["node_index"])
# total_loss += batch_loss * num_samples
# total_acc += batch_acc * num_samples
total_sample += num_samples
end = time.time()
print("Epoch {} Valid: Loss={}, Speed(per batch)={}".format(epoch, total_loss/total_sample,
(end-start)/batch))
return total_loss/total_sample, result/batch
def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per_step=1):
batch = 0
total_sample = 0
......@@ -69,53 +44,13 @@ def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per
batch_loss, y_pred = exe.run(program, fetch_list=[loss, score], feed=batch_feed_dict)
input_dict = {
"y_true": batch_feed_dict["node_label"],
# "y_pred": y_pred[batch_feed_dict["node_index"]]
"y_pred": y_pred
}
# print(evaluator.eval(input_dict))
result += evaluator.eval(input_dict)["rocauc"]
# if batch % log_per_step == 0:
# print(batch, result/batch)
num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples
# total_acc += batch_acc * num_samples
total_sample += num_samples
# print("Epoch {} Valid: Loss={}, Speed(per batch)={}".format(epoch, total_loss/total_sample, (end-start)/batch))
return total_loss.item()/total_sample, result/batch
def run_epoch(batch_iter, exe, program, prefix, model_loss, model_acc, epoch, log_per_step=100):
"""
已废弃
"""
batch = 0
total_loss = 0.
total_acc = 0.
total_sample = 0
start = time.time()
for batch_feed_dict in batch_iter():
batch += 1
batch_loss, batch_acc = exe.run(program,
fetch_list=[model_loss, model_acc],
feed=batch_feed_dict)
if batch % log_per_step == 0:
log.info("Batch %s %s-Loss %s %s-Acc %s" %
(batch, prefix, batch_loss, prefix, batch_acc))
num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples
total_acc += batch_acc * num_samples
total_sample += num_samples
end = time.time()
log.info("%s Epoch %s Loss %.5lf Acc %.5lf Speed(per batch) %.5lf sec" %
(prefix, epoch, total_loss / total_sample,
total_acc / total_sample, (end - start) / batch))
......@@ -593,7 +593,7 @@ class Graph(object):
edges = self._edges[eid]
else:
edges = np.array(edges, dtype="int64")
sub_edges = graph_kernel.map_edges(
np.arange(
len(edges), dtype="int64"), edges, reindex)
......
......@@ -18,7 +18,7 @@ import paddle.fluid as fluid
from pgl import graph_wrapper
from pgl.utils import paddle_helper
__all__ = ['gcn', 'gat', 'gin', 'GaAN']
__all__ = ['gcn', 'gat', 'gin', 'gaan']
def gcn(gw, feature, hidden_size, activation, name, norm=None):
......@@ -259,27 +259,19 @@ def gin(gw,
return output
def GaAN(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o, heads,
name):
"""
This is an implementation of the paper GaAN: Gated Attention Networks for Learning
on Large and Spatiotemporal Graphs(https://arxiv.org/abs/1803.07294)
"""
# project the feature of nodes into new vector spaces
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key'))
feat_value = fluid.layers.fc(feature, hidden_size_v * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_value'))
feat_query = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_query'))
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send function
def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o, heads, name):
"""Implementation of GaAN"""
def send_func(src_feat, dst_feat, edge_feat):
# 计算每条边上的注意力分数
# E * (M * D1), 每个 dst 点都查询它的全部邻边的 src 点
feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
# E * M * D1
old = feat_query
feat_query = fluid.layers.reshape(feat_query, [-1, heads, hidden_size_a])
feat_key = fluid.layers.reshape(feat_key, [-1, heads, hidden_size_a])
# E * M
alpha = fluid.layers.reduce_sum(feat_key * feat_query, dim=-1)
return {'dst_node_feat': dst_feat['node_feat'],
......@@ -288,53 +280,75 @@ def GaAN(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
'alpha': alpha,
'feat_gate': src_feat['feat_gate']}
# send stage
message = gw.send(send_func, nfeat_list=[('node_feat', feature),
('feat_key', feat_key), ('feat_value', feat_value),
('feat_query', feat_query), ('feat_gate', feat_gate)],
efeat_list=None,
)
# recv function
def recv_func(message):
dst_feat = message['dst_node_feat'] # feature of dst nodes on each edge
src_feat = message['src_node_feat'] # feature of src nodes on each edge
x = fluid.layers.sequence_pool(dst_feat, 'average') # feature of center nodes
z = fluid.layers.sequence_pool(src_feat, 'average') # mean feature of neighbors
# compute gate
# 每条边的终点的特征
dst_feat = message['dst_node_feat']
# 每条边的出发点的特征
src_feat = message['src_node_feat']
# 每个中心点自己的特征
x = fluid.layers.sequence_pool(dst_feat, 'average')
# 每个中心点的邻居的特征的平均值
z = fluid.layers.sequence_pool(src_feat, 'average')
# 计算 gate
feat_gate = message['feat_gate']
g_max = fluid.layers.sequence_pool(feat_gate, 'max')
g = fluid.layers.concat([x, g_max, z], axis=1)
g = fluid.layers.fc(g, heads, bias_attr=False, act='sigmoid')
g = fluid.layers.fc(g, heads, bias_attr=False, act="sigmoid")
# softmax of attention coefficient
# softmax
alpha = message['alpha']
alpha = paddle_helper.sequence_softmax(alpha)
alpha = paddle_helper.sequence_softmax(alpha) # E * M
feat_value = message['feat_value']
feat_value = message['feat_value'] # E * (M * D2)
old = feat_value
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v])
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2
feat_value = fluid.layers.elementwise_mul(feat_value, alpha, axis=0)
feat_value = fluid.layers.reshape(feat_value, [-1, heads * hidden_size_v])
feat_value = fluid.layers.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2)
feat_value = fluid.layers.lod_reset(feat_value, old)
feat_value = fluid.layers.sequence_pool(feat_value, 'sum')
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v])
feat_value = fluid.layers.sequence_pool(feat_value, 'sum') # N * (M * D2)
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2
output = fluid.layers.elementwise_mul(feat_value, g, axis=0)
output = fluid.layers.reshape(output, [-1, heads*hidden_size_v])
output = fluid.layers.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2)
output = fluid.layers.concat([x, output], axis=1)
return output
# recv stage
# feature N * D
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M)
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key'))
# N * (D2 * M)
feat_value = fluid.layers.fc(feature, hidden_size_v * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_value'))
# N * (D1 * M)
feat_query = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_query'))
# N * Dm
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send 阶段
message = gw.send(
send_func,
nfeat_list=[('node_feat', feature), ('feat_key', feat_key), ('feat_value', feat_value),
('feat_query', feat_query), ('feat_gate', feat_gate)],
efeat_list=None,
)
# 聚合邻居特征
output = gw.recv(message, recv_func)
# output
output = fluid.layers.fc(output, hidden_size_o, bias_attr=False,
param_attr=fluid.ParamAttr(name=name+'_project_output'))
outout = fluid.layers.leaky_relu(output, alpha=0.1)
param_attr=fluid.ParamAttr(name=name + '_project_output'))
output = fluid.layers.leaky_relu(output, alpha=0.1)
output = fluid.layers.dropout(output, dropout_prob=0.1)
return output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册