未验证 提交 6f775280 编写于 作者: H Huang Zhengjie 提交者: GitHub

Merge pull request #29 from PaddlePaddle/develop

Merge Develop
......@@ -21,7 +21,7 @@ import tqdm
import numpy as np
import logging
import random
from pgl.contrib import heter_graph
from pgl import heter_graph
import pickle as pkl
......
......@@ -21,7 +21,7 @@ import logging
import paddle.fluid as fluid
import paddle.fluid.layers as fl
from pgl.contrib import heter_graph_wrapper
from pgl import heter_graph_wrapper
class GATNE(object):
......
......@@ -19,8 +19,8 @@ import pgl
import time
from pgl.utils import mp_reader
from pgl.utils.logger import log
import train
import time
import copy
def node_batch_iter(nodes, node_label, batch_size):
......@@ -46,12 +46,11 @@ def traverse(item):
yield item
def flat_node_and_edge(nodes, eids):
def flat_node_and_edge(nodes):
"""flat_node_and_edge
"""
nodes = list(set(traverse(nodes)))
eids = list(set(traverse(eids)))
return nodes, eids
return nodes
def worker(batch_info, graph, graph_wrapper, samples):
......@@ -61,31 +60,42 @@ def worker(batch_info, graph, graph_wrapper, samples):
def work():
"""work
"""
first = True
_graph_wrapper = copy.copy(graph_wrapper)
_graph_wrapper.node_feat_tensor_dict = {}
for batch_train_samples, batch_train_labels in batch_info:
start_nodes = batch_train_samples
nodes = start_nodes
eids = []
edges = []
for max_deg in samples:
pred, pred_eid = graph.sample_predecessor(
start_nodes, max_degree=max_deg, return_eids=True)
pred_nodes = graph.sample_predecessor(
start_nodes, max_degree=max_deg)
for dst_node, src_nodes in zip(start_nodes, pred_nodes):
for src_node in src_nodes:
edges.append((src_node, dst_node))
last_nodes = nodes
nodes = [nodes, pred]
eids = [eids, pred_eid]
nodes, eids = flat_node_and_edge(nodes, eids)
nodes = [nodes, pred_nodes]
nodes = flat_node_and_edge(nodes)
# Find new nodes
start_nodes = list(set(nodes) - set(last_nodes))
if len(start_nodes) == 0:
break
subgraph = graph.subgraph(nodes=nodes, eid=eids)
subgraph = graph.subgraph(
nodes=nodes,
edges=edges,
with_node_feat=False,
with_edge_feat=False)
sub_node_index = subgraph.reindex_from_parrent_nodes(
batch_train_samples)
feed_dict = graph_wrapper.to_feed(subgraph)
feed_dict = _graph_wrapper.to_feed(subgraph)
feed_dict["node_label"] = np.expand_dims(
np.array(
batch_train_labels, dtype="int64"), -1)
feed_dict["node_index"] = sub_node_index
feed_dict["parent_node_index"] = np.array(nodes, dtype="int64")
yield feed_dict
return work
......@@ -97,23 +107,25 @@ def multiprocess_graph_reader(graph,
node_index,
batch_size,
node_label,
with_parent_node_index=False,
num_workers=4):
"""multiprocess_graph_reader
"""
def parse_to_subgraph(rd):
def parse_to_subgraph(rd, prefix, node_feat, _with_parent_node_index):
"""parse_to_subgraph
"""
def work():
"""work
"""
last = time.time()
for data in rd():
this = time.time()
feed_dict = data
now = time.time()
last = now
for key in node_feat:
feed_dict[prefix + '/node_feat/' + key] = node_feat[key][
feed_dict["parent_node_index"]]
if not _with_parent_node_index:
del feed_dict["parent_node_index"]
yield feed_dict
return work
......@@ -129,46 +141,17 @@ def multiprocess_graph_reader(graph,
reader_pool.append(
worker(batch_info[block_size * i:block_size * (i + 1)], graph,
graph_wrapper, samples))
multi_process_sample = mp_reader.multiprocess_reader(
reader_pool, use_pipe=True, queue_size=1000)
r = parse_to_subgraph(multi_process_sample)
return paddle.reader.buffered(r, 1000)
return reader()
def graph_reader(graph, graph_wrapper, samples, node_index, batch_size,
node_label):
"""graph_reader"""
def reader():
"""reader"""
for batch_train_samples, batch_train_labels in node_batch_iter(
node_index, node_label, batch_size=batch_size):
start_nodes = batch_train_samples
nodes = start_nodes
eids = []
for max_deg in samples:
pred, pred_eid = graph.sample_predecessor(
start_nodes, max_degree=max_deg, return_eids=True)
last_nodes = nodes
nodes = [nodes, pred]
eids = [eids, pred_eid]
nodes, eids = flat_node_and_edge(nodes, eids)
# Find new nodes
start_nodes = list(set(nodes) - set(last_nodes))
if len(start_nodes) == 0:
break
subgraph = graph.subgraph(nodes=nodes, eid=eids)
feed_dict = graph_wrapper.to_feed(subgraph)
sub_node_index = subgraph.reindex_from_parrent_nodes(
batch_train_samples)
if len(reader_pool) == 1:
r = parse_to_subgraph(reader_pool[0],
repr(graph_wrapper), graph.node_feat,
with_parent_node_index)
else:
multi_process_sample = mp_reader.multiprocess_reader(
reader_pool, use_pipe=True, queue_size=1000)
r = parse_to_subgraph(multi_process_sample,
repr(graph_wrapper), graph.node_feat,
with_parent_node_index)
return paddle.reader.buffered(r, num_workers)
feed_dict["node_label"] = np.expand_dims(
np.array(
batch_train_labels, dtype="int64"), -1)
feed_dict["node_index"] = np.array(sub_node_index, dtype="int32")
yield feed_dict
return paddle.reader.buffered(reader, 1000)
return reader()
......@@ -63,10 +63,7 @@ def load_data(normalize=True, symmetry=True):
log.info("Feature shape %s" % (repr(feature.shape)))
graph = pgl.graph.Graph(
num_nodes=feature.shape[0],
edges=list(zip(src, dst)),
node_feat={"index": np.arange(
0, len(feature), dtype="int64")})
num_nodes=feature.shape[0], edges=list(zip(src, dst)))
return {
"graph": graph,
......@@ -89,7 +86,13 @@ def build_graph_model(graph_wrapper, num_class, k_hop, graphsage_type,
node_label = fluid.layers.data(
"node_label", shape=[None, 1], dtype="int64", append_batch_size=False)
feature = fluid.layers.gather(feature, graph_wrapper.node_feat['index'])
parent_node_index = fluid.layers.data(
"parent_node_index",
shape=[None],
dtype="int64",
append_batch_size=False)
feature = fluid.layers.gather(feature, parent_node_index)
feature.stop_gradient = True
for i in range(k_hop):
......@@ -221,59 +224,35 @@ def main(args):
exe.run(startup_program)
feature_init(place)
if args.sample_workers > 1:
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
else:
train_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
if args.sample_workers > 1:
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
else:
val_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
if args.sample_workers > 1:
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
else:
test_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
with_parent_node_index=True,
node_index=data['train_index'],
node_label=data["train_label"])
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
with_parent_node_index=True,
node_index=data['val_index'],
node_label=data["val_label"])
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
with_parent_node_index=True,
node_index=data['test_index'],
node_label=data["test_label"])
for epoch in range(args.epoch):
run_epoch(
......
......@@ -195,7 +195,7 @@ def run_epoch(batch_iter,
if num_trainer > 1:
num_samples = sum(
[len(batch["node_index"]) for batch in batch_feed_dict])
[len(_batch["node_index"]) for _batch in batch_feed_dict])
else:
num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples
......@@ -262,59 +262,32 @@ def main(args):
else:
train_exe = exe
if args.sample_workers > 1:
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
else:
train_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
if args.sample_workers > 1:
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
else:
val_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
if args.sample_workers > 1:
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
else:
test_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
for epoch in range(args.epoch):
run_epoch(
......
......@@ -97,11 +97,7 @@ def load_data(normalize=True, symmetry=True, scale=1):
graph = pgl.graph.Graph(
num_nodes=feature.shape[0],
edges=edges,
node_feat={
"index": np.arange(
0, len(feature), dtype="int64"),
"feature": feature
})
node_feat={"feature": feature})
return {
"graph": graph,
......@@ -244,59 +240,32 @@ def main(args):
test_program = train_program.clone(for_test=True)
if args.sample_workers > 1:
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
else:
train_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
if args.sample_workers > 1:
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
else:
val_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
if args.sample_workers > 1:
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
else:
test_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
with fluid.program_guard(train_program, startup_program):
adam = fluid.optimizer.Adam(learning_rate=args.lr)
......
......@@ -23,7 +23,7 @@ import tqdm
import time
import logging
import random
from pgl.contrib import heter_graph
from pgl import heter_graph
import pickle as pkl
......@@ -40,8 +40,10 @@ class Dataset(object):
def __init__(self, config):
self.config = config
self.walk_files = config['input_path'] + config['walk_path']
self.word2id_file = config['input_path'] + config['word2id_file']
self.walk_files = os.path.join(config['input_path'],
config['walk_path'])
self.word2id_file = os.path.join(config['input_path'],
config['word2id_file'])
self.word2freq = {}
self.word2id = {}
......@@ -65,12 +67,16 @@ class Dataset(object):
for walk_file in glob.glob(self.walk_files):
with open(walk_file, 'r') as reader:
for walk in reader:
walk = walk.strip().split(' ')
walk = walk.strip().split()
if len(walk) > 1:
self.sentences_count += 1
for word in walk:
self.token_count += 1
word_freq[word] = word_freq.get(word, 0) + 1
if int(word) >= self.config[
'paper_start_index']: # remove paper
continue
else:
self.token_count += 1
word_freq[word] = word_freq.get(word, 0) + 1
wid = 0
logging.info('Read %d sentences.' % self.sentences_count)
......@@ -123,7 +129,11 @@ class Dataset(object):
for filename in walkpath_files:
with open(filename) as reader:
for line in reader:
words = line.strip().split(' ')
words = line.strip().split()
words = [
w for w in words
if int(w) < self.config['paper_start_index']
]
if len(words) > 1:
word_ids = [
self.word2id[w] for w in words if w in self.word2id
......
......@@ -13,9 +13,12 @@ sampler:
new_author_label_file: author_label.txt
new_venue_label_file: venue_label.txt
walk_saved_path: walks/
walk_batch_size: 1000
num_walks: 1000
walk_length: 100
metapath: conf-paper-author-paper-conf
num_sample_workers: 16
first_node_type: conf
metapath: c2p-p2a-a2p-p2c #conf-paper-author-paper-conf
optimizer:
type: Adam
......@@ -39,9 +42,10 @@ data_loader:
walk_path: walks/*
word2id_file: word2id.pkl
batch_size: 32
win_size: 7 # default: 7
win_size: 5 # default: 7
neg_num: 5
min_count: 10
paper_start_index: 1697414
model:
type: SkipgramModel
......
......@@ -101,7 +101,7 @@ class SkipgramModel(object):
pos_score = fl.squeeze(pos_logits, axes=[1])
pos_score = fl.clip(pos_score, min=-10, max=10)
pos_score = -1.0 * fl.logsigmoid(pos_score)
pos_score = -self.neg_num * fl.logsigmoid(pos_score)
neg_logits = fl.matmul(
embed_src, weight_negs,
......@@ -111,4 +111,4 @@ class SkipgramModel(object):
neg_score = -1.0 * fl.logsigmoid(-1.0 * neg_score)
neg_score = fl.reduce_sum(neg_score, dim=1, keep_dim=True)
self.loss = fl.reduce_mean(pos_score + neg_score)
self.loss = fl.reduce_mean(pos_score + neg_score) / self.neg_num / 2
......@@ -18,6 +18,7 @@ training metapath2vec model.
import multiprocessing
from multiprocessing import Pool
from multiprocessing import Process
import argparse
import sys
import os
......@@ -27,7 +28,7 @@ import tqdm
import time
import logging
import random
from pgl.contrib import heter_graph
from pgl import heter_graph
from pgl.sample import metapath_randomwalk
from utils import *
......@@ -77,9 +78,14 @@ class Sampler(object):
self.config['data_path'] + 'paper_conf.txt', self.paper_id2index,
self.conf_id2index)
edges_by_types['edge'] = paper_author_edges + paper_conf_edges
logging.info('%d edges have been loaded.' %
(len(edges_by_types['edge'])))
# edges_by_types['edge'] = paper_author_edges + paper_conf_edges
edges_by_types['p2c'] = paper_conf_edges
edges_by_types['c2p'] = [(dst, src) for src, dst in paper_conf_edges]
edges_by_types['p2a'] = paper_author_edges
edges_by_types['a2p'] = [(dst, src) for src, dst in paper_author_edges]
# logging.info('%d edges have been loaded.' %
# (len(edges_by_types['edge'])))
node_features = {
'index': np.array([i for i in range(num_nodes)]).reshape(
......@@ -110,7 +116,7 @@ class Sampler(object):
return id2index, name2index, node_types
def load_edges(self, file_, src2index, dst2index, symmetry=True):
def load_edges(self, file_, src2index, dst2index, symmetry=False):
"""Load edges from file.
"""
edges = []
......@@ -143,41 +149,65 @@ class Sampler(object):
return index_label_list
def generate_walks(args):
"""Generate metapath random walk and save to file.
def walk_generator(graph, batch_size, metapath, n_type, walk_length):
"""Generate metapath random walk.
"""
g, meta_path, filename, walk_length = args
walks = []
node_types = g._node_types
first_type = meta_path.split('-')[0]
nodes = np.where(node_types == first_type)[0]
if len(nodes) > 4000:
nodes = np.random.choice(nodes, 4000, replace=False)
logging.info('%d number of start nodes' % (len(nodes)))
logging.info('save walks in file: %s' % (filename))
np.random.seed(os.getpid())
while True:
for start_nodes in graph.node_batch_iter(
batch_size=batch_size, n_type=n_type):
walks = metapath_randomwalk(
graph=graph,
start_nodes=start_nodes,
metapath=metapath,
walk_length=walk_length)
yield walks
def walk_to_files(g, batch_size, metapath, n_type, walk_length, max_num,
filename):
"""Generate metapath randomwalk and save in files"""
# g, batch_size, metapath, n_type, walk_length, max_num, filename = args
with open(filename, 'w') as writer:
for start_node in nodes:
walk = metapath_randomwalk(g, start_node, meta_path, walk_length)
walk = [str(walk[i]) for i in range(0, len(walk), 2)] # skip paper
writer.write(' '.join(walk) + '\n')
cc = 0
for walks in walk_generator(g, batch_size, metapath, n_type,
walk_length):
for walk in walks:
writer.write("%s\n" % "\t".join([str(i) for i in walk]))
cc += 1
if cc == max_num:
return
return
def multiprocess_generate_walks_to_files(graph, n_type, meta_path, num_walks,
walk_length, batch_size,
num_sample_workers, saved_path):
"""Use multiprocess to generate metapath random walk to files.
"""
num_nodes_by_type = graph.num_nodes_by_type(n_type)
logging.info("num_nodes_by_type: %s" % num_nodes_by_type)
max_num = (num_walks * num_nodes_by_type // num_sample_workers) + 1
logging.info("max sample number of every worker: %s" % max_num)
def multiprocess_generate_walks(sampler, edge_type, meta_path, num_walks,
walk_length, saved_path):
"""Use multiprocess to generate metapath random walk.
"""
args = []
for i in range(num_walks):
filename = saved_path + '%04d' % (i)
args.append(
(sampler.graph[edge_type], meta_path, filename, walk_length))
pool = Pool(16)
pool.map(generate_walks, args)
pool.close()
pool.join()
for i in range(num_sample_workers):
filename = os.path.join(saved_path, 'part-%05d' % (i))
args.append((graph, batch_size, meta_path, n_type, walk_length,
max_num, filename))
ps = []
for i in range(num_sample_workers):
p = Process(target=walk_to_files, args=args[i])
p.start()
ps.append(p)
for i in range(num_sample_workers):
ps[i].join()
# pool = Pool(num_sample_workers)
# pool.map(walk_to_files, args)
# pool.close()
# pool.join()
if __name__ == "__main__":
......@@ -220,13 +250,15 @@ if __name__ == "__main__":
begin = time.time()
logging.info('multi process sampling')
multiprocess_generate_walks(
sampler=sampler,
edge_type='edge',
multiprocess_generate_walks_to_files(
graph=sampler.graph,
n_type=config['first_node_type'],
meta_path=config['metapath'],
num_walks=config['num_walks'],
walk_length=config['walk_length'],
saved_path=config['walk_saved_path'])
batch_size=config['walk_batch_size'],
num_sample_workers=config['num_sample_workers'],
saved_path=config['walk_saved_path'], )
logging.info('total time: %.4f' % (time.time() - begin))
logging.info('generating multi class data')
......
# STGCN: Spatio-Temporal Graph Convolutional Network
[Spatio-Temporal Graph Convolutional Network \(STGCN\)](https://arxiv.org/pdf/1709.04875.pdf) is a novel deep learning framework to tackle time series prediction problem. Based on PGL, we reproduce STGCN algorithms to predict new confirmed patients in some cities with the historical immigration records.
### Datasets
You can make your customized dataset by the following format:
* input.csv: Historical immigration records with shape of [num\_time\_steps * num\_cities].
* output.csv: New confirmed patients records with shape of [num\_time\_steps * num\_cities].
* W.csv: Weighted Adjacency Matrix with shape of [num\_cities * num\_cities].
* city.csv: Each line is a number and the corresponding city name.
### Dependencies
- paddlepaddle 1.6
- pgl 1.0.0
### How to run
For examples, use gpu to train STGCN on your dataset.
```
python main.py --use_cuda --input_file dataset/input_csv --label_file dataset/output.csv --adj_mat_file dataset/W.csv --city_file dataset/city.csv
```
#### Hyperparameters
- n\_route: Number of city.
- n\_his: "n\_his" time steps of previous observations of historical immigration records.
- n\_pred: Next "n\_pred" time steps of New confirmed patients records.
- Ks: Number of GCN layers.
- Kt: Kernel size of temporal convolution.
- use\_cuda: Use gpu if assign use\_cuda.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""__init__"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""data processing
"""
import numpy as np
import pandas as pd
from utils.math_utils import z_score
class Dataset(object):
"""Dataset
"""
def __init__(self, data, stats):
self.__data = data
self.mean = stats['mean']
self.std = stats['std']
def get_data(self, type): # type: train, val or test
return self.__data[type]
def get_stats(self):
return {'mean': self.mean, 'std': self.std}
def get_len(self, type):
return len(self.__data[type])
def z_inverse(self, type):
return self.__data[type] * self.std + self.mean
def seq_gen(len_seq, data_seq, offset, n_frame, n_route, day_slot, C_0=1):
"""Generate data in the form of standard sequence unit."""
n_slot = day_slot - n_frame + 1
tmp_seq = np.zeros((len_seq * n_slot, n_frame, n_route, C_0))
for i in range(len_seq):
for j in range(n_slot):
sta = (i + offset) * day_slot + j
end = sta + n_frame
tmp_seq[i * n_slot + j, :, :, :] = np.reshape(
data_seq[sta:end, :], [n_frame, n_route, C_0])
return tmp_seq
def adj_matrx_gen_custom(input_file, city_file):
"""genenrate Adjacency Matrix from file
"""
print("generate adj_matrix data (take long time)...")
# data
df = pd.read_csv(
input_file,
sep='\t',
names=['date', '迁出省份', '迁出城市', '迁入省份', '迁入城市', '人数'])
# 只需要2020年的数据
df['date'] = pd.to_datetime(df['date'], format="%Y%m%d")
df = df.set_index('date')
df = df['2020']
city_df = pd.read_csv(city_file)
# 剔除武汉
city_df = city_df.drop(0)
num = len(city_df)
matrix = np.zeros([num, num])
for i in city_df['city']:
for j in city_df['city']:
if (i == j):
continue
# 选出从i到j的每日人数
cut = df[df['迁出城市'].str.contains(i)]
cut = cut[cut['迁入城市'].str.contains(j)]
# 求均值作为权重
average = cut['人数'].mean()
# 赋值给matrix
i_index = int(city_df[city_df['city'] == i]['num']) - 1
j_index = int(city_df[city_df['city'] == j]['num']) - 1
matrix[i_index, j_index] = average
np.savetxt("dataset/W_74.csv", matrix, delimiter=",")
def data_gen_custom(input_file, output_file, city_file, n, n_his, n_pred,
n_config):
"""data_gen_custom"""
print("generate training data...")
# data
df = pd.read_csv(
input_file,
sep='\t',
names=['date', '迁出省份', '迁出城市', '迁入省份', '迁入城市', '人数'])
# 只需要2020年的数据
df['date'] = pd.to_datetime(df['date'], format="%Y%m%d")
df = df.set_index('date')
df = df['2020']
city_df = pd.read_csv(city_file)
input_df = pd.DataFrame()
out_df_wuhan = df[df['迁出城市'].str.contains('武汉')]
for i in city_df['city']:
# 筛选迁入城市
in_df_i = out_df_wuhan[out_df_wuhan['迁入城市'].str.contains(i)]
# 确保按时间升序
# in_df_i.sort_values("date",inplace=True)
# 按时间插入
in_df_i.reset_index(drop=True, inplace=True)
input_df[i] = in_df_i['人数']
# 替换Nan值
input_df = input_df.replace(np.nan, 0)
x = input_df
y = pd.read_csv(output_file)
# 删除第1列
x.drop(
x.columns[x.columns.str.contains(
'unnamed', case=False)],
axis=1,
inplace=True)
y = y.drop(columns=['date'])
# 剔除迁入武汉的数据
x = x.drop(columns=['武汉'])
y = y.drop(columns=['武汉'])
# param
n_val, n_test = n_config
n_train = len(y) - n_val - n_test - 2
# (?,26,74,1)
df = pd.DataFrame(columns=x.columns)
for i in range(len(y) - n_pred + 1):
df = df.append(x[i:i + n_his])
df = df.append(y[i:i + n_pred])
data = df.values.reshape(-1, n_his + n_pred, n,
1) # n == num_nodes == city num
x_stats = {'mean': np.mean(data), 'std': np.std(data)}
x_train = data[:n_train]
x_val = data[n_train:n_train + n_val]
x_test = data[n_train + n_val:]
x_data = {'train': x_train, 'val': x_val, 'test': x_test}
dataset = Dataset(x_data, x_stats)
print("generate successfully!")
return dataset
def data_gen_mydata(input_file, label_file, n, n_his, n_pred, n_config):
"""data processing
"""
# data
x = pd.read_csv(input_file)
y = pd.read_csv(label_file)
x = x.drop(columns=['date'])
y = y.drop(columns=['date'])
x = x.drop(columns=['武汉'])
y = y.drop(columns=['武汉'])
# param
n_val, n_test = n_config
n_train = len(y) - n_val - n_test - 2
# (?,26,74,1)
df = pd.DataFrame(columns=x.columns)
for i in range(len(y) - n_pred + 1):
df = df.append(x[i:i + n_his])
df = df.append(y[i:i + n_pred])
data = df.values.reshape(-1, n_his + n_pred, n, 1)
x_stats = {'mean': np.mean(data), 'std': np.std(data)}
x_train = data[:n_train]
x_val = data[n_train:n_train + n_val]
x_test = data[n_train + n_val:]
x_data = {'train': x_train, 'val': x_val, 'test': x_test}
dataset = Dataset(x_data, x_stats)
return dataset
def data_gen(file_path, data_config, n_route, n_frame=21, day_slot=288):
"""Source file load and dataset generation."""
n_train, n_val, n_test = data_config
# generate training, validation and test data
try:
data_seq = pd.read_csv(file_path, header=None).values
except FileNotFoundError:
print(f'ERROR: input file was not found in {file_path}.')
seq_train = seq_gen(n_train, data_seq, 0, n_frame, n_route, day_slot)
seq_val = seq_gen(n_val, data_seq, n_train, n_frame, n_route, day_slot)
seq_test = seq_gen(n_test, data_seq, n_train + n_val, n_frame, n_route,
day_slot)
# x_stats: dict, the stats for the train dataset, including the value of mean and standard deviation.
x_stats = {'mean': np.mean(seq_train), 'std': np.std(seq_train)}
# x_train, x_val, x_test: np.array, [sample_size, n_frame, n_route, channel_size].
x_train = z_score(seq_train, x_stats['mean'], x_stats['std'])
x_val = z_score(seq_val, x_stats['mean'], x_stats['std'])
x_test = z_score(seq_test, x_stats['mean'], x_stats['std'])
x_data = {'train': x_train, 'val': x_val, 'test': x_test}
dataset = Dataset(x_data, x_stats)
return dataset
def gen_batch(inputs, batch_size, dynamic_batch=False, shuffle=False):
"""Data iterator in batch.
Args:
inputs: np.ndarray, [len_seq, n_frame, n_route, C_0], standard sequence units.
batch_size: int, size of batch.
dynamic_batch: bool, whether changes the batch size in the last batch
if its length is less than the default.
shuffle: bool, whether shuffle the batches.
"""
len_inputs = len(inputs)
if shuffle:
idx = np.arange(len_inputs)
np.random.shuffle(idx)
for start_idx in range(0, len_inputs, batch_size):
end_idx = start_idx + batch_size
if end_idx > len_inputs:
if dynamic_batch:
end_idx = len_inputs
else:
break
if shuffle:
slide = idx[start_idx:end_idx]
else:
slide = slice(start_idx, end_idx)
yield inputs[slide]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PGL Graph
"""
import sys
import os
import numpy as np
import pandas as pd
from pgl.graph import Graph
def weight_matrix(file_path, sigma2=0.1, epsilon=0.5, scaling=True):
"""Load weight matrix function."""
try:
W = pd.read_csv(file_path, header=None).values
except FileNotFoundError:
print(f'ERROR: input file was not found in {file_path}.')
# check whether W is a 0/1 matrix.
if set(np.unique(W)) == {0, 1}:
print('The input graph is a 0/1 matrix; set "scaling" to False.')
scaling = False
if scaling:
n = W.shape[0]
W = W / 10000.
W2, W_mask = W * W, np.ones([n, n]) - np.identity(n)
# refer to Eq.10
return np.exp(-W2 / sigma2) * (
np.exp(-W2 / sigma2) >= epsilon) * W_mask
else:
return W
class GraphFactory(object):
"""GraphFactory"""
def __init__(self, args):
self.args = args
self.adj_matrix = weight_matrix(self.args.adj_mat_file)
L = np.eye(self.adj_matrix.shape[0]) + self.adj_matrix
D = np.sum(self.adj_matrix, axis=1)
# L = D - self.adj_matrix
# import ipdb; ipdb.set_trace()
edges = []
weights = []
for i in range(self.adj_matrix.shape[0]):
for j in range(self.adj_matrix.shape[1]):
edges.append([i, j])
weights.append(L[i][j])
self.edges = np.array(edges, dtype=np.int64)
self.weights = np.array(weights, dtype=np.float32).reshape(-1, 1)
self.norm = np.zeros_like(D, dtype=np.float32)
self.norm[D > 0] = np.power(D[D > 0], -0.5)
self.norm = self.norm.reshape(-1, 1)
def build_graph(self, x_batch):
"""build graph"""
B, T, n, _ = x_batch.shape
batch = B * T
batch_edges = []
for i in range(batch):
batch_edges.append(self.edges + (i * n))
batch_edges = np.vstack(batch_edges)
num_nodes = B * T * n
node_feat = {'norm': np.tile(self.norm, [batch, 1])}
edge_feat = {'weights': np.tile(self.weights, [batch, 1])}
graph = Graph(
num_nodes=num_nodes,
edges=batch_edges,
node_feat=node_feat,
edge_feat=edge_feat)
return graph
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement the training process of STGCN model.
"""
import os
import sys
import time
import argparse
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as fl
import pgl
from pgl.utils.logger import log
from data_loader.data_utils import data_gen_mydata, gen_batch
from data_loader.graph import GraphFactory
from models.model import STGCNModel
from models.tester import model_inference, model_test
def main(args):
"""main"""
PeMS = data_gen_mydata(args.input_file, args.label_file, args.n_route,
args.n_his, args.n_pred, (args.n_val, args.n_test))
log.info(PeMS.get_stats())
log.info(PeMS.get_len('train'))
gf = GraphFactory(args)
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper(
"gw",
place,
node_feat=[('norm', [None, 1], "float32")],
edge_feat=[('weights', [None, 1], "float32")])
model = STGCNModel(args, gw)
train_loss, y_pred = model.forward()
infer_program = train_program.clone(for_test=True)
with fluid.program_guard(train_program, startup_program):
epoch_step = int(PeMS.get_len('train') / args.batch_size) + 1
lr = fl.exponential_decay(
learning_rate=args.lr,
decay_steps=5 * epoch_step,
decay_rate=0.7,
staircase=True)
if args.opt == 'RMSProp':
train_op = fluid.optimizer.RMSPropOptimizer(lr).minimize(
train_loss)
elif args.opt == 'ADAM':
train_op = fluid.optimizer.Adam(lr).minimize(train_loss)
exe = fluid.Executor(place)
exe.run(startup_program)
if args.inf_mode == 'sep':
# for inference mode 'sep', the type of step index is int.
step_idx = args.n_pred - 1
tmp_idx = [step_idx]
min_val = min_va_val = np.array([4e1, 1e5, 1e5])
elif args.inf_mode == 'merge':
# for inference mode 'merge', the type of step index is np.ndarray.
step_idx = tmp_idx = np.arange(3, args.n_pred + 1, 3) - 1
min_val = min_va_val = np.array([4e1, 1e5, 1e5]) * len(step_idx)
else:
raise ValueError(f'ERROR: test mode "{args.inf_mode}" is not defined.')
step = 0
for epoch in range(1, args.epochs + 1):
for idx, x_batch in enumerate(
gen_batch(
PeMS.get_data('train'),
args.batch_size,
dynamic_batch=True,
shuffle=True)):
x = np.array(x_batch[:, 0:args.n_his, :, :], dtype=np.float32)
graph = gf.build_graph(x)
feed = gw.to_feed(graph)
feed['input'] = np.array(
x_batch[:, 0:args.n_his + 1, :, :], dtype=np.float32)
b_loss, b_lr = exe.run(train_program,
feed=feed,
fetch_list=[train_loss, lr])
if idx % 5 == 0:
log.info("epoch %d | step %d | lr %.6f | loss %.6f" %
(epoch, idx, b_lr[0], b_loss[0]))
min_va_val, min_val = \
model_inference(exe, gw, gf, infer_program, y_pred, PeMS, args, \
step_idx, min_va_val, min_val)
for ix in tmp_idx:
va, te = min_va_val[ix - 2:ix + 1], min_val[ix - 2:ix + 1]
print(f'Time Step {ix + 1}: '
f'MAPE {va[0]:7.3%}, {te[0]:7.3%}; '
f'MAE {va[1]:4.3f}, {te[1]:4.3f}; '
f'RMSE {va[2]:6.3f}, {te[2]:6.3f}.')
if epoch % 5 == 0:
model_test(exe, gw, gf, infer_program, y_pred, PeMS, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--n_route', type=int, default=74)
parser.add_argument('--n_his', type=int, default=23)
parser.add_argument('--n_pred', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--save', type=int, default=10)
parser.add_argument('--Ks', type=int, default=3) #equal to num_layers
parser.add_argument('--Kt', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-2)
parser.add_argument('--keep_prob', type=float, default=1.0)
parser.add_argument('--opt', type=str, default='RMSProp')
parser.add_argument('--inf_mode', type=str, default='sep')
parser.add_argument('--input_file', type=str, default='dataset/input.csv')
parser.add_argument('--label_file', type=str, default='dataset/output.csv')
parser.add_argument(
'--city_file', type=str, default='dataset/crawl_list.csv')
parser.add_argument('--adj_mat_file', type=str, default='dataset/W_74.csv')
parser.add_argument('--output_path', type=str, default='./outputs/')
parser.add_argument('--n_val', type=str, default=1)
parser.add_argument('--n_test', type=str, default=1)
parser.add_argument('--use_cuda', action='store_true')
args = parser.parse_args()
blocks = [[1, 32, 64], [64, 32, 128]]
args.blocks = blocks
log.info(args)
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file implement the STGCN model.
"""
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as fl
import pgl
class STGCNModel(object):
"""Implementation of Spatio-Temporal Graph Convolutional Networks"""
def __init__(self, args, gw):
self.args = args
self.gw = gw
self.input = fl.data(
name="input",
shape=[None, args.n_his + 1, args.n_route, 1],
dtype="float32")
def forward(self):
"""forward"""
x = self.input[:, 0:self.args.n_his, :, :]
# Ko>0: kernel size of temporal convolution in the output layer.
Ko = self.args.n_his
# ST-Block
for i, channels in enumerate(self.args.blocks):
x = self.st_conv_block(
x,
self.args.Ks,
self.args.Kt,
channels,
"st_conv_%d" % i,
self.args.keep_prob,
act_func='GLU')
# output layer
if Ko > 1:
y = self.output_layer(x, Ko, 'output_layer')
else:
raise ValueError(f'ERROR: kernel size Ko must be greater than 1, \
but received "{Ko}".')
label = self.input[:, self.args.n_his:self.args.n_his + 1, :, :]
train_loss = fl.reduce_sum((y - label) * (y - label))
single_pred = y[:, 0, :, :] # shape: [batch, n, 1]
return train_loss, single_pred
def st_conv_block(self,
x,
Ks,
Kt,
channels,
name,
keep_prob,
act_func='GLU'):
"""Spatio-Temporal convolution block"""
c_si, c_t, c_oo = channels
x_s = self.temporal_conv_layer(
x, Kt, c_si, c_t, "%s_tconv_in" % name, act_func=act_func)
x_t = self.spatio_conv_layer(x_s, Ks, c_t, c_t, "%s_sonv" % name)
x_o = self.temporal_conv_layer(x_t, Kt, c_t, c_oo,
"%s_tconv_out" % name)
x_ln = fl.layer_norm(x_o)
return fl.dropout(x_ln, dropout_prob=(1.0 - keep_prob))
def temporal_conv_layer(self, x, Kt, c_in, c_out, name, act_func='relu'):
"""Temporal convolution layer"""
_, T, n, _ = x.shape
if c_in > c_out:
x_input = fl.conv2d(
input=x,
num_filters=c_out,
filter_size=[1, 1],
stride=[1, 1],
padding="SAME",
data_format="NHWC",
param_attr=fluid.ParamAttr(name="%s_conv2d_1" % name))
elif c_in < c_out:
# if the size of input channel is less than the output,
# padding x to the same size of output channel.
pad = fl.fill_constant_batch_size_like(
input=x,
shape=[-1, T, n, c_out - c_in],
dtype="float32",
value=0.0)
x_input = fl.concat([x, pad], axis=3)
else:
x_input = x
# x_input = x_input[:, Kt - 1:T, :, :]
if act_func == 'GLU':
# gated liner unit
bt_init = fluid.initializer.ConstantInitializer(value=0.0)
bt = fl.create_parameter(
shape=[2 * c_out],
dtype="float32",
attr=fluid.ParamAttr(
name="%s_bt" % name, trainable=True, initializer=bt_init),
)
x_conv = fl.conv2d(
input=x,
num_filters=2 * c_out,
filter_size=[Kt, 1],
stride=[1, 1],
padding="SAME",
data_format="NHWC",
param_attr=fluid.ParamAttr(name="%s_conv2d_wt" % name))
x_conv = x_conv + bt
return (x_conv[:, :, :, 0:c_out] + x_input
) * fl.sigmoid(x_conv[:, :, :, -c_out:])
else:
bt_init = fluid.initializer.ConstantInitializer(value=0.0)
bt = fl.create_parameter(
shape=[c_out],
dtype="float32",
attr=fluid.ParamAttr(
name="%s_bt" % name, trainable=True, initializer=bt_init),
)
x_conv = fl.conv2d(
input=x,
num_filters=c_out,
filter_size=[Kt, 1],
stride=[1, 1],
padding="SAME",
data_format="NHWC",
param_attr=fluid.ParamAttr(name="%s_conv2d_wt" % name))
x_conv = x_conv + bt
if act_func == "linear":
return x_conv
elif act_func == "sigmoid":
return fl.sigmoid(x_conv)
elif act_func == "relu":
return fl.relu(x_conv + x_input)
else:
raise ValueError(
f'ERROR: activation function "{act_func}" is not defined.')
def spatio_conv_layer(self, x, Ks, c_in, c_out, name):
"""Spatio convolution layer"""
_, T, n, _ = x.shape
if c_in > c_out:
x_input = fl.conv2d(
input=x,
num_filters=c_out,
filter_size=[1, 1],
stride=[1, 1],
padding="SAME",
data_format="NHWC",
param_attr=fluid.ParamAttr(name="%s_conv2d_1" % name))
elif c_in < c_out:
# if the size of input channel is less than the output,
# padding x to the same size of output channel.
pad = fl.fill_constant_batch_size_like(
input=x,
shape=[-1, T, n, c_out - c_in],
dtype="float32",
value=0.0)
x_input = fl.concat([x, pad], axis=3)
else:
x_input = x
for i in range(Ks):
# x_input shape: [B,T, num_nodes, c_out]
x_input = fl.reshape(x_input, [-1, c_out])
x_input = self.message_passing(
self.gw,
x_input,
name="%s_mp_%d" % (name, i),
norm=self.gw.node_feat["norm"])
x_input = fl.fc(x_input,
size=c_out,
bias_attr=False,
param_attr=fluid.ParamAttr(name="%s_gcn_fc_%d" %
(name, i)))
bias = fluid.layers.create_parameter(
shape=[c_out],
dtype='float32',
is_bias=True,
name='%s_gcn_bias_%d' % (name, i))
x_input = fluid.layers.elementwise_add(x_input, bias, act="relu")
x_input = fl.reshape(x_input, [-1, T, n, c_out])
return x_input
def message_passing(self, gw, feature, name, norm=None):
"""Message passing layer"""
def send_src_copy(src_feat, dst_feat, edge_feat):
"""send function"""
return src_feat["h"] * edge_feat['w']
if norm is not None:
feature = feature * norm
msg = gw.send(
send_src_copy,
nfeat_list=[("h", feature)],
efeat_list=[('w', gw.edge_feat['weights'])])
output = gw.recv(msg, "sum")
if norm is not None:
output = output * norm
return output
def output_layer(self, x, T, name, act_func='GLU'):
"""Output layer"""
_, _, n, channel = x.shape
# maps multi-steps to one.
x_i = self.temporal_conv_layer(
x=x,
Kt=T,
c_in=channel,
c_out=channel,
name="%s_in" % name,
act_func=act_func)
x_ln = fl.layer_norm(x_i)
x_o = self.temporal_conv_layer(
x=x_ln,
Kt=1,
c_in=channel,
c_out=channel,
name="%s_out" % name,
act_func='sigmoid')
# maps multi-channels to one.
x_fc = self.fully_con_layer(
x=x_o, n=n, channel=channel, name="%s_fc" % name)
return x_fc
def fully_con_layer(self, x, n, channel, name):
"""Fully connected layer"""
bt_init = fluid.initializer.ConstantInitializer(value=0.0)
bt = fl.create_parameter(
shape=[n, 1],
dtype="float32",
attr=fluid.ParamAttr(
name="%s_bt" % name, trainable=True, initializer=bt_init), )
x_conv = fl.conv2d(
input=x,
num_filters=1,
filter_size=[1, 1],
stride=[1, 1],
padding="SAME",
data_format="NHWC",
param_attr=fluid.ParamAttr(name="%s_conv2d" % name))
x_conv = x_conv + bt
return x_conv
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file implement the testing process of STGCN model.
"""
import os
import sys
import time
import argparse
import numpy as np
import pandas as pd
import paddle.fluid as fluid
import paddle.fluid.layers as fl
import pgl
from pgl.utils.logger import log
from data_loader.data_utils import gen_batch
from utils.math_utils import evaluation
def multi_pred(exe, gw, gf, program, y_pred, seq, batch_size, \
n_his, n_pred, step_idx, dynamic_batch=True):
"""multi step prediction"""
pred_list = []
for i in gen_batch(
seq, min(batch_size, len(seq)), dynamic_batch=dynamic_batch):
# Note: use np.copy() to avoid the modification of source data.
test_seq = np.copy(i[:, 0:n_his + 1, :, :]).astype(np.float32)
graph = gf.build_graph(i[:, 0:n_his, :, :])
feed = gw.to_feed(graph)
step_list = []
for j in range(n_pred):
feed['input'] = test_seq
pred = exe.run(program, feed=feed, fetch_list=[y_pred])
if isinstance(pred, list):
pred = np.array(pred[0])
test_seq[:, 0:n_his - 1, :, :] = test_seq[:, 1:n_his, :, :]
test_seq[:, n_his - 1, :, :] = pred
step_list.append(pred)
pred_list.append(step_list)
# pred_array -> [n_pred, len(seq), n_route, C_0)
pred_array = np.concatenate(pred_list, axis=1)
return pred_array, pred_array.shape[1]
def model_inference(exe, gw, gf, program, pred, inputs, args, step_idx,
min_va_val, min_val):
"""inference model"""
x_val, x_test, x_stats = inputs.get_data('val'), inputs.get_data(
'test'), inputs.get_stats()
if args.n_his + args.n_pred > x_val.shape[1]:
raise ValueError(
f'ERROR: the value of n_pred "{args.n_pred}" exceeds the length limit.'
)
# y_val shape: [n_pred, len(x_val), n_route, C_0)
y_val, len_val = multi_pred(exe, gw, gf, program, pred, \
x_val, args.batch_size, args.n_his, args.n_pred, step_idx)
evl_val = evaluation(x_val[0:len_val, step_idx + args.n_his, :, :],
y_val[step_idx], x_stats)
# chks: indicator that reflects the relationship of values between evl_val and min_va_val.
chks = evl_val < min_va_val
# update the metric on test set, if model's performance got improved on the validation.
if sum(chks):
min_va_val[chks] = evl_val[chks]
y_pred, len_pred = multi_pred(exe, gw, gf, program, pred, \
x_test, args.batch_size, args.n_his, args.n_pred, step_idx)
evl_pred = evaluation(x_test[0:len_pred, step_idx + args.n_his, :, :],
y_pred[step_idx], x_stats)
min_val = evl_pred
return min_va_val, min_val
def model_test(exe, gw, gf, program, pred, inputs, args):
"""test model"""
if args.inf_mode == 'sep':
# for inference mode 'sep', the type of step index is int.
step_idx = args.n_pred - 1
tmp_idx = [step_idx]
elif args.inf_mode == 'merge':
# for inference mode 'merge', the type of step index is np.ndarray.
step_idx = tmp_idx = np.arange(3, args.n_pred + 1, 3) - 1
print(step_idx)
else:
raise ValueError(f'ERROR: test mode "{args.inf_mode}" is not defined.')
x_test, x_stats = inputs.get_data('test'), inputs.get_stats()
y_test, len_test = multi_pred(exe, gw, gf, program, pred, \
x_test, args.batch_size, args.n_his, args.n_pred, step_idx)
# save result
gt = x_test[0:len_test, args.n_his:, :, :].reshape(-1, args.n_route)
y_pred = y_test.reshape(-1, args.n_route)
city_df = pd.read_csv(args.city_file)
city_df = city_df.drop(0)
np.savetxt(
os.path.join(args.output_path, "groundtruth.csv"),
gt.astype(np.int32),
fmt='%d',
delimiter=',',
header=",".join(city_df['city']))
np.savetxt(
os.path.join(args.output_path, "prediction.csv"),
y_pred.astype(np.int32),
fmt='%d',
delimiter=",",
header=",".join(city_df['city']))
for i in range(step_idx + 1):
evl = evaluation(x_test[0:len_test, step_idx + args.n_his, :, :],
y_test[i], x_stats)
for ix in tmp_idx:
te = evl[ix - 2:ix + 1]
print(
f'Time Step {i + 1}: MAPE {te[0]:7.3%}; MAE {te[1]:4.3f}; RMSE {te[2]:6.3f}.'
)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation"""
import os
import sys
import time
import argparse
import numpy as np
def z_score(x, mean, std):
"""z_score"""
return (x - mean) / std
def z_inverse(x, mean, std):
"""The inverse of function z_score"""
return x * std + mean
def MAPE(v, v_):
"""Mean absolute percentage error."""
return np.mean(np.abs(v_ - v) / (v + 1e-5))
def RMSE(v, v_):
"""Mean squared error."""
return np.sqrt(np.mean((v_ - v)**2))
def MAE(v, v_):
"""Mean absolute error."""
return np.mean(np.abs(v_ - v))
def evaluation(y, y_, x_stats):
"""Calculate MAPE, MAE and RMSE between ground truth and prediction."""
dim = len(y_.shape)
if dim == 3:
# single_step case
v = z_inverse(y, x_stats['mean'], x_stats['std'])
v_ = z_inverse(y_, x_stats['mean'], x_stats['std'])
return np.array([MAPE(v, v_), MAE(v, v_), RMSE(v, v_)])
else:
# multi_step case
tmp_list = []
# y -> [time_step, batch_size, n_route, 1]
y = np.swapaxes(y, 0, 1)
# recursively call
for i in range(y_.shape[0]):
tmp_res = evaluation(y[i], y_[i], x_stats)
tmp_list.append(tmp_res)
return np.concatenate(tmp_list, axis=-1)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test ogb
"""
import argparse
import pgl
import numpy as np
import paddle.fluid as fluid
from pgl.contrib.ogb.linkproppred.dataset_pgl import PglLinkPropPredDataset
from pgl.utils import paddle_helper
from ogb.linkproppred import Evaluator
def send_func(src_feat, dst_feat, edge_feat):
"""send_func"""
return src_feat["h"]
def recv_func(feat):
"""recv_func"""
return fluid.layers.sequence_pool(feat, pool_type="sum")
class GNNModel(object):
"""GNNModel"""
def __init__(self, name, num_nodes, emb_dim, num_layers):
self.num_nodes = num_nodes
self.emb_dim = emb_dim
self.num_layers = num_layers
self.name = name
self.src_nodes = fluid.layers.data(
name='src_nodes',
shape=[None, 1],
dtype='int64', )
self.dst_nodes = fluid.layers.data(
name='dst_nodes',
shape=[None, 1],
dtype='int64', )
self.edge_label = fluid.layers.data(
name='edge_label',
shape=[None, 1],
dtype='float32', )
def forward(self, graph):
"""forward"""
h = fluid.layers.create_parameter(
shape=[self.num_nodes, self.emb_dim],
dtype="float32",
name=self.name + "_embedding")
# edge_attr = fluid.layers.fc(graph.edge_feat["feat"], size=self.emb_dim)
for layer in range(self.num_layers):
msg = graph.send(
send_func,
nfeat_list=[("h", h)], )
h = graph.recv(msg, recv_func)
h = fluid.layers.fc(
h,
size=self.emb_dim,
bias_attr=False,
param_attr=fluid.ParamAttr(name=self.name + '_%s' % layer))
h = h * graph.node_feat["norm"]
bias = fluid.layers.create_parameter(
shape=[self.emb_dim],
dtype='float32',
is_bias=True,
name=self.name + '_bias_%s' % layer)
h = fluid.layers.elementwise_add(h, bias, act="relu")
src = fluid.layers.gather(h, self.src_nodes)
dst = fluid.layers.gather(h, self.dst_nodes)
edge_embed = src * dst
pred = fluid.layers.fc(input=edge_embed,
size=1,
name=self.name + "_pred_output")
prob = fluid.layers.sigmoid(pred)
loss = fluid.layers.sigmoid_cross_entropy_with_logits(pred,
self.edge_label)
loss = fluid.layers.reduce_mean(loss)
return pred, prob, loss
def main():
"""main
"""
# Training settings
parser = argparse.ArgumentParser(description='Graph Dataset')
parser.add_argument(
'--epochs',
type=int,
default=100,
help='number of epochs to train (default: 100)')
parser.add_argument(
'--dataset',
type=str,
default="ogbl-ppa",
help='dataset name (default: protein protein associations)')
args = parser.parse_args()
#place = fluid.CUDAPlace(0)
place = fluid.CPUPlace() # Dataset too big to use GPU
### automatic dataloading and splitting
print("loadding dataset")
dataset = PglLinkPropPredDataset(name=args.dataset)
splitted_edge = dataset.get_edge_split()
print(splitted_edge['train_edge'].shape)
print(splitted_edge['train_edge_label'].shape)
print("building evaluator")
### automatic evaluator. takes dataset name as input
evaluator = Evaluator(args.dataset)
graph_data = dataset[0]
print("num_nodes: %d" % graph_data.num_nodes)
train_program = fluid.Program()
startup_program = fluid.Program()
test_program = fluid.Program()
# degree normalize
indegree = graph_data.indegree()
norm = np.zeros_like(indegree, dtype="float32")
norm[indegree > 0] = np.power(indegree[indegree > 0], -0.5)
graph_data.node_feat["norm"] = np.expand_dims(norm, -1).astype("float32")
with fluid.program_guard(train_program, startup_program):
model = GNNModel(
name="gnn",
num_nodes=graph_data.num_nodes,
emb_dim=64,
num_layers=2)
gw = pgl.graph_wrapper.GraphWrapper(
"graph",
place,
node_feat=graph_data.node_feat_info(),
edge_feat=graph_data.edge_feat_info())
pred, prob, loss = model.forward(gw)
val_program = train_program.clone(for_test=True)
with fluid.program_guard(train_program, startup_program):
adam = fluid.optimizer.Adam(
learning_rate=1e-2,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0005))
adam.minimize(loss)
exe = fluid.Executor(place)
exe.run(startup_program)
feed = gw.to_feed(graph_data)
for epoch in range(1, args.epochs + 1):
feed['src_nodes'] = splitted_edge["train_edge"][:, 0].reshape(-1, 1)
feed['dst_nodes'] = splitted_edge["train_edge"][:, 1].reshape(-1, 1)
feed['edge_label'] = splitted_edge["train_edge_label"].astype(
"float32").reshape(-1, 1)
res_loss, y_pred = exe.run(train_program,
feed=feed,
fetch_list=[loss, prob])
print("Loss %s" % res_loss[0])
result = {}
print("Evaluating...")
feed['src_nodes'] = splitted_edge["valid_edge"][:, 0].reshape(-1, 1)
feed['dst_nodes'] = splitted_edge["valid_edge"][:, 1].reshape(-1, 1)
feed['edge_label'] = splitted_edge["valid_edge_label"].astype(
"float32").reshape(-1, 1)
y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0]
input_dict = {
"y_true": splitted_edge["valid_edge_label"],
"y_pred": y_pred.reshape(-1, ),
}
result["valid"] = evaluator.eval(input_dict)
feed['src_nodes'] = splitted_edge["test_edge"][:, 0].reshape(-1, 1)
feed['dst_nodes'] = splitted_edge["test_edge"][:, 1].reshape(-1, 1)
feed['edge_label'] = splitted_edge["test_edge_label"].astype(
"float32").reshape(-1, 1)
y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0]
input_dict = {
"y_true": splitted_edge["test_edge_label"],
"y_pred": y_pred.reshape(-1, ),
}
result["test"] = evaluator.eval(input_dict)
print(result)
if __name__ == "__main__":
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test ogb
"""
import argparse
import pgl
import numpy as np
import paddle.fluid as fluid
from pgl.contrib.ogb.nodeproppred.dataset_pgl import PglNodePropPredDataset
from pgl.utils import paddle_helper
from ogb.nodeproppred import Evaluator
def train():
pass
def send_func(src_feat, dst_feat, edge_feat):
return (src_feat["h"] + edge_feat["h"]) * src_feat["norm"]
class GNNModel(object):
def __init__(self, name, emb_dim, num_task, num_layers):
self.num_task = num_task
self.emb_dim = emb_dim
self.num_layers = num_layers
self.name = name
def forward(self, graph):
h = fluid.layers.embedding(
graph.node_feat["x"],
size=(2, self.emb_dim)) # name=self.name + "_embedding")
edge_attr = fluid.layers.fc(graph.edge_feat["feat"], size=self.emb_dim)
for layer in range(self.num_layers):
msg = graph.send(
send_func,
nfeat_list=[("h", h), ("norm", graph.node_feat["norm"])],
efeat_list=[("h", edge_attr)])
h = graph.recv(msg, "sum")
h = fluid.layers.fc(
h,
size=self.emb_dim,
bias_attr=False,
param_attr=fluid.ParamAttr(name=self.name + '_%s' % layer))
h = h * graph.node_feat["norm"]
bias = fluid.layers.create_parameter(
shape=[self.emb_dim],
dtype='float32',
is_bias=True,
name=self.name + '_bias_%s' % layer)
h = fluid.layers.elementwise_add(h, bias, act="relu")
pred = fluid.layers.fc(h,
self.num_task,
act=None,
name=self.name + "_pred_output")
return pred
def main():
"""main
"""
# Training settings
parser = argparse.ArgumentParser(description='Graph Dataset')
parser.add_argument(
'--epochs',
type=int,
default=100,
help='number of epochs to train (default: 100)')
parser.add_argument(
'--dataset',
type=str,
default="ogbn-proteins",
help='dataset name (default: proteinfunc)')
args = parser.parse_args()
#device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
#place = fluid.CUDAPlace(0)
place = fluid.CPUPlace() # Dataset too big to use GPU
### automatic dataloading and splitting
dataset = PglNodePropPredDataset(name=args.dataset)
splitted_idx = dataset.get_idx_split()
### automatic evaluator. takes dataset name as input
evaluator = Evaluator(args.dataset)
graph_data, label = dataset[0]
train_program = fluid.Program()
startup_program = fluid.Program()
test_program = fluid.Program()
# degree normalize
indegree = graph_data.indegree()
norm = np.zeros_like(indegree, dtype="float32")
norm[indegree > 0] = np.power(indegree[indegree > 0], -0.5)
graph_data.node_feat["norm"] = np.expand_dims(norm, -1).astype("float32")
graph_data.node_feat["x"] = np.zeros((len(indegree), 1), dtype="int64")
graph_data.edge_feat["feat"] = graph_data.edge_feat["feat"].astype(
"float32")
model = GNNModel(
name="gnn", num_task=dataset.num_tasks, emb_dim=64, num_layers=2)
with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.StaticGraphWrapper("graph", graph_data, place)
pred = model.forward(gw)
sigmoid_pred = fluid.layers.sigmoid(pred)
val_program = train_program.clone(for_test=True)
initializer = []
with fluid.program_guard(train_program, startup_program):
train_node_index, init = paddle_helper.constant(
"train_node_index", dtype="int64", value=splitted_idx["train"])
initializer.append(init)
train_node_label, init = paddle_helper.constant(
"train_node_label",
dtype="float32",
value=label[splitted_idx["train"]].astype("float32"))
initializer.append(init)
train_pred_t = fluid.layers.gather(pred, train_node_index)
train_loss_t = fluid.layers.sigmoid_cross_entropy_with_logits(
x=train_pred_t, label=train_node_label)
train_loss_t = fluid.layers.reduce_sum(train_loss_t)
train_pred_t = fluid.layers.sigmoid(train_pred_t)
adam = fluid.optimizer.Adam(
learning_rate=1e-2,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0005))
adam.minimize(train_loss_t)
exe = fluid.Executor(place)
exe.run(startup_program)
gw.initialize(place)
for init in initializer:
init(place)
for epoch in range(1, args.epochs + 1):
loss = exe.run(train_program, feed={}, fetch_list=[train_loss_t])
print("Loss %s" % loss[0])
print("Evaluating...")
y_pred = exe.run(val_program, feed={}, fetch_list=[sigmoid_pred])[0]
result = {}
input_dict = {
"y_true": label[splitted_idx["train"]],
"y_pred": y_pred[splitted_idx["train"]]
}
result["train"] = evaluator.eval(input_dict)
input_dict = {
"y_true": label[splitted_idx["valid"]],
"y_pred": y_pred[splitted_idx["valid"]]
}
result["valid"] = evaluator.eval(input_dict)
input_dict = {
"y_true": label[splitted_idx["test"]],
"y_pred": y_pred[splitted_idx["test"]]
}
result["test"] = evaluator.eval(input_dict)
print(result)
if __name__ == "__main__":
main()
......@@ -18,4 +18,6 @@ from pgl import layers
from pgl import graph_wrapper
from pgl import graph
from pgl import data_loader
from pgl import heter_graph
from pgl import heter_graph_wrapper
from pgl import contrib
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""__init__.py"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PglGraphPropPredDataset
"""
import pandas as pd
import shutil, os
import os.path as osp
import numpy as np
from ogb.utils.url import decide_download, download_url, extract_zip
from ogb.graphproppred import make_master_file
from pgl.contrib.ogb.io.read_graph_pgl import read_csv_graph_pgl
def to_bool(value):
"""to_bool"""
return np.array([value], dtype="bool")[0]
class PglGraphPropPredDataset(object):
"""PglGraphPropPredDataset"""
def __init__(self, name, root="dataset"):
self.name = name ## original name, e.g., ogbg-mol-tox21
self.dir_name = "_".join(
name.split("-")
) + "_pgl" ## replace hyphen with underline, e.g., ogbg_mol_tox21_dgl
self.original_root = root
self.root = osp.join(root, self.dir_name)
self.meta_info = make_master_file.df #pd.read_csv(
#os.path.join(os.path.dirname(__file__), "master.csv"), index_col=0)
if not self.name in self.meta_info:
print(self.name)
error_mssg = "Invalid dataset name {}.\n".format(self.name)
error_mssg += "Available datasets are as follows:\n"
error_mssg += "\n".join(self.meta_info.keys())
raise ValueError(error_mssg)
self.download_name = self.meta_info[self.name][
"download_name"] ## name of downloaded file, e.g., tox21
self.num_tasks = int(self.meta_info[self.name]["num tasks"])
self.task_type = self.meta_info[self.name]["task type"]
super(PglGraphPropPredDataset, self).__init__()
self.pre_process()
def pre_process(self):
"""Pre-processing"""
processed_dir = osp.join(self.root, 'processed')
raw_dir = osp.join(self.root, 'raw')
pre_processed_file_path = osp.join(processed_dir, 'pgl_data_processed')
if os.path.exists(pre_processed_file_path):
# TODO: Load Preprocessed
pass
else:
### download
url = self.meta_info[self.name]["url"]
if decide_download(url):
path = download_url(url, self.original_root)
extract_zip(path, self.original_root)
os.unlink(path)
# delete folder if there exists
try:
shutil.rmtree(self.root)
except:
pass
shutil.move(
osp.join(self.original_root, self.download_name),
self.root)
else:
print("Stop download.")
exit(-1)
### preprocess
add_inverse_edge = to_bool(self.meta_info[self.name][
"add_inverse_edge"])
self.graphs = read_csv_graph_pgl(
raw_dir, add_inverse_edge=add_inverse_edge)
self.graphs = np.array(self.graphs)
self.labels = np.array(
pd.read_csv(
osp.join(raw_dir, "graph-label.csv.gz"),
compression="gzip",
header=None).values)
# TODO: Load Graph
### load preprocessed files
def get_idx_split(self):
"""Train/Valid/Test split"""
split_type = self.meta_info[self.name]["split"]
path = osp.join(self.root, "split", split_type)
train_idx = pd.read_csv(
osp.join(path, "train.csv.gz"), compression="gzip",
header=None).values.T[0]
valid_idx = pd.read_csv(
osp.join(path, "valid.csv.gz"), compression="gzip",
header=None).values.T[0]
test_idx = pd.read_csv(
osp.join(path, "test.csv.gz"), compression="gzip",
header=None).values.T[0]
return {
"train": np.array(
train_idx, dtype="int64"),
"valid": np.array(
valid_idx, dtype="int64"),
"test": np.array(
test_idx, dtype="int64")
}
def __getitem__(self, idx):
"""Get datapoint with index"""
return self.graphs[idx], self.labels[idx]
def __len__(self):
"""Length of the dataset
Returns
-------
int
Length of Dataset
"""
return len(self.graphs)
def __repr__(self): # pragma: no cover
return '{}({})'.format(self.__class__.__name__, len(self))
if __name__ == "__main__":
pgl_dataset = PglGraphPropPredDataset(name="ogbg-mol-bace")
splitted_index = pgl_dataset.get_idx_split()
print(pgl_dataset)
print(pgl_dataset[3:20])
#print(pgl_dataset[splitted_index["train"]])
#print(pgl_dataset[splitted_index["valid"]])
#print(pgl_dataset[splitted_index["test"]])
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,8 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate Contrib api
"""__init__.py
"""
from pgl.contrib import heter_graph
from pgl.contrib import heter_graph_wrapper
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""pgl read_csv_graph for ogb
"""
import pandas as pd
import os.path as osp
import numpy as np
import pgl
from ogb.io.read_graph_raw import read_csv_graph_raw
def read_csv_graph_pgl(raw_dir, add_inverse_edge=False):
"""Read CSV data and build PGL Graph
"""
graph_list = read_csv_graph_raw(raw_dir, add_inverse_edge)
pgl_graph_list = []
for graph in graph_list:
edges = list(zip(graph["edge_index"][0], graph["edge_index"][1]))
g = pgl.graph.Graph(num_nodes=graph["num_nodes"], edges=edges)
if graph["edge_feat"] is not None:
g.edge_feat["feat"] = graph["edge_feat"]
if graph["node_feat"] is not None:
g.node_feat["feat"] = graph["node_feat"]
pgl_graph_list.append(g)
return pgl_graph_list
if __name__ == "__main__":
# graph_list = read_csv_graph_dgl('dataset/proteinfunc_v2/raw', add_inverse_edge = True)
graph_list = read_csv_graph_pgl(
'dataset/ogbn_proteins_pgl/raw', add_inverse_edge=True)
print(graph_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""__init__.py
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LinkPropPredDataset for pgl
"""
import pandas as pd
import shutil, os
import os.path as osp
import numpy as np
from ogb.utils.url import decide_download, download_url, extract_zip
from ogb.linkproppred import make_master_file
from pgl.contrib.ogb.io.read_graph_pgl import read_csv_graph_pgl
def to_bool(value):
"""to_bool"""
return np.array([value], dtype="bool")[0]
class PglLinkPropPredDataset(object):
"""PglLinkPropPredDataset
"""
def __init__(self, name, root="dataset"):
self.name = name ## original name, e.g., ogbl-ppa
self.dir_name = "_".join(name.split(
"-")) + "_pgl" ## replace hyphen with underline, e.g., ogbl_ppa_pgl
self.original_root = root
self.root = osp.join(root, self.dir_name)
self.meta_info = make_master_file.df #pd.read_csv(os.path.join(os.path.dirname(__file__), "master.csv"), index_col=0)
if not self.name in self.meta_info:
print(self.name)
error_mssg = "Invalid dataset name {}.\n".format(self.name)
error_mssg += "Available datasets are as follows:\n"
error_mssg += "\n".join(self.meta_info.keys())
raise ValueError(error_mssg)
self.download_name = self.meta_info[self.name][
"download_name"] ## name of downloaded file, e.g., ppassoc
self.task_type = self.meta_info[self.name]["task type"]
super(PglLinkPropPredDataset, self).__init__()
self.pre_process()
def pre_process(self):
"""pre_process downlaoding data
"""
processed_dir = osp.join(self.root, 'processed')
pre_processed_file_path = osp.join(processed_dir, 'dgl_data_processed')
if osp.exists(pre_processed_file_path):
#TODO: Reload Preprocess files
pass
else:
### check download
if not osp.exists(osp.join(self.root, "raw", "edge.csv.gz")):
url = self.meta_info[self.name]["url"]
if decide_download(url):
path = download_url(url, self.original_root)
extract_zip(path, self.original_root)
os.unlink(path)
# delete folder if there exists
try:
shutil.rmtree(self.root)
except:
pass
shutil.move(
osp.join(self.original_root, self.download_name),
self.root)
else:
print("Stop download.")
exit(-1)
raw_dir = osp.join(self.root, "raw")
### pre-process and save
add_inverse_edge = to_bool(self.meta_info[self.name][
"add_inverse_edge"])
self.graph = read_csv_graph_pgl(
raw_dir, add_inverse_edge=add_inverse_edge)
#TODO: SAVE preprocess graph
def get_edge_split(self):
"""Train/Validation/Test split
"""
split_type = self.meta_info[self.name]["split"]
path = osp.join(self.root, "split", split_type)
train_idx = pd.read_csv(
osp.join(path, "train.csv.gz"), compression="gzip",
header=None).values
valid_idx = pd.read_csv(
osp.join(path, "valid.csv.gz"), compression="gzip",
header=None).values
test_idx = pd.read_csv(
osp.join(path, "test.csv.gz"), compression="gzip",
header=None).values
if self.task_type == "link prediction":
target_type = np.int64
else:
target_type = np.float32
return {
"train_edge": np.array(
train_idx[:, :2], dtype="int64"),
"train_edge_label": np.array(
train_idx[:, 2], dtype=target_type),
"valid_edge": np.array(
valid_idx[:, :2], dtype="int64"),
"valid_edge_label": np.array(
valid_idx[:, 2], dtype=target_type),
"test_edge": np.array(
test_idx[:, :2], dtype="int64"),
"test_edge_label": np.array(
test_idx[:, 2], dtype=target_type)
}
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self.graph[0]
def __len__(self):
return 1
def __repr__(self): # pragma: no cover
return '{}({})'.format(self.__class__.__name__, len(self))
if __name__ == "__main__":
pgl_dataset = PglLinkPropPredDataset(name="ogbl-ppa")
splitted_edge = pgl_dataset.get_edge_split()
print(pgl_dataset[0])
print(splitted_edge)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""__init__.py
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NodePropPredDataset for pgl
"""
import pandas as pd
import shutil, os
import os.path as osp
import numpy as np
from ogb.utils.url import decide_download, download_url, extract_zip
from ogb.nodeproppred import make_master_file # create master.csv
from pgl.contrib.ogb.io.read_graph_pgl import read_csv_graph_pgl
def to_bool(value):
"""to_bool"""
return np.array([value], dtype="bool")[0]
class PglNodePropPredDataset(object):
"""PglNodePropPredDataset
"""
def __init__(self, name, root="dataset"):
self.name = name ## original name, e.g., ogbn-proteins
self.dir_name = "_".join(
name.split("-")
) + "_pgl" ## replace hyphen with underline, e.g., ogbn_proteins_pgl
self.original_root = root
self.root = osp.join(root, self.dir_name)
self.meta_info = make_master_file.df #pd.read_csv(
#os.path.join(os.path.dirname(__file__), "master.csv"), index_col=0)
if not self.name in self.meta_info:
error_mssg = "Invalid dataset name {}.\n".format(self.name)
error_mssg += "Available datasets are as follows:\n"
error_mssg += "\n".join(self.meta_info.keys())
raise ValueError(error_mssg)
self.download_name = self.meta_info[self.name][
"download_name"] ## name of downloaded file, e.g., tox21
self.num_tasks = int(self.meta_info[self.name]["num tasks"])
self.task_type = self.meta_info[self.name]["task type"]
super(PglNodePropPredDataset, self).__init__()
self.pre_process()
def pre_process(self):
"""pre_process downlaoding data
"""
processed_dir = osp.join(self.root, 'processed')
pre_processed_file_path = osp.join(processed_dir, 'pgl_data_processed')
if osp.exists(pre_processed_file_path):
# TODO: Reload Preprocess files
pass
else:
### check download
if not osp.exists(osp.join(self.root, "raw", "edge.csv.gz")):
url = self.meta_info[self.name]["url"]
if decide_download(url):
path = download_url(url, self.original_root)
extract_zip(path, self.original_root)
os.unlink(path)
# delete folder if there exists
try:
shutil.rmtree(self.root)
except:
pass
shutil.move(
osp.join(self.original_root, self.download_name),
self.root)
else:
print("Stop download.")
exit(-1)
raw_dir = osp.join(self.root, "raw")
### pre-process and save
add_inverse_edge = to_bool(self.meta_info[self.name][
"add_inverse_edge"])
self.graph = read_csv_graph_pgl(
raw_dir, add_inverse_edge=add_inverse_edge)
### adding prediction target
node_label = pd.read_csv(
osp.join(raw_dir, 'node-label.csv.gz'),
compression="gzip",
header=None).values
if "classification" in self.task_type:
node_label = np.array(node_label, dtype=np.int64)
else:
node_label = np.array(node_label, dtype=np.float32)
label_dict = {"labels": node_label}
# TODO: SAVE preprocess graph
self.labels = label_dict['labels']
def get_idx_split(self):
"""Train/Validation/Test split
"""
split_type = self.meta_info[self.name]["split"]
path = osp.join(self.root, "split", split_type)
train_idx = pd.read_csv(
osp.join(path, "train.csv.gz"), compression="gzip",
header=None).values.T[0]
valid_idx = pd.read_csv(
osp.join(path, "valid.csv.gz"), compression="gzip",
header=None).values.T[0]
test_idx = pd.read_csv(
osp.join(path, "test.csv.gz"), compression="gzip",
header=None).values.T[0]
return {
"train": np.array(
train_idx, dtype="int64"),
"valid": np.array(
valid_idx, dtype="int64"),
"test": np.array(
test_idx, dtype="int64")
}
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self.graph[idx], self.labels
def __len__(self):
return 1
def __repr__(self): # pragma: no cover
return '{}({})'.format(self.__class__.__name__, len(self))
if __name__ == "__main__":
pgl_dataset = PglNodePropPredDataset(name="ogbn-proteins")
splitted_index = pgl_dataset.get_idx_split()
print(pgl_dataset[0])
print(splitted_index)
......@@ -15,6 +15,7 @@
This package implement Graph structure for handling graph data.
"""
import os
import numpy as np
import pickle as pkl
import time
......@@ -43,8 +44,8 @@ class EdgeIndex(object):
"""
def __init__(self, u, v, num_nodes):
self._v, self._eid, self._degree, self._sorted_u,\
self._sorted_v, self._sorted_eid = graph_kernel.build_index(u, v, num_nodes)
self._degree, self._sorted_v, self._sorted_u, \
self._sorted_eid, self._indptr = graph_kernel.build_index(u, v, num_nodes)
@property
def degree(self):
......@@ -52,23 +53,40 @@ class EdgeIndex(object):
"""
return self._degree
@property
def v(self):
"""Return the compressed v.
def view_v(self, u=None):
"""Return the compressed v for given u.
"""
return self._v
if u is None:
return np.split(self._sorted_v, self._indptr[1:])
else:
u = np.array(u, dtype="int64")
return graph_kernel.slice_by_index(
self._sorted_v, self._indptr, index=u)
@property
def eid(self):
"""Return the edge id.
def view_eid(self, u=None):
"""Return the compressed edge id for given u.
"""
return self._eid
if u is None:
return np.split(self._sorted_eid, self._indptr[1:])
else:
u = np.array(u, dtype="int64")
return graph_kernel.slice_by_index(
self._sorted_eid, self._indptr, index=u)
def triples(self):
"""Return the sorted (u, v, eid) tuples.
"""
return self._sorted_u, self._sorted_v, self._sorted_eid
def dump(self, path):
if not os.path.exists(path):
os.makedirs(path)
np.save(os.path.join(path, 'degree.npy'), self._degree)
np.save(os.path.join(path, 'sorted_u.npy'), self._sorted_u)
np.save(os.path.join(path, 'sorted_v.npy'), self._sorted_v)
np.save(os.path.join(path, 'sorted_eid.npy'), self._sorted_eid)
np.save(os.path.join(path, 'indptr.npy'), self._indptr)
class Graph(object):
"""Implementation of graph structure in pgl.
......@@ -128,6 +146,31 @@ class Graph(object):
self._adj_src_index = None
self._adj_dst_index = None
def dump(self, path):
if not os.path.exists(path):
os.makedirs(path)
np.save(os.path.join(path, 'num_nodes.npy'), self._num_nodes)
np.save(os.path.join(path, 'edges.npy'), self._edges)
if self._adj_src_index:
self._adj_src_index.dump(os.path.join(path, 'adj_src'))
if self._adj_dst_index:
self._adj_dst_index.dump(os.path.join(path, 'adj_dst'))
def dump_feat(feat_path, feat):
"""Dump all features to .npy file.
"""
if len(feat) == 0:
return
if not os.path.exists(feat_path):
os.makedirs(feat_path)
for key in feat:
np.save(os.path.join(feat_path, key + ".npy"), feat[key])
dump_feat(os.path.join(path, "node_feat"), self.node_feat)
dump_feat(os.path.join(path, "edge_feat"), self.edge_feat)
@property
def adj_src_index(self):
"""Return an EdgeIndex object for src.
......@@ -287,17 +330,11 @@ class Graph(object):
[]]
"""
if nodes is None:
if return_eids:
return self.adj_src_index.v, self.adj_src_index.eid
else:
return self.adj_src_index.v
if return_eids:
return self.adj_src_index.view_v(
nodes), self.adj_src_index.view_eid(nodes)
else:
if return_eids:
return self.adj_src_index.v[nodes], self.adj_src_index.eid[
nodes]
else:
return self.adj_src_index.v[nodes]
return self.adj_src_index.view_v(nodes)
def sample_successor(self,
nodes,
......@@ -385,17 +422,11 @@ class Graph(object):
[2]]
"""
if nodes is None:
if return_eids:
return self.adj_dst_index.v, self.adj_dst_index.eid
else:
return self.adj_dst_index.v
if return_eids:
return self.adj_dst_index.view_v(
nodes), self.adj_dst_index.view_eid(nodes)
else:
if return_eids:
return self.adj_dst_index.v[nodes], self.adj_dst_index.eid[
nodes]
else:
return self.adj_dst_index.v[nodes]
return self.adj_dst_index.view_v(nodes)
def sample_predecessor(self,
nodes,
......@@ -510,7 +541,13 @@ class Graph(object):
(key, _hide_num_nodes(value.shape), value.dtype))
return edge_feat_info
def subgraph(self, nodes, eid=None, edges=None):
def subgraph(self,
nodes,
eid=None,
edges=None,
edge_feats=None,
with_node_feat=True,
with_edge_feat=True):
"""Generate subgraph with nodes and edge ids.
This function will generate a :code:`pgl.graph.Subgraph` object and
......@@ -525,6 +562,10 @@ class Graph(object):
eid (optional): Edge ids which will be included in the subgraph.
edges (optional): Edge(src, dst) list which will be included in the subgraph.
with_node_feat: Whether to inherit node features from parent graph.
with_edge_feat: Whether to inherit edge features from parent graph.
Return:
A :code:`pgl.graph.Subgraph` object.
......@@ -547,14 +588,20 @@ class Graph(object):
len(edges), dtype="int64"), edges, reindex)
sub_edge_feat = {}
for key, value in self._edge_feat.items():
if eid is None:
raise ValueError("Eid can not be None with edge features.")
sub_edge_feat[key] = value[eid]
if edges is None:
if with_edge_feat:
for key, value in self._edge_feat.items():
if eid is None:
raise ValueError(
"Eid can not be None with edge features.")
sub_edge_feat[key] = value[eid]
else:
sub_edge_feat = edge_feats
sub_node_feat = {}
for key, value in self._node_feat.items():
sub_node_feat[key] = value[nodes]
if with_node_feat:
for key, value in self._node_feat.items():
sub_node_feat[key] = value[nodes]
subgraph = SubGraph(
num_nodes=len(nodes),
......@@ -783,3 +830,45 @@ class SubGraph(Graph):
A list of node ids in parent graph.
"""
return graph_kernel.map_nodes(nodes, self._to_reindex)
class MemmapEdgeIndex(EdgeIndex):
def __init__(self, path):
self._degree = np.load(os.path.join(path, 'degree.npy'), mmap_mode="r")
self._sorted_u = np.load(
os.path.join(path, 'sorted_u.npy'), mmap_mode="r")
self._sorted_v = np.load(
os.path.join(path, 'sorted_v.npy'), mmap_mode="r")
self._sorted_eid = np.load(
os.path.join(path, 'sorted_eid.npy'), mmap_mode="r")
self._indptr = np.load(os.path.join(path, 'indptr.npy'), mmap_mode="r")
class MemmapGraph(Graph):
def __init__(self, path):
self._num_nodes = np.load(os.path.join(path, 'num_nodes.npy'))
self._edges = np.load(os.path.join(path, 'edges.npy'), mmap_mode="r")
if os.path.isdir(os.path.join(path, 'adj_src')):
self._adj_src_index = MemmapEdgeIndex(
os.path.join(path, 'adj_src'))
else:
self._adj_src_index = None
if os.path.isdir(os.path.join(path, 'adj_dst')):
self._adj_dst_index = MemmapEdgeIndex(
os.path.join(path, 'adj_dst'))
else:
self._adj_dst_index = None
def load_feat(feat_path):
"""Load features from .npy file.
"""
feat = {}
if os.path.isdir(feat_path):
for feat_name in os.listdir(feat_path):
feat[os.path.splitext(feat_name)[0]] = np.load(
os.path.join(feat_path, feat_name), mmap_mode="r")
return feat
self._node_feat = load_feat(os.path.join(path, 'node_feat'))
self._edge_feat = load_feat(os.path.join(path, 'edge_feat'))
......@@ -53,14 +53,21 @@ def build_index(np.ndarray[np.int64_t, ndim=1] u,
_tmp_eid[indptr[u[i]] + count[u[i]]] = i
_tmp_u[indptr[u[i]] + count[u[i]]] = u[i]
count[u[i]] += 1
return degree, _tmp_v, _tmp_u, _tmp_eid, indptr
cdef list output_eid = []
cdef list output_v = []
for i in xrange(n_size):
output_eid.append(_tmp_eid[indptr[i]:indptr[i+1]])
output_v.append(_tmp_v[indptr[i]:indptr[i+1]])
return np.array(output_v), np.array(output_eid), degree, _tmp_u, _tmp_v, _tmp_eid
@cython.boundscheck(False)
@cython.wraparound(False)
def slice_by_index(np.ndarray[np.int64_t, ndim=1] u,
np.ndarray[np.int64_t, ndim=1] indptr,
np.ndarray[np.int64_t, ndim=1] index):
cdef list output = []
cdef long long i
cdef long long h = len(index)
cdef long long j
for i in xrange(h):
j = index[i]
output.append(u[indptr[j]:indptr[j+1]])
return np.array(output)
@cython.boundscheck(False)
@cython.wraparound(False)
......@@ -253,22 +260,10 @@ def sample_subset_with_eid(list nids, list eids, long long maxdegree, shuffle=Fa
@cython.boundscheck(False)
@cython.wraparound(False)
def skip_gram_gen_pair(vector[long long] walk_path, long win_size=5):
"""Return node paris generated by skip-gram algorithm.
This function will auto remove the pair which src node is the same
as dst node.
Args:
walk_path: List of nodes as a walk path.
win_size: the windows size used in skip-gram.
Return:
A tuple of (src node list, dst node list).
"""
def skip_gram_gen_pair(vector[long long] walk, long win_size=5):
cdef vector[long long] src
cdef vector[long long] dst
cdef long long l = len(walk_path)
cdef long long l = len(walk)
cdef long long real_win_size, left, right, i
cdef np.ndarray[np.int64_t, ndim=1] rnd = np.random.randint(1, win_size+1,
dtype=np.int64, size=l)
......@@ -282,23 +277,15 @@ def skip_gram_gen_pair(vector[long long] walk_path, long win_size=5):
if right >= l:
right = l - 1
for j in xrange(left, right+1):
if walk_path[i] == walk_path[j]:
if walk[i] == walk[j]:
continue
src.push_back(walk_path[i])
dst.push_back(walk_path[j])
src.push_back(walk[i])
dst.push_back(walk[j])
return src, dst
@cython.boundscheck(False)
@cython.wraparound(False)
def alias_sample_build_table(np.ndarray[np.float64_t, ndim=1] probs):
"""Return the alias table and event table for alias sampling.
Args:
porobs: A list of float numbers as the probability.
Return:
A tuple of (alias table, event table).
"""
cdef long long l = len(probs)
cdef np.ndarray[np.float64_t, ndim=1] alias = probs * l
cdef np.ndarray[np.int64_t, ndim=1] events = np.zeros(l, dtype=np.int64)
......
......@@ -89,8 +89,8 @@ class BaseGraphWrapper(object):
"""
def __init__(self):
self._node_feat_tensor_dict = {}
self._edge_feat_tensor_dict = {}
self.node_feat_tensor_dict = {}
self.edge_feat_tensor_dict = {}
self._edges_src = None
self._edges_dst = None
self._num_nodes = None
......@@ -98,6 +98,10 @@ class BaseGraphWrapper(object):
self._edge_uniq_dst = None
self._edge_uniq_dst_count = None
self._node_ids = None
self._data_name_prefix = ""
def __repr__(self):
return self._data_name_prefix
def send(self, message_func, nfeat_list=None, efeat_list=None):
"""Send message from all src nodes to dst nodes.
......@@ -220,7 +224,7 @@ class BaseGraphWrapper(object):
A dictionary whose keys are the feature names and the values
are feature tensor.
"""
return self._edge_feat_tensor_dict
return self.edge_feat_tensor_dict
@property
def node_feat(self):
......@@ -230,7 +234,7 @@ class BaseGraphWrapper(object):
A dictionary whose keys are the feature names and the values
are feature tensor.
"""
return self._node_feat_tensor_dict
return self.node_feat_tensor_dict
def indegree(self):
"""Return the indegree tensor for all nodes.
......@@ -298,8 +302,8 @@ class StaticGraphWrapper(BaseGraphWrapper):
def __init__(self, name, graph, place):
super(StaticGraphWrapper, self).__init__()
self._data_name_prefix = name
self._initializers = []
self.__data_name_prefix = name
self.__create_graph_attr(graph)
def __create_graph_attr(self, graph):
......@@ -326,43 +330,43 @@ class StaticGraphWrapper(BaseGraphWrapper):
self._edges_src, init = paddle_helper.constant(
dtype="int64",
value=src,
name=self.__data_name_prefix + '/edges_src')
name=self._data_name_prefix + '/edges_src')
self._initializers.append(init)
self._edges_dst, init = paddle_helper.constant(
dtype="int64",
value=dst,
name=self.__data_name_prefix + '/edges_dst')
name=self._data_name_prefix + '/edges_dst')
self._initializers.append(init)
self._num_nodes, init = paddle_helper.constant(
dtype="int64",
hide_batch_size=False,
value=np.array([graph.num_nodes]),
name=self.__data_name_prefix + '/num_nodes')
name=self._data_name_prefix + '/num_nodes')
self._initializers.append(init)
self._edge_uniq_dst, init = paddle_helper.constant(
name=self.__data_name_prefix + "/uniq_dst",
name=self._data_name_prefix + "/uniq_dst",
dtype="int64",
value=uniq_dst)
self._initializers.append(init)
self._edge_uniq_dst_count, init = paddle_helper.constant(
name=self.__data_name_prefix + "/uniq_dst_count",
name=self._data_name_prefix + "/uniq_dst_count",
dtype="int32",
value=uniq_dst_count)
self._initializers.append(init)
node_ids_value = np.arange(0, graph.num_nodes, dtype="int64")
self._node_ids, init = paddle_helper.constant(
name=self.__data_name_prefix + "/node_ids",
name=self._data_name_prefix + "/node_ids",
dtype="int64",
value=node_ids_value)
self._initializers.append(init)
self._indegree, init = paddle_helper.constant(
name=self.__data_name_prefix + "/indegree",
name=self._data_name_prefix + "/indegree",
dtype="int64",
value=indegree)
self._initializers.append(init)
......@@ -373,9 +377,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
for node_feat_name, node_feat_value in node_feat.items():
node_feat_shape = node_feat_value.shape
node_feat_dtype = node_feat_value.dtype
self._node_feat_tensor_dict[
self.node_feat_tensor_dict[
node_feat_name], init = paddle_helper.constant(
name=self.__data_name_prefix + '/node_feat/' +
name=self._data_name_prefix + '/node_feat/' +
node_feat_name,
dtype=node_feat_dtype,
value=node_feat_value)
......@@ -387,9 +391,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
for edge_feat_name, edge_feat_value in edge_feat.items():
edge_feat_shape = edge_feat_value.shape
edge_feat_dtype = edge_feat_value.dtype
self._edge_feat_tensor_dict[
self.edge_feat_tensor_dict[
edge_feat_name], init = paddle_helper.constant(
name=self.__data_name_prefix + '/edge_feat/' +
name=self._data_name_prefix + '/edge_feat/' +
edge_feat_name,
dtype=edge_feat_dtype,
value=edge_feat_value)
......@@ -477,8 +481,8 @@ class GraphWrapper(BaseGraphWrapper):
def __init__(self, name, place, node_feat=[], edge_feat=[]):
super(GraphWrapper, self).__init__()
# collect holders for PyReader
self._data_name_prefix = name
self._holder_list = []
self.__data_name_prefix = name
self._place = place
self.__create_graph_attr_holders()
for node_feat_name, node_feat_shape, node_feat_dtype in node_feat:
......@@ -493,43 +497,43 @@ class GraphWrapper(BaseGraphWrapper):
"""Create data holders for graph attributes.
"""
self._edges_src = fluid.layers.data(
self.__data_name_prefix + '/edges_src',
self._data_name_prefix + '/edges_src',
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._edges_dst = fluid.layers.data(
self.__data_name_prefix + '/edges_dst',
self._data_name_prefix + '/edges_dst',
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._num_nodes = fluid.layers.data(
self.__data_name_prefix + '/num_nodes',
self._data_name_prefix + '/num_nodes',
shape=[1],
append_batch_size=False,
dtype='int64',
stop_gradient=True)
self._edge_uniq_dst = fluid.layers.data(
self.__data_name_prefix + "/uniq_dst",
self._data_name_prefix + "/uniq_dst",
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._edge_uniq_dst_count = fluid.layers.data(
self.__data_name_prefix + "/uniq_dst_count",
self._data_name_prefix + "/uniq_dst_count",
shape=[None],
append_batch_size=False,
dtype="int32",
stop_gradient=True)
self._node_ids = fluid.layers.data(
self.__data_name_prefix + "/node_ids",
self._data_name_prefix + "/node_ids",
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._indegree = fluid.layers.data(
self.__data_name_prefix + "/indegree",
self._data_name_prefix + "/indegree",
shape=[None],
append_batch_size=False,
dtype="int64",
......@@ -545,12 +549,12 @@ class GraphWrapper(BaseGraphWrapper):
"""Create data holders for node features.
"""
feat_holder = fluid.layers.data(
self.__data_name_prefix + '/node_feat/' + node_feat_name,
self._data_name_prefix + '/node_feat/' + node_feat_name,
shape=node_feat_shape,
append_batch_size=False,
dtype=node_feat_dtype,
stop_gradient=True)
self._node_feat_tensor_dict[node_feat_name] = feat_holder
self.node_feat_tensor_dict[node_feat_name] = feat_holder
self._holder_list.append(feat_holder)
def __create_graph_edge_feat_holders(self, edge_feat_name, edge_feat_shape,
......@@ -558,12 +562,12 @@ class GraphWrapper(BaseGraphWrapper):
"""Create edge holders for edge features.
"""
feat_holder = fluid.layers.data(
self.__data_name_prefix + '/edge_feat/' + edge_feat_name,
self._data_name_prefix + '/edge_feat/' + edge_feat_name,
shape=edge_feat_shape,
append_batch_size=False,
dtype=edge_feat_dtype,
stop_gradient=True)
self._edge_feat_tensor_dict[edge_feat_name] = feat_holder
self.edge_feat_tensor_dict[edge_feat_name] = feat_holder
self._holder_list.append(feat_holder)
def to_feed(self, graph):
......@@ -594,20 +598,21 @@ class GraphWrapper(BaseGraphWrapper):
edge_feat[key] = value[eid]
node_feat = graph.node_feat
feed_dict[self.__data_name_prefix + '/edges_src'] = src
feed_dict[self.__data_name_prefix + '/edges_dst'] = dst
feed_dict[self.__data_name_prefix + '/num_nodes'] = np.array(graph.num_nodes)
feed_dict[self.__data_name_prefix + '/uniq_dst'] = uniq_dst
feed_dict[self.__data_name_prefix + '/uniq_dst_count'] = uniq_dst_count
feed_dict[self.__data_name_prefix + '/node_ids'] = graph.nodes
feed_dict[self.__data_name_prefix + '/indegree'] = indegree
for key in self._node_feat_tensor_dict:
feed_dict[self.__data_name_prefix + '/node_feat/' +
feed_dict[self._data_name_prefix + '/edges_src'] = src
feed_dict[self._data_name_prefix + '/edges_dst'] = dst
feed_dict[self._data_name_prefix + '/num_nodes'] = np.array(
graph.num_nodes)
feed_dict[self._data_name_prefix + '/uniq_dst'] = uniq_dst
feed_dict[self._data_name_prefix + '/uniq_dst_count'] = uniq_dst_count
feed_dict[self._data_name_prefix + '/node_ids'] = graph.nodes
feed_dict[self._data_name_prefix + '/indegree'] = indegree
for key in self.node_feat_tensor_dict:
feed_dict[self._data_name_prefix + '/node_feat/' +
key] = node_feat[key]
for key in self._edge_feat_tensor_dict:
feed_dict[self.__data_name_prefix + '/edge_feat/' +
for key in self.edge_feat_tensor_dict:
feed_dict[self._data_name_prefix + '/edge_feat/' +
key] = edge_feat[key]
return feed_dict
......
......@@ -21,7 +21,7 @@ import time
import pgl.graph_kernel as graph_kernel
from pgl.graph import Graph
__all__ = ['HeterGraph']
__all__ = ['HeterGraph', 'SubHeterGraph']
def _hide_num_nodes(shape):
......@@ -32,31 +32,6 @@ def _hide_num_nodes(shape):
return shape
class NodeGraph(Graph):
"""Implementation of a graph that has multple node types.
Args:
num_nodes: number of nodes in the graph
edges: list of (u, v) tuples
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of numpy array as edge features
"""
def __init__(self,
num_nodes,
edges,
node_types=None,
node_feat=None,
edge_feat=None):
super(NodeGraph, self).__init__(num_nodes, edges, node_feat, edge_feat)
if isinstance(node_types, list):
self._node_types = np.array(node_types, dtype=object)[:, 1]
else:
self._node_types = node_types
class HeterGraph(object):
"""Implementation of heterogeneous graph structure in pgl
......@@ -102,6 +77,16 @@ class HeterGraph(object):
self._num_nodes = num_nodes
self._edges_dict = edges
if isinstance(node_types, list):
self._node_types = np.array(node_types, dtype=object)[:, 1]
else:
self._node_types = node_types
self._nodes_type_dict = {}
for n_type in np.unique(self._node_types):
self._nodes_type_dict[n_type] = np.where(
self._node_types == n_type)[0]
if node_feat is not None:
self._node_feat = node_feat
else:
......@@ -113,30 +98,262 @@ class HeterGraph(object):
self._edge_feat = {}
self._multi_graph = {}
for key, value in self._edges_dict.items():
if not self._edge_feat:
edge_feat = None
else:
edge_feat = self._edge_feat[key]
self._multi_graph[key] = NodeGraph(
self._multi_graph[key] = Graph(
num_nodes=self._num_nodes,
edges=value,
node_types=node_types,
node_feat=self._node_feat,
edge_feat=edge_feat)
self._edge_types = self.edge_types_info()
@property
def edge_types(self):
"""Return a list of edge types.
"""
return self._edge_types
@property
def num_nodes(self):
"""Return the number of nodes.
"""
return self._num_nodes
@property
def num_edges(self):
"""Return edges number of all edge types.
"""
n_edges = {}
for e_type in self._edge_types:
n_edges[e_type] = self._multi_graph[e_type].num_edges
return n_edges
@property
def node_types(self):
"""Return the node types.
"""
return self._node_types
@property
def edge_feat(self, edge_type=None):
"""Return edge features of all edge types.
"""
return self._edge_feat
@property
def node_feat(self):
"""Return a dictionary of node features.
"""
return self._node_feat
@property
def nodes(self):
"""Return all nodes id from 0 to :code:`num_nodes - 1`
"""
return np.arange(self._num_nodes, dtype='int64')
def __getitem__(self, edge_type):
"""__getitem__
"""
return self._multi_graph[edge_type]
def num_nodes_by_type(self, n_type=None):
"""Return the number of nodes with the specified node type.
"""
if n_type not in self._nodes_type_dict:
raise ("%s is not in valid node type" % n_type)
else:
return len(self._nodes_type_dict[n_type])
def indegree(self, nodes=None, edge_type=None):
"""Return the indegree of the given nodes with the specified edge_type.
Args:
nodes: Return the indegree of given nodes.
if nodes is None, return indegree for all nodes.
edge_types: Return the indegree with specified edge_type.
if edge_type is None, return the total indegree of the given nodes.
Return:
A numpy.ndarray as the given nodes' indegree.
"""
if edge_type is None:
indegrees = []
for e_type in self._edge_types:
indegrees.append(self._multi_graph[e_type].indegree(nodes))
indegrees = np.sum(np.vstack(indegrees), axis=0)
return indegrees
else:
return self._multi_graph[edge_type].indegree(nodes)
def outdegree(self, nodes=None, edge_type=None):
"""Return the outdegree of the given nodes with the specified edge_type.
Args:
nodes: Return the outdegree of given nodes,
if nodes is None, return outdegree for all nodes
edge_types: Return the outdegree with specified edge_type.
if edge_type is None, return the total outdegree of the given nodes.
Return:
A numpy.array as the given nodes' outdegree.
"""
if edge_type is None:
outdegrees = []
for e_type in self._edge_types:
outdegrees.append(self._multi_graph[e_type].outdegree(nodes))
outdegrees = np.sum(np.vstack(outdegrees), axis=0)
return outdegrees
else:
return self._multi_graph[edge_type].outdegree(nodes)
def successor(self, edge_type, nodes=None, return_eids=False):
"""Find successor of given nodes with the specified edge_type.
Args:
nodes: Return the successor of given nodes,
if nodes is None, return successor for all nodes
edge_types: Return the successor with specified edge_type.
if edge_type is None, return the total successor of the given nodes
and eids are invalid in this way.
return_eids: If True return nodes together with corresponding eid
"""
return self._multi_graph[edge_type].successor(nodes, return_eids)
def sample_successor(self,
edge_type,
nodes,
max_degree,
return_eids=False,
shuffle=False):
"""Sample successors of given nodes with the specified edge_type.
Args:
edge_type: The specified edge_type.
nodes: Given nodes whose successors will be sampled.
max_degree: The max sampled successors for each nodes.
return_eids: Whether to return the corresponding eids.
Return:
Return a list of numpy.ndarray and each numpy.ndarray represent a list
of sampled successor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
connected nodes to their successors.
"""
return self._multi_graph[edge_type].sample_successor(
nodes=nodes,
max_degree=max_degree,
return_eids=return_eids,
shuffle=shuffle)
def predecessor(self, edge_type, nodes=None, return_eids=False):
"""Find predecessor of given nodes with the specified edge_type.
Args:
nodes: Return the predecessor of given nodes,
if nodes is None, return predecessor for all nodes
edge_types: Return the predecessor with specified edge_type.
return_eids: If True return nodes together with corresponding eid
"""
return self._multi_graph[edge_type].predecessor(nodes, return_eids)
def sample_predecessor(self,
edge_type,
nodes,
max_degree,
return_eids=False,
shuffle=False):
"""Sample predecessors of given nodes with the specified edge_type.
Args:
edge_type: The specified edge_type.
nodes: Given nodes whose predecessors will be sampled.
max_degree: The max sampled predecessors for each nodes.
return_eids: Whether to return the corresponding eids.
Return:
Return a list of numpy.ndarray and each numpy.ndarray represent a list
of sampled predecessor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
connected nodes to their predecessors.
"""
return self._multi_graph[edge_type].sample_predecessor(
nodes=nodes,
max_degree=max_degree,
return_eids=return_eids,
shuffle=shuffle)
def node_batch_iter(self, batch_size, shuffle=True, n_type=None):
"""Node batch iterator
Iterate all nodes by batch with the specified node type.
Args:
batch_size: The batch size of each batch of nodes.
shuffle: Whether shuffle the nodes.
n_type: Iterate the nodes with the specified node type. If n_type is None,
iterate all nodes by batch.
Return:
Batch iterator
"""
if n_type is None:
nodes = np.arange(self._num_nodes, dtype="int64")
else:
nodes = self._nodes_type_dict[n_type]
if shuffle:
np.random.shuffle(nodes)
start = 0
while start < len(nodes):
yield nodes[start:start + batch_size]
start += batch_size
def sample_nodes(self, sample_num, n_type=None):
"""Sample nodes with the specified n_type from the graph
This function helps to sample nodes with the specified n_type from the graph.
If n_type is None, this function will sample nodes from all nodes.
Nodes might be duplicated.
Args:
sample_num: The number of samples
n_type: The nodes of type to be sampled
Return:
A list of nodes
"""
if n_type is not None:
return np.random.choice(
self._nodes_type_dict[n_type], size=sample_num)
else:
return np.random.randint(
low=0, high=self._num_nodes, size=sample_num)
def node_feat_info(self):
"""Return the information of node feature for HeterGraphWrapper.
......@@ -186,3 +403,60 @@ class HeterGraph(object):
edge_types_info.append(key)
return edge_types_info
class SubHeterGraph(HeterGraph):
"""Implementation of SubHeterGraph in pgl.
SubHeterGraph is inherit from :code:`HeterGraph`.
Args:
num_nodes: number of nodes in a heterogeneous graph
edges: dict, every element in dict is a list of (u, v) tuples.
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of dict as edge features for every edge type
reindex: A dictionary that maps parent hetergraph node id to subhetergraph node id.
"""
def __init__(self,
num_nodes,
edges,
node_types=None,
node_feat=None,
edge_feat=None,
reindex=None):
super(SubHeterGraph, self).__init__(
num_nodes=num_nodes,
edges=edges,
node_types=node_types,
node_feat=node_feat,
edge_feat=edge_feat)
if reindex is None:
reindex = {}
self._from_reindex = reindex
self._to_reindex = {u: v for v, u in reindex.items()}
def reindex_from_parrent_nodes(self, nodes):
"""Map the given parent graph node id to subgraph id.
Args:
nodes: A list of nodes from parent graph.
Return:
A list of subgraph ids.
"""
return graph_kernel.map_nodes(nodes, self._from_reindex)
def reindex_to_parrent_nodes(self, nodes):
"""Map the given subgraph node id to parent graph id.
Args:
nodes: A list of nodes in this subgraph.
Return:
A list of node ids in parent graph.
"""
return graph_kernel.map_nodes(nodes, self._to_reindex)
......@@ -64,8 +64,8 @@ class HeterGraphWrapper(object):
import paddle.fluid as fluid
import numpy as np
from pgl.contrib import heter_graph
from pgl.contrib import heter_graph_wrapper
from pgl import heter_graph
from pgl import heter_graph_wrapper
num_nodes = 4
node_types = [(0, 'user'), (1, 'item'), (2, 'item'), (3, 'user')]
edges = {
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""redis_hetergraph"""
import pgl
import redis
from redis import BlockingConnectionPool, StrictRedis
from redis._compat import b, unicode, bytes, long, basestring
from rediscluster.nodemanager import NodeManager
from rediscluster.crc import crc16
from collections import OrderedDict
import threading
import numpy as np
import time
import json
import pgl.graph as pgraph
import pickle as pkl
from pgl.utils.logger import log
import pgl.graph_kernel as graph_kernel
from pgl import heter_graph
import pgl.redis_graph as rg
class RedisHeterGraph(rg.RedisGraph):
"""Redis Heterogeneous Graph"""
def __init__(self, name, edge_types, redis_config, num_parts):
super(RedisHeterGraph, self).__init__(name, redis_config, num_parts)
self._num_edges = {}
self.edge_types = edge_types
self.e_type = None
self._edge_feat_info = {}
self._edge_feat_dtype = {}
self._edge_feat_shape = {}
def num_edges_by_type(self, e_type):
"""get edge number by specified edge type"""
if e_type not in self._num_edges:
self._num_edges[e_type] = int(
self._rs.get("%s:num_edges" % e_type))
return self._num_edges[e_type]
def num_edges(self):
"""num_edges"""
num_edges = {}
for e_type in self.edge_types:
num_edges[e_type] = self.num_edges_by_type(e_type)
return num_edges
def edge_feat_info_by_type(self, e_type):
"""get edge features information by specified edge type"""
if e_type not in self._edge_feat_info:
buff = self._rs.get("%s:ef:infos" % e_type)
if buff is not None:
self._edge_feat_info[e_type] = json.loads(buff.decode())
else:
self._edge_feat_info[e_type] = []
return self._edge_feat_info[e_type]
def edge_feat_info(self):
"""edge_feat_info"""
edge_feat_info = {}
for e_type in self.edge_types:
edge_feat_info[e_type] = self.edge_feat_info_by_type(e_type)
return edge_feat_info
def edge_feat_shape(self, e_type, key):
"""edge_feat_shape"""
if e_type not in self._edge_feat_shape:
e_feat_shape = {}
for k, shape, _ in self.edge_feat_info()[e_type]:
e_feat_shape[k] = shape
self._edge_feat_shape[e_type] = e_feat_shape
return self._edge_feat_shape[e_type][key]
def edge_feat_dtype(self, e_type, key):
"""edge_feat_dtype"""
if e_type not in self._edge_feat_dtype:
e_feat_dtype = {}
for k, _, dtype in self.edge_feat_info()[e_type]:
e_feat_dtype[k] = dtype
self._edge_feat_dtype[e_type] = e_feat_dtype
return self._edge_feat_dtype[e_type][key]
def sample_predecessor(self, e_type, nodes, max_degree, return_eids=False):
"""sample predecessor with the specified edge type"""
query = ["%s:d:%s" % (e_type, n) for n in nodes]
rets = rg.hmget_sample_helper(self._rs, query, self.num_parts,
max_degree)
v = []
eid = []
for buff in rets:
if buff is None:
v.append(np.array([], dtype="int64"))
eid.append(np.array([], dtype="int64"))
else:
npret = np.frombuffer(
buff, dtype="int64").reshape([-1, 2]).astype("int64")
v.append(npret[:, 0])
eid.append(npret[:, 1])
if return_eids:
return np.array(v), np.array(eid)
else:
return np.array(v)
def sample_successor(self, e_type, nodes, max_degree, return_eids=False):
"""sample successor with the specified edge type"""
query = ["%s:s:%s" % (e_type, n) for n in nodes]
rets = rg.hmget_sample_helper(self._rs, query, self.num_parts,
max_degree)
v = []
eid = []
for buff in rets:
if buff is None:
v.append(np.array([], dtype="int64"))
eid.append(np.array([], dtype="int64"))
else:
npret = np.frombuffer(
buff, dtype="int64").reshape([-1, 2]).astype("int64")
v.append(npret[:, 0])
eid.append(npret[:, 1])
if return_eids:
return np.array(v), np.array(eid)
else:
return np.array(v)
def predecessor(self, e_type, nodes, return_eids=False):
"""predecessor with the specified edge type"""
query = ["%s:d:%s" % (e_type, n) for n in nodes]
ret = rg.hmget_helper(self._rs, query, self.num_parts)
v = []
eid = []
for buff in ret:
if buff is not None:
npret = np.frombuffer(
buff, dtype="int64").reshape([-1, 2]).astype("int64")
v.append(npret[:, 0])
eid.append(npret[:, 1])
else:
v.append(np.array([], dtype="int64"))
eid.append(np.array([], dtype="int64"))
if return_eids:
return np.array(v), np.array(eid)
else:
return np.array(v)
def successor(self, e_type, nodes, return_eids=False):
"""successor with the specified edge type"""
query = ["%s:s:%s" % (e_type, n) for n in nodes]
ret = rg.hmget_helper(self._rs, query, self.num_parts)
v = []
eid = []
for buff in ret:
if buff is not None:
npret = np.frombuffer(
buff, dtype="int64").reshape([-1, 2]).astype("int64")
v.append(npret[:, 0])
eid.append(npret[:, 1])
else:
v.append(np.array([], dtype="int64"))
eid.append(np.array([], dtype="int64"))
if return_eids:
return np.array(v), np.array(eid)
else:
return np.array(v)
def get_edges_by_id(self, e_type, eids):
"""get_edges_by_id"""
queries = ["%s:e:%s" % (e_type, e) for e in eids]
ret = rg.hmget_helper(self._rs, queries, self.num_parts)
o = np.asarray(ret, dtype="int64")
dst = o % self.num_nodes
src = o // self.num_nodes
data = np.hstack(
[src.reshape([-1, 1]), dst.reshape([-1, 1])]).astype("int64")
return data
def get_edge_feat_by_id(self, e_type, key, eids):
"""get_edge_feat_by_id"""
queries = ["%s:ef:%s:%i" % (e_type, key, e) for e in eids]
ret = rg.hmget_helper(self._rs, queries, self.num_parts)
if ret is None:
return None
else:
ret = b"".join(ret)
data = np.frombuffer(ret, dtype=self.edge_feat_dtype(e_type, key))
data = data.reshape(self.edge_feat_shape(e_type, key))
return data
def get_node_types(self, nodes):
"""get_node_types """
queries = ["nt:%i" % n for n in nodes]
ret = rg.hmget_helper(self._rs, queries, self.num_parts)
node_types = []
for buff in ret:
if buff:
node_types.append(buff.decode())
else:
node_types = None
return node_types
def subgraph(self, nodes, eid, edges=None):
"""Generate heterogeneous subgraph with nodes and edge ids.
WARNING: ALL NODES IN EID MUST BE INCLUDED BY NODES
Args:
nodes: Node ids which will be included in the subgraph.
eid: Edge ids which will be included in the subgraph.
Return:
A :code:`pgl.heter_graph.Subgraph` object.
"""
reindex = {}
for ind, node in enumerate(nodes):
reindex[node] = ind
_node_types = self.get_node_types(nodes)
if _node_types is None:
node_types = None
else:
node_types = []
for idx, t in zip(nodes, _node_types):
node_types.append([reindex[idx], t])
if edges is None:
edges = {}
for e_type, eid_list in eid.items():
edges[e_type] = self.get_edges_by_id(e_type, eid_list)
sub_edges = {}
for e_type, edges_list in edges.items():
sub_edges[e_type] = graph_kernel.map_edges(
np.arange(
len(edges_list), dtype="int64"), edges_list, reindex)
sub_edge_feat = {}
for e_type, edge_feat_info in self.edge_feat_info().items():
type_edge_feat = {}
for key, _, _ in edge_feat_info:
type_edge_feat[key] = self.get_edge_feat_by_id(e_type, key,
eid)
sub_edge_feat[e_type] = type_edge_feat
sub_node_feat = {}
for key, _, _ in self.node_feat_info():
sub_node_feat[key] = self.get_node_feat_by_id(key, nodes)
subgraph = heter_graph.SubHeterGraph(
num_nodes=len(nodes),
edges=sub_edges,
node_types=node_types,
node_feat=sub_node_feat,
edge_feat=sub_edge_feat,
reindex=reindex)
return subgraph
......@@ -24,10 +24,29 @@ from pgl import graph_kernel
__all__ = [
'graphsage_sample', 'node2vec_sample', 'deepwalk_sample',
'metapath_randomwalk'
'metapath_randomwalk', 'pinsage_sample'
]
def traverse(item):
"""traverse the list or numpy"""
if isinstance(item, list) or isinstance(item, np.ndarray):
for i in iter(item):
for j in traverse(i):
yield j
else:
yield item
def flat_node_and_edge(nodes, eids, weights=None):
"""flatten the sub-lists to one list"""
nodes = list(set(traverse(nodes)))
eids = list(traverse(eids))
if weights is not None:
weights = list(traverse(weights))
return nodes, eids, weights
def edge_hash(src, dst):
"""edge_hash
"""
......@@ -88,7 +107,6 @@ def graphsage_sample(graph, nodes, samples, ignore_edges=[]):
start_nodes = list(nodes_set - last_nodes_set)
layer_nodes = [nodes] + layer_nodes
layer_eids = [eids] + layer_eids
log.debug("flat time: %s" % (time.time() - start))
start = time.time()
# Find new nodes
......@@ -256,43 +274,207 @@ def node2vec_sample(graph, nodes, max_depth, p=1.0, q=1.0):
return walk
def metapath_randomwalk(graph, start_node, metapath, walk_length):
def metapath_randomwalk(graph,
start_nodes,
metapath,
walk_length,
alias_name=None,
events_name=None):
"""Implementation of metapath random walk in heterogeneous graph.
Args:
graph: instance of pgl heterogeneous graph
start_node: start node to generate walk
start_nodes: start nodes to generate walk
metapath: meta path for sample nodes.
e.g: "user-item-user"
e.g: "c2p-p2a-a2p-p2c"
walk_length: the walk length
Return:
a list of metapath walk, each element is a node id.
a list of metapath walks.
"""
np.random.seed()
edge_types = metapath.split('-')
walk = []
metapath = metapath.split('-')
assert metapath[0] == metapath[
-1], "The last meta path item should be the same as the first one"
mp_len = len(metapath) - 1
walk.append(start_node)
for i in range(1, walk_length):
cur_node = walk[-1]
succs = graph.successor(cur_node)
if succs.size > 0:
succs_node_types = graph._node_types[succs]
for node in start_nodes:
walk.append([node])
cur_walk_ids = np.arange(0, len(start_nodes))
cur_nodes = np.array(start_nodes)
mp_len = len(edge_types)
for i in range(0, walk_length - 1):
g = graph[edge_types[i % mp_len]]
cur_succs = g.successor(cur_nodes)
mask = [len(succ) > 0 for succ in cur_succs]
if np.any(mask):
cur_walk_ids = cur_walk_ids[mask]
cur_nodes = cur_nodes[mask]
cur_succs = cur_succs[mask]
else:
# no successor of current node
# stop when all nodes have no successor
break
succs_nodes = succs[np.where(succs_node_types == metapath[i % mp_len])[
0]]
if succs_nodes.size > 0:
walk.append(np.random.choice(succs_nodes))
if alias_name is not None and events_name is not None:
sample_index = [
alias_sample([1], g.node_feat[alias_name][node],
g.node_feat[events_name][node])[0]
for node in cur_nodes
]
else:
# no successor of such node type
break
outdegree = [len(cur_succ) for cur_succ in cur_succs]
sample_index = np.floor(
np.random.rand(cur_succs.shape[0]) * outdegree).astype("int64")
nxt_cur_nodes = []
for s, ind, walk_id in zip(cur_succs, sample_index, cur_walk_ids):
walk[walk_id].append(s[ind])
nxt_cur_nodes.append(s[ind])
cur_nodes = np.array(nxt_cur_nodes)
return walk
def random_walk_with_start_prob(graph, nodes, max_depth, proba=0.5):
"""Implement of random walk with the probability of returning the origin node.
This function get random walks path for given nodes and depth.
Args:
nodes: Walk starting from nodes
max_depth: Max walking depth
proba: the proba to return the origin node
Return:
A list of walks.
"""
walk = []
# init
for node in nodes:
walk.append([node])
walk_ids = np.arange(0, len(nodes))
cur_nodes = np.array(nodes)
nodes = np.array(nodes)
for l in range(max_depth):
# select the walks not end
if l >= 1:
return_proba = np.random.rand(cur_nodes.shape[0])
proba_mask = (return_proba < proba)
cur_nodes[proba_mask] = nodes[proba_mask]
outdegree = graph.outdegree(cur_nodes)
mask = (outdegree != 0)
if np.any(mask):
cur_walk_ids = walk_ids[mask]
outdegree = outdegree[mask]
else:
# stop when all nodes have no successor, wait start next loop to get precesssor
continue
succ = graph.successor(cur_nodes[mask])
sample_index = np.floor(
np.random.rand(outdegree.shape[0]) * outdegree).astype("int64")
nxt_cur_nodes = cur_nodes
for s, ind, walk_id in zip(succ, sample_index, cur_walk_ids):
walk[walk_id].append(s[ind])
nxt_cur_nodes[walk_id] = s[ind]
cur_nodes = np.array(nxt_cur_nodes)
return walk
def pinsage_sample(graph,
nodes,
samples,
top_k=10,
proba=0.5,
norm_bais=1.0,
ignore_edges=set()):
"""Implement of graphsage sample.
Reference paper: .
Args:
graph: A pgl graph instance
nodes: Sample starting from nodes
samples: A list, number of neighbors in each layer
top_k: select the top_k visit count nodes to construct the edges
proba: the probability to return the origin node
norm_bais: the normlization for the visit count
ignore_edges: list of edge(src, dst) will be ignored.
Return:
A list of subgraphs
"""
start = time.time()
num_layers = len(samples)
start_nodes = nodes
edges, weights = [], []
layer_nodes, layer_edges, layer_weights = [], [], []
ignore_edge_set = set([edge_hash(src, dst) for src, dst in ignore_edges])
for layer_idx in reversed(range(num_layers)):
if len(start_nodes) == 0:
layer_nodes = [nodes] + layer_nodes
layer_edges = [edges] + layer_edges
layer_edges_weight = [weights] + layer_weights
continue
walks = random_walk_with_start_prob(
graph, start_nodes, samples[layer_idx], proba=proba)
walks = [walk[1:] for walk in walks]
pred_edges = []
pred_weights = []
pred_nodes = []
for node, walk in zip(start_nodes, walks):
walk_nodes = []
walk_weights = []
count_sum = 0
for random_walk_node in walk:
if len(ignore_edge_set) > 0 and random_walk_node != node and \
edge_hash(random_walk_node, node) in ignore_edge_set:
continue
walk_nodes.append(random_walk_node)
unique, counts = np.unique(walk_nodes, return_counts=True)
frequencies = np.asarray((unique, counts)).T
frequencies = frequencies[np.argsort(frequencies[:, 1])]
frequencies = frequencies[-1 * top_k:, :]
for random_walk_node, random_count in zip(
frequencies[:, 0].tolist(), frequencies[:, 1].tolist()):
pred_nodes.append(random_walk_node)
pred_edges.append((random_walk_node, node))
walk_weights.append(random_count)
count_sum += random_count
count_sum += len(walk_weights) * norm_bais
walk_weights = (np.array(walk_weights) + norm_bais) / (count_sum)
pred_weights.extend(walk_weights.tolist())
last_node_set = set(nodes)
nodes, edges, weights = flat_node_and_edge([nodes, pred_nodes], \
[edges, pred_edges], [weights, pred_weights])
layer_edges = [edges] + layer_edges
layer_weights = [weights] + layer_weights
layer_nodes = [nodes] + layer_nodes
start_nodes = list(set(nodes) - last_node_set)
start = time.time()
feed_dict = {}
subgraphs = []
for i in range(num_layers):
edge_feat_dict = {
"weight": np.array(
layer_weights[i], dtype='float32')
}
subgraphs.append(
graph.subgraph(
nodes=layer_nodes[0],
edges=layer_edges[i],
edge_feats=edge_feat_dict))
subgraphs[i].node_feat["index"] = np.array(
layer_nodes[0], dtype="int64")
return subgraphs
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test_hetergraph"""
import time
import unittest
import json
import os
import numpy as np
from pgl.sample import metapath_randomwalk
from pgl.graph import Graph
from pgl import heter_graph
class HeterGraphTest(unittest.TestCase):
"""HeterGraph test
"""
@classmethod
def setUpClass(cls):
np.random.seed(1)
edges = {}
# for test no successor
edges['c2p'] = [(1, 4), (0, 5), (1, 9), (1, 8), (2, 8), (2, 5), (3, 6),
(3, 7), (3, 4), (3, 8)]
edges['p2c'] = [(v, u) for u, v in edges['c2p']]
edges['p2a'] = [(4, 10), (4, 11), (4, 12), (4, 14), (4, 13), (6, 12),
(6, 11), (6, 14), (7, 12), (7, 11), (8, 14), (9, 10)]
edges['a2p'] = [(v, u) for u, v in edges['p2a']]
# for test speed
# edges['c2p'] = [(0, 4), (0, 5), (1, 9), (1,8), (2,8), (2,5), (3,6), (3,7), (3,4), (3,8)]
# edges['p2c'] = [(v,u) for u, v in edges['c2p']]
# edges['p2a'] = [(4,10), (4,11), (4,12), (4,14), (5,13), (6,13), (6,11), (6,14), (7,12), (7,11), (8,14), (9,13)]
# edges['a2p'] = [(v,u) for u, v in edges['p2a']]
node_types = ['c' for _ in range(4)] + ['p' for _ in range(6)
] + ['a' for _ in range(5)]
node_types = [(i, t) for i, t in enumerate(node_types)]
cls.graph = heter_graph.HeterGraph(
num_nodes=len(node_types), edges=edges, node_types=node_types)
def test_num_nodes_by_type(self):
print()
n_types = {'c': 4, 'p': 6, 'a': 5}
for nt in n_types:
num_nodes = self.graph.num_nodes_by_type(nt)
self.assertEqual(num_nodes, n_types[nt])
def test_node_batch_iter(self):
print()
batch_size = 2
ground = [[4, 5], [6, 7], [8, 9]]
for idx, nodes in enumerate(
self.graph.node_batch_iter(
batch_size=batch_size, shuffle=False, n_type='p')):
self.assertEqual(len(nodes), batch_size)
self.assertListEqual(list(nodes), ground[idx])
def test_sample_nodes(self):
print()
p_ground = [4, 5, 6, 7, 8, 9]
sample_num = 10
nodes = self.graph.sample_nodes(sample_num=sample_num, n_type='p')
self.assertEqual(len(nodes), sample_num)
for n in nodes:
self.assertIn(n, p_ground)
# test n_type == None
ground = [i for i in range(15)]
nodes = self.graph.sample_nodes(sample_num=sample_num, n_type=None)
self.assertEqual(len(nodes), sample_num)
for n in nodes:
self.assertIn(n, ground)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test_metapath_randomwalk"""
import time
import unittest
import json
import os
import numpy as np
from pgl.sample import metapath_randomwalk
from pgl.graph import Graph
from pgl import heter_graph
np.random.seed(1)
class MetapathRandomwalkTest(unittest.TestCase):
"""metapath_randomwalk test
"""
def setUp(self):
edges = {}
# for test no successor
edges['c2p'] = [(1, 4), (0, 5), (1, 9), (1, 8), (2, 8), (2, 5), (3, 6),
(3, 7), (3, 4), (3, 8)]
edges['p2c'] = [(v, u) for u, v in edges['c2p']]
edges['p2a'] = [(4, 10), (4, 11), (4, 12), (4, 14), (4, 13), (6, 12),
(6, 11), (6, 14), (7, 12), (7, 11), (8, 14), (9, 10)]
edges['a2p'] = [(v, u) for u, v in edges['p2a']]
# for test speed
# edges['c2p'] = [(0, 4), (0, 5), (1, 9), (1,8), (2,8), (2,5), (3,6), (3,7), (3,4), (3,8)]
# edges['p2c'] = [(v,u) for u, v in edges['c2p']]
# edges['p2a'] = [(4,10), (4,11), (4,12), (4,14), (5,13), (6,13), (6,11), (6,14), (7,12), (7,11), (8,14), (9,13)]
# edges['a2p'] = [(v,u) for u, v in edges['p2a']]
self.node_types = ['c' for _ in range(4)] + [
'p' for _ in range(6)
] + ['a' for _ in range(5)]
node_types = [(i, t) for i, t in enumerate(self.node_types)]
self.graph = heter_graph.HeterGraph(
num_nodes=len(node_types), edges=edges, node_types=node_types)
def test_metapath_randomwalk(self):
meta_path = 'c2p-p2a-a2p-p2c'
path = ['c', 'p', 'a', 'p', 'c']
start_nodes = [0, 1, 2, 3]
walk_len = 10
walks = metapath_randomwalk(
graph=self.graph,
start_nodes=start_nodes,
metapath=meta_path,
walk_length=walk_len)
self.assertEqual(len(walks), 4)
for walk in walks:
for i in range(len(walk)):
idx = i % (len(path) - 1)
self.assertEqual(self.node_types[walk[i]], path[idx])
if __name__ == "__main__":
unittest.main()
......@@ -25,6 +25,8 @@ except:
import numpy as np
import time
import paddle.fluid as fluid
from queue import Queue
import threading
def serialize_data(data):
......@@ -129,22 +131,39 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000, pipe_size=10):
p.start()
reader_num = len(readers)
finish_num = 0
conn_to_remove = []
finish_flag = np.zeros(len(conns), dtype="int32")
start = time.time()
def queue_worker(sub_conn, que):
while True:
buff = sub_conn.recv()
sample = deserialize_data(buff)
if sample is None:
que.put(None)
sub_conn.close()
break
que.put(sample)
thread_pool = []
output_queue = Queue(maxsize=reader_num)
for i in range(reader_num):
t = threading.Thread(
target=queue_worker, args=(conns[i], output_queue))
t.daemon = True
t.start()
thread_pool.append(t)
finish_num = 0
while finish_num < reader_num:
for conn_id, conn in enumerate(conns):
if finish_flag[conn_id] > 0:
continue
if conn.poll(0.01):
buff = conn.recv()
sample = deserialize_data(buff)
if sample is None:
finish_num += 1
conn.close()
finish_flag[conn_id] = 1
else:
yield sample
sample = output_queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for thread in thread_pool:
thread.join()
if use_pipe:
return pipe_reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册