未验证 提交 860b4b20 编写于 作者: H Huang Zhengjie 提交者: GitHub

Merge pull request #114 from Yelrose/master

add graph_gather op; rename graph_pool.py as graph_op.py
......@@ -72,9 +72,9 @@ class GAT(object):
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers):
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
......@@ -113,9 +113,9 @@ class APPNP(object):
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers):
feature = L.dropout(
......@@ -169,9 +169,9 @@ class GCNII(object):
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers):
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
......@@ -191,5 +191,3 @@ class GCNII(object):
feature = L.fc(feature, self.num_class, act=None, name="output")
return feature
......@@ -25,7 +25,7 @@ from pgl.utils import op
from pgl.utils import paddle_helper
from pgl.utils.logger import log
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"]
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper", "BatchGraphWrapper"]
def send(src, dst, nfeat, efeat, message_func):
"""Send message from src to dst.
......@@ -101,7 +101,6 @@ class BaseGraphWrapper(object):
self._indegree = None
self._edge_uniq_dst = None
self._edge_uniq_dst_count = None
self._node_ids = None
self._graph_lod = None
self._num_graph = None
self._num_edges = None
......@@ -416,13 +415,6 @@ class StaticGraphWrapper(BaseGraphWrapper):
value=graph_lod)
self._initializers.append(init)
node_ids_value = np.arange(0, graph.num_nodes, dtype="int64")
self._node_ids, init = paddle_helper.constant(
name=self._data_name_prefix + "/node_ids",
dtype="int64",
value=node_ids_value)
self._initializers.append(init)
self._indegree, init = paddle_helper.constant(
name=self._data_name_prefix + "/indegree",
dtype="int64",
......@@ -601,12 +593,6 @@ class GraphWrapper(BaseGraphWrapper):
dtype="int32",
stop_gradient=True)
self._node_ids = L.data(
self._data_name_prefix + "/node_ids",
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._indegree = L.data(
self._data_name_prefix + "/indegree",
shape=[None],
......@@ -619,7 +605,6 @@ class GraphWrapper(BaseGraphWrapper):
self._num_nodes,
self._edge_uniq_dst,
self._edge_uniq_dst_count,
self._node_ids,
self._indegree,
self._graph_lod,
self._num_graph,
......@@ -700,7 +685,6 @@ class GraphWrapper(BaseGraphWrapper):
[graph.num_nodes], dtype="int64")
feed_dict[self._data_name_prefix + '/uniq_dst'] = uniq_dst
feed_dict[self._data_name_prefix + '/uniq_dst_count'] = uniq_dst_count
feed_dict[self._data_name_prefix + '/node_ids'] = graph.nodes
feed_dict[self._data_name_prefix + '/indegree'] = indegree
feed_dict[self._data_name_prefix + '/graph_lod'] = graph_lod
feed_dict[self._data_name_prefix + '/num_graph'] = np.array(
......@@ -746,7 +730,6 @@ class DropEdgeWrapper(BaseGraphWrapper):
self._num_nodes = graph_wrapper.num_nodes
self._graph_lod = graph_wrapper.graph_lod
self._num_graph = graph_wrapper.num_graph
self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32")
# Dropout Edges
src, dst = graph_wrapper.edges
......@@ -780,3 +763,96 @@ class DropEdgeWrapper(BaseGraphWrapper):
self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True
self._indegree = get_degree(self._edges_dst, self._num_nodes)
class BatchGraphWrapper(BaseGraphWrapper):
"""Implement a graph wrapper that user can use their own data holder.
And this graph wrapper support multiple graphs which is benefit for data parallel algorithms.
Args:
num_nodes (int32 or int64): Shape [ num_graph ].
num_edges (int32 or int64): Shape [ num_graph ].
edges (int32 or int64): Shape [ total_num_edges_in_the_graphs, 2 ]
or Tuple with (src, dst).
node_feats: A dictionary for node features. Each value should be tensor
with shape [ total_num_nodes_in_the_graphs, feature_size]
edge_feats: A dictionary for edge features. Each value should be tensor
with shape [ total_num_edges_in_the_graphs, feature_size]
"""
def __init__(self, num_nodes, num_edges, edges, node_feats=None, edge_feats=None):
super(BatchGraphWrapper, self).__init__()
node_shift, edge_lod = self.__build_meta_data(num_nodes, num_edges)
self.__build_edges(edges, node_shift, edge_lod)
# assign node features
if node_feats is not None:
for key, value in node_feats.items():
self.node_feat_tensor_dict[key] = value
# assign edge features
if edge_feats is not None:
for key, value in edge_feats.items():
self.edge_feat_tensor_dict[key] = value
# other meta-data
self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(self._edges_dst, dtype="int32")
self._edge_uniq_dst.stop_gradient=True
last = L.reduce_sum(uniq_count, keep_dim=True)
uniq_count = L.cumsum(uniq_count, exclusive=True)
self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True
self._indegree = get_degree(self._edges_dst, self._num_nodes)
def __build_meta_data(self, num_nodes, num_edges):
""" Merge information for nodes and edges.
"""
num_nodes = L.reshape(num_nodes, [-1])
num_edges = L.reshape(num_edges, [-1])
num_nodes = paddle_helper.ensure_dtype(num_nodes, dtype="int32")
num_edges = paddle_helper.ensure_dtype(num_edges, dtype="int32")
num_graph = L.shape(num_nodes)[0]
sum_num_nodes = L.reduce_sum(num_nodes)
sum_num_edges = L.reduce_sum(num_edges)
edge_lod = L.concat([L.cumsum(num_edges, exclusive=True), sum_num_edges])
edge_lod = paddle_helper.lod_remove(edge_lod)
node_shift = L.cumsum(num_nodes, exclusive=True)
graph_lod = L.concat([node_shift, sum_num_nodes])
graph_lod = paddle_helper.lod_remove(graph_lod)
self._num_nodes = sum_num_nodes
self._num_edges = sum_num_edges
self._num_graph = num_graph
self._graph_lod = graph_lod
return node_shift, edge_lod
def __build_edges(self, edges, node_shift, edge_lod):
""" Merge subgraph edges.
"""
if isinstance(edges, tuple):
src, dst = edges
else:
src = edges[:, 0]
dst = edges[:, 1]
src = L.reshape(src, [-1])
dst = L.reshape(dst, [-1])
src = paddle_helper.ensure_dtype(src, dtype="int32")
dst = paddle_helper.ensure_dtype(dst, dtype="int32")
# preprocess edges
lod_dst = L.lod_reset(dst, edge_lod)
node_shift = L.reshape(node_shift, [-1, 1])
node_shift = L.sequence_expand_as(node_shift, lod_dst)
node_shift = L.reshape(node_shift, [-1])
src = src + node_shift
dst = dst + node_shift
# sort edges
self._edges_dst, index = L.argsort(dst)
self._edges_src = L.gather(src, index, overwrite=False)
......@@ -18,10 +18,10 @@ from pgl.layers import conv
from pgl.layers.conv import *
from pgl.layers import set2set
from pgl.layers.set2set import *
from pgl.layers import graph_pool
from pgl.layers.graph_pool import *
from pgl.layers import graph_op
from pgl.layers.graph_op import *
__all__ = []
__all__ += conv.__all__
__all__ += set2set.__all__
__all__ += graph_pool.__all__
__all__ += graph_op.__all__
......@@ -14,12 +14,13 @@
"""This package implements common layers to help building
graph neural networks.
"""
import paddle.fluid as fluid
import paddle.fluid as F
import paddle.fluid.layers as L
from pgl import graph_wrapper
from pgl.utils import paddle_helper
from pgl.utils import op
__all__ = ['graph_pooling', 'graph_norm']
__all__ = ['graph_pooling', 'graph_norm', 'graph_gather']
def graph_pooling(gw, node_feat, pool_type):
......@@ -38,7 +39,7 @@ def graph_pooling(gw, node_feat, pool_type):
A tensor with shape (num_graph, hidden_size)
"""
graph_feat = op.nested_lod_reset(node_feat, gw.graph_lod)
graph_feat = fluid.layers.sequence_pool(graph_feat, pool_type)
graph_feat = L.sequence_pool(graph_feat, pool_type)
return graph_feat
......@@ -57,11 +58,42 @@ def graph_norm(gw, feature):
Return:
A tensor with shape (num_nodes, hidden_size)
"""
nodes = fluid.layers.fill_constant(
nodes = L.fill_constant(
[gw.num_nodes, 1], dtype="float32", value=1.0)
norm = graph_pooling(gw, nodes, pool_type="sum")
norm = fluid.layers.sqrt(norm)
norm = L.sqrt(norm)
feature_lod = op.nested_lod_reset(feature, gw.graph_lod)
norm = fluid.layers.sequence_expand_as(norm, feature_lod)
norm = L.sequence_expand_as(norm, feature_lod)
norm.stop_gradient = True
return feature_lod / norm
def graph_gather(gw, feature, index):
"""Implementation of graph gather
Gather the corresponding index for each graph.
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, ).
index (int32): A tensor with K-rank where the first dim denotes the graph.
Shape (num_graph, ) or (num_graph, k1, k2, k3, ..., kn).
WARNING: We dont support negative index.
Return:
A tensor with shape (num_graph, k1, k2, k3, ..., kn, hidden_size)
"""
shape = L.shape(index)
output_dim = int(feature.shape[-1])
index = index + gw.graph_lod[:-1]
index = L.reshape(index, [-1])
feature = L.gather(feature, index, overwrite=False)
new_shape = []
for i in range(shape.shape[0]):
new_shape.append(shape[i])
new_shape.append(output_dim)
feature = L.reshape(feature, new_shape)
return feature
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file is for testing gin layer.
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from pgl.layers.conv import gcn
from pgl import graph
from pgl import graph_wrapper
class BatchedGraphWrapper(unittest.TestCase):
"""BatchedGraphWrapper
"""
def test_batched_graph_wrapper(self):
"""test_batch_graph_wrapper
"""
np.random.seed(1)
graph_list = []
num_graph = 5
feed_num_nodes = []
feed_num_edges = []
feed_edges = []
feed_node_feats = []
for _ in range(num_graph):
num_nodes = np.random.randint(5, 20)
edges = np.random.randint(low=0, high=num_nodes, size=(10, 2))
node_feat = {"feature": np.random.rand(num_nodes, 4).astype("float32")}
single_graph = graph.Graph(num_nodes=num_nodes, edges=edges, node_feat=node_feat)
feed_num_nodes.append(num_nodes)
feed_num_edges.append(len(edges))
feed_edges.append(edges)
feed_node_feats.append(node_feat["feature"])
graph_list.append(single_graph)
multi_graph = graph.MultiGraph(graph_list)
np.random.seed(1)
hidden_size = 8
num_nodes = 10
place = F.CUDAPlace(0)# if use_cuda else F.CPUPlace()
prog = F.Program()
startup_prog = F.Program()
with F.program_guard(prog, startup_prog):
with F.unique_name.guard():
# Standard Graph Wrapper
gw = graph_wrapper.GraphWrapper(
name='graph',
place=place,
node_feat=[("feature", [-1, 4], "float32")])
output = gcn(gw,
gw.node_feat['feature'],
hidden_size=hidden_size,
activation='relu',
name='gcn')
# BatchGraphWrapper
num_nodes = L.data(name="num_nodes", shape=[-1], dtype="int32")
num_edges= L.data(name="num_edges", shape=[-1], dtype="int32")
edges = L.data(name="edges", shape=[-1, 2], dtype="int32")
node_feat = L.data(name="node_feats", shape=[-1, 4], dtype="float32")
batch_gw = graph_wrapper.BatchGraphWrapper(num_nodes=num_nodes,
num_edges=num_edges,
edges=edges,
node_feats={"feature": node_feat})
output2 = gcn(batch_gw,
batch_gw.node_feat['feature'],
hidden_size=hidden_size,
activation='relu',
name='gcn')
exe = F.Executor(place)
exe.run(startup_prog)
feed_dict = gw.to_feed(multi_graph)
feed_dict["num_nodes"] = np.array(feed_num_nodes, dtype="int32")
feed_dict["num_edges"] = np.array(feed_num_edges, dtype="int32")
feed_dict["edges"] = np.array(np.concatenate(feed_edges, 0), dtype="int32").reshape([-1, 2])
feed_dict["node_feats"] = np.array(np.concatenate(feed_node_feats, 0), dtype="float32").reshape([-1, 4])
# Run
O1, O2 = exe.run(prog, feed=feed_dict, fetch_list=[output, output2])
# The output from two kind of models should be same.
for o1, o2 in zip(O1, O2):
dist = np.sum((o1 - o2) ** 2)
self.assertLess(dist, 1e-15)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file is for testing gin layer.
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
import pgl
from pgl import graph
from pgl import graph_wrapper
class GraphGatherTest(unittest.TestCase):
"""GraphGatherTest
"""
def test_graph_gather(self):
"""test_graph_gather
"""
np.random.seed(1)
graph_list = []
num_graph = 10
for _ in range(num_graph):
num_nodes = np.random.randint(5, 20)
edges = np.random.randint(low=0, high=num_nodes, size=(10, 2))
node_feat = {"feature": np.random.rand(num_nodes, 4).astype("float32")}
g = graph.Graph(num_nodes=num_nodes, edges=edges, node_feat=node_feat)
graph_list.append(g)
gg = graph.MultiGraph(graph_list)
use_cuda = False
place = F.CUDAPlace(0) if use_cuda else F.CPUPlace()
prog = F.Program()
startup_prog = F.Program()
with F.program_guard(prog, startup_prog):
gw = graph_wrapper.GraphWrapper(
name='graph',
place=place,
node_feat=g.node_feat_info(),
edge_feat=g.edge_feat_info())
index = L.data(name="index", dtype="int32", shape=[-1])
feats = pgl.layers.graph_gather(gw, gw.node_feat["feature"], index)
exe = F.Executor(place)
exe.run(startup_prog)
feed_dict = gw.to_feed(gg)
feed_dict["index"] = np.zeros(num_graph, dtype="int32")
ret = exe.run(prog, feed=feed_dict, fetch_list=[feats])
self.assertEqual(list(ret[0].shape), [num_graph, 4])
for i in range(num_graph):
dist = (ret[0][i] - graph_list[i].node_feat["feature"][0])
dist = np.sum(dist ** 2)
self.assertLess(dist, 1e-15)
if __name__ == "__main__":
unittest.main()
......@@ -22,13 +22,14 @@ import paddle
from paddle.fluid import core
import paddle.fluid as fluid
import paddle.fluid.layer_helper as layer_helper
import paddle.fluid.layers as L
from pgl.utils.logger import log
def gather(input, index):
"""Gather input from given index.
Slicing input data with given index. This function rewrite paddle.fluid.layers.gather
Slicing input data with given index. This function rewrite paddle.L.gather
to fix issue: https://github.com/PaddlePaddle/Paddle/issues/17509 when paddlepaddle's
version is less than 1.5.
......@@ -42,16 +43,16 @@ def gather(input, index):
"""
try:
# PaddlePaddle 1.5
output = fluid.layers.gather(input, index, overwrite=False)
output = L.gather(input, index, overwrite=False)
return output
except TypeError as e:
warnings.warn("Your paddle version is less than 1.5"
" gather may be slower.")
if index.dtype == core.VarDesc.VarType.INT32:
index = fluid.layers.cast(index, "int64")
index = L.cast(index, "int64")
if index.shape[-1] != 1:
index = fluid.layers.reshape(index, shape=[-1, 1])
index = L.reshape(index, shape=[-1, 1])
index.stop_gradient = True
helper = layer_helper.LayerHelper("gather", **locals()) #**locals())
......@@ -112,7 +113,7 @@ def constant(name, value, dtype, hide_batch_size=True):
raise TypeError("value should be Numpy array.")
value = value.astype(dtype)
data = fluid.layers.create_global_var(
data = L.create_global_var(
shape=value.shape,
value=0,
dtype=value.dtype,
......@@ -181,7 +182,7 @@ def lod_constant(name, value, lod, dtype):
_lod = [0]
for l in lod:
_lod.append(_lod[-1] + l)
output = fluid.layers.lod_reset(data, target_lod=_lod)
output = L.lod_reset(data, target_lod=_lod)
return output, data_initializer
......@@ -189,7 +190,7 @@ def sequence_softmax(x, beta=None):
"""Compute sequence softmax over paddle LodTensor
This function compute softmax normalization along with the length of sequence.
This function is an extention of :code:`fluid.layers.sequence_softmax` which can only
This function is an extention of :code:`L.sequence_softmax` which can only
deal with LodTensor whose last dimension is 1.
Args:
......@@ -203,12 +204,12 @@ def sequence_softmax(x, beta=None):
if beta is not None:
x = x * beta
x_max = fluid.layers.sequence_pool(x, "max")
x_max = fluid.layers.sequence_expand_as(x_max, x)
x_max = L.sequence_pool(x, "max")
x_max = L.sequence_expand_as(x_max, x)
x = x - x_max
exp_x = fluid.layers.exp(x)
sum_exp_x = fluid.layers.sequence_pool(exp_x, "sum")
sum_exp_x = fluid.layers.sequence_expand_as(sum_exp_x, exp_x)
exp_x = L.exp(x)
sum_exp_x = L.sequence_pool(exp_x, "sum")
sum_exp_x = L.sequence_expand_as(sum_exp_x, exp_x)
return exp_x / sum_exp_x
......@@ -228,7 +229,7 @@ def scatter_add(input, index, updates):
Same type and shape as input.
"""
output = fluid.layers.scatter(input, index, updates, overwrite=False)
output = L.scatter(input, index, updates, overwrite=False)
return output
......@@ -248,7 +249,7 @@ def scatter_max(input, index, updates):
Same type and shape as input.
"""
output = fluid.layers.scatter(input, index, updates, mode='max')
output = L.scatter(input, index, updates, mode='max')
return output
def masked_select(input, mask):
......@@ -264,6 +265,41 @@ def masked_select(input, mask):
Return:
Part of inputs where mask is True.
"""
index = fluid.layers.where(mask)
return fluid.layers.gather(input, index)
index = L.where(mask)
return L.gather(input, index)
def ensure_dtype(input, dtype):
"""ensure_dtype
If input is dtype, return input
else cast input into dtype
Args:
input: Input tensor
dtype: a string of type
Return:
If input is dtype, return input, else cast input into dtype
"""
if str(input.dtype) == dtype:
return input
else:
return L.cast(input, dtype=dtype)
def lod_remove(input):
"""Lod Remove
Remove the lod for LodTensor and Flatten the data into 1D-Tensor.
Args:
input: A tensor to be flattend
Return:
A 1D input
"""
return L.reshape(L.reshape(input, [1, -1]), [-1])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册