提交 7166db92 编写于 作者: Y Yelrose

Add Batch GraphWrapper

上级 6933c683
......@@ -25,7 +25,7 @@ from pgl.utils import op
from pgl.utils import paddle_helper
from pgl.utils.logger import log
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"]
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper", "BatchGraphWrapper"]
def send(src, dst, nfeat, efeat, message_func):
"""Send message from src to dst.
......@@ -101,7 +101,6 @@ class BaseGraphWrapper(object):
self._indegree = None
self._edge_uniq_dst = None
self._edge_uniq_dst_count = None
self._node_ids = None
self._graph_lod = None
self._num_graph = None
self._num_edges = None
......@@ -416,13 +415,6 @@ class StaticGraphWrapper(BaseGraphWrapper):
value=graph_lod)
self._initializers.append(init)
node_ids_value = np.arange(0, graph.num_nodes, dtype="int64")
self._node_ids, init = paddle_helper.constant(
name=self._data_name_prefix + "/node_ids",
dtype="int64",
value=node_ids_value)
self._initializers.append(init)
self._indegree, init = paddle_helper.constant(
name=self._data_name_prefix + "/indegree",
dtype="int64",
......@@ -601,12 +593,6 @@ class GraphWrapper(BaseGraphWrapper):
dtype="int32",
stop_gradient=True)
self._node_ids = L.data(
self._data_name_prefix + "/node_ids",
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._indegree = L.data(
self._data_name_prefix + "/indegree",
shape=[None],
......@@ -619,7 +605,6 @@ class GraphWrapper(BaseGraphWrapper):
self._num_nodes,
self._edge_uniq_dst,
self._edge_uniq_dst_count,
self._node_ids,
self._indegree,
self._graph_lod,
self._num_graph,
......@@ -700,7 +685,6 @@ class GraphWrapper(BaseGraphWrapper):
[graph.num_nodes], dtype="int64")
feed_dict[self._data_name_prefix + '/uniq_dst'] = uniq_dst
feed_dict[self._data_name_prefix + '/uniq_dst_count'] = uniq_dst_count
feed_dict[self._data_name_prefix + '/node_ids'] = graph.nodes
feed_dict[self._data_name_prefix + '/indegree'] = indegree
feed_dict[self._data_name_prefix + '/graph_lod'] = graph_lod
feed_dict[self._data_name_prefix + '/num_graph'] = np.array(
......@@ -746,7 +730,6 @@ class DropEdgeWrapper(BaseGraphWrapper):
self._num_nodes = graph_wrapper.num_nodes
self._graph_lod = graph_wrapper.graph_lod
self._num_graph = graph_wrapper.num_graph
self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32")
# Dropout Edges
src, dst = graph_wrapper.edges
......@@ -780,3 +763,89 @@ class DropEdgeWrapper(BaseGraphWrapper):
self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True
self._indegree = get_degree(self._edges_dst, self._num_nodes)
class BatchGraphWrapper(BaseGraphWrapper):
"""Implement a graph wrapper that user can use their own data holder.
And this graph wrapper support multiple graphs which is benefit for data parallel algorithms.
Args:
num_nodes (int32 or int64): Shape [ num_graph ].
num_edges (int32 or int64): Shape [ num_graph ].
edges (int32 or int64): Shape [ total_num_edges_in_the_graphs, 2 ]
node_feats: A dictionary for node features. Each value should be tensor
with shape [ total_num_nodes_in_the_graphs, feature_size]
edge_feats: A dictionary for edge features. Each value should be tensor
with shape [ total_num_edges_in_the_graphs, feature_size]
"""
def __init__(self, num_nodes, num_edges, edges, node_feats=None, edge_feats=None):
super(BatchGraphWrapper, self).__init__()
node_shift, edge_lod = self.__build_meta_data(num_nodes, num_edges)
self.__build_edges(edges, node_shift, edge_lod)
# assign node features
if node_feats is not None:
for key, value in node_feats.items():
self.node_feat_tensor_dict[key] = value
# assign edge features
if edge_feats is not None:
for key, value in edge_feats.items():
self.edge_feat_tensor_dict[key] = value
# other meta-data
self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(self._edges_dst, dtype="int32")
self._edge_uniq_dst.stop_gradient=True
last = L.reduce_sum(uniq_count, keep_dim=True)
uniq_count = L.cumsum(uniq_count, exclusive=True)
self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True
self._indegree = get_degree(self._edges_dst, self._num_nodes)
def __build_meta_data(self, num_nodes, num_edges):
""" Merge information for nodes and edges.
"""
num_nodes = L.reshape(num_nodes, [-1])
num_edges = L.reshape(num_edges, [-1])
num_nodes = paddle_helper.ensure_dtype(num_nodes, dtype="int32")
num_edges = paddle_helper.ensure_dtype(num_edges, dtype="int32")
num_graph = L.shape(num_nodes)[0]
sum_num_nodes = L.reduce_sum(num_nodes)
sum_num_edges = L.reduce_sum(num_edges)
edge_lod = L.concat([L.cumsum(num_edges, exclusive=True), sum_num_edges])
node_shift = L.cumsum(num_nodes, exclusive=True)
graph_lod = L.concat([node_shift, sum_num_nodes])
self._num_nodes = sum_num_nodes
self._num_edges = sum_num_edges
self._num_graph = num_graph
self._graph_lod = graph_lod
return node_shift, edge_lod
def __build_edges(self, edges, node_shift, edge_lod):
""" Merge subgraph edges.
"""
src = edges[:, 0]
dst = edges[:, 1]
src = L.reshape(src, [-1])
dst = L.reshape(dst, [-1])
src = paddle_helper.ensure_dtype(src, dtype="int32")
dst = paddle_helper.ensure_dtype(dst, dtype="int32")
# preprocess edges
lod_dst = L.lod_reset(dst, edge_lod)
node_shift = L.reshape(node_shift, [-1, 1])
node_shift = L.sequence_expand_as(node_shift, lod_dst)
node_shift = L.reshape(node_shift, [-1])
src = src + node_shift
dst = dst + node_shift
# sort edges
self._edges_dst, index = L.argsort(dst)
self._edges_src = L.gather(src, index, overwrite=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册