提交 2bd2b054 编写于 作者: W wangwenjin

add GaAN example in examples

上级 adf9f8d1
# GaAN: Gated Attention Networks for Learning on Large and Spatiotemporal Graphs
[GaAN](https://arxiv.org/abs/1803.07294) is a powerful neural network designed for machine learning on graph. It introduces an gated attention mechanism. Based on PGL, we reproduce the GaAN algorithm and train the model on [ogbn-proteins](https://ogb.stanford.edu/docs/nodeprop/#ogbn-proteins).
## Datasets
The ogbn-proteins dataset will be downloaded in directory ./dataset automatically.
## Dependencies
- paddlepaddle
- pgl
- ogb
## How to run
```bash
python train.py --lr 1e-2 --rc 0 --batch_size 1024 --epochs 100
```
### Hyperparameters
- use_gpu: whether to use gpu or not
- mini_data: use a small dataset to test code
- epochs: number of training epochs
- lr: learning rate
- rc: regularization coefficient
- log_path: the path of log
- batch_size: the number of batch size
- heads: the number of heads of attention
- hidden_size_a: the size of query and key vectors
- hidden_size_v: the size of value vectors
- hidden_size_m: the size of projection space for computing gates
- hidden_size_o: the size of output of GaAN layer
"""
将 ogb_proteins 的数据处理为 PGL 的 graph 数据,并返回 graph, label, train/valid/test 等信息
"""
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
from ogb.nodeproppred import NodePropPredDataset, Evaluator
import pgl
import numpy as np
import os
import time
def get_graph_data(d_name="ogbn-proteins", mini_data=False):
"""
Param:
d_name: name of dataset
mini_data: if mini_data==True, only use a small dataset (for test)
"""
# 导入 ogb 数据
dataset = NodePropPredDataset(name = d_name)
num_tasks = dataset.num_tasks # obtaining the number of prediction tasks in a dataset
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
graph, label = dataset[0]
# 调整维度,符合 PGL 的 Graph 要求
graph["edge_index"] = graph["edge_index"].T
# 使用小规模数据,500个节点
if mini_data:
graph['num_nodes'] = 500
mask = (graph['edge_index'][:, 0] < 500)*(graph['edge_index'][:, 1] < 500)
graph["edge_index"] = graph["edge_index"][mask]
graph["edge_feat"] = graph["edge_feat"][mask]
label = label[:500]
train_idx = np.arange(0,400)
valid_idx = np.arange(400,450)
test_idx = np.arange(450,500)
# 输出 dataset 的信息
print(graph.keys())
print("节点个数 ", graph["num_nodes"])
print("节点最小编号", graph['edge_index'][0].min())
print("边个数 ", graph["edge_index"].shape[1])
print("边索引 shape ", graph["edge_index"].shape)
print("边特征 shape ", graph["edge_feat"].shape)
print("节点特征是 ", graph["node_feat"])
print("species shape", graph['species'].shape)
print("label shape ", label.shape)
# 读取/计算 node feature
# 确定读取文件的路径
if mini_data:
node_feat_path = './dataset/ogbn_proteins_node_feat_small.npy'
else:
node_feat_path = './dataset/ogbn_proteins_node_feat.npy'
new_node_feat = None
if os.path.exists(node_feat_path):
# 如果文件存在,直接读取
print("读取 node feature 开始".center(50, '='))
new_node_feat = np.load(node_feat_path)
print("读取 node feature 成功".center(50, '='))
else:
# 如果文件不存在,则计算
# 每个节点 i 的特征为其邻边特征的均值
print("计算 node feature 开始".center(50, '='))
start = time.perf_counter()
for i in range(graph['num_nodes']):
if i % 100 == 0:
dur = time.perf_counter() - start
print("{}/{}({}%), times: {:.2f}s".format(
i, graph['num_nodes'], i/graph['num_nodes']*100, dur
))
mask = (graph['edge_index'][:, 0] == i) # 选择 i 的所有邻边
# 计算均值
current_node_feat = np.mean(np.compress(mask, graph['edge_feat'], axis=0),
axis=0, keepdims=True)
if i == 0:
new_node_feat = [current_node_feat]
else:
new_node_feat.append(current_node_feat)
new_node_feat = np.concatenate(new_node_feat, axis=0)
print("计算 node feature 结束".center(50,'='))
print("存储 node feature 中,在"+node_feat_path.center(50, '='))
np.save(node_feat_path, new_node_feat)
print("存储 node feature 结束".center(50,'='))
print(new_node_feat)
# 构造 Graph 对象
g = pgl.graph.Graph(
num_nodes=graph["num_nodes"],
edges = graph["edge_index"],
node_feat = {'node_feat': new_node_feat},
edge_feat = None
)
print("创建 Graph 对象成功")
print(g)
return g, label, train_idx, valid_idx, test_idx, Evaluator(d_name)
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pickle as pkl
import paddle
import paddle.fluid as fluid
import pgl
import time
from pgl.utils import mp_reader
from pgl.utils.logger import log
import time
import copy
def node_batch_iter(nodes, node_label, batch_size):
"""node_batch_iter
"""
perm = np.arange(len(nodes))
np.random.shuffle(perm)
start = 0
while start < len(nodes):
index = perm[start:start + batch_size]
start += batch_size
yield nodes[index], node_label[index]
def traverse(item):
"""traverse
"""
if isinstance(item, list) or isinstance(item, np.ndarray):
for i in iter(item):
for j in traverse(i):
yield j
else:
yield item
def flat_node_and_edge(nodes):
"""flat_node_and_edge
"""
nodes = list(set(traverse(nodes)))
return nodes
def worker(batch_info, graph, graph_wrapper, samples):
"""Worker
"""
def work():
"""work
"""
_graph_wrapper = copy.copy(graph_wrapper)
_graph_wrapper.node_feat_tensor_dict = {}
for batch_train_samples, batch_train_labels in batch_info:
start_nodes = batch_train_samples
nodes = start_nodes
edges = []
for max_deg in samples:
pred_nodes = graph.sample_predecessor(
start_nodes, max_degree=max_deg)
for dst_node, src_nodes in zip(start_nodes, pred_nodes):
for src_node in src_nodes:
edges.append((src_node, dst_node))
last_nodes = nodes
nodes = [nodes, pred_nodes]
nodes = flat_node_and_edge(nodes)
# Find new nodes
start_nodes = list(set(nodes) - set(last_nodes))
if len(start_nodes) == 0:
break
subgraph = graph.subgraph(
nodes=nodes,
edges=edges,
with_node_feat=True,
with_edge_feat=True)
sub_node_index = subgraph.reindex_from_parrent_nodes(
batch_train_samples)
feed_dict = _graph_wrapper.to_feed(subgraph)
feed_dict["node_label"] = batch_train_labels
feed_dict["node_index"] = sub_node_index
feed_dict["parent_node_index"] = np.array(nodes, dtype="int64")
yield feed_dict
return work
def multiprocess_graph_reader(graph,
graph_wrapper,
samples,
node_index,
batch_size,
node_label,
with_parent_node_index=False,
num_workers=4):
"""multiprocess_graph_reader
"""
def parse_to_subgraph(rd, prefix, node_feat, _with_parent_node_index):
"""parse_to_subgraph
"""
def work():
"""work
"""
for data in rd():
feed_dict = data
for key in node_feat:
feed_dict[prefix + '/node_feat/' + key] = node_feat[key][
feed_dict["parent_node_index"]]
if not _with_parent_node_index:
del feed_dict["parent_node_index"]
yield feed_dict
return work
def reader():
"""reader"""
batch_info = list(
node_batch_iter(
node_index, node_label, batch_size=batch_size))
block_size = int(len(batch_info) / num_workers + 1)
reader_pool = []
for i in range(num_workers):
reader_pool.append(
worker(batch_info[block_size * i:block_size * (i + 1)], graph,
graph_wrapper, samples))
if len(reader_pool) == 1:
r = parse_to_subgraph(reader_pool[0],
repr(graph_wrapper), graph.node_feat,
with_parent_node_index)
else:
multi_process_sample = mp_reader.multiprocess_reader(
reader_pool, use_pipe=True, queue_size=1000)
r = parse_to_subgraph(multi_process_sample,
repr(graph_wrapper), graph.node_feat,
with_parent_node_index)
return paddle.reader.buffered(r, num_workers)
return reader()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from preprocess import get_graph_data
import pgl
import argparse
import numpy as np
import time
from paddle import fluid
from visualdl import LogWriter
import reader
from train_tool import train_epoch, valid_epoch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--d_name", type=str, choices=["ogbn-proteins"], default="ogbn-proteins",
help="the name of dataset in ogb")
parser.add_argument("--mini_data", type=str, choices=["True", "False"], default="False",
help="use a small dataset to test the code")
parser.add_argument("--use_gpu", type=bool, choices=[True, False], default=True,
help="use gpu")
parser.add_argument("--gpu_id", type=int, default=0,
help="the id of gpu")
parser.add_argument("--exp_id", type=int, default=0,
help="the id of experiment")
parser.add_argument("--epochs", type=int, default=100,
help="the number of training epochs")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate of Adam")
parser.add_argument("--rc", type=float, default=0,
help="regularization coefficient")
parser.add_argument("--log_path", type=str, default="./log",
help="the path of log")
parser.add_argument("--batch_size", type=int, default=1024,
help="the number of batch size")
parser.add_argument("--heads", type=int, default=8,
help="the number of heads of attention")
parser.add_argument("--hidden_size_a", type=int, default=24,
help="the hidden size of query and key vectors")
parser.add_argument("--hidden_size_v", type=int, default=32,
help="the hidden size of value vectors")
parser.add_argument("--hidden_size_m", type=int, default=64,
help="the hidden size of projection for computing gates")
parser.add_argument("--hidden_size_o", type=int ,default=128,
help="the hidden size of each layer in GaAN")
args = parser.parse_args()
print("setting".center(50, "="))
print("lr = {}, rc = {}, epochs = {}, batch_size = {}".format(args.lr, args.rc, args.epochs,
args.batch_size))
print("Experiment ID: {}".format(args.exp_id).center(50, "="))
print("training in GPU: {}".format(args.gpu_id).center(50, "="))
d_name = args.d_name
# get data
g, label, train_idx, valid_idx, test_idx, evaluator = get_graph_data(
d_name=d_name,
mini_data=eval(args.mini_data))
# create log writer
log_writer = LogWriter(args.log_path, sync_cycle=10)
with log_writer.mode("train") as logger:
log_train_loss_epoch = logger.scalar("loss")
log_train_rocauc_epoch = logger.scalar("rocauc")
with log_writer.mode("valid") as logger:
log_valid_loss_epoch = logger.scalar("loss")
log_valid_rocauc_epoch = logger.scalar("rocauc")
log_text = log_writer.text("text")
log_time = log_writer.scalar("time")
log_test_loss = log_writer.scalar("test_loss")
log_test_rocauc = log_writer.scalar("test_rocauc")
# training
samples = [25, 10] # 2-hop sample size
batch_size = args.batch_size
sample_workers = 1
place = fluid.CUDAPlace(args.gpu_id) if args.use_gpu else fluid.CPUPlace()
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper(
name='graph',
place = place,
node_feat=g.node_feat_info(),
edge_feat=g.edge_feat_info()
)
node_index = fluid.layers.data('node_index', shape=[None, 1], dtype="int64",
append_batch_size=False)
node_label = fluid.layers.data('node_label', shape=[None, 112], dtype="float32",
append_batch_size=False)
parent_node_index = fluid.layers.data('parent_node_index', shape=[None, 1], dtype="int64",
append_batch_size=False)
feature = gw.node_feat['node_feat']
for i in range(3):
feature = pgl.layers.GaAN(gw, feature, args.hidden_size_a, args.hidden_size_v,
args.hidden_size_m, args.hidden_size_o, args.heads, name='GaAN_'+str(i))
output = fluid.layers.fc(feature, 112, act=None)
output = fluid.layers.gather(output, node_index)
score = fluid.layers.sigmoid(output)
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=output, label=node_label)
loss = fluid.layers.mean(loss)
val_program = train_program.clone(for_test=True)
with fluid.program_guard(train_program, startup_program):
lr = args.lr
adam = fluid.optimizer.Adam(
learning_rate=lr,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=args.rc))
adam.minimize(loss)
exe = fluid.Executor(place)
exe.run(startup_program)
train_iter = reader.multiprocess_graph_reader(
g,
gw,
samples=samples,
num_workers=sample_workers,
batch_size=batch_size,
with_parent_node_index=True,
node_index=train_idx,
node_label=np.array(label[train_idx], dtype='float32'))
val_iter = reader.multiprocess_graph_reader(
g,
gw,
samples=samples,
num_workers=sample_workers,
batch_size=batch_size,
with_parent_node_index=True,
node_index=valid_idx,
node_label=np.array(label[valid_idx], dtype='float32'))
test_iter = reader.multiprocess_graph_reader(
g,
gw,
samples=samples,
num_workers=sample_workers,
batch_size=batch_size,
with_parent_node_index=True,
node_index=test_idx,
node_label=np.array(label[test_idx], dtype='float32'))
start = time.time()
print("Training Begin".center(50, "="))
log_text.add_record(0, "Training Begin".center(50, "="))
for epoch in range(args.epochs):
start_e = time.time()
# print("Train Epoch {}".format(epoch).center(50, "="))
train_loss, train_rocauc = train_epoch(
train_iter, program=train_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch
)
print("Valid Epoch {}".format(epoch).center(50, "="))
valid_loss, valid_rocauc = valid_epoch(
val_iter, program=val_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch)
end_e = time.time()
print("Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}".format(
epoch, train_loss, valid_loss, train_rocauc, valid_rocauc, end_e-start_e
))
log_text.add_record(epoch+1,
"Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}".format(
epoch, train_loss, valid_loss, train_rocauc, valid_rocauc, end_e-start_e
))
log_train_loss_epoch.add_record(epoch, train_loss)
log_valid_loss_epoch.add_record(epoch, valid_loss)
log_train_rocauc_epoch.add_record(epoch, train_rocauc)
log_valid_rocauc_epoch.add_record(epoch, valid_rocauc)
log_time.add_record(epoch, end_e-start_e)
print("Test Stage".center(50, "="))
log_text.add_record(args.epochs+1, "Test Stage".center(50, "="))
test_loss, test_rocauc = valid_epoch(
test_iter, program=val_program, exe=exe, loss=loss, score=score,
evaluator=evaluator, epoch=epoch)
log_test_loss.add_record(0, test_loss)
log_test_rocauc.add_record(0, test_rocauc)
end = time.time()
print("test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}".format(
test_loss, test_rocauc, end-start
))
print("End".center(50, "="))
log_text.add_record(args.epochs+2, "test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}".format(
test_loss, test_rocauc, end-start
))
log_text.add_record(args.epochs+3, "End".center(50, "="))
import time
from pgl.utils.logger import log
def train_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per_step=1):
batch = 0
total_loss = 0.0
total_sample = 0
result = 0
for batch_feed_dict in batch_iter():
batch += 1
batch_loss, y_pred = exe.run(program, fetch_list=[loss, score], feed=batch_feed_dict)
num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples
total_sample += num_samples
input_dict = {
"y_true": batch_feed_dict["node_label"],
# "y_pred": y_pred[batch_feed_dict["node_index"]]
"y_pred": y_pred
}
result += evaluator.eval(input_dict)["rocauc"]
# if batch % log_per_step == 0:
# print("Batch {}: Loss={}".format(batch, batch_loss))
# log.info("Batch %s %s-Loss %s %s-Acc %s" %
# (batch, prefix, batch_loss, prefix, batch_acc))
# print("Epoch {} Train: Loss={}, rocauc={}, Speed(per batch)={}".format(
# epoch, total_loss/total_sample, result/batch, (end-start)/batch))
return total_loss.item()/total_sample, result/batch
def inference(batch_iter, exe, program, loss, score, evaluator, epoch, log_per_step=1):
batch = 0
total_sample = 0
total_loss = 0
result = 0
start = time.time()
for batch_feed_dict in batch_iter():
batch += 1
y_pred = exe.run(program, fetch_list=[score], feed=batch_feed_dict)[0]
input_dict = {
"y_true": batch_feed_dict["node_label"],
"y_pred": y_pred[batch_feed_dict["node_index"]]
}
result += evaluator.eval(input_dict)["rocauc"]
if batch % log_per_step == 0:
print(batch, result/batch)
num_samples = len(batch_feed_dict["node_index"])
# total_loss += batch_loss * num_samples
# total_acc += batch_acc * num_samples
total_sample += num_samples
end = time.time()
print("Epoch {} Valid: Loss={}, Speed(per batch)={}".format(epoch, total_loss/total_sample,
(end-start)/batch))
return total_loss/total_sample, result/batch
def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per_step=1):
batch = 0
total_sample = 0
result = 0
total_loss = 0.0
for batch_feed_dict in batch_iter():
batch += 1
batch_loss, y_pred = exe.run(program, fetch_list=[loss, score], feed=batch_feed_dict)
input_dict = {
"y_true": batch_feed_dict["node_label"],
# "y_pred": y_pred[batch_feed_dict["node_index"]]
"y_pred": y_pred
}
# print(evaluator.eval(input_dict))
result += evaluator.eval(input_dict)["rocauc"]
# if batch % log_per_step == 0:
# print(batch, result/batch)
num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples
# total_acc += batch_acc * num_samples
total_sample += num_samples
# print("Epoch {} Valid: Loss={}, Speed(per batch)={}".format(epoch, total_loss/total_sample, (end-start)/batch))
return total_loss.item()/total_sample, result/batch
def run_epoch(batch_iter, exe, program, prefix, model_loss, model_acc, epoch, log_per_step=100):
"""
已废弃
"""
batch = 0
total_loss = 0.
total_acc = 0.
total_sample = 0
start = time.time()
for batch_feed_dict in batch_iter():
batch += 1
batch_loss, batch_acc = exe.run(program,
fetch_list=[model_loss, model_acc],
feed=batch_feed_dict)
if batch % log_per_step == 0:
log.info("Batch %s %s-Loss %s %s-Acc %s" %
(batch, prefix, batch_loss, prefix, batch_acc))
num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples
total_acc += batch_acc * num_samples
total_sample += num_samples
end = time.time()
log.info("%s Epoch %s Loss %.5lf Acc %.5lf Speed(per batch) %.5lf sec" %
(prefix, epoch, total_loss / total_sample,
total_acc / total_sample, (end - start) / batch))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册