提交 216e2829 编写于 作者: Z Zhong Hui

tmp commit for rgcn

上级 867dbd26
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['CPU_NUM'] = str(20)
import numpy as np
import copy
import paddle
import paddle.fluid as fluid
import pgl
#from pgl.sample import graph_saint_random_walk_sample
from pgl.sample import deepwalk_sample
from pgl.contrib.ogb.nodeproppred.dataset_pgl import PglNodePropPredDataset
from rgcn import RGCNModel, softmax_loss, paper_mask
from pgl.utils.mp_mapper import mp_reader_mapper
from pgl.utils.share_numpy import ToShareMemGraph
def hetero2homo(heterograph):
edge = []
for edge_type in heterograph.edge_types_info():
edge.append(heterograph[edge_type].edges)
edges = np.vstack(edge)
g = pgl.graph.Graph(num_nodes=heterograph.num_nodes, edges=edges)
g.outdegree()
ToShareMemGraph(g)
return g
def extract_edges_from_nodes(hetergraph, sample_nodes):
eids = {}
for key in hetergraph.edge_types_info():
graph = hetergraph[key]
eids[key] = pgl.graph_kernel.extract_edges_from_nodes(
graph.adj_src_index._indptr, graph.adj_src_index._sorted_v,
graph.adj_src_index._sorted_eid, sample_nodes)
return eids
def graph_saint_random_walk_sample(graph,
hetergraph,
nodes,
max_depth,
alias_name=None,
events_name=None):
"""Implement of graph saint random walk sample.
First, this function will get random walks path for given nodes and depth.
Then, it will create subgraph from all sampled nodes.
Reference Paper: https://arxiv.org/abs/1907.04931
Args:
graph: A pgl graph instance
nodes: Walk starting from nodes
max_depth: Max walking depth
Return:
a subgraph of sampled nodes.
"""
# the seed of multiprocess for numpy should be reset.
np.random.seed()
graph.outdegree()
# try sample from random nodes
# nodes=np.random.choice(np.arange(graph.num_nodes, dtype='int64'), size=len(nodes), replace=False)
nodes = np.random.choice(
np.arange(
graph.num_nodes, dtype='int64'), size=20000, replace=False)
walks = deepwalk_sample(graph, nodes, max_depth, alias_name, events_name)
sample_nodes = []
for walk in walks:
sample_nodes.extend(walk)
print("length of sample_nodes ", len(sample_nodes))
sample_nodes = np.unique(sample_nodes)
print("length of unique sample_nodes ", len(sample_nodes))
eids = extract_edges_from_nodes(hetergraph, sample_nodes)
subgraph = hetergraph.subgraph(
nodes=sample_nodes, eid=eids, with_node_feat=True, with_edge_feat=True)
#subgraph.node_feat["index"] = np.array(sample_nodes, dtype="int64")
all_label = graph._node_feat['train_label'][sample_nodes]
train_index = np.where(all_label > -1)[0]
train_label = all_label[train_index]
#print("sample", train_index.shape)
#print("sample", train_label.shape)
return subgraph, sample_nodes, train_index, train_label
def graph_saint_hetero(graph, hetergraph, batch_nodes, max_depth=2):
subgraph, sample_nodes, train_index, train_label = graph_saint_random_walk_sample(
graph, hetergraph, batch_nodes, max_depth)
# train_index = subgraph.reindex_from_parrent_nodes(batch_nodes)
return subgraph, train_index, sample_nodes, train_label
def traverse(item):
"""traverse
"""
if isinstance(item, list) or isinstance(item, np.ndarray):
for i in iter(item):
for j in traverse(i):
yield j
else:
yield item
def flat_node_and_edge(nodes):
"""flat_node_and_edge
"""
nodes = list(set(traverse(nodes)))
return nodes
def k_hop_sampler(graph, hetergraph, batch_nodes, samples=[30, 30]):
# for batch_train_samples, batch_train_labels in batch_info:
np.random.seed()
start_nodes = copy.deepcopy(batch_nodes)
nodes = start_nodes
edges = []
for max_deg in samples:
pred_nodes = graph.sample_predecessor(start_nodes, max_degree=max_deg)
for dst_node, src_nodes in zip(start_nodes, pred_nodes):
for src_node in src_nodes:
edges.append((src_node, dst_node))
last_nodes = nodes
nodes = [nodes, pred_nodes]
nodes = flat_node_and_edge(nodes)
# Find new nodes
start_nodes = list(set(nodes) - set(last_nodes))
if len(start_nodes) == 0:
break
nodes = np.unique(np.array(nodes, dtype='int64'))
eids = extract_edges_from_nodes(hetergraph, nodes)
subgraph = hetergraph.subgraph(
nodes=nodes, eid=eids, with_node_feat=True, with_edge_feat=True)
#sub_node_index = subgraph.reindex_from_parrent_nodes(batch_nodes)
train_index = subgraph.reindex_from_parrent_nodes(batch_nodes)
return subgraph, train_index, np.array(nodes, dtype='int64'), None
def dataloader(source_node, label, batch_size=1024):
index = np.arange(len(source_node))
np.random.shuffle(index)
def loader():
start = 0
while start < len(source_node):
end = min(start + batch_size, len(source_node))
yield source_node[index[start:end]], label[index[start:end]]
start = end
return loader
def sample_loader(phase, homograph, hetergraph, gw, source_node, label):
#print(source_node)
#print(label)
if phase == 'train':
sample_func = graph_saint_hetero
batch_size = 20000
else:
sample_func = k_hop_sampler
batch_size = 512
def map_fun(node_label):
node, label = node_label
subgraph, train_index, sample_nodes, train_label = sample_func(
homograph, hetergraph, node)
#print(train_index.shape)
#print(sample_nodes.shape)
#print(sum(subgraph['p2p'].edges[:,0] * subgraph['p2p'].edges[:, 1] == 0) /len(subgraph['p2p'].edges) )
feed_dict = gw.to_feed(subgraph)
feed_dict['label'] = label if train_label is None else train_label
feed_dict['train_index'] = train_index
feed_dict['sub_node_index'] = sample_nodes
return feed_dict
loader = dataloader(source_node, label, batch_size)
reader = mp_reader_mapper(loader, func=map_fun, num_works=6)
for feed_dict in reader():
yield feed_dict
def run_epoch(exe, loss, acc, homograph, hetergraph, gw, train_program,
test_program, all_label, split_idx, split_real_idx):
best_acc = 1.0
for epoch in range(1000):
for phase in ['train', 'valid', 'test']:
# if phase == 'train':
# continue
running_loss = []
running_acc = []
for feed_dict in sample_loader(
phase, homograph, hetergraph, gw,
split_real_idx[phase]['paper'],
all_label['paper'][split_idx[phase]['paper']]):
print("train_shape\t", feed_dict['train_index'].shape)
print("allnode_shape\t", feed_dict['sub_node_index'].shape)
res = exe.run(
train_program if phase == 'train' else test_program,
# test_program,
feed=feed_dict,
fetch_list=[loss.name, acc.name],
use_prune=True)
running_loss.append(res[0])
running_acc.append(res[1])
if phase == 'train':
print("training_acc %f" % res[1])
avg_loss = sum(running_loss) / len(running_loss)
avg_acc = sum(running_acc) / len(running_acc)
if phase == 'valid':
if avg_acc > best_acc:
fluid.io.save_persistables(exe, './output/checkpoint',
train_program)
best_acc = avg_acc
print('new best_acc %f' % best_acc)
print("%d, %s %f %f" % (epoch, phase, avg_loss, avg_acc))
def main():
num_class = 349
num_nodes = 1939743
start_paper_index = 1203354
hidden_size = 128
dataset = PglNodePropPredDataset('ogbn-mag')
g, all_label = dataset[0]
homograph = hetero2homo(g)
for key in g.edge_types_info():
g[key].outdegree()
ToShareMemGraph(g[key])
split_idx = dataset.get_idx_split()
split_real_idx = copy.deepcopy(split_idx)
start_paper_index = g.num_node_dict['paper'][1]
# reindex the original idx of each type of node
for t, idx in split_real_idx.items():
for k, v in idx.items():
split_real_idx[t][k] += g.num_node_dict[k][1]
homograph._node_feat['train_label'] = -1 * np.ones(
[homograph.num_nodes, 1], dtype='int64')
train_label = all_label['paper'][split_idx['train']['paper']]
train_index = split_real_idx['train']['paper']
homograph._node_feat['train_label'][train_index] = train_label
#place = fluid.CUDAPlace(0)
place = fluid.CPUPlace()
train_program = fluid.Program()
startup_program = fluid.Program()
test_program = fluid.Program()
additional_paper_feature = g.node_feat_dict[
'paper'][:, :hidden_size].astype('float32')
extact_index = (np.arange(start_paper_index, num_nodes)).astype('int32')
with fluid.program_guard(train_program, startup_program):
paper_feature = fluid.layers.create_parameter(
shape=additional_paper_feature.shape,
dtype='float32',
default_initializer=fluid.initializer.NumpyArrayInitializer(
additional_paper_feature),
name='paper_feature')
paper_index = fluid.layers.create_parameter(
shape=extact_index.shape,
dtype='int32',
default_initializer=fluid.initializer.NumpyArrayInitializer(
extact_index),
name='paper_index')
#paper_feature.stop_gradient=True
paper_index.stop_gradient = True
sub_node_index = fluid.layers.data(
shape=[-1], dtype='int64', name='sub_node_index')
train_index = fluid.layers.data(
shape=[-1], dtype='int64', name='train_index')
label = fluid.layers.data(shape=[-1], dtype="int64", name='label')
label = fluid.layers.reshape(label, [-1, 1])
label.stop_gradient = True
gw = pgl.heter_graph_wrapper.HeterGraphWrapper(
name="heter_graph",
edge_types=g.edge_types_info(),
node_feat=g.node_feat_info(),
edge_feat=g.edge_feat_info())
feat = fluid.layers.create_parameter(
shape=[num_nodes, hidden_size], dtype='float32')
# TODO: the paper feature replaced the total feat, not add
feat = fluid.layers.scatter(
feat, paper_index, paper_feature, overwrite=False)
sub_node_feat = fluid.layers.gather(feat, sub_node_index)
model = RGCNModel(gw, 2, num_class, num_nodes, g.edge_types_info())
feat = model.forward(sub_node_feat)
#feat = paper_mask(feat, gw, start_paper_index)
feat = fluid.layers.gather(feat, train_index)
loss, logit, acc = softmax_loss(feat, label, num_class)
opt = fluid.optimizer.AdamOptimizer(learning_rate=0.002)
opt.minimize(loss)
test_program = train_program.clone(for_test=True)
from paddle.fluid.contrib import summary
summary(train_program)
exe = fluid.Executor(place)
exe.run(startup_program)
# fluid.io.load_persistables(executor=exe, dirname='./output/checkpoint',
# main_program=train_program)
run_epoch(exe, loss, acc, homograph, g, gw, train_program, test_program,
all_label, split_idx, split_real_idx)
return None
feed_dict = gw.to_feed(g)
#rand_label = (np.random.rand(num_nodes - start_paper_index) >
# 0.5).astype('int64')
#feed_dict['label'] = rand_label
feed_dict['label'] = all_label['paper'][split_idx['train']['paper']]
feed_dict['train_index'] = split_real_idx['train']['paper']
#feed_dict['sub_node_index'] = np.arange(num_nodes).astype('int64')
#feed_dict['paper_index'] = extact_index
#feed_dict['paper_feature'] = additional_paper_feature
for epoch in range(10):
feed_dict['label'] = all_label['paper'][split_idx['train']['paper']]
feed_dict['train_index'] = split_real_idx['train']['paper']
for step in range(10):
res = exe.run(train_program,
feed=feed_dict,
fetch_list=[loss.name, acc.name])
print("%d,%d %f %f" % (epoch, step, res[0], res[1]))
#print(res[1])
feed_dict['label'] = all_label['paper'][split_idx['valid']['paper']]
feed_dict['train_index'] = split_real_idx['valid']['paper']
res = exe.run(test_program,
feed=feed_dict,
fetch_list=[loss.name, acc.name])
print("Test %d, %f %f" % (epoch, res[0], res[1]))
if __name__ == "__main__":
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import pgl
import paddle.fluid as fluid
import numpy as np
from pgl.contrib.ogb.nodeproppred.dataset_pgl import PglNodePropPredDataset
def rgcn_conv(graph_wrapper,
feature,
hidden_size,
edge_types,
name="rgcn_conv"):
def __message(src_feat, dst_feat, edge_feat):
"""send function
"""
return src_feat['h']
def __reduce(feat):
"""recv function
"""
return fluid.layers.sequence_pool(feat, pool_type='average')
gw = graph_wrapper
if not isinstance(edge_types, list):
edge_types = [edge_types]
#output = fluid.layers.zeros((feature.shape[0], hidden_size), dtype='float32')
output = None
for i in range(len(edge_types)):
assert feature is not None
feature = fluid.layers.fc(
feature,
size=hidden_size,
param_attr=fluid.ParamAttr(name='%s_edge_fc_%s' %
(name, edge_types[i])),
act=None)
if output is None:
output = fluid.layers.zeros_like(feature)
msg = gw[edge_types[i]].send(__message, nfeat_list=[('h', feature)])
neigh_feat = gw[edge_types[i]].recv(msg, __reduce)
# The weight of FC should be the same for the same type of node
# The edge type str should be `A2B`(from type A to type B)
neigh_feat = fluid.layers.fc(
neigh_feat,
size=hidden_size,
param_attr=fluid.ParamAttr(name='%s_node_fc_%s' %
(name, edge_types[i].split("2")[-1])),
act=None)
output = output + neigh_feat
#output = fluid.layers.relu(out)
return output
class RGCNModel:
def __init__(self, gw, layers, num_class, num_nodes, edge_types):
self.hidden_size = 64
self.layers = layers
self.num_nodes = num_nodes
self.edge_types = edge_types
self.gw = gw
self.num_class = num_class
def forward(self, feat):
for i in range(self.layers - 1):
feat = rgcn_conv(
self.gw,
feat,
self.hidden_size,
self.edge_types,
name="rgcn_%d" % i)
feat = fluid.layers.relu(feat)
feat = fluid.layers.dropout(feat, dropout_prob=0.5)
feat = rgcn_conv(
self.gw,
feat,
self.num_class,
self.edge_types,
name="rgcn_%d" % (self.layers - 1))
return feat
def softmax_loss(feat, label, class_num):
#logit = fluid.layers.fc(feat, class_num)
logit = feat
loss = fluid.layers.softmax_with_cross_entropy(logit, label)
loss = fluid.layers.mean(loss)
acc = fluid.layers.accuracy(fluid.layers.softmax(logit), label)
return loss, logit, acc
def paper_mask(feat, gw, start_index):
mask = fluid.layers.cast(gw[0].node_feat['index'] > start_index)
feat = fluid.layers.mask_select(feat, mask)
return feat
if __name__ == "__main__":
#PglNodePropPredDataset('ogbn-mag')
num_nodes = 4
num_class = 2
node_types = [(0, 'user'), (1, 'user'), (2, 'item'), (3, 'item')]
edges = {
'U2U': [(0, 1), (1, 0)],
'U2I': [(1, 2), (0, 3), (1, 3)],
'I2I': [(2, 3), (3, 2)],
}
node_feat = {'feature': np.random.randn(4, 16)}
edges_feat = {
'U2U': {
'h': np.random.randn(2, 16)
},
'U2I': {
'h': np.random.randn(3, 16)
},
'I2I': {
'h': np.random.randn(2, 16)
},
}
g = pgl.heter_graph.HeterGraph(
num_nodes=num_nodes,
edges=edges,
node_types=node_types,
node_feat=node_feat,
edge_feat=edges_feat)
place = fluid.CPUPlace()
train_program = fluid.Program()
startup_program = fluid.Program()
test_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
label = fluid.layers.data(shape=[-1], dtype="int64", name='label')
#label = fluid.layers.create_global_var(shape=[4], value=1, dtype="int64")
label = fluid.layers.reshape(label, [-1, 1])
label.stop_gradient = True
gw = pgl.heter_graph_wrapper.HeterGraphWrapper(
name="heter_graph",
edge_types=g.edge_types_info(),
node_feat=g.node_feat_info(),
edge_feat=g.edge_feat_info())
feat = fluid.layers.create_parameter(
shape=[num_nodes, 16], dtype='float32')
model = RGCNModel(gw, 3, num_class, num_nodes, g.edge_types_info())
feat = model.forward(feat)
loss, logit, acc = softmax_loss(feat, label, 2)
opt = fluid.optimizer.AdamOptimizer(learning_rate=0.001)
opt.minimize(loss)
from paddle.fluid.contrib import summary
summary(train_program)
exe = fluid.Executor(place)
exe.run(startup_program)
feed_dict = gw.to_feed(g)
feed_dict['label'] = np.array([1, 0, 1, 1]).astype('int64')
for i in range(100):
res = exe.run(train_program,
feed=feed_dict,
fetch_list=[loss.name, logit.name, acc.name])
print("%d %f %f" % (i, res[0], res[2]))
#print(res[1])
......@@ -18,7 +18,11 @@ import pandas as pd
import os.path as osp
import numpy as np
import pgl
from ogb.io.read_graph_raw import read_csv_graph_raw
from pgl import heter_graph
from ogb.io.read_graph_raw import read_csv_graph_raw, read_csv_heterograph_raw
from collections import OrderedDict
import logging
logger = logging.getLogger(__name__)
def read_csv_graph_pgl(raw_dir, add_inverse_edge=False):
......@@ -42,6 +46,67 @@ def read_csv_graph_pgl(raw_dir, add_inverse_edge=False):
return pgl_graph_list
def read_csv_heterograph_pgl(raw_dir,
add_inverse_edge=False,
additional_node_files=[],
additional_edge_files=[]):
"""Read CSV data and build PGL heterograph
"""
graph_list = read_csv_heterograph_raw(
raw_dir,
add_inverse_edge,
additional_node_files=additional_node_files,
additional_edge_files=additional_edge_files)
pgl_graph_list = []
logger.info('Converting graphs into PGL objects...')
for graph in graph_list:
# logger.info(graph)
node_index = OrderedDict()
node_types = []
num_nodes = 0
for k, v in graph["num_nodes_dict"].items():
node_types.append(
np.ones(
shape=[v, 1], dtype='int64') * len(node_index))
node_index[k] = (v, num_nodes)
num_nodes += v
# logger.info(node_index)
node_types = np.vstack(node_types)
edges_by_types = {}
for k, v in graph["edge_index_dict"].items():
v[0, :] += node_index[k[0]][1]
v[1, :] += node_index[k[2]][1]
inverse_v = np.array(v)
inverse_v[0, :] = v[1, :]
inverse_v[1, :] = v[0, :]
if k[0] != k[1]:
edges_by_types["{}2{}".format(k[0][0], k[2][0])] = v.T
edges_by_types["{}2{}".format(k[2][0], k[0][0])] = inverse_v.T
else:
edges = np.hstack((v, inverse_v))
edges_by_types["{}2{}".format(k[0][0], k[2][0])] = edges.T
node_features = {
'index': np.array([i for i in range(num_nodes)]).reshape(
-1, 1).astype(np.int64)
}
# logger.info(edges_by_types.items())
g = heter_graph.HeterGraph(
num_nodes=num_nodes,
edges=edges_by_types,
node_types=node_types,
node_feat=node_features)
g.edge_feat_dict = graph['edge_feat_dict']
g.node_feat_dict = graph['node_feat_dict']
g.num_node_dict = node_index
pgl_graph_list.append(g)
logger.info("Done, converted!")
return pgl_graph_list
if __name__ == "__main__":
# graph_list = read_csv_graph_dgl('dataset/proteinfunc_v2/raw', add_inverse_edge = True)
graph_list = read_csv_graph_pgl(
......
......@@ -19,7 +19,8 @@ import os.path as osp
import numpy as np
from ogb.utils.url import decide_download, download_url, extract_zip
from ogb.nodeproppred import make_master_file # create master.csv
from pgl.contrib.ogb.io.read_graph_pgl import read_csv_graph_pgl
from pgl.contrib.ogb.io.read_graph_pgl import read_csv_graph_pgl, read_csv_heterograph_pgl
from ogb.io.read_graph_raw import read_node_label_hetero, read_nodesplitidx_split_hetero
def to_bool(value):
......@@ -53,6 +54,9 @@ class PglNodePropPredDataset(object):
self.num_tasks = int(self.meta_info[self.name]["num tasks"])
self.task_type = self.meta_info[self.name]["task type"]
self.eval_metric = self.meta_info[self.name]["eval metric"]
self.__num_classes__ = int(self.meta_info[self.name]["num classes"])
self.is_hetero = self.meta_info[self.name]["is hetero"]
super(PglNodePropPredDataset, self).__init__()
......@@ -65,11 +69,11 @@ class PglNodePropPredDataset(object):
pre_processed_file_path = osp.join(processed_dir, 'pgl_data_processed')
if osp.exists(pre_processed_file_path):
# TODO: Reload Preprocess files
# TODO: Reload Preprocess files
pass
else:
### check download
if not osp.exists(osp.join(self.root, "raw", "edge.csv.gz")):
if not osp.exists(osp.join(self.root, "raw")):
url = self.meta_info[self.name]["url"]
if decide_download(url):
path = download_url(url, self.original_root)
......@@ -88,52 +92,112 @@ class PglNodePropPredDataset(object):
exit(-1)
raw_dir = osp.join(self.root, "raw")
self.raw_dir = raw_dir
### pre-process and save
add_inverse_edge = to_bool(self.meta_info[self.name][
"add_inverse_edge"])
self.graph = read_csv_graph_pgl(
raw_dir, add_inverse_edge=add_inverse_edge)
add_inverse_edge = self.meta_info[self.name][
"add_inverse_edge"] == "True"
### adding prediction target
node_label = pd.read_csv(
osp.join(raw_dir, 'node-label.csv.gz'),
compression="gzip",
header=None).values
if "classification" in self.task_type:
node_label = np.array(node_label, dtype=np.int64)
if self.meta_info[self.name]["additional node files"] == 'None':
additional_node_files = []
else:
additional_node_files = self.meta_info[self.name][
"additional node files"].split(',')
if self.meta_info[self.name]["additional edge files"] == 'None':
additional_edge_files = []
else:
node_label = np.array(node_label, dtype=np.float32)
additional_edge_files = self.meta_info[self.name][
"additional edge files"].split(',')
if self.is_hetero:
self.graph = read_csv_heterograph_pgl(
self.raw_dir,
add_inverse_edge=add_inverse_edge,
additional_node_files=additional_node_files,
additional_edge_files=additional_edge_files)
node_label_dict = read_node_label_hetero(self.raw_dir)
y_dict = {}
if "classification" in self.task_type:
for nodetype, node_label in node_label_dict.items():
# detect if there is any nan
if np.isnan(node_label).any():
y_dict[nodetype] = np.array(
node_label, dtype='float32')
else:
y_dict[nodetype] = np.array(
node_label, dtype='int64')
else:
for nodetype, node_label in node_label_dict.items():
y_dict[nodetype] = np.array(
node_label, dtype='float32')
self.labels = y_dict
label_dict = {"labels": node_label}
else:
self.graph = read_csv_graph_pgl(
raw_dir, add_inverse_edge=add_inverse_edge)
### adding prediction target
node_label = pd.read_csv(
osp.join(raw_dir, 'node-label.csv.gz'),
compression="gzip",
header=None).values
if "classification" in self.task_type:
node_label = np.array(node_label, dtype=np.int64)
else:
node_label = np.array(node_label, dtype=np.float32)
# TODO: SAVE preprocess graph
self.labels = label_dict['labels']
label_dict = {"labels": node_label}
# TODO: SAVE preprocess graph
self.labels = label_dict['labels']
def get_idx_split(self):
"""Train/Validation/Test split
"""
split_type = self.meta_info[self.name]["split"]
path = osp.join(self.root, "split", split_type)
train_idx = pd.read_csv(
osp.join(path, "train.csv.gz"), compression="gzip",
header=None).values.T[0]
valid_idx = pd.read_csv(
osp.join(path, "valid.csv.gz"), compression="gzip",
header=None).values.T[0]
test_idx = pd.read_csv(
osp.join(path, "test.csv.gz"), compression="gzip",
header=None).values.T[0]
return {
"train": np.array(
train_idx, dtype="int64"),
"valid": np.array(
valid_idx, dtype="int64"),
"test": np.array(
test_idx, dtype="int64")
}
if self.is_hetero:
train_idx_dict, valid_idx_dict, test_idx_dict = read_nodesplitidx_split_hetero(
path)
for nodetype in train_idx_dict.keys():
train_idx_dict[nodetype] = np.array(
train_idx_dict[nodetype], dtype='int64')
valid_idx_dict[nodetype] = np.array(
valid_idx_dict[nodetype], dtype='int64')
test_idx_dict[nodetype] = np.array(
test_idx_dict[nodetype], dtype='int64')
# code refers dataset_pyg
# TODO: check the code
return {
"train": train_idx_dict,
"valid": valid_idx_dict,
"test": test_idx_dict
}
else:
train_idx = pd.read_csv(
osp.join(path, "train.csv.gz"),
compression="gzip",
header=None).values.T[0]
valid_idx = pd.read_csv(
osp.join(path, "valid.csv.gz"),
compression="gzip",
header=None).values.T[0]
test_idx = pd.read_csv(
osp.join(path, "test.csv.gz"), compression="gzip",
header=None).values.T[0]
return {
"train": np.array(
train_idx, dtype="int64"),
"valid": np.array(
valid_idx, dtype="int64"),
"test": np.array(
test_idx, dtype="int64")
}
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
......@@ -147,7 +211,7 @@ class PglNodePropPredDataset(object):
if __name__ == "__main__":
pgl_dataset = PglNodePropPredDataset(name="ogbn-proteins")
pgl_dataset = PglNodePropPredDataset(name="ogbn-mag")
splitted_index = pgl_dataset.get_idx_split()
print(pgl_dataset[0])
print(splitted_index)
......@@ -167,7 +167,7 @@ class HeterGraph(object):
"""Return the number of nodes with the specified node type.
"""
if n_type not in self._nodes_type_dict:
raise ("%s is not in valid node type" % n_type)
raise ValueError("%s is not in valid node type" % n_type)
else:
return len(self._nodes_type_dict[n_type])
......@@ -250,9 +250,9 @@ class HeterGraph(object):
Return:
Return a list of numpy.ndarray and each numpy.ndarray represent a list
of sampled successor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
of sampled successor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
connected nodes to their successors.
"""
return self._multi_graph[edge_type].sample_successor(
......@@ -294,9 +294,9 @@ class HeterGraph(object):
Return:
Return a list of numpy.ndarray and each numpy.ndarray represent a list
of sampled predecessor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
of sampled predecessor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
connected nodes to their predecessors.
"""
return self._multi_graph[edge_type].sample_predecessor(
......@@ -314,8 +314,8 @@ class HeterGraph(object):
batch_size: The batch size of each batch of nodes.
shuffle: Whether shuffle the nodes.
n_type: Iterate the nodes with the specified node type. If n_type is None,
n_type: Iterate the nodes with the specified node type. If n_type is None,
iterate all nodes by batch.
Return:
......@@ -354,6 +354,68 @@ class HeterGraph(object):
return np.random.randint(
low=0, high=self._num_nodes, size=sample_num)
def subgraph(self,
nodes,
eid=None,
edges=None,
edge_feats=None,
with_node_feat=True,
with_edge_feat=True):
"""Generate subgraph of hetergraph with nodes and edges ids.
Note that the eid or edges should be a dict for different types of graph.
"""
if eid is None and edges is None:
raise ValueError("Eid and edges can't be None at the same time.")
reindex = {}
for ind, node in enumerate(nodes):
reindex[node] = ind
if edges is None:
edges = {}
for edge_type, v in eid.items():
edges[edge_type] = self._multi_graph[edge_type].edges[np.array(
v, dtype="int64")]
else:
for edge_type, v in edges.items():
edges[edge_type] = np.array(v, dtype="int64")
sub_edges = {}
for edge_type, value in edges.items():
sub_edges[edge_type] = graph_kernel.map_edges(
np.arange(
len(value), dtype="int64"),
edges[edge_type],
reindex)
sub_edge_feat = {}
if edges is None:
if with_edge_feat:
for edge_type in sub_edges.keys():
value = self._edge_feat[edge_type]
sub_edge_feat[edges_type] = {}
for k, v in value.items():
sub_edge_feat[edges_type][k] = value[eid[edge_type]]
else:
sub_edge_feat = edge_feats
sub_node_feat = {}
if with_node_feat:
for key, value in self._node_feat.items():
sub_node_feat[key] = value[nodes]
sub_node_types = self.node_types[nodes]
subgraph = SubHeterGraph(
num_nodes=len(nodes),
edges=sub_edges,
node_types=sub_node_types,
node_feat=sub_node_feat,
edge_feat=sub_edge_feat,
reindex=reindex)
return subgraph
def node_feat_info(self):
"""Return the information of node feature for HeterGraphWrapper.
......@@ -393,10 +455,10 @@ class HeterGraph(object):
def edge_types_info(self):
"""Return the information of all edge types.
Return:
A list of all edge types.
"""
edge_types_info = []
for key, _ in self._edges_dict.items():
......@@ -408,7 +470,7 @@ class HeterGraph(object):
class SubHeterGraph(HeterGraph):
"""Implementation of SubHeterGraph in pgl.
SubHeterGraph is inherit from :code:`HeterGraph`.
SubHeterGraph is inherit from :code:`HeterGraph`.
Args:
num_nodes: number of nodes in a heterogeneous graph
......
......@@ -136,6 +136,35 @@ class HeterGraphTest(unittest.TestCase):
for n in nodes:
self.assertIn(n, ground)
def test_subgraph(self):
print()
eids = {}
edges = {}
eids['c2p'] = [0, 1, 5, 8]
eids['p2c'] = eids['c2p']
eids['p2a'] = [1, 2, 3, 4]
eids['a2p'] = eids['p2a']
edges['c2p'] = [(1, 4), (0, 5), (2, 5), (3, 4)]
edges['p2c'] = [(v, u) for u, v in edges['c2p']]
edges['p2a'] = [(4, 11), (4, 12), (4, 14), (4, 13)]
edges['a2p'] = [(v, u) for u, v in edges['p2a']]
nodes = set()
for edge in edges.values():
for tup in edge:
nodes.add(tup[0])
nodes.add(tup[1])
g = self.graph.subgraph(
nodes=sorted(list(nodes)),
#edges=edges
eid=eids)
print(g._from_reindex)
print('subgraph', g['c2p'].edges)
print(g['p2c'].edges)
print('subgraph', g['p2a'].edges)
print(g['a2p'].edges)
pass
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file aims to use multiprocessing to do following process.
`
for data in reader():
yield func(data)
`
"""
#encoding=utf8
import numpy as np
import multiprocessing as mp
import traceback
from pgl.utils.logger import log
def mp_reader_mapper(reader, func, num_works=4):
"""
This function aims to use multiprocessing to do following process.
`
for data in reader():
yield func(data)
`
The data in_stream is the `reader`, the mapper is map the in_stream to
an out_stream.
Please ensure the `func` have specific return value, not `None`!
:param reader: the data iterator
:param func: the map func
:param num_works: number of works
:return: an new iterator
"""
def _read_into_pipe(func, conn):
"""
read into pipe, and use the `func` to get final data.
"""
while True:
data = conn.recv()
if data is None:
conn.send(None)
conn.close()
break
conn.send(func(data))
def pipe_reader():
"""pipe_reader"""
conns = []
all_process = []
for w in range(num_works):
parent_conn, child_conn = mp.Pipe()
conns.append(parent_conn)
p = mp.Process(target=_read_into_pipe, args=(func, child_conn))
p.start()
all_process.append(p)
data_iter = reader()
if not hasattr(data_iter, "__next__"):
__next__ = data_iter.next
else:
__next__ = data_iter.__next__
def next_data():
"""next_data"""
_next = None
try:
_next = __next__()
except StopIteration:
# log.debug(traceback.format_exc())
pass
except Exception:
log.debug(traceback.format_exc())
return _next
for i in range(num_works):
conns[i].send(next_data())
finish_num = 0
finish_flag = np.zeros(len(conns), dtype="int32")
while finish_num < num_works:
for conn_id, conn in enumerate(conns):
if finish_flag[conn_id] > 0:
continue
sample = conn.recv()
if sample is None:
finish_num += 1
conn.close()
finish_flag[conn_id] = 1
else:
yield sample
conns[conn_id].send(next_data())
return pipe_reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The source code of anonymousmemmap is
from: https://github.com/rainwoodman/sharedmem
Many tanks!
"""
import numpy
import mmap
try:
# numpy >= 1.16
_unpickle_ctypes_type = numpy.ctypeslib.as_ctypes_type(numpy.dtype('|u1'))
except:
# older version numpy < 1.16
_unpickle_ctypes_type = numpy.ctypeslib._typecodes['|u1']
def __unpickle__(ai, dtype):
dtype = numpy.dtype(dtype)
tp = _unpickle_ctypes_type * 1
# if there are strides, use strides, otherwise the stride is the itemsize of dtype
if ai['strides']:
tp *= ai['strides'][-1]
else:
tp *= dtype.itemsize
for i in numpy.asarray(ai['shape'])[::-1]:
tp *= i
# grab a flat char array at the sharemem address, with length at least contain ai required
ra = tp.from_address(ai['data'][0])
buffer = numpy.ctypeslib.as_array(ra).ravel()
# view it as what it should look like
shm = numpy.ndarray(
buffer=buffer, dtype=dtype, strides=ai['strides'],
shape=ai['shape']).view(type=anonymousmemmap)
return shm
class anonymousmemmap(numpy.memmap):
""" Arrays allocated on shared memory.
The array is stored in an anonymous memory map that is shared between child-processes.
"""
def __new__(subtype, shape, dtype=numpy.uint8, order='C'):
descr = numpy.dtype(dtype)
_dbytes = descr.itemsize
shape = numpy.atleast_1d(shape)
size = 1
for k in shape:
size *= k
bytes = int(size * _dbytes)
if bytes > 0:
mm = mmap.mmap(-1, bytes)
else:
mm = numpy.empty(0, dtype=descr)
self = numpy.ndarray.__new__(
subtype, shape, dtype=descr, buffer=mm, order=order)
self._mmap = mm
return self
def __array_wrap__(self, outarr, context=None):
# after ufunc this won't be on shm!
return numpy.ndarray.__array_wrap__(
self.view(numpy.ndarray), outarr, context)
def __reduce__(self):
return __unpickle__, (self.__array_interface__, self.dtype)
def copy_to_shm(a):
""" Copy an array to the shared memory.
Notes
-----
copy is not always necessary because the private memory is always copy-on-write.
Use :code:`a = copy(a)` to immediately dereference the old 'a' on private memory
"""
shared = anonymousmemmap(a.shape, dtype=a.dtype)
shared[:] = a[:]
return shared
def ToShareMemGraph(graph):
"""Copy the graph object to anonymous shared memory.
"""
def share_feat(feat):
for key in feat:
feat[key] = copy_to_shm(feat[key])
def share_adj_index(index):
if index is not None:
index._degree = copy_to_shm(index._degree)
index._sorted_u = copy_to_shm(index._sorted_u)
index._sorted_v = copy_to_shm(index._sorted_v)
index._sorted_eid = copy_to_shm(index._sorted_eid)
index._indptr = copy_to_shm(index._indptr)
graph._edges = copy_to_shm(graph._edges)
share_adj_index(graph._adj_src_index)
share_adj_index(graph._adj_dst_index)
share_feat(graph._node_feat)
share_feat(graph._edge_feat)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册