提交 c820ca0a 编写于 作者: Y Yelrose

fixed pgl for pslib

上级 a584651d
...@@ -820,9 +820,11 @@ class BatchGraphWrapper(BaseGraphWrapper): ...@@ -820,9 +820,11 @@ class BatchGraphWrapper(BaseGraphWrapper):
sum_num_nodes = L.reduce_sum(num_nodes) sum_num_nodes = L.reduce_sum(num_nodes)
sum_num_edges = L.reduce_sum(num_edges) sum_num_edges = L.reduce_sum(num_edges)
edge_lod = L.concat([L.cumsum(num_edges, exclusive=True), sum_num_edges]) edge_lod = L.concat([L.cumsum(num_edges, exclusive=True), sum_num_edges])
edge_lod = paddle_helper.lod_remove(edge_lod)
node_shift = L.cumsum(num_nodes, exclusive=True) node_shift = L.cumsum(num_nodes, exclusive=True)
graph_lod = L.concat([node_shift, sum_num_nodes]) graph_lod = L.concat([node_shift, sum_num_nodes])
graph_lod = paddle_helper.lod_remove(graph_lod)
self._num_nodes = sum_num_nodes self._num_nodes = sum_num_nodes
self._num_edges = sum_num_edges self._num_edges = sum_num_edges
self._num_graph = num_graph self._num_graph = num_graph
......
...@@ -86,7 +86,7 @@ def graph_gather(gw, feature, index): ...@@ -86,7 +86,7 @@ def graph_gather(gw, feature, index):
A tensor with shape (num_graph, k1, k2, k3, ..., kn, hidden_size) A tensor with shape (num_graph, k1, k2, k3, ..., kn, hidden_size)
""" """
shape = L.shape(index) shape = L.shape(index)
output_dim = feature.shape[-1] output_dim = int(feature.shape[-1])
index = index + gw.graph_lod[:-1] index = index + gw.graph_lod[:-1]
index = L.reshape(index, [-1]) index = L.reshape(index, [-1])
feature = L.gather(feature, index, overwrite=False) feature = L.gather(feature, index, overwrite=False)
......
...@@ -24,7 +24,7 @@ import numpy as np ...@@ -24,7 +24,7 @@ import numpy as np
import paddle.fluid as F import paddle.fluid as F
import paddle.fluid.layers as L import paddle.fluid.layers as L
from pgl.layers.conv import gin from pgl.layers.conv import gcn
from pgl import graph from pgl import graph
from pgl import graph_wrapper from pgl import graph_wrapper
...@@ -33,13 +33,13 @@ class BatchedGraphWrapper(unittest.TestCase): ...@@ -33,13 +33,13 @@ class BatchedGraphWrapper(unittest.TestCase):
"""BatchedGraphWrapper """BatchedGraphWrapper
""" """
def test_batched_graph_wrapper(self): def test_batched_graph_wrapper(self):
"""test_gin """test_batch_graph_wrapper
""" """
np.random.seed(1) np.random.seed(1)
graph_list = [] graph_list = []
num_graph = 10 num_graph = 5
feed_num_nodes = [] feed_num_nodes = []
feed_num_edges = [] feed_num_edges = []
feed_edges = [] feed_edges = []
...@@ -74,14 +74,12 @@ class BatchedGraphWrapper(unittest.TestCase): ...@@ -74,14 +74,12 @@ class BatchedGraphWrapper(unittest.TestCase):
place=place, place=place,
node_feat=[("feature", [-1, 4], "float32")]) node_feat=[("feature", [-1, 4], "float32")])
output = gin(gw, output = gcn(gw,
gw.node_feat['feature'], gw.node_feat['feature'],
hidden_size=hidden_size, hidden_size=hidden_size,
activation='relu', activation='relu',
name='gin', name='gcn')
init_eps=1,
train_eps=True)
# BatchGraphWrapper # BatchGraphWrapper
num_nodes = L.data(name="num_nodes", shape=[-1], dtype="int32") num_nodes = L.data(name="num_nodes", shape=[-1], dtype="int32")
num_edges= L.data(name="num_edges", shape=[-1], dtype="int32") num_edges= L.data(name="num_edges", shape=[-1], dtype="int32")
...@@ -92,13 +90,11 @@ class BatchedGraphWrapper(unittest.TestCase): ...@@ -92,13 +90,11 @@ class BatchedGraphWrapper(unittest.TestCase):
edges=edges, edges=edges,
node_feats={"feature": node_feat}) node_feats={"feature": node_feat})
output2 = gin(batch_gw, output2 = gcn(batch_gw,
batch_gw.node_feat['feature'], batch_gw.node_feat['feature'],
hidden_size=hidden_size, hidden_size=hidden_size,
activation='relu', activation='relu',
name='gin', name='gcn')
init_eps=1,
train_eps=True)
exe = F.Executor(place) exe = F.Executor(place)
...@@ -110,11 +106,12 @@ class BatchedGraphWrapper(unittest.TestCase): ...@@ -110,11 +106,12 @@ class BatchedGraphWrapper(unittest.TestCase):
feed_dict["node_feats"] = np.array(np.concatenate(feed_node_feats, 0), dtype="float32").reshape([-1, 4]) feed_dict["node_feats"] = np.array(np.concatenate(feed_node_feats, 0), dtype="float32").reshape([-1, 4])
# Run # Run
o1, o2 = exe.run(prog, feed=feed_dict, fetch_list=[output, output2]) O1, O2 = exe.run(prog, feed=feed_dict, fetch_list=[output, output2])
# The output from two kind of models should be same. # The output from two kind of models should be same.
dist = np.sum((o1 - o2) ** 2) for o1, o2 in zip(O1, O2):
self.assertLess(dist, 1e-15) dist = np.sum((o1 - o2) ** 2)
self.assertLess(dist, 1e-15)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -22,13 +22,14 @@ import paddle ...@@ -22,13 +22,14 @@ import paddle
from paddle.fluid import core from paddle.fluid import core
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layer_helper as layer_helper import paddle.fluid.layer_helper as layer_helper
import paddle.fluid.layers as L
from pgl.utils.logger import log from pgl.utils.logger import log
def gather(input, index): def gather(input, index):
"""Gather input from given index. """Gather input from given index.
Slicing input data with given index. This function rewrite paddle.fluid.layers.gather Slicing input data with given index. This function rewrite paddle.L.gather
to fix issue: https://github.com/PaddlePaddle/Paddle/issues/17509 when paddlepaddle's to fix issue: https://github.com/PaddlePaddle/Paddle/issues/17509 when paddlepaddle's
version is less than 1.5. version is less than 1.5.
...@@ -42,16 +43,16 @@ def gather(input, index): ...@@ -42,16 +43,16 @@ def gather(input, index):
""" """
try: try:
# PaddlePaddle 1.5 # PaddlePaddle 1.5
output = fluid.layers.gather(input, index, overwrite=False) output = L.gather(input, index, overwrite=False)
return output return output
except TypeError as e: except TypeError as e:
warnings.warn("Your paddle version is less than 1.5" warnings.warn("Your paddle version is less than 1.5"
" gather may be slower.") " gather may be slower.")
if index.dtype == core.VarDesc.VarType.INT32: if index.dtype == core.VarDesc.VarType.INT32:
index = fluid.layers.cast(index, "int64") index = L.cast(index, "int64")
if index.shape[-1] != 1: if index.shape[-1] != 1:
index = fluid.layers.reshape(index, shape=[-1, 1]) index = L.reshape(index, shape=[-1, 1])
index.stop_gradient = True index.stop_gradient = True
helper = layer_helper.LayerHelper("gather", **locals()) #**locals()) helper = layer_helper.LayerHelper("gather", **locals()) #**locals())
...@@ -112,7 +113,7 @@ def constant(name, value, dtype, hide_batch_size=True): ...@@ -112,7 +113,7 @@ def constant(name, value, dtype, hide_batch_size=True):
raise TypeError("value should be Numpy array.") raise TypeError("value should be Numpy array.")
value = value.astype(dtype) value = value.astype(dtype)
data = fluid.layers.create_global_var( data = L.create_global_var(
shape=value.shape, shape=value.shape,
value=0, value=0,
dtype=value.dtype, dtype=value.dtype,
...@@ -181,7 +182,7 @@ def lod_constant(name, value, lod, dtype): ...@@ -181,7 +182,7 @@ def lod_constant(name, value, lod, dtype):
_lod = [0] _lod = [0]
for l in lod: for l in lod:
_lod.append(_lod[-1] + l) _lod.append(_lod[-1] + l)
output = fluid.layers.lod_reset(data, target_lod=_lod) output = L.lod_reset(data, target_lod=_lod)
return output, data_initializer return output, data_initializer
...@@ -189,7 +190,7 @@ def sequence_softmax(x, beta=None): ...@@ -189,7 +190,7 @@ def sequence_softmax(x, beta=None):
"""Compute sequence softmax over paddle LodTensor """Compute sequence softmax over paddle LodTensor
This function compute softmax normalization along with the length of sequence. This function compute softmax normalization along with the length of sequence.
This function is an extention of :code:`fluid.layers.sequence_softmax` which can only This function is an extention of :code:`L.sequence_softmax` which can only
deal with LodTensor whose last dimension is 1. deal with LodTensor whose last dimension is 1.
Args: Args:
...@@ -203,12 +204,12 @@ def sequence_softmax(x, beta=None): ...@@ -203,12 +204,12 @@ def sequence_softmax(x, beta=None):
if beta is not None: if beta is not None:
x = x * beta x = x * beta
x_max = fluid.layers.sequence_pool(x, "max") x_max = L.sequence_pool(x, "max")
x_max = fluid.layers.sequence_expand_as(x_max, x) x_max = L.sequence_expand_as(x_max, x)
x = x - x_max x = x - x_max
exp_x = fluid.layers.exp(x) exp_x = L.exp(x)
sum_exp_x = fluid.layers.sequence_pool(exp_x, "sum") sum_exp_x = L.sequence_pool(exp_x, "sum")
sum_exp_x = fluid.layers.sequence_expand_as(sum_exp_x, exp_x) sum_exp_x = L.sequence_expand_as(sum_exp_x, exp_x)
return exp_x / sum_exp_x return exp_x / sum_exp_x
...@@ -228,7 +229,7 @@ def scatter_add(input, index, updates): ...@@ -228,7 +229,7 @@ def scatter_add(input, index, updates):
Same type and shape as input. Same type and shape as input.
""" """
output = fluid.layers.scatter(input, index, updates, overwrite=False) output = L.scatter(input, index, updates, overwrite=False)
return output return output
...@@ -248,7 +249,7 @@ def scatter_max(input, index, updates): ...@@ -248,7 +249,7 @@ def scatter_max(input, index, updates):
Same type and shape as input. Same type and shape as input.
""" """
output = fluid.layers.scatter(input, index, updates, mode='max') output = L.scatter(input, index, updates, mode='max')
return output return output
def masked_select(input, mask): def masked_select(input, mask):
...@@ -264,12 +265,41 @@ def masked_select(input, mask): ...@@ -264,12 +265,41 @@ def masked_select(input, mask):
Return: Return:
Part of inputs where mask is True. Part of inputs where mask is True.
""" """
index = fluid.layers.where(mask) index = L.where(mask)
return fluid.layers.gather(input, index) return L.gather(input, index)
def ensure_dtype(input, dtype): def ensure_dtype(input, dtype):
"""ensure_dtype
If input is dtype, return input
else cast input into dtype
Args:
input: Input tensor
dtype: a string of type
Return:
If input is dtype, return input, else cast input into dtype
"""
if str(input.dtype) == dtype: if str(input.dtype) == dtype:
return input return input
else: else:
return fluid.layers.cast(input, dtype=dtype) return L.cast(input, dtype=dtype)
def lod_remove(input):
"""Lod Remove
Remove the lod for LodTensor and Flatten the data into 1D-Tensor.
Args:
input: A tensor to be flattend
Return:
A 1D input
"""
return L.reshape(L.reshape(input, [1, -1]), [-1])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册