未验证 提交 8b611f04 编写于 作者: C cyber-pioneer 提交者: GitHub

remove fluid.layer.gather_tree (#48480)

上级 f62b3fc8
...@@ -128,7 +128,6 @@ __all__ = [ ...@@ -128,7 +128,6 @@ __all__ = [
'shard_index', 'shard_index',
'hard_swish', 'hard_swish',
'mish', 'mish',
'gather_tree',
'uniform_random', 'uniform_random',
'unbind', 'unbind',
] ]
...@@ -7928,70 +7927,6 @@ def mish(x, threshold=20, name=None): ...@@ -7928,70 +7927,6 @@ def mish(x, threshold=20, name=None):
return out return out
def gather_tree(ids, parents):
r"""
To be used after beam search. After beam search, we get selected ids at
each time step and the corresponding parents in the search tree. Both ids
and parents have the layout :attr:`[max_time, batch_size, beam_size]`. Then
:attr:`gather_tree` is used to backtrace from the last time step and
generate the full sequences by collecting selected ids.
Here is an example:
.. code-block:: text
Given:
ids = [[[2 2]
[6 1]]
[[3 9]
[6 1]]
[[0 1]
[9 0]]]
parents = [[[0 0]
[1 1]]
[[1 0]
[1 0]]
[[0 0]
[0 1]]]
Then:
gather_tree(ids, parents)
= [[[2 2]
[1 6]]
[[3 3]
[6 1]]
[[0 1]
[9 0]]]
Args:
ids(Tensor): A Tensor with shape :attr:`[length, batch_size, beam_size]`
and data type :attr:`int32` or :attr:`int64`. It contains the selected
ids of all time steps.
parents(Tensor): A Tensor with the same shape and data type as :attr:`ids`,
It contains the parents corresponding to selected ids when searching
among beams.
Returns:
A Tensor with the same shape and data type as :attr:`ids`. \
It contains the full sequences. The sequences are collected from \
:attr:`ids` by backtracing according to :attr:`parents`.
Examples:
.. code-block:: python
import paddle
ids = paddle.to_tensor([[[2, 2], [6, 1]], [[3, 9], [6, 1]], [[0, 1], [9, 0]]])
parents = paddle.to_tensor([[[0, 0], [1, 1]], [[1, 0], [1, 0]], [[0, 0], [0, 1]]])
final_sequences = paddle.nn.functional.gather_tree(ids, parents)
# [[[2, 2], [1, 6]], [[3, 3], [6, 1]], [[0, 1], [9, 0]]]
"""
return paddle.nn.functional.gather_tree(ids, parents)
@deprecated(since="2.0.0", update_to="paddle.uniform") @deprecated(since="2.0.0", update_to="paddle.uniform")
@templatedoc() @templatedoc()
def uniform_random( def uniform_random(
......
...@@ -1427,7 +1427,7 @@ class BeamSearchDecoder(Decoder): ...@@ -1427,7 +1427,7 @@ class BeamSearchDecoder(Decoder):
`[time_step, batch_size, beam_size]`. `final_states` is the same \ `[time_step, batch_size, beam_size]`. `final_states` is the same \
as the input argument `final_states`. as the input argument `final_states`.
""" """
predicted_ids = nn.gather_tree( predicted_ids = paddle.nn.functional.gather_tree(
outputs.predicted_ids, outputs.parent_ids outputs.predicted_ids, outputs.parent_ids
) )
# TODO: use FinalBeamSearchDecoderOutput as output # TODO: use FinalBeamSearchDecoderOutput as output
......
...@@ -502,7 +502,9 @@ class BaseModel(fluid.dygraph.Layer): ...@@ -502,7 +502,9 @@ class BaseModel(fluid.dygraph.Layer):
predicted_ids = paddle.stack(predicted_ids) predicted_ids = paddle.stack(predicted_ids)
parent_ids = paddle.stack(parent_ids) parent_ids = paddle.stack(parent_ids)
predicted_ids = fluid.layers.gather_tree(predicted_ids, parent_ids) predicted_ids = paddle.nn.functional.gather_tree(
predicted_ids, parent_ids
)
predicted_ids = self._transpose_batch_time(predicted_ids) predicted_ids = self._transpose_batch_time(predicted_ids)
return predicted_ids return predicted_ids
......
...@@ -884,7 +884,7 @@ class Transformer(Layer): ...@@ -884,7 +884,7 @@ class Transformer(Layer):
predict_ids = paddle.stack(predict_ids, axis=0) predict_ids = paddle.stack(predict_ids, axis=0)
parent_ids = paddle.stack(parent_ids, axis=0) parent_ids = paddle.stack(parent_ids, axis=0)
finished_seq = paddle.transpose( finished_seq = paddle.transpose(
layers.gather_tree(predict_ids, parent_ids), [1, 2, 0] paddle.nn.functional.gather_tree(predict_ids, parent_ids), [1, 2, 0]
) )
finished_scores = topk_scores finished_scores = topk_scores
......
...@@ -67,7 +67,7 @@ class TestGatherTreeOpAPI(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestGatherTreeOpAPI(unittest.TestCase):
dtype='int64', dtype='int64',
append_batch_size=False, append_batch_size=False,
) )
final_sequences = fluid.layers.gather_tree(ids, parents) final_sequences = paddle.nn.functional.gather_tree(ids, parents)
paddle.disable_static() paddle.disable_static()
def test_case2(self): def test_case2(self):
...@@ -100,14 +100,14 @@ class TestGatherTreeOpError(unittest.TestCase): ...@@ -100,14 +100,14 @@ class TestGatherTreeOpError(unittest.TestCase):
def test_Variable_ids(): def test_Variable_ids():
# the input type must be Variable # the input type must be Variable
np_ids = np.random.random((5, 2, 2), dtype='int64') np_ids = np.random.random((5, 2, 2), dtype='int64')
fluid.layers.gather_tree(np_ids, parents) paddle.nn.functional.gather_tree(np_ids, parents)
self.assertRaises(TypeError, test_Variable_ids) self.assertRaises(TypeError, test_Variable_ids)
def test_Variable_parents(): def test_Variable_parents():
# the input type must be Variable # the input type must be Variable
np_parents = np.random.random((5, 2, 2), dtype='int64') np_parents = np.random.random((5, 2, 2), dtype='int64')
fluid.layers.gather_tree(ids, np_parents) paddle.nn.functional.gather_tree(ids, np_parents)
self.assertRaises(TypeError, test_Variable_parents) self.assertRaises(TypeError, test_Variable_parents)
...@@ -119,7 +119,7 @@ class TestGatherTreeOpError(unittest.TestCase): ...@@ -119,7 +119,7 @@ class TestGatherTreeOpError(unittest.TestCase):
dtype='float32', dtype='float32',
append_batch_size=False, append_batch_size=False,
) )
fluid.layers.gather_tree(bad_ids, parents) paddle.nn.functional.gather_tree(bad_ids, parents)
self.assertRaises(TypeError, test_type_ids) self.assertRaises(TypeError, test_type_ids)
...@@ -131,7 +131,7 @@ class TestGatherTreeOpError(unittest.TestCase): ...@@ -131,7 +131,7 @@ class TestGatherTreeOpError(unittest.TestCase):
dtype='float32', dtype='float32',
append_batch_size=False, append_batch_size=False,
) )
fluid.layers.gather_tree(ids, bad_parents) paddle.nn.functional.gather_tree(ids, bad_parents)
self.assertRaises(TypeError, test_type_parents) self.assertRaises(TypeError, test_type_parents)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册