提交 0bd10e14 编写于 作者: L liweibin

speed up sampling

上级 570bf814
...@@ -19,8 +19,8 @@ import pgl ...@@ -19,8 +19,8 @@ import pgl
import time import time
from pgl.utils import mp_reader from pgl.utils import mp_reader
from pgl.utils.logger import log from pgl.utils.logger import log
import train
import time import time
import copy
def node_batch_iter(nodes, node_label, batch_size): def node_batch_iter(nodes, node_label, batch_size):
...@@ -46,12 +46,11 @@ def traverse(item): ...@@ -46,12 +46,11 @@ def traverse(item):
yield item yield item
def flat_node_and_edge(nodes, eids): def flat_node_and_edge(nodes):
"""flat_node_and_edge """flat_node_and_edge
""" """
nodes = list(set(traverse(nodes))) nodes = list(set(traverse(nodes)))
eids = list(set(traverse(eids))) return nodes
return nodes, eids
def worker(batch_info, graph, graph_wrapper, samples): def worker(batch_info, graph, graph_wrapper, samples):
...@@ -61,31 +60,42 @@ def worker(batch_info, graph, graph_wrapper, samples): ...@@ -61,31 +60,42 @@ def worker(batch_info, graph, graph_wrapper, samples):
def work(): def work():
"""work """work
""" """
first = True _graph_wrapper = copy.copy(graph_wrapper)
_graph_wrapper.node_feat_tensor_dict = {}
for batch_train_samples, batch_train_labels in batch_info: for batch_train_samples, batch_train_labels in batch_info:
start_nodes = batch_train_samples start_nodes = batch_train_samples
nodes = start_nodes nodes = start_nodes
eids = [] edges = []
for max_deg in samples: for max_deg in samples:
pred, pred_eid = graph.sample_predecessor( pred_nodes = graph.sample_predecessor(
start_nodes, max_degree=max_deg, return_eids=True) start_nodes, max_degree=max_deg)
for dst_node, src_nodes in zip(start_nodes, pred_nodes):
for src_node in src_nodes:
edges.append((src_node, dst_node))
last_nodes = nodes last_nodes = nodes
nodes = [nodes, pred] nodes = [nodes, pred_nodes]
eids = [eids, pred_eid] nodes = flat_node_and_edge(nodes)
nodes, eids = flat_node_and_edge(nodes, eids)
# Find new nodes # Find new nodes
start_nodes = list(set(nodes) - set(last_nodes)) start_nodes = list(set(nodes) - set(last_nodes))
if len(start_nodes) == 0: if len(start_nodes) == 0:
break break
subgraph = graph.subgraph(nodes=nodes, eid=eids) subgraph = graph.subgraph(
nodes=nodes,
edges=edges,
with_node_feat=False,
with_edge_feat=False)
sub_node_index = subgraph.reindex_from_parrent_nodes( sub_node_index = subgraph.reindex_from_parrent_nodes(
batch_train_samples) batch_train_samples)
feed_dict = graph_wrapper.to_feed(subgraph) feed_dict = _graph_wrapper.to_feed(subgraph)
feed_dict["node_label"] = np.expand_dims( feed_dict["node_label"] = np.expand_dims(
np.array( np.array(
batch_train_labels, dtype="int64"), -1) batch_train_labels, dtype="int64"), -1)
feed_dict["node_index"] = sub_node_index feed_dict["node_index"] = sub_node_index
feed_dict["parent_node_index"] = np.array(nodes, dtype="int64")
yield feed_dict yield feed_dict
return work return work
...@@ -97,23 +107,25 @@ def multiprocess_graph_reader(graph, ...@@ -97,23 +107,25 @@ def multiprocess_graph_reader(graph,
node_index, node_index,
batch_size, batch_size,
node_label, node_label,
with_parent_node_index=False,
num_workers=4): num_workers=4):
"""multiprocess_graph_reader """multiprocess_graph_reader
""" """
def parse_to_subgraph(rd): def parse_to_subgraph(rd, prefix, node_feat, _with_parent_node_index):
"""parse_to_subgraph """parse_to_subgraph
""" """
def work(): def work():
"""work """work
""" """
last = time.time()
for data in rd(): for data in rd():
this = time.time()
feed_dict = data feed_dict = data
now = time.time() for key in node_feat:
last = now feed_dict[prefix + '/node_feat/' + key] = node_feat[key][
feed_dict["parent_node_index"]]
if not _with_parent_node_index:
del feed_dict["parent_node_index"]
yield feed_dict yield feed_dict
return work return work
...@@ -129,46 +141,17 @@ def multiprocess_graph_reader(graph, ...@@ -129,46 +141,17 @@ def multiprocess_graph_reader(graph,
reader_pool.append( reader_pool.append(
worker(batch_info[block_size * i:block_size * (i + 1)], graph, worker(batch_info[block_size * i:block_size * (i + 1)], graph,
graph_wrapper, samples)) graph_wrapper, samples))
multi_process_sample = mp_reader.multiprocess_reader(
reader_pool, use_pipe=True, queue_size=1000)
r = parse_to_subgraph(multi_process_sample)
return paddle.reader.buffered(r, 1000)
return reader()
def graph_reader(graph, graph_wrapper, samples, node_index, batch_size,
node_label):
"""graph_reader"""
def reader():
"""reader"""
for batch_train_samples, batch_train_labels in node_batch_iter(
node_index, node_label, batch_size=batch_size):
start_nodes = batch_train_samples
nodes = start_nodes
eids = []
for max_deg in samples:
pred, pred_eid = graph.sample_predecessor(
start_nodes, max_degree=max_deg, return_eids=True)
last_nodes = nodes
nodes = [nodes, pred]
eids = [eids, pred_eid]
nodes, eids = flat_node_and_edge(nodes, eids)
# Find new nodes
start_nodes = list(set(nodes) - set(last_nodes))
if len(start_nodes) == 0:
break
subgraph = graph.subgraph(nodes=nodes, eid=eids) if len(reader_pool) == 1:
feed_dict = graph_wrapper.to_feed(subgraph) r = parse_to_subgraph(reader_pool[0],
sub_node_index = subgraph.reindex_from_parrent_nodes( repr(graph_wrapper), graph.node_feat,
batch_train_samples) with_parent_node_index)
else:
multi_process_sample = mp_reader.multiprocess_reader(
reader_pool, use_pipe=True, queue_size=1000)
r = parse_to_subgraph(multi_process_sample,
repr(graph_wrapper), graph.node_feat,
with_parent_node_index)
return paddle.reader.buffered(r, num_workers)
feed_dict["node_label"] = np.expand_dims( return reader()
np.array(
batch_train_labels, dtype="int64"), -1)
feed_dict["node_index"] = np.array(sub_node_index, dtype="int32")
yield feed_dict
return paddle.reader.buffered(reader, 1000)
...@@ -63,10 +63,7 @@ def load_data(normalize=True, symmetry=True): ...@@ -63,10 +63,7 @@ def load_data(normalize=True, symmetry=True):
log.info("Feature shape %s" % (repr(feature.shape))) log.info("Feature shape %s" % (repr(feature.shape)))
graph = pgl.graph.Graph( graph = pgl.graph.Graph(
num_nodes=feature.shape[0], num_nodes=feature.shape[0], edges=list(zip(src, dst)))
edges=list(zip(src, dst)),
node_feat={"index": np.arange(
0, len(feature), dtype="int64")})
return { return {
"graph": graph, "graph": graph,
...@@ -89,7 +86,13 @@ def build_graph_model(graph_wrapper, num_class, k_hop, graphsage_type, ...@@ -89,7 +86,13 @@ def build_graph_model(graph_wrapper, num_class, k_hop, graphsage_type,
node_label = fluid.layers.data( node_label = fluid.layers.data(
"node_label", shape=[None, 1], dtype="int64", append_batch_size=False) "node_label", shape=[None, 1], dtype="int64", append_batch_size=False)
feature = fluid.layers.gather(feature, graph_wrapper.node_feat['index']) parent_node_index = fluid.layers.data(
"parent_node_index",
shape=[None],
dtype="int64",
append_batch_size=False)
feature = fluid.layers.gather(feature, parent_node_index)
feature.stop_gradient = True feature.stop_gradient = True
for i in range(k_hop): for i in range(k_hop):
...@@ -221,59 +224,35 @@ def main(args): ...@@ -221,59 +224,35 @@ def main(args):
exe.run(startup_program) exe.run(startup_program)
feature_init(place) feature_init(place)
if args.sample_workers > 1: train_iter = reader.multiprocess_graph_reader(
train_iter = reader.multiprocess_graph_reader( data['graph'],
data['graph'], graph_wrapper,
graph_wrapper, samples=samples,
samples=samples, num_workers=args.sample_workers,
num_workers=args.sample_workers, batch_size=args.batch_size,
batch_size=args.batch_size, with_parent_node_index=True,
node_index=data['train_index'], node_index=data['train_index'],
node_label=data["train_label"]) node_label=data["train_label"])
else:
train_iter = reader.graph_reader( val_iter = reader.multiprocess_graph_reader(
data['graph'], data['graph'],
graph_wrapper, graph_wrapper,
samples=samples, samples=samples,
batch_size=args.batch_size, num_workers=args.sample_workers,
node_index=data['train_index'], batch_size=args.batch_size,
node_label=data["train_label"]) with_parent_node_index=True,
node_index=data['val_index'],
if args.sample_workers > 1: node_label=data["val_label"])
val_iter = reader.multiprocess_graph_reader(
data['graph'], test_iter = reader.multiprocess_graph_reader(
graph_wrapper, data['graph'],
samples=samples, graph_wrapper,
num_workers=args.sample_workers, samples=samples,
batch_size=args.batch_size, num_workers=args.sample_workers,
node_index=data['val_index'], batch_size=args.batch_size,
node_label=data["val_label"]) with_parent_node_index=True,
else: node_index=data['test_index'],
val_iter = reader.graph_reader( node_label=data["test_label"])
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
if args.sample_workers > 1:
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
else:
test_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
for epoch in range(args.epoch): for epoch in range(args.epoch):
run_epoch( run_epoch(
......
...@@ -262,59 +262,32 @@ def main(args): ...@@ -262,59 +262,32 @@ def main(args):
else: else:
train_exe = exe train_exe = exe
if args.sample_workers > 1: train_iter = reader.multiprocess_graph_reader(
train_iter = reader.multiprocess_graph_reader( data['graph'],
data['graph'], graph_wrapper,
graph_wrapper, samples=samples,
samples=samples, num_workers=args.sample_workers,
num_workers=args.sample_workers, batch_size=args.batch_size,
batch_size=args.batch_size, node_index=data['train_index'],
node_index=data['train_index'], node_label=data["train_label"])
node_label=data["train_label"])
else: val_iter = reader.multiprocess_graph_reader(
train_iter = reader.graph_reader( data['graph'],
data['graph'], graph_wrapper,
graph_wrapper, samples=samples,
samples=samples, num_workers=args.sample_workers,
batch_size=args.batch_size, batch_size=args.batch_size,
node_index=data['train_index'], node_index=data['val_index'],
node_label=data["train_label"]) node_label=data["val_label"])
if args.sample_workers > 1: test_iter = reader.multiprocess_graph_reader(
val_iter = reader.multiprocess_graph_reader( data['graph'],
data['graph'], graph_wrapper,
graph_wrapper, samples=samples,
samples=samples, num_workers=args.sample_workers,
num_workers=args.sample_workers, batch_size=args.batch_size,
batch_size=args.batch_size, node_index=data['test_index'],
node_index=data['val_index'], node_label=data["test_label"])
node_label=data["val_label"])
else:
val_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
if args.sample_workers > 1:
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
else:
test_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
for epoch in range(args.epoch): for epoch in range(args.epoch):
run_epoch( run_epoch(
......
...@@ -97,11 +97,7 @@ def load_data(normalize=True, symmetry=True, scale=1): ...@@ -97,11 +97,7 @@ def load_data(normalize=True, symmetry=True, scale=1):
graph = pgl.graph.Graph( graph = pgl.graph.Graph(
num_nodes=feature.shape[0], num_nodes=feature.shape[0],
edges=edges, edges=edges,
node_feat={ node_feat={"feature": feature})
"index": np.arange(
0, len(feature), dtype="int64"),
"feature": feature
})
return { return {
"graph": graph, "graph": graph,
...@@ -244,59 +240,32 @@ def main(args): ...@@ -244,59 +240,32 @@ def main(args):
test_program = train_program.clone(for_test=True) test_program = train_program.clone(for_test=True)
if args.sample_workers > 1: train_iter = reader.multiprocess_graph_reader(
train_iter = reader.multiprocess_graph_reader( data['graph'],
data['graph'], graph_wrapper,
graph_wrapper, samples=samples,
samples=samples, num_workers=args.sample_workers,
num_workers=args.sample_workers, batch_size=args.batch_size,
batch_size=args.batch_size, node_index=data['train_index'],
node_index=data['train_index'], node_label=data["train_label"])
node_label=data["train_label"])
else: val_iter = reader.multiprocess_graph_reader(
train_iter = reader.graph_reader( data['graph'],
data['graph'], graph_wrapper,
graph_wrapper, samples=samples,
samples=samples, num_workers=args.sample_workers,
batch_size=args.batch_size, batch_size=args.batch_size,
node_index=data['train_index'], node_index=data['val_index'],
node_label=data["train_label"]) node_label=data["val_label"])
if args.sample_workers > 1: test_iter = reader.multiprocess_graph_reader(
val_iter = reader.multiprocess_graph_reader( data['graph'],
data['graph'], graph_wrapper,
graph_wrapper, samples=samples,
samples=samples, num_workers=args.sample_workers,
num_workers=args.sample_workers, batch_size=args.batch_size,
batch_size=args.batch_size, node_index=data['test_index'],
node_index=data['val_index'], node_label=data["test_label"])
node_label=data["val_label"])
else:
val_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
if args.sample_workers > 1:
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
else:
test_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
adam = fluid.optimizer.Adam(learning_rate=args.lr) adam = fluid.optimizer.Adam(learning_rate=args.lr)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
This package implement Graph structure for handling graph data. This package implement Graph structure for handling graph data.
""" """
import os
import numpy as np import numpy as np
import pickle as pkl import pickle as pkl
import time import time
...@@ -77,6 +78,15 @@ class EdgeIndex(object): ...@@ -77,6 +78,15 @@ class EdgeIndex(object):
""" """
return self._sorted_u, self._sorted_v, self._sorted_eid return self._sorted_u, self._sorted_v, self._sorted_eid
def dump(self, path):
if not os.path.exists(path):
os.makedirs(path)
np.save(path + '/degree.npy', self._degree)
np.save(path + '/sorted_u.npy', self._sorted_u)
np.save(path + '/sorted_v.npy', self._sorted_v)
np.save(path + '/sorted_eid.npy', self._sorted_eid)
np.save(path + '/indptr.npy', self._indptr)
class Graph(object): class Graph(object):
"""Implementation of graph structure in pgl. """Implementation of graph structure in pgl.
...@@ -136,6 +146,18 @@ class Graph(object): ...@@ -136,6 +146,18 @@ class Graph(object):
self._adj_src_index = None self._adj_src_index = None
self._adj_dst_index = None self._adj_dst_index = None
def dump(self, path):
if not os.path.exists(path):
os.makedirs(path)
np.save(path + '/num_nodes.npy', self._num_nodes)
np.save(path + '/edges.npy', self._edges)
if self._adj_src_index:
self._adj_src_index.dump(path + '/adj_src')
if self._adj_dst_index:
self._adj_dst_index.dump(path + '/adj_dst')
@property @property
def adj_src_index(self): def adj_src_index(self):
"""Return an EdgeIndex object for src. """Return an EdgeIndex object for src.
...@@ -506,7 +528,13 @@ class Graph(object): ...@@ -506,7 +528,13 @@ class Graph(object):
(key, _hide_num_nodes(value.shape), value.dtype)) (key, _hide_num_nodes(value.shape), value.dtype))
return edge_feat_info return edge_feat_info
def subgraph(self, nodes, eid=None, edges=None): def subgraph(self,
nodes,
eid=None,
edges=None,
edge_feats=None,
with_node_feat=True,
with_edge_feat=True):
"""Generate subgraph with nodes and edge ids. """Generate subgraph with nodes and edge ids.
This function will generate a :code:`pgl.graph.Subgraph` object and This function will generate a :code:`pgl.graph.Subgraph` object and
...@@ -521,6 +549,10 @@ class Graph(object): ...@@ -521,6 +549,10 @@ class Graph(object):
eid (optional): Edge ids which will be included in the subgraph. eid (optional): Edge ids which will be included in the subgraph.
edges (optional): Edge(src, dst) list which will be included in the subgraph. edges (optional): Edge(src, dst) list which will be included in the subgraph.
with_node_feat: Whether to inherit node features from parent graph.
with_edge_feat: Whether to inherit edge features from parent graph.
Return: Return:
A :code:`pgl.graph.Subgraph` object. A :code:`pgl.graph.Subgraph` object.
...@@ -543,14 +575,20 @@ class Graph(object): ...@@ -543,14 +575,20 @@ class Graph(object):
len(edges), dtype="int64"), edges, reindex) len(edges), dtype="int64"), edges, reindex)
sub_edge_feat = {} sub_edge_feat = {}
for key, value in self._edge_feat.items(): if edges is None:
if eid is None: if with_edge_feat:
raise ValueError("Eid can not be None with edge features.") for key, value in self._edge_feat.items():
sub_edge_feat[key] = value[eid] if eid is None:
raise ValueError(
"Eid can not be None with edge features.")
sub_edge_feat[key] = value[eid]
else:
sub_edge_feat = edge_feats
sub_node_feat = {} sub_node_feat = {}
for key, value in self._node_feat.items(): if with_node_feat:
sub_node_feat[key] = value[nodes] for key, value in self._node_feat.items():
sub_node_feat[key] = value[nodes]
subgraph = SubGraph( subgraph = SubGraph(
num_nodes=len(nodes), num_nodes=len(nodes),
...@@ -779,3 +817,27 @@ class SubGraph(Graph): ...@@ -779,3 +817,27 @@ class SubGraph(Graph):
A list of node ids in parent graph. A list of node ids in parent graph.
""" """
return graph_kernel.map_nodes(nodes, self._to_reindex) return graph_kernel.map_nodes(nodes, self._to_reindex)
class MemmapEdgeIndex(EdgeIndex):
def __init__(self, path):
self._degree = np.load(path + '/degree.npy', mmap_mode="r")
self._sorted_u = np.load(path + '/sorted_u.npy', mmap_mode="r")
self._sorted_v = np.load(path + '/sorted_v.npy', mmap_mode="r")
self._sorted_eid = np.load(path + '/sorted_eid.npy', mmap_mode="r")
self._indptr = np.load(path + '/indptr.npy', mmap_mode="r")
class MemmapGraph(Graph):
def __init__(self, path):
self._num_nodes = np.load(path + '/num_nodes.npy')
self._edges = np.load(path + '/edges.npy', mmap_mode="r")
if os.path.exists(path + '/adj_src'):
self._adj_src_index = MemmapEdgeIndex(path + '/adj_src')
else:
self._adj_src_index = None
if os.path.exists(path + '/adj_dst'):
self._adj_dst_index = MemmapEdgeIndex(path + '/adj_dst')
else:
self._adj_dst_index = None
...@@ -89,8 +89,8 @@ class BaseGraphWrapper(object): ...@@ -89,8 +89,8 @@ class BaseGraphWrapper(object):
""" """
def __init__(self): def __init__(self):
self._node_feat_tensor_dict = {} self.node_feat_tensor_dict = {}
self._edge_feat_tensor_dict = {} self.edge_feat_tensor_dict = {}
self._edges_src = None self._edges_src = None
self._edges_dst = None self._edges_dst = None
self._num_nodes = None self._num_nodes = None
...@@ -98,6 +98,10 @@ class BaseGraphWrapper(object): ...@@ -98,6 +98,10 @@ class BaseGraphWrapper(object):
self._edge_uniq_dst = None self._edge_uniq_dst = None
self._edge_uniq_dst_count = None self._edge_uniq_dst_count = None
self._node_ids = None self._node_ids = None
self._data_name_prefix = ""
def __repr__(self):
return self._data_name_prefix
def send(self, message_func, nfeat_list=None, efeat_list=None): def send(self, message_func, nfeat_list=None, efeat_list=None):
"""Send message from all src nodes to dst nodes. """Send message from all src nodes to dst nodes.
...@@ -220,7 +224,7 @@ class BaseGraphWrapper(object): ...@@ -220,7 +224,7 @@ class BaseGraphWrapper(object):
A dictionary whose keys are the feature names and the values A dictionary whose keys are the feature names and the values
are feature tensor. are feature tensor.
""" """
return self._edge_feat_tensor_dict return self.edge_feat_tensor_dict
@property @property
def node_feat(self): def node_feat(self):
...@@ -230,7 +234,7 @@ class BaseGraphWrapper(object): ...@@ -230,7 +234,7 @@ class BaseGraphWrapper(object):
A dictionary whose keys are the feature names and the values A dictionary whose keys are the feature names and the values
are feature tensor. are feature tensor.
""" """
return self._node_feat_tensor_dict return self.node_feat_tensor_dict
def indegree(self): def indegree(self):
"""Return the indegree tensor for all nodes. """Return the indegree tensor for all nodes.
...@@ -298,8 +302,8 @@ class StaticGraphWrapper(BaseGraphWrapper): ...@@ -298,8 +302,8 @@ class StaticGraphWrapper(BaseGraphWrapper):
def __init__(self, name, graph, place): def __init__(self, name, graph, place):
super(StaticGraphWrapper, self).__init__() super(StaticGraphWrapper, self).__init__()
self._data_name_prefix = name
self._initializers = [] self._initializers = []
self.__data_name_prefix = name
self.__create_graph_attr(graph) self.__create_graph_attr(graph)
def __create_graph_attr(self, graph): def __create_graph_attr(self, graph):
...@@ -326,43 +330,43 @@ class StaticGraphWrapper(BaseGraphWrapper): ...@@ -326,43 +330,43 @@ class StaticGraphWrapper(BaseGraphWrapper):
self._edges_src, init = paddle_helper.constant( self._edges_src, init = paddle_helper.constant(
dtype="int64", dtype="int64",
value=src, value=src,
name=self.__data_name_prefix + '/edges_src') name=self._data_name_prefix + '/edges_src')
self._initializers.append(init) self._initializers.append(init)
self._edges_dst, init = paddle_helper.constant( self._edges_dst, init = paddle_helper.constant(
dtype="int64", dtype="int64",
value=dst, value=dst,
name=self.__data_name_prefix + '/edges_dst') name=self._data_name_prefix + '/edges_dst')
self._initializers.append(init) self._initializers.append(init)
self._num_nodes, init = paddle_helper.constant( self._num_nodes, init = paddle_helper.constant(
dtype="int64", dtype="int64",
hide_batch_size=False, hide_batch_size=False,
value=np.array([graph.num_nodes]), value=np.array([graph.num_nodes]),
name=self.__data_name_prefix + '/num_nodes') name=self._data_name_prefix + '/num_nodes')
self._initializers.append(init) self._initializers.append(init)
self._edge_uniq_dst, init = paddle_helper.constant( self._edge_uniq_dst, init = paddle_helper.constant(
name=self.__data_name_prefix + "/uniq_dst", name=self._data_name_prefix + "/uniq_dst",
dtype="int64", dtype="int64",
value=uniq_dst) value=uniq_dst)
self._initializers.append(init) self._initializers.append(init)
self._edge_uniq_dst_count, init = paddle_helper.constant( self._edge_uniq_dst_count, init = paddle_helper.constant(
name=self.__data_name_prefix + "/uniq_dst_count", name=self._data_name_prefix + "/uniq_dst_count",
dtype="int32", dtype="int32",
value=uniq_dst_count) value=uniq_dst_count)
self._initializers.append(init) self._initializers.append(init)
node_ids_value = np.arange(0, graph.num_nodes, dtype="int64") node_ids_value = np.arange(0, graph.num_nodes, dtype="int64")
self._node_ids, init = paddle_helper.constant( self._node_ids, init = paddle_helper.constant(
name=self.__data_name_prefix + "/node_ids", name=self._data_name_prefix + "/node_ids",
dtype="int64", dtype="int64",
value=node_ids_value) value=node_ids_value)
self._initializers.append(init) self._initializers.append(init)
self._indegree, init = paddle_helper.constant( self._indegree, init = paddle_helper.constant(
name=self.__data_name_prefix + "/indegree", name=self._data_name_prefix + "/indegree",
dtype="int64", dtype="int64",
value=indegree) value=indegree)
self._initializers.append(init) self._initializers.append(init)
...@@ -373,9 +377,9 @@ class StaticGraphWrapper(BaseGraphWrapper): ...@@ -373,9 +377,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
for node_feat_name, node_feat_value in node_feat.items(): for node_feat_name, node_feat_value in node_feat.items():
node_feat_shape = node_feat_value.shape node_feat_shape = node_feat_value.shape
node_feat_dtype = node_feat_value.dtype node_feat_dtype = node_feat_value.dtype
self._node_feat_tensor_dict[ self.node_feat_tensor_dict[
node_feat_name], init = paddle_helper.constant( node_feat_name], init = paddle_helper.constant(
name=self.__data_name_prefix + '/node_feat/' + name=self._data_name_prefix + '/node_feat/' +
node_feat_name, node_feat_name,
dtype=node_feat_dtype, dtype=node_feat_dtype,
value=node_feat_value) value=node_feat_value)
...@@ -387,9 +391,9 @@ class StaticGraphWrapper(BaseGraphWrapper): ...@@ -387,9 +391,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
for edge_feat_name, edge_feat_value in edge_feat.items(): for edge_feat_name, edge_feat_value in edge_feat.items():
edge_feat_shape = edge_feat_value.shape edge_feat_shape = edge_feat_value.shape
edge_feat_dtype = edge_feat_value.dtype edge_feat_dtype = edge_feat_value.dtype
self._edge_feat_tensor_dict[ self.edge_feat_tensor_dict[
edge_feat_name], init = paddle_helper.constant( edge_feat_name], init = paddle_helper.constant(
name=self.__data_name_prefix + '/edge_feat/' + name=self._data_name_prefix + '/edge_feat/' +
edge_feat_name, edge_feat_name,
dtype=edge_feat_dtype, dtype=edge_feat_dtype,
value=edge_feat_value) value=edge_feat_value)
...@@ -477,8 +481,8 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -477,8 +481,8 @@ class GraphWrapper(BaseGraphWrapper):
def __init__(self, name, place, node_feat=[], edge_feat=[]): def __init__(self, name, place, node_feat=[], edge_feat=[]):
super(GraphWrapper, self).__init__() super(GraphWrapper, self).__init__()
# collect holders for PyReader # collect holders for PyReader
self._data_name_prefix = name
self._holder_list = [] self._holder_list = []
self.__data_name_prefix = name
self._place = place self._place = place
self.__create_graph_attr_holders() self.__create_graph_attr_holders()
for node_feat_name, node_feat_shape, node_feat_dtype in node_feat: for node_feat_name, node_feat_shape, node_feat_dtype in node_feat:
...@@ -493,43 +497,43 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -493,43 +497,43 @@ class GraphWrapper(BaseGraphWrapper):
"""Create data holders for graph attributes. """Create data holders for graph attributes.
""" """
self._edges_src = fluid.layers.data( self._edges_src = fluid.layers.data(
self.__data_name_prefix + '/edges_src', self._data_name_prefix + '/edges_src',
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int64", dtype="int64",
stop_gradient=True) stop_gradient=True)
self._edges_dst = fluid.layers.data( self._edges_dst = fluid.layers.data(
self.__data_name_prefix + '/edges_dst', self._data_name_prefix + '/edges_dst',
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int64", dtype="int64",
stop_gradient=True) stop_gradient=True)
self._num_nodes = fluid.layers.data( self._num_nodes = fluid.layers.data(
self.__data_name_prefix + '/num_nodes', self._data_name_prefix + '/num_nodes',
shape=[1], shape=[1],
append_batch_size=False, append_batch_size=False,
dtype='int64', dtype='int64',
stop_gradient=True) stop_gradient=True)
self._edge_uniq_dst = fluid.layers.data( self._edge_uniq_dst = fluid.layers.data(
self.__data_name_prefix + "/uniq_dst", self._data_name_prefix + "/uniq_dst",
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int64", dtype="int64",
stop_gradient=True) stop_gradient=True)
self._edge_uniq_dst_count = fluid.layers.data( self._edge_uniq_dst_count = fluid.layers.data(
self.__data_name_prefix + "/uniq_dst_count", self._data_name_prefix + "/uniq_dst_count",
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int32", dtype="int32",
stop_gradient=True) stop_gradient=True)
self._node_ids = fluid.layers.data( self._node_ids = fluid.layers.data(
self.__data_name_prefix + "/node_ids", self._data_name_prefix + "/node_ids",
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int64", dtype="int64",
stop_gradient=True) stop_gradient=True)
self._indegree = fluid.layers.data( self._indegree = fluid.layers.data(
self.__data_name_prefix + "/indegree", self._data_name_prefix + "/indegree",
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int64", dtype="int64",
...@@ -545,12 +549,12 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -545,12 +549,12 @@ class GraphWrapper(BaseGraphWrapper):
"""Create data holders for node features. """Create data holders for node features.
""" """
feat_holder = fluid.layers.data( feat_holder = fluid.layers.data(
self.__data_name_prefix + '/node_feat/' + node_feat_name, self._data_name_prefix + '/node_feat/' + node_feat_name,
shape=node_feat_shape, shape=node_feat_shape,
append_batch_size=False, append_batch_size=False,
dtype=node_feat_dtype, dtype=node_feat_dtype,
stop_gradient=True) stop_gradient=True)
self._node_feat_tensor_dict[node_feat_name] = feat_holder self.node_feat_tensor_dict[node_feat_name] = feat_holder
self._holder_list.append(feat_holder) self._holder_list.append(feat_holder)
def __create_graph_edge_feat_holders(self, edge_feat_name, edge_feat_shape, def __create_graph_edge_feat_holders(self, edge_feat_name, edge_feat_shape,
...@@ -558,12 +562,12 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -558,12 +562,12 @@ class GraphWrapper(BaseGraphWrapper):
"""Create edge holders for edge features. """Create edge holders for edge features.
""" """
feat_holder = fluid.layers.data( feat_holder = fluid.layers.data(
self.__data_name_prefix + '/edge_feat/' + edge_feat_name, self._data_name_prefix + '/edge_feat/' + edge_feat_name,
shape=edge_feat_shape, shape=edge_feat_shape,
append_batch_size=False, append_batch_size=False,
dtype=edge_feat_dtype, dtype=edge_feat_dtype,
stop_gradient=True) stop_gradient=True)
self._edge_feat_tensor_dict[edge_feat_name] = feat_holder self.edge_feat_tensor_dict[edge_feat_name] = feat_holder
self._holder_list.append(feat_holder) self._holder_list.append(feat_holder)
def to_feed(self, graph): def to_feed(self, graph):
...@@ -594,20 +598,21 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -594,20 +598,21 @@ class GraphWrapper(BaseGraphWrapper):
edge_feat[key] = value[eid] edge_feat[key] = value[eid]
node_feat = graph.node_feat node_feat = graph.node_feat
feed_dict[self.__data_name_prefix + '/edges_src'] = src feed_dict[self._data_name_prefix + '/edges_src'] = src
feed_dict[self.__data_name_prefix + '/edges_dst'] = dst feed_dict[self._data_name_prefix + '/edges_dst'] = dst
feed_dict[self.__data_name_prefix + '/num_nodes'] = np.array(graph.num_nodes) feed_dict[self._data_name_prefix + '/num_nodes'] = np.array(
feed_dict[self.__data_name_prefix + '/uniq_dst'] = uniq_dst graph.num_nodes)
feed_dict[self.__data_name_prefix + '/uniq_dst_count'] = uniq_dst_count feed_dict[self._data_name_prefix + '/uniq_dst'] = uniq_dst
feed_dict[self.__data_name_prefix + '/node_ids'] = graph.nodes feed_dict[self._data_name_prefix + '/uniq_dst_count'] = uniq_dst_count
feed_dict[self.__data_name_prefix + '/indegree'] = indegree feed_dict[self._data_name_prefix + '/node_ids'] = graph.nodes
feed_dict[self._data_name_prefix + '/indegree'] = indegree
for key in self._node_feat_tensor_dict:
feed_dict[self.__data_name_prefix + '/node_feat/' + for key in self.node_feat_tensor_dict:
feed_dict[self._data_name_prefix + '/node_feat/' +
key] = node_feat[key] key] = node_feat[key]
for key in self._edge_feat_tensor_dict: for key in self.edge_feat_tensor_dict:
feed_dict[self.__data_name_prefix + '/edge_feat/' + feed_dict[self._data_name_prefix + '/edge_feat/' +
key] = edge_feat[key] key] = edge_feat[key]
return feed_dict return feed_dict
......
...@@ -25,6 +25,8 @@ except: ...@@ -25,6 +25,8 @@ except:
import numpy as np import numpy as np
import time import time
import paddle.fluid as fluid import paddle.fluid as fluid
from queue import Queue
import threading
def serialize_data(data): def serialize_data(data):
...@@ -129,22 +131,39 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000, pipe_size=10): ...@@ -129,22 +131,39 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000, pipe_size=10):
p.start() p.start()
reader_num = len(readers) reader_num = len(readers)
finish_num = 0
conn_to_remove = [] conn_to_remove = []
finish_flag = np.zeros(len(conns), dtype="int32") finish_flag = np.zeros(len(conns), dtype="int32")
start = time.time()
def queue_worker(sub_conn, que):
while True:
buff = sub_conn.recv()
sample = deserialize_data(buff)
if sample is None:
que.put(None)
sub_conn.close()
break
que.put(sample)
thread_pool = []
output_queue = Queue(maxsize=reader_num)
for i in range(reader_num):
t = threading.Thread(
target=queue_worker, args=(conns[i], output_queue))
t.daemon = True
t.start()
thread_pool.append(t)
finish_num = 0
while finish_num < reader_num: while finish_num < reader_num:
for conn_id, conn in enumerate(conns): sample = output_queue.get()
if finish_flag[conn_id] > 0: if sample is None:
continue finish_num += 1
if conn.poll(0.01): else:
buff = conn.recv() yield sample
sample = deserialize_data(buff)
if sample is None: for thread in thread_pool:
finish_num += 1 thread.join()
conn.close()
finish_flag[conn_id] = 1
else:
yield sample
if use_pipe: if use_pipe:
return pipe_reader return pipe_reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册