diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 2691acbbb45e402311454b3f0cbafb58eba2ad73..a27fea7f45d675cd707a65f8a47bfe0b39b61435 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14968,46 +14968,47 @@ def gather_tree(ids, parents): [9 0]]] Args: - ids(Variable): A Tensor with shape :attr:`[length, batch_size, beam_size]` + ids(Tensor): A Tensor with shape :attr:`[length, batch_size, beam_size]` and data type :attr:`int32` or :attr:`int64`. It contains the selected ids of all time steps. - parents(Variable): A Tensor with the same shape and data type as :attr:`ids`, + parents(Tensor): A Tensor with the same shape and data type as :attr:`ids`, It contains the parents corresponding to selected ids when searching among beams. Returns: - Variable: A Tensor with the same shape and data type as :attr:`ids`. \ + A Tensor with the same shape and data type as :attr:`ids`. \ It contains the full sequences. The sequences are collected from \ :attr:`ids` by backtracing according to :attr:`parents`. Examples: .. code-block:: python - import paddle.fluid as fluid + import paddle - ids = fluid.layers.data(name='ids', - shape=[5, 2, 2], - dtype='int64', - append_batch_size=False) - parents = fluid.layers.data(name='parents', - shape=[5, 2, 2], - dtype='int64', - append_batch_size=False) - final_sequences = fluid.layers.gather_tree(ids, parents) - """ - helper = LayerHelper('gather_tree', **locals()) - check_variable_and_dtype(ids, 'ids', ['int32', 'int64'], 'gather_tree') - check_variable_and_dtype(parents, 'parents', ['int32', 'int64'], - 'gather_tree') - out = helper.create_variable_for_type_inference(dtype=ids.dtype) + ids = paddle.to_tensor([[[2, 2], [6, 1]], [[3, 9], [6, 1]], [[0, 1], [9, 0]]]) - helper.append_op( - type="gather_tree", - inputs={"Ids": ids, - "Parents": parents}, - outputs={"Out": out}) + parents = paddle.to_tensor([[[0, 0], [1, 1]], [[1, 0], [1, 0]], [[0, 0], [0, 1]]]) - return out + final_sequences = paddle.nn.functional.gather_tree(ids, parents) + # [[[2, 2], [1, 6]], [[3, 3], [6, 1]], [[0, 1], [9, 0]]] + + """ + if in_dygraph_mode(): + return core.ops.gather_tree(ids, parents) + else: + helper = LayerHelper('gather_tree', **locals()) + check_variable_and_dtype(ids, 'ids', ['int32', 'int64'], 'gather_tree') + check_variable_and_dtype(parents, 'parents', ['int32', 'int64'], + 'gather_tree') + out = helper.create_variable_for_type_inference(dtype=ids.dtype) + + helper.append_op( + type="gather_tree", + inputs={"Ids": ids, + "Parents": parents}, + outputs={"Out": out}) + + return out @deprecated(since="2.0.0", update_to="paddle.uniform") diff --git a/python/paddle/fluid/tests/unittests/test_gather_tree_op.py b/python/paddle/fluid/tests/unittests/test_gather_tree_op.py index f23d2c68c66b9daa16dd6bdd6db52cf6585724b3..74e2cd9f741441ecec07bfca65b95645b71f5b54 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_tree_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_tree_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid as fluid from paddle.fluid.framework import program_guard, Program @@ -52,6 +53,7 @@ class TestGatherTreeOp(OpTest): class TestGatherTreeOpAPI(unittest.TestCase): def test_case(self): + paddle.enable_static() ids = fluid.layers.data( name='ids', shape=[5, 2, 2], dtype='int64', append_batch_size=False) parents = fluid.layers.data( @@ -60,10 +62,19 @@ class TestGatherTreeOpAPI(unittest.TestCase): dtype='int64', append_batch_size=False) final_sequences = fluid.layers.gather_tree(ids, parents) + paddle.disable_static() + + def test_case2(self): + ids = paddle.to_tensor( + [[[2, 2], [6, 1]], [[3, 9], [6, 1]], [[0, 1], [9, 0]]]) + parents = paddle.to_tensor( + [[[0, 0], [1, 1]], [[1, 0], [1, 0]], [[0, 0], [0, 1]]]) + final_sequences = paddle.nn.functional.gather_tree(ids, parents) class TestGatherTreeOpError(unittest.TestCase): def test_errors(self): + paddle.enable_static() with program_guard(Program(), Program()): ids = fluid.layers.data( name='ids', @@ -111,6 +122,7 @@ class TestGatherTreeOpError(unittest.TestCase): fluid.layers.gather_tree(ids, bad_parents) self.assertRaises(TypeError, test_type_parents) + paddle.disable_static() if __name__ == "__main__":