未验证 提交 860b4b20 编写于 作者: H Huang Zhengjie 提交者: GitHub

Merge pull request #114 from Yelrose/master

add graph_gather op; rename graph_pool.py as graph_op.py
...@@ -72,9 +72,9 @@ class GAT(object): ...@@ -72,9 +72,9 @@ class GAT(object):
def forward(self, graph_wrapper, feature, phase): def forward(self, graph_wrapper, feature, phase):
if phase == "train": if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers): for i in range(self.num_layers):
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout) ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
...@@ -113,9 +113,9 @@ class APPNP(object): ...@@ -113,9 +113,9 @@ class APPNP(object):
def forward(self, graph_wrapper, feature, phase): def forward(self, graph_wrapper, feature, phase):
if phase == "train": if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers): for i in range(self.num_layers):
feature = L.dropout( feature = L.dropout(
...@@ -169,9 +169,9 @@ class GCNII(object): ...@@ -169,9 +169,9 @@ class GCNII(object):
def forward(self, graph_wrapper, feature, phase): def forward(self, graph_wrapper, feature, phase):
if phase == "train": if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers): for i in range(self.num_layers):
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i) feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
...@@ -191,5 +191,3 @@ class GCNII(object): ...@@ -191,5 +191,3 @@ class GCNII(object):
feature = L.fc(feature, self.num_class, act=None, name="output") feature = L.fc(feature, self.num_class, act=None, name="output")
return feature return feature
...@@ -25,7 +25,7 @@ from pgl.utils import op ...@@ -25,7 +25,7 @@ from pgl.utils import op
from pgl.utils import paddle_helper from pgl.utils import paddle_helper
from pgl.utils.logger import log from pgl.utils.logger import log
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"] __all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper", "BatchGraphWrapper"]
def send(src, dst, nfeat, efeat, message_func): def send(src, dst, nfeat, efeat, message_func):
"""Send message from src to dst. """Send message from src to dst.
...@@ -101,7 +101,6 @@ class BaseGraphWrapper(object): ...@@ -101,7 +101,6 @@ class BaseGraphWrapper(object):
self._indegree = None self._indegree = None
self._edge_uniq_dst = None self._edge_uniq_dst = None
self._edge_uniq_dst_count = None self._edge_uniq_dst_count = None
self._node_ids = None
self._graph_lod = None self._graph_lod = None
self._num_graph = None self._num_graph = None
self._num_edges = None self._num_edges = None
...@@ -416,13 +415,6 @@ class StaticGraphWrapper(BaseGraphWrapper): ...@@ -416,13 +415,6 @@ class StaticGraphWrapper(BaseGraphWrapper):
value=graph_lod) value=graph_lod)
self._initializers.append(init) self._initializers.append(init)
node_ids_value = np.arange(0, graph.num_nodes, dtype="int64")
self._node_ids, init = paddle_helper.constant(
name=self._data_name_prefix + "/node_ids",
dtype="int64",
value=node_ids_value)
self._initializers.append(init)
self._indegree, init = paddle_helper.constant( self._indegree, init = paddle_helper.constant(
name=self._data_name_prefix + "/indegree", name=self._data_name_prefix + "/indegree",
dtype="int64", dtype="int64",
...@@ -601,12 +593,6 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -601,12 +593,6 @@ class GraphWrapper(BaseGraphWrapper):
dtype="int32", dtype="int32",
stop_gradient=True) stop_gradient=True)
self._node_ids = L.data(
self._data_name_prefix + "/node_ids",
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._indegree = L.data( self._indegree = L.data(
self._data_name_prefix + "/indegree", self._data_name_prefix + "/indegree",
shape=[None], shape=[None],
...@@ -619,7 +605,6 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -619,7 +605,6 @@ class GraphWrapper(BaseGraphWrapper):
self._num_nodes, self._num_nodes,
self._edge_uniq_dst, self._edge_uniq_dst,
self._edge_uniq_dst_count, self._edge_uniq_dst_count,
self._node_ids,
self._indegree, self._indegree,
self._graph_lod, self._graph_lod,
self._num_graph, self._num_graph,
...@@ -700,7 +685,6 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -700,7 +685,6 @@ class GraphWrapper(BaseGraphWrapper):
[graph.num_nodes], dtype="int64") [graph.num_nodes], dtype="int64")
feed_dict[self._data_name_prefix + '/uniq_dst'] = uniq_dst feed_dict[self._data_name_prefix + '/uniq_dst'] = uniq_dst
feed_dict[self._data_name_prefix + '/uniq_dst_count'] = uniq_dst_count feed_dict[self._data_name_prefix + '/uniq_dst_count'] = uniq_dst_count
feed_dict[self._data_name_prefix + '/node_ids'] = graph.nodes
feed_dict[self._data_name_prefix + '/indegree'] = indegree feed_dict[self._data_name_prefix + '/indegree'] = indegree
feed_dict[self._data_name_prefix + '/graph_lod'] = graph_lod feed_dict[self._data_name_prefix + '/graph_lod'] = graph_lod
feed_dict[self._data_name_prefix + '/num_graph'] = np.array( feed_dict[self._data_name_prefix + '/num_graph'] = np.array(
...@@ -746,7 +730,6 @@ class DropEdgeWrapper(BaseGraphWrapper): ...@@ -746,7 +730,6 @@ class DropEdgeWrapper(BaseGraphWrapper):
self._num_nodes = graph_wrapper.num_nodes self._num_nodes = graph_wrapper.num_nodes
self._graph_lod = graph_wrapper.graph_lod self._graph_lod = graph_wrapper.graph_lod
self._num_graph = graph_wrapper.num_graph self._num_graph = graph_wrapper.num_graph
self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32")
# Dropout Edges # Dropout Edges
src, dst = graph_wrapper.edges src, dst = graph_wrapper.edges
...@@ -780,3 +763,96 @@ class DropEdgeWrapper(BaseGraphWrapper): ...@@ -780,3 +763,96 @@ class DropEdgeWrapper(BaseGraphWrapper):
self._edge_uniq_dst_count = L.concat([uniq_count, last]) self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True self._edge_uniq_dst_count.stop_gradient=True
self._indegree = get_degree(self._edges_dst, self._num_nodes) self._indegree = get_degree(self._edges_dst, self._num_nodes)
class BatchGraphWrapper(BaseGraphWrapper):
"""Implement a graph wrapper that user can use their own data holder.
And this graph wrapper support multiple graphs which is benefit for data parallel algorithms.
Args:
num_nodes (int32 or int64): Shape [ num_graph ].
num_edges (int32 or int64): Shape [ num_graph ].
edges (int32 or int64): Shape [ total_num_edges_in_the_graphs, 2 ]
or Tuple with (src, dst).
node_feats: A dictionary for node features. Each value should be tensor
with shape [ total_num_nodes_in_the_graphs, feature_size]
edge_feats: A dictionary for edge features. Each value should be tensor
with shape [ total_num_edges_in_the_graphs, feature_size]
"""
def __init__(self, num_nodes, num_edges, edges, node_feats=None, edge_feats=None):
super(BatchGraphWrapper, self).__init__()
node_shift, edge_lod = self.__build_meta_data(num_nodes, num_edges)
self.__build_edges(edges, node_shift, edge_lod)
# assign node features
if node_feats is not None:
for key, value in node_feats.items():
self.node_feat_tensor_dict[key] = value
# assign edge features
if edge_feats is not None:
for key, value in edge_feats.items():
self.edge_feat_tensor_dict[key] = value
# other meta-data
self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(self._edges_dst, dtype="int32")
self._edge_uniq_dst.stop_gradient=True
last = L.reduce_sum(uniq_count, keep_dim=True)
uniq_count = L.cumsum(uniq_count, exclusive=True)
self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True
self._indegree = get_degree(self._edges_dst, self._num_nodes)
def __build_meta_data(self, num_nodes, num_edges):
""" Merge information for nodes and edges.
"""
num_nodes = L.reshape(num_nodes, [-1])
num_edges = L.reshape(num_edges, [-1])
num_nodes = paddle_helper.ensure_dtype(num_nodes, dtype="int32")
num_edges = paddle_helper.ensure_dtype(num_edges, dtype="int32")
num_graph = L.shape(num_nodes)[0]
sum_num_nodes = L.reduce_sum(num_nodes)
sum_num_edges = L.reduce_sum(num_edges)
edge_lod = L.concat([L.cumsum(num_edges, exclusive=True), sum_num_edges])
edge_lod = paddle_helper.lod_remove(edge_lod)
node_shift = L.cumsum(num_nodes, exclusive=True)
graph_lod = L.concat([node_shift, sum_num_nodes])
graph_lod = paddle_helper.lod_remove(graph_lod)
self._num_nodes = sum_num_nodes
self._num_edges = sum_num_edges
self._num_graph = num_graph
self._graph_lod = graph_lod
return node_shift, edge_lod
def __build_edges(self, edges, node_shift, edge_lod):
""" Merge subgraph edges.
"""
if isinstance(edges, tuple):
src, dst = edges
else:
src = edges[:, 0]
dst = edges[:, 1]
src = L.reshape(src, [-1])
dst = L.reshape(dst, [-1])
src = paddle_helper.ensure_dtype(src, dtype="int32")
dst = paddle_helper.ensure_dtype(dst, dtype="int32")
# preprocess edges
lod_dst = L.lod_reset(dst, edge_lod)
node_shift = L.reshape(node_shift, [-1, 1])
node_shift = L.sequence_expand_as(node_shift, lod_dst)
node_shift = L.reshape(node_shift, [-1])
src = src + node_shift
dst = dst + node_shift
# sort edges
self._edges_dst, index = L.argsort(dst)
self._edges_src = L.gather(src, index, overwrite=False)
...@@ -18,10 +18,10 @@ from pgl.layers import conv ...@@ -18,10 +18,10 @@ from pgl.layers import conv
from pgl.layers.conv import * from pgl.layers.conv import *
from pgl.layers import set2set from pgl.layers import set2set
from pgl.layers.set2set import * from pgl.layers.set2set import *
from pgl.layers import graph_pool from pgl.layers import graph_op
from pgl.layers.graph_pool import * from pgl.layers.graph_op import *
__all__ = [] __all__ = []
__all__ += conv.__all__ __all__ += conv.__all__
__all__ += set2set.__all__ __all__ += set2set.__all__
__all__ += graph_pool.__all__ __all__ += graph_op.__all__
...@@ -14,12 +14,13 @@ ...@@ -14,12 +14,13 @@
"""This package implements common layers to help building """This package implements common layers to help building
graph neural networks. graph neural networks.
""" """
import paddle.fluid as fluid import paddle.fluid as F
import paddle.fluid.layers as L
from pgl import graph_wrapper from pgl import graph_wrapper
from pgl.utils import paddle_helper from pgl.utils import paddle_helper
from pgl.utils import op from pgl.utils import op
__all__ = ['graph_pooling', 'graph_norm'] __all__ = ['graph_pooling', 'graph_norm', 'graph_gather']
def graph_pooling(gw, node_feat, pool_type): def graph_pooling(gw, node_feat, pool_type):
...@@ -38,7 +39,7 @@ def graph_pooling(gw, node_feat, pool_type): ...@@ -38,7 +39,7 @@ def graph_pooling(gw, node_feat, pool_type):
A tensor with shape (num_graph, hidden_size) A tensor with shape (num_graph, hidden_size)
""" """
graph_feat = op.nested_lod_reset(node_feat, gw.graph_lod) graph_feat = op.nested_lod_reset(node_feat, gw.graph_lod)
graph_feat = fluid.layers.sequence_pool(graph_feat, pool_type) graph_feat = L.sequence_pool(graph_feat, pool_type)
return graph_feat return graph_feat
...@@ -57,11 +58,42 @@ def graph_norm(gw, feature): ...@@ -57,11 +58,42 @@ def graph_norm(gw, feature):
Return: Return:
A tensor with shape (num_nodes, hidden_size) A tensor with shape (num_nodes, hidden_size)
""" """
nodes = fluid.layers.fill_constant( nodes = L.fill_constant(
[gw.num_nodes, 1], dtype="float32", value=1.0) [gw.num_nodes, 1], dtype="float32", value=1.0)
norm = graph_pooling(gw, nodes, pool_type="sum") norm = graph_pooling(gw, nodes, pool_type="sum")
norm = fluid.layers.sqrt(norm) norm = L.sqrt(norm)
feature_lod = op.nested_lod_reset(feature, gw.graph_lod) feature_lod = op.nested_lod_reset(feature, gw.graph_lod)
norm = fluid.layers.sequence_expand_as(norm, feature_lod) norm = L.sequence_expand_as(norm, feature_lod)
norm.stop_gradient = True norm.stop_gradient = True
return feature_lod / norm return feature_lod / norm
def graph_gather(gw, feature, index):
"""Implementation of graph gather
Gather the corresponding index for each graph.
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, ).
index (int32): A tensor with K-rank where the first dim denotes the graph.
Shape (num_graph, ) or (num_graph, k1, k2, k3, ..., kn).
WARNING: We dont support negative index.
Return:
A tensor with shape (num_graph, k1, k2, k3, ..., kn, hidden_size)
"""
shape = L.shape(index)
output_dim = int(feature.shape[-1])
index = index + gw.graph_lod[:-1]
index = L.reshape(index, [-1])
feature = L.gather(feature, index, overwrite=False)
new_shape = []
for i in range(shape.shape[0]):
new_shape.append(shape[i])
new_shape.append(output_dim)
feature = L.reshape(feature, new_shape)
return feature
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file is for testing gin layer.
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from pgl.layers.conv import gcn
from pgl import graph
from pgl import graph_wrapper
class BatchedGraphWrapper(unittest.TestCase):
"""BatchedGraphWrapper
"""
def test_batched_graph_wrapper(self):
"""test_batch_graph_wrapper
"""
np.random.seed(1)
graph_list = []
num_graph = 5
feed_num_nodes = []
feed_num_edges = []
feed_edges = []
feed_node_feats = []
for _ in range(num_graph):
num_nodes = np.random.randint(5, 20)
edges = np.random.randint(low=0, high=num_nodes, size=(10, 2))
node_feat = {"feature": np.random.rand(num_nodes, 4).astype("float32")}
single_graph = graph.Graph(num_nodes=num_nodes, edges=edges, node_feat=node_feat)
feed_num_nodes.append(num_nodes)
feed_num_edges.append(len(edges))
feed_edges.append(edges)
feed_node_feats.append(node_feat["feature"])
graph_list.append(single_graph)
multi_graph = graph.MultiGraph(graph_list)
np.random.seed(1)
hidden_size = 8
num_nodes = 10
place = F.CUDAPlace(0)# if use_cuda else F.CPUPlace()
prog = F.Program()
startup_prog = F.Program()
with F.program_guard(prog, startup_prog):
with F.unique_name.guard():
# Standard Graph Wrapper
gw = graph_wrapper.GraphWrapper(
name='graph',
place=place,
node_feat=[("feature", [-1, 4], "float32")])
output = gcn(gw,
gw.node_feat['feature'],
hidden_size=hidden_size,
activation='relu',
name='gcn')
# BatchGraphWrapper
num_nodes = L.data(name="num_nodes", shape=[-1], dtype="int32")
num_edges= L.data(name="num_edges", shape=[-1], dtype="int32")
edges = L.data(name="edges", shape=[-1, 2], dtype="int32")
node_feat = L.data(name="node_feats", shape=[-1, 4], dtype="float32")
batch_gw = graph_wrapper.BatchGraphWrapper(num_nodes=num_nodes,
num_edges=num_edges,
edges=edges,
node_feats={"feature": node_feat})
output2 = gcn(batch_gw,
batch_gw.node_feat['feature'],
hidden_size=hidden_size,
activation='relu',
name='gcn')
exe = F.Executor(place)
exe.run(startup_prog)
feed_dict = gw.to_feed(multi_graph)
feed_dict["num_nodes"] = np.array(feed_num_nodes, dtype="int32")
feed_dict["num_edges"] = np.array(feed_num_edges, dtype="int32")
feed_dict["edges"] = np.array(np.concatenate(feed_edges, 0), dtype="int32").reshape([-1, 2])
feed_dict["node_feats"] = np.array(np.concatenate(feed_node_feats, 0), dtype="float32").reshape([-1, 4])
# Run
O1, O2 = exe.run(prog, feed=feed_dict, fetch_list=[output, output2])
# The output from two kind of models should be same.
for o1, o2 in zip(O1, O2):
dist = np.sum((o1 - o2) ** 2)
self.assertLess(dist, 1e-15)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file is for testing gin layer.
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
import pgl
from pgl import graph
from pgl import graph_wrapper
class GraphGatherTest(unittest.TestCase):
"""GraphGatherTest
"""
def test_graph_gather(self):
"""test_graph_gather
"""
np.random.seed(1)
graph_list = []
num_graph = 10
for _ in range(num_graph):
num_nodes = np.random.randint(5, 20)
edges = np.random.randint(low=0, high=num_nodes, size=(10, 2))
node_feat = {"feature": np.random.rand(num_nodes, 4).astype("float32")}
g = graph.Graph(num_nodes=num_nodes, edges=edges, node_feat=node_feat)
graph_list.append(g)
gg = graph.MultiGraph(graph_list)
use_cuda = False
place = F.CUDAPlace(0) if use_cuda else F.CPUPlace()
prog = F.Program()
startup_prog = F.Program()
with F.program_guard(prog, startup_prog):
gw = graph_wrapper.GraphWrapper(
name='graph',
place=place,
node_feat=g.node_feat_info(),
edge_feat=g.edge_feat_info())
index = L.data(name="index", dtype="int32", shape=[-1])
feats = pgl.layers.graph_gather(gw, gw.node_feat["feature"], index)
exe = F.Executor(place)
exe.run(startup_prog)
feed_dict = gw.to_feed(gg)
feed_dict["index"] = np.zeros(num_graph, dtype="int32")
ret = exe.run(prog, feed=feed_dict, fetch_list=[feats])
self.assertEqual(list(ret[0].shape), [num_graph, 4])
for i in range(num_graph):
dist = (ret[0][i] - graph_list[i].node_feat["feature"][0])
dist = np.sum(dist ** 2)
self.assertLess(dist, 1e-15)
if __name__ == "__main__":
unittest.main()
...@@ -22,13 +22,14 @@ import paddle ...@@ -22,13 +22,14 @@ import paddle
from paddle.fluid import core from paddle.fluid import core
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layer_helper as layer_helper import paddle.fluid.layer_helper as layer_helper
import paddle.fluid.layers as L
from pgl.utils.logger import log from pgl.utils.logger import log
def gather(input, index): def gather(input, index):
"""Gather input from given index. """Gather input from given index.
Slicing input data with given index. This function rewrite paddle.fluid.layers.gather Slicing input data with given index. This function rewrite paddle.L.gather
to fix issue: https://github.com/PaddlePaddle/Paddle/issues/17509 when paddlepaddle's to fix issue: https://github.com/PaddlePaddle/Paddle/issues/17509 when paddlepaddle's
version is less than 1.5. version is less than 1.5.
...@@ -42,16 +43,16 @@ def gather(input, index): ...@@ -42,16 +43,16 @@ def gather(input, index):
""" """
try: try:
# PaddlePaddle 1.5 # PaddlePaddle 1.5
output = fluid.layers.gather(input, index, overwrite=False) output = L.gather(input, index, overwrite=False)
return output return output
except TypeError as e: except TypeError as e:
warnings.warn("Your paddle version is less than 1.5" warnings.warn("Your paddle version is less than 1.5"
" gather may be slower.") " gather may be slower.")
if index.dtype == core.VarDesc.VarType.INT32: if index.dtype == core.VarDesc.VarType.INT32:
index = fluid.layers.cast(index, "int64") index = L.cast(index, "int64")
if index.shape[-1] != 1: if index.shape[-1] != 1:
index = fluid.layers.reshape(index, shape=[-1, 1]) index = L.reshape(index, shape=[-1, 1])
index.stop_gradient = True index.stop_gradient = True
helper = layer_helper.LayerHelper("gather", **locals()) #**locals()) helper = layer_helper.LayerHelper("gather", **locals()) #**locals())
...@@ -112,7 +113,7 @@ def constant(name, value, dtype, hide_batch_size=True): ...@@ -112,7 +113,7 @@ def constant(name, value, dtype, hide_batch_size=True):
raise TypeError("value should be Numpy array.") raise TypeError("value should be Numpy array.")
value = value.astype(dtype) value = value.astype(dtype)
data = fluid.layers.create_global_var( data = L.create_global_var(
shape=value.shape, shape=value.shape,
value=0, value=0,
dtype=value.dtype, dtype=value.dtype,
...@@ -181,7 +182,7 @@ def lod_constant(name, value, lod, dtype): ...@@ -181,7 +182,7 @@ def lod_constant(name, value, lod, dtype):
_lod = [0] _lod = [0]
for l in lod: for l in lod:
_lod.append(_lod[-1] + l) _lod.append(_lod[-1] + l)
output = fluid.layers.lod_reset(data, target_lod=_lod) output = L.lod_reset(data, target_lod=_lod)
return output, data_initializer return output, data_initializer
...@@ -189,7 +190,7 @@ def sequence_softmax(x, beta=None): ...@@ -189,7 +190,7 @@ def sequence_softmax(x, beta=None):
"""Compute sequence softmax over paddle LodTensor """Compute sequence softmax over paddle LodTensor
This function compute softmax normalization along with the length of sequence. This function compute softmax normalization along with the length of sequence.
This function is an extention of :code:`fluid.layers.sequence_softmax` which can only This function is an extention of :code:`L.sequence_softmax` which can only
deal with LodTensor whose last dimension is 1. deal with LodTensor whose last dimension is 1.
Args: Args:
...@@ -203,12 +204,12 @@ def sequence_softmax(x, beta=None): ...@@ -203,12 +204,12 @@ def sequence_softmax(x, beta=None):
if beta is not None: if beta is not None:
x = x * beta x = x * beta
x_max = fluid.layers.sequence_pool(x, "max") x_max = L.sequence_pool(x, "max")
x_max = fluid.layers.sequence_expand_as(x_max, x) x_max = L.sequence_expand_as(x_max, x)
x = x - x_max x = x - x_max
exp_x = fluid.layers.exp(x) exp_x = L.exp(x)
sum_exp_x = fluid.layers.sequence_pool(exp_x, "sum") sum_exp_x = L.sequence_pool(exp_x, "sum")
sum_exp_x = fluid.layers.sequence_expand_as(sum_exp_x, exp_x) sum_exp_x = L.sequence_expand_as(sum_exp_x, exp_x)
return exp_x / sum_exp_x return exp_x / sum_exp_x
...@@ -228,7 +229,7 @@ def scatter_add(input, index, updates): ...@@ -228,7 +229,7 @@ def scatter_add(input, index, updates):
Same type and shape as input. Same type and shape as input.
""" """
output = fluid.layers.scatter(input, index, updates, overwrite=False) output = L.scatter(input, index, updates, overwrite=False)
return output return output
...@@ -248,7 +249,7 @@ def scatter_max(input, index, updates): ...@@ -248,7 +249,7 @@ def scatter_max(input, index, updates):
Same type and shape as input. Same type and shape as input.
""" """
output = fluid.layers.scatter(input, index, updates, mode='max') output = L.scatter(input, index, updates, mode='max')
return output return output
def masked_select(input, mask): def masked_select(input, mask):
...@@ -264,6 +265,41 @@ def masked_select(input, mask): ...@@ -264,6 +265,41 @@ def masked_select(input, mask):
Return: Return:
Part of inputs where mask is True. Part of inputs where mask is True.
""" """
index = fluid.layers.where(mask) index = L.where(mask)
return fluid.layers.gather(input, index) return L.gather(input, index)
def ensure_dtype(input, dtype):
"""ensure_dtype
If input is dtype, return input
else cast input into dtype
Args:
input: Input tensor
dtype: a string of type
Return:
If input is dtype, return input, else cast input into dtype
"""
if str(input.dtype) == dtype:
return input
else:
return L.cast(input, dtype=dtype)
def lod_remove(input):
"""Lod Remove
Remove the lod for LodTensor and Flatten the data into 1D-Tensor.
Args:
input: A tensor to be flattend
Return:
A 1D input
"""
return L.reshape(L.reshape(input, [1, -1]), [-1])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册