未验证 提交 29cd2939 编写于 作者: H Huang Zhengjie 提交者: GitHub

Merge pull request #120 from sys1874/main

add UniPM 
## Masked Label Prediction: Unified Massage Passing Model for Semi-Supervised Classification
This experiment is based on stanford OGB (1.2.1) benchmark. The description of 《Masked Label Prediction: Unified Massage Passing Model for Semi-Supervised Classification》 is [avaiable here](). The steps are:
### Install environment:
```
git clone https://github.com/PaddlePaddle/PGL.git
cd PGL
pip install -e
pip install -r requirements.txt
```
### Arxiv dataset:
1. ```python main_arxiv.py --place 0 --log_file arxiv_baseline.txt``` to get the baseline result of arxiv dataset.
2. ```python main_arxiv.py --place 0 --use_label_e --log_file arxiv_unipm.txt``` to get the UniPM result of arxiv dataset.
### Products dataset:
1. ```python main_product.py --place 0 --log_file product_label_embedding.txt --use_label_e``` to get the UniPM result of Products dataset.
### Proteins dataset:
1. ```python main_protein.py --place 0 --log_file protein_baseline.txt ``` to get the baseline result of Proteins dataset.
2. ```python main_protein.py --place 0 --use_label_e --log_file protein_label_embedding.txt``` to get the UniPM result of Proteins dataset.
### The **detailed hyperparameter** is:
```
Arxiv_dataset(Full Batch): Products_dataset(NeighborSampler): Proteins_dataset(Random Partition):
--num_layers 3 --num_layers 3 --num_layers 7
--hidden_size 128 --hidden_size 128 --hidden_size 64
--num_heads 2 --num_heads 4 --num_heads 4
--dropout 0.3 --dropout 0.3 --dropout 0.1
--lr 0.001 --lr 0.001 --lr 0.001
--use_label_e True --use_label_e True --use_label_e True
--label_rate 0.625 --label_rate 0.625 --label_rate 0.5
--weight_decay. 0.0005
```
### Reference performance for OGB:
| Model |Test Accuracy |Valid Accuracy | Parameters | Hardware |
| ------------------ |-------------- | --------------- | -------------- |----------|
| Arxiv_baseline | 0.7225 ± 0.0015 | 0.7367 ± 0.0012 | 468,369 | Tesla V100 (32GB) |
| Arxiv_UniPM | 0.7311 ± 0.0021 | 0.7450 ± 0.0005 | 473,489 | Tesla V100 (32GB) |
| Products_baseline | 0.8023 ± 0.0026 | 0.9286 ± 0.0017 | 1,470,905 | Tesla V100 (32GB) |
| Products_UniPM | 0.8256 ± 0.0031 | 0.9308 ± 0.0017 | 1,475,605 | Tesla V100 (32GB) |
| Proteins_baseline | 0.8611 ± 0.0017 | 0.9128 ± 0.0007 | 1,879,664 | Tesla V100 (32GB) |
| Proteins_UniPM | 0.8643 ± 0.0016 | 0.9175 ± 0.0007 | 1,909,104 | Tesla V100 (32GB) |
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base DataLoader
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os
import sys
import six
from io import open
from collections import namedtuple
import numpy as np
import tqdm
import paddle
from pgl.utils import mp_reader
import collections
import time
import pgl
def batch_iter(data, perm, batch_size, fid, num_workers):
"""node_batch_iter
"""
size = len(data)
start = 0
cc = 0
while start < size:
index = perm[start:start + batch_size]
start += batch_size
cc += 1
if cc % num_workers != fid:
continue
yield data[index]
def scan_batch_iter(data, batch_size, fid, num_workers):
"""node_batch_iter
"""
batch = []
cc = 0
for line_example in data.scan():
cc += 1
if cc % num_workers != fid:
continue
batch.append(line_example)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
class BaseDataGenerator(object):
"""Base Data Geneartor"""
def __init__(self, buf_size, batch_size, num_workers, shuffle=True):
self.num_workers = num_workers
self.batch_size = batch_size
self.line_examples = []
self.buf_size = buf_size
self.shuffle = shuffle
def batch_fn(self, batch_examples):
""" batch_fn batch producer"""
raise NotImplementedError("No defined Batch Fn")
def batch_iter(self, fid, perm):
""" batch iterator"""
if self.shuffle:
for batch in batch_iter(self, perm, self.batch_size, fid,
self.num_workers):
yield batch
else:
for batch in scan_batch_iter(self, self.batch_size, fid,
self.num_workers):
yield batch
def __len__(self):
return len(self.line_examples)
def __getitem__(self, idx):
if isinstance(idx, collections.Iterable):
return [self[bidx] for bidx in idx]
else:
return self.line_examples[idx]
def generator(self):
"""batch dict generator"""
def worker(filter_id, perm):
""" multiprocess worker"""
def func_run():
""" func_run """
pid = os.getpid()
np.random.seed(pid + int(time.time()))
for batch_examples in self.batch_iter(filter_id, perm):
batch_dict = self.batch_fn(batch_examples)
yield batch_dict
return func_run
# consume a seed
np.random.rand()
if self.shuffle:
perm = np.arange(0, len(self))
np.random.shuffle(perm)
else:
perm = None
if self.num_workers == 1:
r = paddle.reader.buffered(worker(0, perm), self.buf_size)
else:
worker_pool = [
worker(wid, perm) for wid in range(self.num_workers)
]
worker = mp_reader.multiprocess_reader(
worker_pool, use_pipe=True, queue_size=1000)
r = paddle.reader.buffered(worker, self.buf_size)
for batch in r():
yield batch
def scan(self):
'''scan
'''
for line_example in self.line_examples:
yield line_example
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
ogb_products_dataloader
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from dataloader.base_dataloader import BaseDataGenerator
from pgl.contrib.ogb.nodeproppred.dataset_pgl import PglNodePropPredDataset
import tqdm
from collections import namedtuple
import pgl
import numpy as np
import copy
def add_self_loop_for_subgraph(graph):
'''add_self_loop_for_subgraph
'''
self_loop_edges = np.zeros((graph.num_nodes, 2))
self_loop_edges[:, 0] = self_loop_edges[:, 1] = np.arange(graph.num_nodes)
edges = np.vstack((graph.edges, self_loop_edges))
edges = np.unique(edges, axis=0)
g = pgl.graph.SubGraph(num_nodes=graph.num_nodes, edges=edges, reindex=graph._from_reindex)
for k, v in graph._node_feat.items():
g._node_feat[k] = v
return graph
def traverse(item):
"""traverse
"""
if isinstance(item, list) or isinstance(item, np.ndarray):
for i in iter(item):
for j in traverse(i):
yield j
else:
yield item
def flat_node_and_edge(nodes):
"""flat_node_and_edge
"""
nodes = list(set(traverse(nodes)))
return nodes
def k_hop_sampler(graph, samples, batch_nodes):
graph_list = []
for max_deg in samples:
start_nodes = copy.deepcopy(batch_nodes)
edges = []
if max_deg == -1:
pred_nodes = graph.predecessor(start_nodes)
else:
pred_nodes = graph.sample_predecessor(start_nodes, max_degree=max_deg)
for dst_node, src_nodes in zip(start_nodes, pred_nodes):
for src_node in src_nodes:
edges.append((src_node, dst_node))
nodes = [start_nodes, pred_nodes]
nodes = flat_node_and_edge(nodes)
subgraph = graph.subgraph(
nodes=nodes, edges=edges, with_node_feat=False, with_edge_feat=False)
subgraph = add_self_loop_for_subgraph(subgraph)
sub_node_index = subgraph.reindex_from_parrent_nodes(batch_nodes)
batch_nodes = nodes
graph_list.append((subgraph, batch_nodes, sub_node_index))
graph_list = graph_list[::-1]
# for k, v in graph._node_feat.items():
# graph_list[0][0]._node_feat[k] = v
# sub_node_index = subgraph.reindex_from_parrent_nodes(batch_nodes)
return graph_list
class SampleDataGenerator(BaseDataGenerator):
def __init__(self,
graph_wrappers=None,
buf_size=1000,
batch_size=128,
num_workers=1,
sizes=[30, 30],
shuffle=True,
dataset=None,
nodes_idx=None):
super(SampleDataGenerator, self).__init__(
buf_size=buf_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle)
self.sizes = sizes
self.graph_wrappers = graph_wrappers
self.dataset = dataset
graph, labels = dataset[0]
self.graph = graph
self.num_nodes = graph.num_nodes
if nodes_idx is not None:
self.nodes_idx = nodes_idx
else:
self.nodes_idx = np.arange(self.num_nodes)
self.labels_all = labels
self.labels = labels[self.nodes_idx]
self.sample_based_line_example(self.nodes_idx, self.labels)
def sample_based_line_example(self, nodes_idx, labels):
self.line_examples = []
Example = namedtuple('Example', ["node", "label"])
for node, label in zip(nodes_idx, labels):
self.line_examples.append(Example(node=node, label=label))
print("Len Examples", len(self.line_examples))
def batch_fn(self, batch_ex):
batch_nodes = []
cc = 0
batch_node_id = []
batch_labels = []
for ex in batch_ex:
batch_nodes.append(ex.node)
batch_labels.append(ex.label)
# _graph_wrapper = copy.copy(self.graph_wrapper)
# graph_list
graph_list = k_hop_sampler(self.graph, self.sizes,
batch_nodes) # -1 = 全采样操作
feed_dict_all = {}
for i in range(len(self.sizes)):
feed_dict = self.graph_wrappers[i].to_feed(graph_list[i][0])
feed_dict_all.update(feed_dict)
if i == 0:
feed_dict_all["batch_nodes_" + str(i)] = np.array(graph_list[i][1])
feed_dict_all["sub_node_index_" + str(i)] = graph_list[i][2]
# feed_dict = _graph_wrapper.to_feed(subgraph)
# feed_dict["batch_nodes"] = np.array(batch_nodes)
# feed_dict["sub_node_index"] = sub_node_index
feed_dict_all["label_all"] = self.labels_all
feed_dict_all["label"] = np.array(batch_labels, dtype="int64")
return feed_dict_all
\ No newline at end of file
import math
import torch
import paddle
import pgl
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from pgl.contrib.ogb.nodeproppred.dataset_pgl import PglNodePropPredDataset
from ogb.nodeproppred import Evaluator
from utils import to_undirected, add_self_loop, linear_warmup_decay
from model import Arxiv_baseline_model, Arxiv_label_embedding_model
import argparse
from tqdm import tqdm
evaluator = Evaluator(name='ogbn-arxiv')
# place=F.CUDAPlace(6)
def get_config():
parser = argparse.ArgumentParser()
## 基本模型参数
model_group=parser.add_argument_group('model_base_arg')
model_group.add_argument('--num_layers', default=3, type=int)
model_group.add_argument('--hidden_size', default=128, type=int)
model_group.add_argument('--num_heads', default=2, type=int)
model_group.add_argument('--dropout', default=0.3, type=float)
model_group.add_argument('--attn_dropout', default=0, type=float)
## label embedding模型参数
embed_group=parser.add_argument_group('embed_arg')
embed_group.add_argument('--use_label_e', action='store_true')
embed_group.add_argument('--label_rate', default=0.625, type=float)
## train_arg
train_group=parser.add_argument_group('train_arg')
train_group.add_argument('--runs', default=10, type=int )
train_group.add_argument('--epochs', default=2000, type=int )
train_group.add_argument('--lr', default=0.001, type=float)
train_group.add_argument('--place', default=-1, type=int)
train_group.add_argument('--log_file', default='result_arxiv.txt', type=str)
return parser.parse_args()
# def optimizer_func(lr=0.01):
# return F.optimizer.AdamOptimizer(learning_rate=lr, regularization=F.regularizer.L2Decay(
# regularization_coeff=0.001))
def optimizer_func(lr=0.01):
return F.optimizer.AdamOptimizer(learning_rate=lr, regularization=F.regularizer.L2Decay(
regularization_coeff=0.0005))
def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
feed_dict=model.gw.to_feed(graph)
# feed_dict={}
if parser.use_label_e:
feed_dict['label']=y_true
feed_dict['label_idx']=split_idx['train']
avg_cost_np = test_exe.run(
program=program,
feed=feed_dict,
fetch_list=[model.out_feat])
y_pred=avg_cost_np[0].argmax(axis=-1)
y_pred=np.expand_dims(y_pred, 1)
train_acc = evaluator.eval({
'y_true': y_true[split_idx['train']],
'y_pred': y_pred[split_idx['train']],
})['acc']
val_acc = evaluator.eval({
'y_true': y_true[split_idx['valid']],
'y_pred': y_pred[split_idx['valid']],
})['acc']
test_acc = evaluator.eval({
'y_true': y_true[split_idx['test']],
'y_pred': y_pred[split_idx['test']],
})['acc']
return train_acc, val_acc, test_acc
def train_loop(parser, start_program, main_program, test_program,
model, graph, label, split_idx, exe, run_id, wf=None):
#启动上文构建的训练器
exe.run(start_program)
max_acc=0 # 最佳test_acc
max_step=0 # 最佳test_acc 对应step
max_val_acc=0 # 最佳val_acc
max_cor_acc=0 # 最佳val_acc对应test_acc
max_cor_step=0 # 最佳val_acc对应step
#训练循环
for epoch_id in tqdm(range(parser.epochs)):
#运行训练器
if parser.use_label_e:
feed_dict=model.gw.to_feed(graph)
# feed_dict={}
train_idx_temp = split_idx['train']
np.random.shuffle(train_idx_temp)
label_idx=train_idx_temp[ :int(parser.label_rate*len(train_idx_temp))]
unlabel_idx=train_idx_temp[int(parser.label_rate*len(train_idx_temp)): ]
feed_dict['label']=label
feed_dict['label_idx']= label_idx
feed_dict['train_idx']= unlabel_idx
else:
feed_dict=model.gw.to_feed(graph)
# feed_dict={}
feed_dict['label']=label
feed_dict['train_idx']= split_idx['train']
loss = exe.run(main_program,
feed=feed_dict,
fetch_list=[model.avg_cost])
# print(loss[1][0])
loss = loss[0]
#测试结果
result = eval_test(parser, test_program, model, exe, graph, label, split_idx)
train_acc, valid_acc, test_acc = result
max_acc = max(test_acc, max_acc)
if max_acc == test_acc:
max_step=epoch_id
max_val_acc=max(valid_acc, max_val_acc)
if max_val_acc==valid_acc:
max_cor_acc=test_acc
max_cor_step=epoch_id
max_acc=max(result[2], max_acc)
if max_acc==result[2]:
max_step=epoch_id
result_t=(f'Run: {run_id:02d}, '
f'Epoch: {epoch_id:02d}, '
f'Loss: {loss[0]:.4f}, '
f'Train: {100 * train_acc:.2f}%, '
f'Valid: {100 * valid_acc:.2f}%, '
f'Test: {100 * test_acc:.2f}% \n'
f'max_Test: {100 * max_acc:.2f}%, '
f'max_step: {max_step}\n'
f'max_val: {100 * max_val_acc:.2f}%, '
f'max_val_Test: {100 * max_cor_acc:.2f}%, '
f'max_val_step: {max_cor_step}\n'
)
if (epoch_id+1)%100==0:
print(result_t)
wf.write(result_t)
wf.write('\n')
wf.flush()
return max_cor_acc
if __name__ == '__main__':
parser = get_config()
print('===========args==============')
print(parser)
print('=============================')
startup_prog = F.default_startup_program()
train_prog = F.default_main_program()
place=F.CPUPlace() if parser.place <0 else F.CUDAPlace(parser.place)
dataset = PglNodePropPredDataset(name="ogbn-arxiv")
split_idx=dataset.get_idx_split()
graph, label = dataset[0]
print(label.shape)
graph=to_undirected(graph)
graph=add_self_loop(graph)
with F.program_guard(train_prog, startup_prog):
with F.unique_name.guard():
gw = pgl.graph_wrapper.GraphWrapper(
name="arxiv", node_feat=graph.node_feat_info(), place=place)
# gw = pgl.graph_wrapper.StaticGraphWrapper(name="graph",
# graph=graph,
# place=place)
# gw.initialize(place)
#gw, hidden_size, num_heads, dropout, num_layers)
if parser.use_label_e:
model=Arxiv_label_embedding_model(gw, parser.hidden_size, parser.num_heads,
parser.dropout, parser.num_layers)
else:
model=Arxiv_baseline_model(gw, parser.hidden_size, parser.num_heads,
parser.dropout, parser.num_layers)
test_prog=train_prog.clone(for_test=True)
model.train_program()
# ave_loss = train_program(pred_output)#训练程序
# lr, global_step= linear_warmup_decay(parser.lr, parser.epochs*0.1, parser.epochs)
# adam_optimizer = optimizer_func(lr)#训练优化函数
adam_optimizer = optimizer_func(parser.lr)#训练优化函数
adam_optimizer.minimize(model.avg_cost)
exe = F.Executor(place)
wf = open(parser.log_file, 'w', encoding='utf-8')
total_test_acc=0.0
for run_i in range(parser.runs):
total_test_acc+=train_loop(parser, startup_prog, train_prog, test_prog, model,
graph, label, split_idx, exe, run_i, wf)
wf.write(f'average: {100 * (total_test_acc/parser.runs):.2f}%')
wf.close()
\ No newline at end of file
import math
import torch
import paddle
import pgl
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
import copy
from pgl.contrib.ogb.nodeproppred.dataset_pgl import PglNodePropPredDataset
from ogb.nodeproppred import Evaluator
from utils import to_undirected, add_self_loop, linear_warmup_decay
from model import Products_label_embedding_model
from dataloader.ogb_products_dataloader import SampleDataGenerator
import paddle.fluid.profiler as profiler
from pgl.utils import paddle_helper
import argparse
from tqdm import tqdm
evaluator = Evaluator(name='ogbn-products')
def get_config():
parser = argparse.ArgumentParser()
## 采样参数
data_group= parser.add_argument_group('data_arg')
data_group.add_argument('--batch_size', default=1500, type=int)
data_group.add_argument('--num_workers', default=12, type=int)
data_group.add_argument('--sizes', default=[10, 10, 10], type=int, nargs='+' )
data_group.add_argument('--buf_size', default=1000, type=int)
## 基本模型参数
model_group=parser.add_argument_group('model_base_arg')
model_group.add_argument('--num_layers', default=3, type=int)
model_group.add_argument('--hidden_size', default=128, type=int)
model_group.add_argument('--num_heads', default=4, type=int)
model_group.add_argument('--dropout', default=0.3, type=float)
model_group.add_argument('--attn_dropout', default=0, type=float)
## label embedding模型参数
embed_group=parser.add_argument_group('embed_arg')
embed_group.add_argument('--use_label_e', action='store_true')
embed_group.add_argument('--label_rate', default=0.625, type=float)
## train_arg
train_group=parser.add_argument_group('train_arg')
train_group.add_argument('--runs', default=10, type=int )
train_group.add_argument('--epochs', default=100, type=int )
train_group.add_argument('--lr', default=0.001, type=float)
train_group.add_argument('--place', default=-1, type=int)
train_group.add_argument('--log_file', default='result_products.txt', type=str)
return parser.parse_args()
def optimizer_func(lr):
return F.optimizer.AdamOptimizer(learning_rate=lr)
def eval_test(parser, test_p_list, model, test_exe, dataset, split_idx):
eval_gg=SampleDataGenerator(graph_wrappers=[model.gw_list[0]], buf_size=parser.buf_size,
batch_size=parser.batch_size , num_workers=1,
sizes=[-1,], shuffle=False,
dataset=dataset,
nodes_idx=None)
out_r_temp=[]
test_p, out=test_p_list[0]
pbar = tqdm(total=eval_gg.num_nodes* model.num_layers)
pbar.set_description('Evaluating')
for feed_batch in tqdm(eval_gg.generator()):
feed_batch['label_idx']=split_idx['train']
feat_batch= test_exe.run(test_p,
feed=feed_batch,
fetch_list=out)
out_r_temp.append(feat_batch[0])
pbar.update(feed_batch['label'].shape[0])
our_r=np.concatenate(out_r_temp, axis=0)
for test_p, out in test_p_list[1:]: #np.concatenate
out_r_temp=[]
for feed_batch in tqdm(eval_gg.generator()):
feed_batch['hidden_node_feat'] = our_r[feed_batch['batch_nodes_0']]
feat_batch= test_exe.run(test_p,
feed=feed_batch,
fetch_list=out)
out_r_temp.append(feat_batch[0])
pbar.update(feed_batch['label'].shape[0])
our_r=np.concatenate(out_r_temp, axis=0)
pbar.close()
y_pred=our_r.argmax(axis=-1)
y_pred=np.expand_dims(y_pred, 1)
y_true=eval_gg.labels
train_acc = evaluator.eval({
'y_true': y_true[split_idx['train']],
'y_pred': y_pred[split_idx['train']],
})['acc']
val_acc = evaluator.eval({
'y_true': y_true[split_idx['valid']],
'y_pred': y_pred[split_idx['valid']],
})['acc']
test_acc = evaluator.eval({
'y_true': y_true[split_idx['test']],
'y_pred': y_pred[split_idx['test']],
})['acc']
return train_acc, val_acc, test_acc
def train_loop(parser, start_program, main_program, test_p_list,
model, feat_init, place, dataset, split_idx, exe, run_id, wf=None):
#启动上文构建的训练器
exe.run(start_program)
feat_init(place)
max_acc=0 # 最佳test_acc
max_step=0 # 最佳test_acc 对应step
max_val_acc=0 # 最佳val_acc
max_cor_acc=0 # 最佳val_acc对应test_acc
max_cor_step=0 # 最佳val_acc对应step
#训练循环
for epoch_id in range(parser.epochs):
#运行训练器
if parser.use_label_e:
train_idx_temp=copy.deepcopy(split_idx['train'])
np.random.shuffle(train_idx_temp)
label_idx=train_idx_temp[ :int(parser.label_rate*len(train_idx_temp))]
unlabel_idx=train_idx_temp[int(parser.label_rate*len(train_idx_temp)):]
train_gg=SampleDataGenerator(graph_wrappers=model.gw_list, buf_size=parser.buf_size,
batch_size=parser.batch_size , num_workers=parser.num_workers,
sizes=parser.sizes, shuffle=True,
dataset=dataset,
nodes_idx=unlabel_idx)
pbar = tqdm(total=unlabel_idx.shape[0])
pbar.set_description(f'Epoch {epoch_id:02d}')
total=0.0
acc_num=0.0
for batch_feed in tqdm(train_gg.generator()):
batch_feed['label_idx']=label_idx
loss = exe.run(main_program,
feed=batch_feed,
fetch_list=[model.avg_cost, model.out_feat])
total+=loss[0][0]
acc_num=(loss[1].argmax(axis=-1)==batch_feed['label'].reshape(-1)).sum()+acc_num
pbar.update(batch_feed['label'].shape[0])
pbar.close()
print(total/(len(train_gg)/parser.batch_size))
print('acc: ', (acc_num/unlabel_idx.shape[0])*100)
#测试结果
# total=0.0
if (epoch_id+1)>=50 and (epoch_id+1)%10==0:
result = eval_test(parser, test_p_list, model, exe, dataset, split_idx)
train_acc, valid_acc, test_acc = result
max_acc = max(test_acc, max_acc)
if max_acc == test_acc:
max_step=epoch_id
max_val_acc=max(valid_acc, max_val_acc)
if max_val_acc==valid_acc:
max_cor_acc=test_acc
max_cor_step=epoch_id
max_acc=max(result[2], max_acc)
if max_acc==result[2]:
max_step=epoch_id
result_t=(f'Run: {run_id:02d}, '
f'Epoch: {epoch_id:02d}, '
f'Loss: {total:.4f}, '
f'Train: {100 * train_acc:.2f}%, '
f'Valid: {100 * valid_acc:.2f}%, '
f'Test: {100 * test_acc:.2f}% \n'
f'max_Test: {100 * max_acc:.2f}%, '
f'max_step: {max_step}\n'
f'max_val: {100 * max_val_acc:.2f}%, '
f'max_val_Test: {100 * max_cor_acc:.2f}%, '
f'max_val_step: {max_cor_step}\n'
)
# if (epoch_id+1)%50==0:
print(result_t)
wf.write(result_t)
wf.write('\n')
wf.flush()
return max_cor_acc
if __name__ == '__main__':
parser = get_config()
print('===========args==============')
print(parser)
print('=============================')
startup_prog = F.default_startup_program()
train_prog = F.default_main_program()
place=F.CPUPlace() if parser.place <0 else F.CUDAPlace(parser.place)
dataset = PglNodePropPredDataset(name="ogbn-products")
# dataset = PglNodePropPredDataset(name="ogbn-arxiv")
split_idx=dataset.get_idx_split()
graph, label = dataset[0]
print(label.shape)
with F.program_guard(train_prog, startup_prog):
with F.unique_name.guard():
gw_list=[]
for i in range(len(parser.sizes)):
gw_list.append(pgl.graph_wrapper.GraphWrapper(
name="product_"+str(i)))
feature_input, feat_init=paddle_helper.constant(
name='node_feat_input',
dtype='float32',
value=graph.node_feat['feat'])
if parser.use_label_e:
model=Products_label_embedding_model(feature_input, gw_list,
parser.hidden_size, parser.num_heads,
parser.dropout, parser.num_layers)
else:
model=Arxiv_baseline_model(gw, parser.hidden_size, parser.num_heads,
parser.dropout, parser.num_layers)
# test_prog=train_prog.clone(for_test=True)
model.train_program()
# ave_loss = train_program(pred_output)#训练程序
# lr, global_step= linear_warmup_decay(0.01, 50, 500)
# adam_optimizer = optimizer_func(lr)#训练优化函数
adam_optimizer = optimizer_func(parser.lr)#训练优化函数
adam_optimizer.minimize(model.avg_cost)
test_p_list=[]
with F.unique_name.guard():
## input层
test_p=F.Program()
with F.program_guard(test_p, ):
gw_test=pgl.graph_wrapper.GraphWrapper(
name="product_"+str(0))
feature_input, feat_init__=paddle_helper.constant(
name='node_feat_input',
dtype='float32',
value=graph.node_feat['feat'])
label_feature=model.label_embed_input(model.feature_input)
feature_batch=model.get_batch_feature(label_feature) # 把batch_feat打出来
feature_batch=model.get_gat_layer(0, gw_test, feature_batch,
hidden_size=model.hidden_size,
num_heads=model.num_heads,
concat=True,
layer_norm=True, relu=True)
sub_node_index=F.data(name='sub_node_index_0', shape=[None],
dtype="int64")
feature_batch=L.gather(feature_batch, sub_node_index, overwrite=False)
# test_p=test_p.clone(for_test=True)
test_p_list.append((test_p, feature_batch))
for i in range(1,model.num_layers-1):
test_p=F.Program()
with F.program_guard(test_p, ):
gw_test=pgl.graph_wrapper.GraphWrapper(
name="product_"+str(0))
# feature_batch=model.get_batch_feature(label_feature, test=True) # 把图在CPU存起
feature_batch = F.data( 'hidden_node_feat',
shape=[None, model.num_heads*model.hidden_size],
dtype='float32')
feature_batch=model.get_gat_layer(i, gw_test, feature_batch,
hidden_size=model.hidden_size,
num_heads=model.num_heads,
concat=True,
layer_norm=True, relu=True)
sub_node_index=F.data(name='sub_node_index_0', shape=[None],
dtype="int64")
feature_batch=L.gather(feature_batch, sub_node_index, overwrite=False)
# test_p=test_p.clone(for_test=True)
test_p_list.append((test_p, feature_batch))
test_p=F.Program()
with F.program_guard(test_p, ):
gw_test=pgl.graph_wrapper.GraphWrapper(
name="product_"+str(0))
# feature_batch=model.get_batch_feature(label_feature, test=True)
feature_batch = F.data( 'hidden_node_feat',
shape=[None, model.num_heads*model.hidden_size ],
dtype='float32')
feature_batch = model.get_gat_layer(model.num_layers-1, gw_test, feature_batch,
hidden_size=model.out_size, num_heads=model.num_heads,
concat=False, layer_norm=False, relu=False, gate=True)
sub_node_index=F.data(name='sub_node_index_0', shape=[None],
dtype="int64")
feature_batch=L.gather(feature_batch, sub_node_index, overwrite=False)
# test_p=test_p.clone(for_test=True)
test_p_list.append((test_p, feature_batch))
exe = F.Executor(place)
wf = open(parser.log_file, 'w', encoding='utf-8')
total_test_acc=0.0
for run_i in range(parser.runs):
total_test_acc+=train_loop(parser, startup_prog, train_prog, test_p_list, model, feat_init,
place, dataset, split_idx, exe, run_i, wf)
wf.write(f'average: {100 * (total_test_acc/parser.runs):.2f}%')
wf.close()
\ No newline at end of file
import math
import torch
import paddle
import pgl
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from pgl.contrib.ogb.nodeproppred.dataset_pgl import PglNodePropPredDataset
import time
import copy
from ogb.nodeproppred import Evaluator
from utils import to_undirected, add_self_loop, linear_warmup_decay
from model import Proteins_baseline_model, Proteins_label_embedding_model
from partition import random_partition_v2 as random_partition
import argparse
from tqdm import tqdm
evaluator = Evaluator(name='ogbn-proteins')
# place=F.CUDAPlace(6)
def get_config():
parser = argparse.ArgumentParser()
## 基本模型参数
model_group=parser.add_argument_group('model_base_arg')
model_group.add_argument('--num_layers', default=7, type=int)
model_group.add_argument('--hidden_size', default=64, type=int)
model_group.add_argument('--num_heads', default=4, type=int)
model_group.add_argument('--dropout', default=0.1, type=float)
model_group.add_argument('--attn_dropout', default=0, type=float)
## label embedding模型参数
embed_group=parser.add_argument_group('embed_arg')
embed_group.add_argument('--use_label_e', action='store_true')
embed_group.add_argument('--label_rate', default=0.5, type=float)
## train_arg
train_group=parser.add_argument_group('train_arg')
train_group.add_argument('--runs', default=10, type=int )
train_group.add_argument('--epochs', default=2000, type=int )
train_group.add_argument('--lr', default=0.001, type=float)
train_group.add_argument('--place', default=-1, type=int)
train_group.add_argument('--log_file', default='result_proteins.txt', type=str)
return parser.parse_args()
def optimizer_func(lr=0.01):
return F.optimizer.AdamOptimizer(learning_rate=lr)
def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
y_pred = np.zeros_like(y_true)
graph.node_feat["label"] = y_true
graph.node_feat["nid"] = np.arange(0, graph.num_nodes)
for subgraph in random_partition(num_clusters=5, graph=graph, shuffle=False):
feed_dict = model.gw.to_feed(subgraph)
if parser.use_label_e:
feed_dict['label'] = subgraph.node_feat["label"]
train_idx_temp = set(split_idx['train']) & set(subgraph.node_feat["nid"])
train_idx_temp = subgraph.reindex_from_parrent_nodes(list(train_idx_temp))
feed_dict['label_idx'] = train_idx_temp
batch_y_pred = test_exe.run(
program=program,
feed=feed_dict,
fetch_list=[model.out_feat])[0]
y_pred[subgraph.node_feat["nid"]] = batch_y_pred
train_acc = evaluator.eval({
'y_true': y_true[split_idx['train']],
'y_pred': y_pred[split_idx['train']],
})['rocauc']
val_acc = evaluator.eval({
'y_true': y_true[split_idx['valid']],
'y_pred': y_pred[split_idx['valid']],
})['rocauc']
test_acc = evaluator.eval({
'y_true': y_true[split_idx['test']],
'y_pred': y_pred[split_idx['test']],
})['rocauc']
return train_acc, val_acc, test_acc
def train_loop(parser, start_program, main_program, test_program,
model, graph, label, split_idx, exe, run_id, wf=None):
#启动上文构建的训练器
exe.run(start_program)
max_acc=0 # 最佳test_acc
max_step=0 # 最佳test_acc 对应step
max_val_acc=0 # 最佳val_acc
max_cor_acc=0 # 最佳val_acc对应test_acc
max_cor_step=0 # 最佳val_acc对应step
#训练循环
graph.node_feat["label"] = label
graph.node_feat["nid"] = np.arange(0, graph.num_nodes)
if parser.use_label_e:
train_idx=copy.deepcopy(split_idx['train'])
np.random.shuffle(train_idx[:50125])
label_idx = train_idx[: int(50125*parser.label_rate)]
unlabel_idx = train_idx[int(50125*parser.label_rate): ]
label_idx_total= set(label_idx)
unlabel_idx_total= set(unlabel_idx)
for epoch_id in tqdm(range(parser.epochs)):
for subgraph in random_partition(num_clusters=9, graph=graph, shuffle=True):
#运行训练器
if parser.use_label_e:
feed_dict = model.gw.to_feed(subgraph)
sub_idx = set(subgraph.node_feat["nid"])
train_idx_temp = label_idx_total & sub_idx
label_idx = subgraph.reindex_from_parrent_nodes(list(train_idx_temp))
train_idx_temp = unlabel_idx_total & sub_idx
unlabel_idx = subgraph.reindex_from_parrent_nodes(list(train_idx_temp))
feed_dict['label'] = subgraph.node_feat["label"]
feed_dict['label_idx'] = label_idx
feed_dict['train_idx'] = unlabel_idx
else:
feed_dict = model.gw.to_feed(subgraph)
#feed_dict['label'] = label
train_idx_temp = set(split_idx['train']) & set(subgraph.node_feat["nid"])
train_idx_temp = subgraph.reindex_from_parrent_nodes(list(train_idx_temp))
feed_dict['label'] = subgraph.node_feat["label"]
feed_dict['train_idx'] = train_idx_temp
loss = exe.run(main_program,
feed=feed_dict,
fetch_list=[model.avg_cost])
loss = loss[0]
#测试结果
if (epoch_id+1) > parser.epochs*0.9:
result = eval_test(parser, test_program, model, exe, graph, label, split_idx)
train_acc, valid_acc, test_acc = result
max_acc = max(test_acc, max_acc)
if max_acc == test_acc:
max_step=epoch_id
max_val_acc=max(valid_acc, max_val_acc)
if max_val_acc==valid_acc:
max_cor_acc=test_acc
max_cor_step=epoch_id
max_acc=max(result[2], max_acc)
if max_acc==result[2]:
max_step=epoch_id
result_t=(f'Run: {run_id:02d}, '
f'Epoch: {epoch_id:02d}, '
#f'Loss: {loss[0]:.4f}, '
f'Train: {100 * train_acc:.2f}%, '
f'Valid: {100 * valid_acc:.2f}%, '
f'Test: {100 * test_acc:.2f}% \n'
f'max_Test: {100 * max_acc:.2f}%, '
f'max_step: {max_step}\n'
f'max_val: {100 * max_val_acc:.2f}%, '
f'max_val_Test: {100 * max_cor_acc:.2f}%, '
f'max_val_step: {max_cor_step}\n'
)
print(result_t)
wf.write(result_t)
wf.write('\n')
wf.flush()
return max_cor_acc
def np_scatter(idx, vals, target):
"""target[idx] += vals, but allowing for repeats in idx"""
np.add.at(target, idx, vals)
def aggregate_node_features(graph):
efeat = graph.edge_feat["feat"]
graph.edge_feat["feat"] = efeat
nfeat = np.zeros((graph.num_nodes, efeat.shape[-1]), dtype="float32")
edges_dst = graph.edges[:, 1]
np_scatter(edges_dst, efeat, nfeat)
graph.node_feat["feat"] = nfeat
if __name__ == '__main__':
parser = get_config()
print('===========args==============')
print(parser)
print('=============================')
dataset = PglNodePropPredDataset(name="ogbn-proteins")
split_idx=dataset.get_idx_split()
graph, label = dataset[0]
aggregate_node_features(graph)
place=F.CPUPlace() if parser.place <0 else F.CUDAPlace(parser.place)
startup_prog = F.default_startup_program()
train_prog = F.default_main_program()
with F.program_guard(train_prog, startup_prog):
with F.unique_name.guard():
gw = pgl.graph_wrapper.GraphWrapper(
name="proteins",
node_feat=graph.node_feat_info(),
edge_feat=graph.edge_feat_info())
if parser.use_label_e:
model = Proteins_label_embedding_model(gw, parser.hidden_size, parser.num_heads,
parser.dropout, parser.num_layers)
else:
model = Proteins_baseline_model(gw, parser.hidden_size, parser.num_heads,
parser.dropout, parser.num_layers)
test_prog=train_prog.clone(for_test=True)
model.train_program()
adam_optimizer = optimizer_func(parser.lr)#训练优化函数
adam_optimizer.minimize(model.avg_cost)
exe = F.Executor(place)
wf = open(parser.log_file, 'w', encoding='utf-8')
total_test_acc=0.0
for run_i in range(parser.runs):
total_test_acc+=train_loop(parser, startup_prog, train_prog, test_prog, model,
graph, label, split_idx, exe, run_i, wf)
wf.write(f'average: {100 * (total_test_acc/parser.runs):.2f}%')
wf.close()
此差异已折叠。
'''transformer_gcn
'''
import paddle.fluid as fluid
from pgl import graph_wrapper
from pgl.utils import paddle_helper
import math
def transformer_gat_pgl(gw,
feature,
hidden_size,
name,
num_heads=4,
attn_drop=0,
edge_feature=None,
concat=True,
is_test=False):
'''transformer_gat_pgl
'''
def send_attention(src_feat, dst_feat, edge_feat):
if edge_feat is None or not edge_feat:
output = src_feat["k_h"] * dst_feat["q_h"]
output = fluid.layers.reduce_sum(output, -1)
output = output / (hidden_size ** 0.5)
return {"alpha": output, "v": src_feat["v_h"]} # batch x h batch x h x feat
else:
edge_feat = edge_feat["edge"]
edge_feat = fluid.layers.reshape(edge_feat, [-1, num_heads, hidden_size])
output = (src_feat["k_h"] + edge_feat) * dst_feat["q_h"]
output = fluid.layers.reduce_sum(output, -1)
output = output / (hidden_size ** 0.5)
return {"alpha": output, "v": (src_feat["v_h"] + edge_feat)} # batch x h batch x h x feat
def reduce_attention(msg):
alpha = msg["alpha"] # lod-tensor (batch_size, seq_len, num_heads)
h = msg["v"]
alpha = paddle_helper.sequence_softmax(alpha)
old_h = h
if attn_drop > 1e-15:
alpha = fluid.layers.dropout(
alpha,
dropout_prob=attn_drop,
is_test=is_test,
dropout_implementation="upscale_in_train")
h = h * alpha
h = fluid.layers.lod_reset(h, old_h)
h = fluid.layers.sequence_pool(h, "sum")
if concat:
h = fluid.layers.reshape(h, [-1, num_heads * hidden_size])
else:
h = fluid.layers.reduce_mean(h, dim=1)
return h
# stdv = math.sqrt(6.0 / (feature.shape[-1] + hidden_size * num_heads))
# q_w_attr=fluid.ParamAttr(initializer=fluid.initializer.UniformInitializer(low=-stdv, high=stdv))
q_w_attr=fluid.ParamAttr(initializer=fluid.initializer.XavierInitializer())
q_bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(0.0))
q = fluid.layers.fc(feature,
hidden_size * num_heads,
name=name + '_q_weight',
param_attr=q_w_attr,
bias_attr=q_bias_attr)
# k_w_attr=fluid.ParamAttr(initializer=fluid.initializer.UniformInitializer(low=-stdv, high=stdv))
k_w_attr=fluid.ParamAttr(initializer=fluid.initializer.XavierInitializer())
k_bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(0.0))
k = fluid.layers.fc(feature,
hidden_size * num_heads,
name=name + '_k_weight',
param_attr=k_w_attr,
bias_attr=k_bias_attr)
# v_w_attr=fluid.ParamAttr(initializer=fluid.initializer.UniformInitializer(low=-stdv, high=stdv))
v_w_attr=fluid.ParamAttr(initializer=fluid.initializer.XavierInitializer())
v_bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(0.0))
v = fluid.layers.fc(feature,
hidden_size * num_heads,
name=name + '_v_weight',
param_attr=v_w_attr,
bias_attr=v_bias_attr)
reshape_q = fluid.layers.reshape(q, [-1, num_heads, hidden_size])
reshape_k = fluid.layers.reshape(k, [-1, num_heads, hidden_size])
reshape_v = fluid.layers.reshape(v, [-1, num_heads, hidden_size])
msg = gw.send(
send_attention,
nfeat_list=[("q_h", reshape_q), ("k_h", reshape_k),
("v_h", reshape_v)],
efeat_list=edge_feature)
output = gw.recv(msg, reduce_attention)
return output
"""Loader"""
import numpy as np
import math
from pgl.sample import extract_edges_from_nodes
def random_partition(num_clusters, graph, shuffle=True):
"""random partition"""
batch_size = int(math.ceil(graph.num_nodes / num_clusters))
perm = np.arange(0, graph.num_nodes)
if shuffle:
np.random.shuffle(perm)
batch_no = 0
while batch_no < graph.num_nodes:
batch_nodes = perm[batch_no:batch_no + batch_size]
batch_no += batch_size
eids = extract_edges_from_nodes(graph, batch_nodes)
sub_g = graph.subgraph(nodes=batch_nodes, eid=eids,
with_node_feat=True, with_edge_feat=False)
for key, value in graph.edge_feat.items():
sub_g.edge_feat[key] = graph.edge_feat[key][eids]
yield sub_g
def random_partition_v2(num_clusters, graph, shuffle=True, save_e=[]):
"""random partition v2"""
if shuffle:
cluster_id = np.random.randint(low=0, high=num_clusters, size=graph.num_nodes)
else:
if not save_e:
cluster_id = np.random.randint(low=0, high=num_clusters, size=graph.num_nodes)
save_e.append(cluster_id)
else:
cluster_id = save_e[0]
# assert cluster_id is not None
perm = np.arange(0, graph.num_nodes)
batch_no = 0
while batch_no < num_clusters:
batch_nodes = perm[cluster_id == batch_no]
batch_no += 1
eids = extract_edges_from_nodes(graph, batch_nodes)
sub_g = graph.subgraph(nodes=batch_nodes, eid=eids,
with_node_feat=True, with_edge_feat=False)
for key, value in graph.edge_feat.items():
sub_g.edge_feat[key] = graph.edge_feat[key][eids]
yield sub_g
paddlepaddle_gpu==1.8.3.post107
torch==1.5.1
tqdm==4.31.1
six==1.12.0
numpy==1.19.1
ogb==1.2.1
\ No newline at end of file
""" utils """
import numpy as np
import pgl
import paddle.fluid as fluid
def to_undirected(graph):
""" to_undirected """
inv_edges = np.zeros(graph.edges.shape)
inv_edges[:, 0] = graph.edges[:, 1]
inv_edges[:, 1] = graph.edges[:, 0]
edges = np.vstack((graph.edges, inv_edges))
edges = np.unique(edges, axis=0)
# print(edges.shape)
g = pgl.graph.Graph(num_nodes=graph.num_nodes, edges=edges)
for k, v in graph._node_feat.items():
g._node_feat[k] = v
return g
def add_self_loop(graph):
""" add_self_loop """
self_loop_edges = np.zeros((graph.num_nodes, 2))
self_loop_edges[:, 0] = self_loop_edges[:, 1]=np.arange(graph.num_nodes)
edges = np.vstack((graph.edges, self_loop_edges))
edges = np.unique(edges, axis=0)
# print(edges.shape)
g = pgl.graph.Graph(num_nodes=graph.num_nodes, edges=edges)
for k, v in graph._node_feat.items():
g._node_feat[k] = v
return g
def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
""" Applies linear warmup of learning rate from 0 and decay to 0."""
with fluid.default_main_program()._lr_schedule_guard():
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="scheduled_learning_rate")
global_step = fluid.layers.learning_rate_scheduler._decay_step_counter()
with fluid.layers.control_flow.Switch() as switch:
with switch.case(global_step < warmup_steps):
warmup_lr = learning_rate * (global_step / warmup_steps)
fluid.layers.tensor.assign(warmup_lr, lr)
with switch.default():
decayed_lr = fluid.layers.learning_rate_scheduler.polynomial_decay(
learning_rate=learning_rate,
decay_steps=num_train_steps,
end_learning_rate=0.0,
power=1.0,
cycle=False)
fluid.layers.tensor.assign(decayed_lr, lr)
return lr, global_step
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册