提交 0bd10e14 编写于 作者: L liweibin

speed up sampling

上级 570bf814
......@@ -19,8 +19,8 @@ import pgl
import time
from pgl.utils import mp_reader
from pgl.utils.logger import log
import train
import time
import copy
def node_batch_iter(nodes, node_label, batch_size):
......@@ -46,12 +46,11 @@ def traverse(item):
yield item
def flat_node_and_edge(nodes, eids):
def flat_node_and_edge(nodes):
"""flat_node_and_edge
"""
nodes = list(set(traverse(nodes)))
eids = list(set(traverse(eids)))
return nodes, eids
return nodes
def worker(batch_info, graph, graph_wrapper, samples):
......@@ -61,31 +60,42 @@ def worker(batch_info, graph, graph_wrapper, samples):
def work():
"""work
"""
first = True
_graph_wrapper = copy.copy(graph_wrapper)
_graph_wrapper.node_feat_tensor_dict = {}
for batch_train_samples, batch_train_labels in batch_info:
start_nodes = batch_train_samples
nodes = start_nodes
eids = []
edges = []
for max_deg in samples:
pred, pred_eid = graph.sample_predecessor(
start_nodes, max_degree=max_deg, return_eids=True)
pred_nodes = graph.sample_predecessor(
start_nodes, max_degree=max_deg)
for dst_node, src_nodes in zip(start_nodes, pred_nodes):
for src_node in src_nodes:
edges.append((src_node, dst_node))
last_nodes = nodes
nodes = [nodes, pred]
eids = [eids, pred_eid]
nodes, eids = flat_node_and_edge(nodes, eids)
nodes = [nodes, pred_nodes]
nodes = flat_node_and_edge(nodes)
# Find new nodes
start_nodes = list(set(nodes) - set(last_nodes))
if len(start_nodes) == 0:
break
subgraph = graph.subgraph(nodes=nodes, eid=eids)
subgraph = graph.subgraph(
nodes=nodes,
edges=edges,
with_node_feat=False,
with_edge_feat=False)
sub_node_index = subgraph.reindex_from_parrent_nodes(
batch_train_samples)
feed_dict = graph_wrapper.to_feed(subgraph)
feed_dict = _graph_wrapper.to_feed(subgraph)
feed_dict["node_label"] = np.expand_dims(
np.array(
batch_train_labels, dtype="int64"), -1)
feed_dict["node_index"] = sub_node_index
feed_dict["parent_node_index"] = np.array(nodes, dtype="int64")
yield feed_dict
return work
......@@ -97,23 +107,25 @@ def multiprocess_graph_reader(graph,
node_index,
batch_size,
node_label,
with_parent_node_index=False,
num_workers=4):
"""multiprocess_graph_reader
"""
def parse_to_subgraph(rd):
def parse_to_subgraph(rd, prefix, node_feat, _with_parent_node_index):
"""parse_to_subgraph
"""
def work():
"""work
"""
last = time.time()
for data in rd():
this = time.time()
feed_dict = data
now = time.time()
last = now
for key in node_feat:
feed_dict[prefix + '/node_feat/' + key] = node_feat[key][
feed_dict["parent_node_index"]]
if not _with_parent_node_index:
del feed_dict["parent_node_index"]
yield feed_dict
return work
......@@ -129,46 +141,17 @@ def multiprocess_graph_reader(graph,
reader_pool.append(
worker(batch_info[block_size * i:block_size * (i + 1)], graph,
graph_wrapper, samples))
multi_process_sample = mp_reader.multiprocess_reader(
reader_pool, use_pipe=True, queue_size=1000)
r = parse_to_subgraph(multi_process_sample)
return paddle.reader.buffered(r, 1000)
return reader()
def graph_reader(graph, graph_wrapper, samples, node_index, batch_size,
node_label):
"""graph_reader"""
def reader():
"""reader"""
for batch_train_samples, batch_train_labels in node_batch_iter(
node_index, node_label, batch_size=batch_size):
start_nodes = batch_train_samples
nodes = start_nodes
eids = []
for max_deg in samples:
pred, pred_eid = graph.sample_predecessor(
start_nodes, max_degree=max_deg, return_eids=True)
last_nodes = nodes
nodes = [nodes, pred]
eids = [eids, pred_eid]
nodes, eids = flat_node_and_edge(nodes, eids)
# Find new nodes
start_nodes = list(set(nodes) - set(last_nodes))
if len(start_nodes) == 0:
break
subgraph = graph.subgraph(nodes=nodes, eid=eids)
feed_dict = graph_wrapper.to_feed(subgraph)
sub_node_index = subgraph.reindex_from_parrent_nodes(
batch_train_samples)
if len(reader_pool) == 1:
r = parse_to_subgraph(reader_pool[0],
repr(graph_wrapper), graph.node_feat,
with_parent_node_index)
else:
multi_process_sample = mp_reader.multiprocess_reader(
reader_pool, use_pipe=True, queue_size=1000)
r = parse_to_subgraph(multi_process_sample,
repr(graph_wrapper), graph.node_feat,
with_parent_node_index)
return paddle.reader.buffered(r, num_workers)
feed_dict["node_label"] = np.expand_dims(
np.array(
batch_train_labels, dtype="int64"), -1)
feed_dict["node_index"] = np.array(sub_node_index, dtype="int32")
yield feed_dict
return paddle.reader.buffered(reader, 1000)
return reader()
......@@ -63,10 +63,7 @@ def load_data(normalize=True, symmetry=True):
log.info("Feature shape %s" % (repr(feature.shape)))
graph = pgl.graph.Graph(
num_nodes=feature.shape[0],
edges=list(zip(src, dst)),
node_feat={"index": np.arange(
0, len(feature), dtype="int64")})
num_nodes=feature.shape[0], edges=list(zip(src, dst)))
return {
"graph": graph,
......@@ -89,7 +86,13 @@ def build_graph_model(graph_wrapper, num_class, k_hop, graphsage_type,
node_label = fluid.layers.data(
"node_label", shape=[None, 1], dtype="int64", append_batch_size=False)
feature = fluid.layers.gather(feature, graph_wrapper.node_feat['index'])
parent_node_index = fluid.layers.data(
"parent_node_index",
shape=[None],
dtype="int64",
append_batch_size=False)
feature = fluid.layers.gather(feature, parent_node_index)
feature.stop_gradient = True
for i in range(k_hop):
......@@ -221,59 +224,35 @@ def main(args):
exe.run(startup_program)
feature_init(place)
if args.sample_workers > 1:
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
else:
train_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
if args.sample_workers > 1:
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
else:
val_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
if args.sample_workers > 1:
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
else:
test_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
with_parent_node_index=True,
node_index=data['train_index'],
node_label=data["train_label"])
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
with_parent_node_index=True,
node_index=data['val_index'],
node_label=data["val_label"])
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
with_parent_node_index=True,
node_index=data['test_index'],
node_label=data["test_label"])
for epoch in range(args.epoch):
run_epoch(
......
......@@ -262,59 +262,32 @@ def main(args):
else:
train_exe = exe
if args.sample_workers > 1:
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
else:
train_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
if args.sample_workers > 1:
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
else:
val_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
if args.sample_workers > 1:
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
else:
test_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
for epoch in range(args.epoch):
run_epoch(
......
......@@ -97,11 +97,7 @@ def load_data(normalize=True, symmetry=True, scale=1):
graph = pgl.graph.Graph(
num_nodes=feature.shape[0],
edges=edges,
node_feat={
"index": np.arange(
0, len(feature), dtype="int64"),
"feature": feature
})
node_feat={"feature": feature})
return {
"graph": graph,
......@@ -244,59 +240,32 @@ def main(args):
test_program = train_program.clone(for_test=True)
if args.sample_workers > 1:
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
else:
train_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
if args.sample_workers > 1:
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
else:
val_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
if args.sample_workers > 1:
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
else:
test_iter = reader.graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
train_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['train_index'],
node_label=data["train_label"])
val_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['val_index'],
node_label=data["val_label"])
test_iter = reader.multiprocess_graph_reader(
data['graph'],
graph_wrapper,
samples=samples,
num_workers=args.sample_workers,
batch_size=args.batch_size,
node_index=data['test_index'],
node_label=data["test_label"])
with fluid.program_guard(train_program, startup_program):
adam = fluid.optimizer.Adam(learning_rate=args.lr)
......
......@@ -15,6 +15,7 @@
This package implement Graph structure for handling graph data.
"""
import os
import numpy as np
import pickle as pkl
import time
......@@ -77,6 +78,15 @@ class EdgeIndex(object):
"""
return self._sorted_u, self._sorted_v, self._sorted_eid
def dump(self, path):
if not os.path.exists(path):
os.makedirs(path)
np.save(path + '/degree.npy', self._degree)
np.save(path + '/sorted_u.npy', self._sorted_u)
np.save(path + '/sorted_v.npy', self._sorted_v)
np.save(path + '/sorted_eid.npy', self._sorted_eid)
np.save(path + '/indptr.npy', self._indptr)
class Graph(object):
"""Implementation of graph structure in pgl.
......@@ -136,6 +146,18 @@ class Graph(object):
self._adj_src_index = None
self._adj_dst_index = None
def dump(self, path):
if not os.path.exists(path):
os.makedirs(path)
np.save(path + '/num_nodes.npy', self._num_nodes)
np.save(path + '/edges.npy', self._edges)
if self._adj_src_index:
self._adj_src_index.dump(path + '/adj_src')
if self._adj_dst_index:
self._adj_dst_index.dump(path + '/adj_dst')
@property
def adj_src_index(self):
"""Return an EdgeIndex object for src.
......@@ -506,7 +528,13 @@ class Graph(object):
(key, _hide_num_nodes(value.shape), value.dtype))
return edge_feat_info
def subgraph(self, nodes, eid=None, edges=None):
def subgraph(self,
nodes,
eid=None,
edges=None,
edge_feats=None,
with_node_feat=True,
with_edge_feat=True):
"""Generate subgraph with nodes and edge ids.
This function will generate a :code:`pgl.graph.Subgraph` object and
......@@ -521,6 +549,10 @@ class Graph(object):
eid (optional): Edge ids which will be included in the subgraph.
edges (optional): Edge(src, dst) list which will be included in the subgraph.
with_node_feat: Whether to inherit node features from parent graph.
with_edge_feat: Whether to inherit edge features from parent graph.
Return:
A :code:`pgl.graph.Subgraph` object.
......@@ -543,14 +575,20 @@ class Graph(object):
len(edges), dtype="int64"), edges, reindex)
sub_edge_feat = {}
for key, value in self._edge_feat.items():
if eid is None:
raise ValueError("Eid can not be None with edge features.")
sub_edge_feat[key] = value[eid]
if edges is None:
if with_edge_feat:
for key, value in self._edge_feat.items():
if eid is None:
raise ValueError(
"Eid can not be None with edge features.")
sub_edge_feat[key] = value[eid]
else:
sub_edge_feat = edge_feats
sub_node_feat = {}
for key, value in self._node_feat.items():
sub_node_feat[key] = value[nodes]
if with_node_feat:
for key, value in self._node_feat.items():
sub_node_feat[key] = value[nodes]
subgraph = SubGraph(
num_nodes=len(nodes),
......@@ -779,3 +817,27 @@ class SubGraph(Graph):
A list of node ids in parent graph.
"""
return graph_kernel.map_nodes(nodes, self._to_reindex)
class MemmapEdgeIndex(EdgeIndex):
def __init__(self, path):
self._degree = np.load(path + '/degree.npy', mmap_mode="r")
self._sorted_u = np.load(path + '/sorted_u.npy', mmap_mode="r")
self._sorted_v = np.load(path + '/sorted_v.npy', mmap_mode="r")
self._sorted_eid = np.load(path + '/sorted_eid.npy', mmap_mode="r")
self._indptr = np.load(path + '/indptr.npy', mmap_mode="r")
class MemmapGraph(Graph):
def __init__(self, path):
self._num_nodes = np.load(path + '/num_nodes.npy')
self._edges = np.load(path + '/edges.npy', mmap_mode="r")
if os.path.exists(path + '/adj_src'):
self._adj_src_index = MemmapEdgeIndex(path + '/adj_src')
else:
self._adj_src_index = None
if os.path.exists(path + '/adj_dst'):
self._adj_dst_index = MemmapEdgeIndex(path + '/adj_dst')
else:
self._adj_dst_index = None
......@@ -89,8 +89,8 @@ class BaseGraphWrapper(object):
"""
def __init__(self):
self._node_feat_tensor_dict = {}
self._edge_feat_tensor_dict = {}
self.node_feat_tensor_dict = {}
self.edge_feat_tensor_dict = {}
self._edges_src = None
self._edges_dst = None
self._num_nodes = None
......@@ -98,6 +98,10 @@ class BaseGraphWrapper(object):
self._edge_uniq_dst = None
self._edge_uniq_dst_count = None
self._node_ids = None
self._data_name_prefix = ""
def __repr__(self):
return self._data_name_prefix
def send(self, message_func, nfeat_list=None, efeat_list=None):
"""Send message from all src nodes to dst nodes.
......@@ -220,7 +224,7 @@ class BaseGraphWrapper(object):
A dictionary whose keys are the feature names and the values
are feature tensor.
"""
return self._edge_feat_tensor_dict
return self.edge_feat_tensor_dict
@property
def node_feat(self):
......@@ -230,7 +234,7 @@ class BaseGraphWrapper(object):
A dictionary whose keys are the feature names and the values
are feature tensor.
"""
return self._node_feat_tensor_dict
return self.node_feat_tensor_dict
def indegree(self):
"""Return the indegree tensor for all nodes.
......@@ -298,8 +302,8 @@ class StaticGraphWrapper(BaseGraphWrapper):
def __init__(self, name, graph, place):
super(StaticGraphWrapper, self).__init__()
self._data_name_prefix = name
self._initializers = []
self.__data_name_prefix = name
self.__create_graph_attr(graph)
def __create_graph_attr(self, graph):
......@@ -326,43 +330,43 @@ class StaticGraphWrapper(BaseGraphWrapper):
self._edges_src, init = paddle_helper.constant(
dtype="int64",
value=src,
name=self.__data_name_prefix + '/edges_src')
name=self._data_name_prefix + '/edges_src')
self._initializers.append(init)
self._edges_dst, init = paddle_helper.constant(
dtype="int64",
value=dst,
name=self.__data_name_prefix + '/edges_dst')
name=self._data_name_prefix + '/edges_dst')
self._initializers.append(init)
self._num_nodes, init = paddle_helper.constant(
dtype="int64",
hide_batch_size=False,
value=np.array([graph.num_nodes]),
name=self.__data_name_prefix + '/num_nodes')
name=self._data_name_prefix + '/num_nodes')
self._initializers.append(init)
self._edge_uniq_dst, init = paddle_helper.constant(
name=self.__data_name_prefix + "/uniq_dst",
name=self._data_name_prefix + "/uniq_dst",
dtype="int64",
value=uniq_dst)
self._initializers.append(init)
self._edge_uniq_dst_count, init = paddle_helper.constant(
name=self.__data_name_prefix + "/uniq_dst_count",
name=self._data_name_prefix + "/uniq_dst_count",
dtype="int32",
value=uniq_dst_count)
self._initializers.append(init)
node_ids_value = np.arange(0, graph.num_nodes, dtype="int64")
self._node_ids, init = paddle_helper.constant(
name=self.__data_name_prefix + "/node_ids",
name=self._data_name_prefix + "/node_ids",
dtype="int64",
value=node_ids_value)
self._initializers.append(init)
self._indegree, init = paddle_helper.constant(
name=self.__data_name_prefix + "/indegree",
name=self._data_name_prefix + "/indegree",
dtype="int64",
value=indegree)
self._initializers.append(init)
......@@ -373,9 +377,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
for node_feat_name, node_feat_value in node_feat.items():
node_feat_shape = node_feat_value.shape
node_feat_dtype = node_feat_value.dtype
self._node_feat_tensor_dict[
self.node_feat_tensor_dict[
node_feat_name], init = paddle_helper.constant(
name=self.__data_name_prefix + '/node_feat/' +
name=self._data_name_prefix + '/node_feat/' +
node_feat_name,
dtype=node_feat_dtype,
value=node_feat_value)
......@@ -387,9 +391,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
for edge_feat_name, edge_feat_value in edge_feat.items():
edge_feat_shape = edge_feat_value.shape
edge_feat_dtype = edge_feat_value.dtype
self._edge_feat_tensor_dict[
self.edge_feat_tensor_dict[
edge_feat_name], init = paddle_helper.constant(
name=self.__data_name_prefix + '/edge_feat/' +
name=self._data_name_prefix + '/edge_feat/' +
edge_feat_name,
dtype=edge_feat_dtype,
value=edge_feat_value)
......@@ -477,8 +481,8 @@ class GraphWrapper(BaseGraphWrapper):
def __init__(self, name, place, node_feat=[], edge_feat=[]):
super(GraphWrapper, self).__init__()
# collect holders for PyReader
self._data_name_prefix = name
self._holder_list = []
self.__data_name_prefix = name
self._place = place
self.__create_graph_attr_holders()
for node_feat_name, node_feat_shape, node_feat_dtype in node_feat:
......@@ -493,43 +497,43 @@ class GraphWrapper(BaseGraphWrapper):
"""Create data holders for graph attributes.
"""
self._edges_src = fluid.layers.data(
self.__data_name_prefix + '/edges_src',
self._data_name_prefix + '/edges_src',
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._edges_dst = fluid.layers.data(
self.__data_name_prefix + '/edges_dst',
self._data_name_prefix + '/edges_dst',
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._num_nodes = fluid.layers.data(
self.__data_name_prefix + '/num_nodes',
self._data_name_prefix + '/num_nodes',
shape=[1],
append_batch_size=False,
dtype='int64',
stop_gradient=True)
self._edge_uniq_dst = fluid.layers.data(
self.__data_name_prefix + "/uniq_dst",
self._data_name_prefix + "/uniq_dst",
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._edge_uniq_dst_count = fluid.layers.data(
self.__data_name_prefix + "/uniq_dst_count",
self._data_name_prefix + "/uniq_dst_count",
shape=[None],
append_batch_size=False,
dtype="int32",
stop_gradient=True)
self._node_ids = fluid.layers.data(
self.__data_name_prefix + "/node_ids",
self._data_name_prefix + "/node_ids",
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._indegree = fluid.layers.data(
self.__data_name_prefix + "/indegree",
self._data_name_prefix + "/indegree",
shape=[None],
append_batch_size=False,
dtype="int64",
......@@ -545,12 +549,12 @@ class GraphWrapper(BaseGraphWrapper):
"""Create data holders for node features.
"""
feat_holder = fluid.layers.data(
self.__data_name_prefix + '/node_feat/' + node_feat_name,
self._data_name_prefix + '/node_feat/' + node_feat_name,
shape=node_feat_shape,
append_batch_size=False,
dtype=node_feat_dtype,
stop_gradient=True)
self._node_feat_tensor_dict[node_feat_name] = feat_holder
self.node_feat_tensor_dict[node_feat_name] = feat_holder
self._holder_list.append(feat_holder)
def __create_graph_edge_feat_holders(self, edge_feat_name, edge_feat_shape,
......@@ -558,12 +562,12 @@ class GraphWrapper(BaseGraphWrapper):
"""Create edge holders for edge features.
"""
feat_holder = fluid.layers.data(
self.__data_name_prefix + '/edge_feat/' + edge_feat_name,
self._data_name_prefix + '/edge_feat/' + edge_feat_name,
shape=edge_feat_shape,
append_batch_size=False,
dtype=edge_feat_dtype,
stop_gradient=True)
self._edge_feat_tensor_dict[edge_feat_name] = feat_holder
self.edge_feat_tensor_dict[edge_feat_name] = feat_holder
self._holder_list.append(feat_holder)
def to_feed(self, graph):
......@@ -594,20 +598,21 @@ class GraphWrapper(BaseGraphWrapper):
edge_feat[key] = value[eid]
node_feat = graph.node_feat
feed_dict[self.__data_name_prefix + '/edges_src'] = src
feed_dict[self.__data_name_prefix + '/edges_dst'] = dst
feed_dict[self.__data_name_prefix + '/num_nodes'] = np.array(graph.num_nodes)
feed_dict[self.__data_name_prefix + '/uniq_dst'] = uniq_dst
feed_dict[self.__data_name_prefix + '/uniq_dst_count'] = uniq_dst_count
feed_dict[self.__data_name_prefix + '/node_ids'] = graph.nodes
feed_dict[self.__data_name_prefix + '/indegree'] = indegree
for key in self._node_feat_tensor_dict:
feed_dict[self.__data_name_prefix + '/node_feat/' +
feed_dict[self._data_name_prefix + '/edges_src'] = src
feed_dict[self._data_name_prefix + '/edges_dst'] = dst
feed_dict[self._data_name_prefix + '/num_nodes'] = np.array(
graph.num_nodes)
feed_dict[self._data_name_prefix + '/uniq_dst'] = uniq_dst
feed_dict[self._data_name_prefix + '/uniq_dst_count'] = uniq_dst_count
feed_dict[self._data_name_prefix + '/node_ids'] = graph.nodes
feed_dict[self._data_name_prefix + '/indegree'] = indegree
for key in self.node_feat_tensor_dict:
feed_dict[self._data_name_prefix + '/node_feat/' +
key] = node_feat[key]
for key in self._edge_feat_tensor_dict:
feed_dict[self.__data_name_prefix + '/edge_feat/' +
for key in self.edge_feat_tensor_dict:
feed_dict[self._data_name_prefix + '/edge_feat/' +
key] = edge_feat[key]
return feed_dict
......
......@@ -25,6 +25,8 @@ except:
import numpy as np
import time
import paddle.fluid as fluid
from queue import Queue
import threading
def serialize_data(data):
......@@ -129,22 +131,39 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000, pipe_size=10):
p.start()
reader_num = len(readers)
finish_num = 0
conn_to_remove = []
finish_flag = np.zeros(len(conns), dtype="int32")
start = time.time()
def queue_worker(sub_conn, que):
while True:
buff = sub_conn.recv()
sample = deserialize_data(buff)
if sample is None:
que.put(None)
sub_conn.close()
break
que.put(sample)
thread_pool = []
output_queue = Queue(maxsize=reader_num)
for i in range(reader_num):
t = threading.Thread(
target=queue_worker, args=(conns[i], output_queue))
t.daemon = True
t.start()
thread_pool.append(t)
finish_num = 0
while finish_num < reader_num:
for conn_id, conn in enumerate(conns):
if finish_flag[conn_id] > 0:
continue
if conn.poll(0.01):
buff = conn.recv()
sample = deserialize_data(buff)
if sample is None:
finish_num += 1
conn.close()
finish_flag[conn_id] = 1
else:
yield sample
sample = output_queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for thread in thread_pool:
thread.join()
if use_pipe:
return pipe_reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册