未验证 提交 a44b65de 编写于 作者: L liu zhengxi 提交者: GitHub

[cherry-pick] Update gather_tree (#30784)

* upgrade gather_tree to core.ops (#30697)

* upgrade gather_tree to core.ops

* update gather_tree unittests

* update gather_tree doc (#30693)

* update gather_tree doc, test=document_fix

* update sample code, test=document_fix

* remove tensor type, test=document_fix
上级 b4be9717
...@@ -14968,46 +14968,47 @@ def gather_tree(ids, parents): ...@@ -14968,46 +14968,47 @@ def gather_tree(ids, parents):
[9 0]]] [9 0]]]
Args: Args:
ids(Variable): A Tensor with shape :attr:`[length, batch_size, beam_size]` ids(Tensor): A Tensor with shape :attr:`[length, batch_size, beam_size]`
and data type :attr:`int32` or :attr:`int64`. It contains the selected and data type :attr:`int32` or :attr:`int64`. It contains the selected
ids of all time steps. ids of all time steps.
parents(Variable): A Tensor with the same shape and data type as :attr:`ids`, parents(Tensor): A Tensor with the same shape and data type as :attr:`ids`,
It contains the parents corresponding to selected ids when searching It contains the parents corresponding to selected ids when searching
among beams. among beams.
Returns: Returns:
Variable: A Tensor with the same shape and data type as :attr:`ids`. \ A Tensor with the same shape and data type as :attr:`ids`. \
It contains the full sequences. The sequences are collected from \ It contains the full sequences. The sequences are collected from \
:attr:`ids` by backtracing according to :attr:`parents`. :attr:`ids` by backtracing according to :attr:`parents`.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
ids = fluid.layers.data(name='ids', ids = paddle.to_tensor([[[2, 2], [6, 1]], [[3, 9], [6, 1]], [[0, 1], [9, 0]]])
shape=[5, 2, 2],
dtype='int64',
append_batch_size=False)
parents = fluid.layers.data(name='parents',
shape=[5, 2, 2],
dtype='int64',
append_batch_size=False)
final_sequences = fluid.layers.gather_tree(ids, parents)
"""
helper = LayerHelper('gather_tree', **locals())
check_variable_and_dtype(ids, 'ids', ['int32', 'int64'], 'gather_tree')
check_variable_and_dtype(parents, 'parents', ['int32', 'int64'],
'gather_tree')
out = helper.create_variable_for_type_inference(dtype=ids.dtype)
helper.append_op( parents = paddle.to_tensor([[[0, 0], [1, 1]], [[1, 0], [1, 0]], [[0, 0], [0, 1]]])
type="gather_tree",
inputs={"Ids": ids,
"Parents": parents},
outputs={"Out": out})
return out final_sequences = paddle.nn.functional.gather_tree(ids, parents)
# [[[2, 2], [1, 6]], [[3, 3], [6, 1]], [[0, 1], [9, 0]]]
"""
if in_dygraph_mode():
return core.ops.gather_tree(ids, parents)
else:
helper = LayerHelper('gather_tree', **locals())
check_variable_and_dtype(ids, 'ids', ['int32', 'int64'], 'gather_tree')
check_variable_and_dtype(parents, 'parents', ['int32', 'int64'],
'gather_tree')
out = helper.create_variable_for_type_inference(dtype=ids.dtype)
helper.append_op(
type="gather_tree",
inputs={"Ids": ids,
"Parents": parents},
outputs={"Out": out})
return out
@deprecated(since="2.0.0", update_to="paddle.uniform") @deprecated(since="2.0.0", update_to="paddle.uniform")
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import program_guard, Program from paddle.fluid.framework import program_guard, Program
...@@ -52,6 +53,7 @@ class TestGatherTreeOp(OpTest): ...@@ -52,6 +53,7 @@ class TestGatherTreeOp(OpTest):
class TestGatherTreeOpAPI(unittest.TestCase): class TestGatherTreeOpAPI(unittest.TestCase):
def test_case(self): def test_case(self):
paddle.enable_static()
ids = fluid.layers.data( ids = fluid.layers.data(
name='ids', shape=[5, 2, 2], dtype='int64', append_batch_size=False) name='ids', shape=[5, 2, 2], dtype='int64', append_batch_size=False)
parents = fluid.layers.data( parents = fluid.layers.data(
...@@ -60,10 +62,19 @@ class TestGatherTreeOpAPI(unittest.TestCase): ...@@ -60,10 +62,19 @@ class TestGatherTreeOpAPI(unittest.TestCase):
dtype='int64', dtype='int64',
append_batch_size=False) append_batch_size=False)
final_sequences = fluid.layers.gather_tree(ids, parents) final_sequences = fluid.layers.gather_tree(ids, parents)
paddle.disable_static()
def test_case2(self):
ids = paddle.to_tensor(
[[[2, 2], [6, 1]], [[3, 9], [6, 1]], [[0, 1], [9, 0]]])
parents = paddle.to_tensor(
[[[0, 0], [1, 1]], [[1, 0], [1, 0]], [[0, 0], [0, 1]]])
final_sequences = paddle.nn.functional.gather_tree(ids, parents)
class TestGatherTreeOpError(unittest.TestCase): class TestGatherTreeOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
ids = fluid.layers.data( ids = fluid.layers.data(
name='ids', name='ids',
...@@ -111,6 +122,7 @@ class TestGatherTreeOpError(unittest.TestCase): ...@@ -111,6 +122,7 @@ class TestGatherTreeOpError(unittest.TestCase):
fluid.layers.gather_tree(ids, bad_parents) fluid.layers.gather_tree(ids, bad_parents)
self.assertRaises(TypeError, test_type_parents) self.assertRaises(TypeError, test_type_parents)
paddle.disable_static()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册