From fef3654b4e76f5e2cc9a5f71c1c047cef82192e5 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Wed, 27 Jan 2021 22:22:11 +0800 Subject: [PATCH] upgrade gather_tree to core.ops (#30697) * upgrade gather_tree to core.ops * update gather_tree unittests --- python/paddle/fluid/layers/nn.py | 25 +++++++++++-------- .../tests/unittests/test_gather_tree_op.py | 12 +++++++++ 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index fcf5dd0d4b..85972687b5 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -15011,19 +15011,22 @@ def gather_tree(ids, parents): append_batch_size=False) final_sequences = fluid.layers.gather_tree(ids, parents) """ - helper = LayerHelper('gather_tree', **locals()) - check_variable_and_dtype(ids, 'ids', ['int32', 'int64'], 'gather_tree') - check_variable_and_dtype(parents, 'parents', ['int32', 'int64'], - 'gather_tree') - out = helper.create_variable_for_type_inference(dtype=ids.dtype) + if in_dygraph_mode(): + return core.ops.gather_tree(ids, parents) + else: + helper = LayerHelper('gather_tree', **locals()) + check_variable_and_dtype(ids, 'ids', ['int32', 'int64'], 'gather_tree') + check_variable_and_dtype(parents, 'parents', ['int32', 'int64'], + 'gather_tree') + out = helper.create_variable_for_type_inference(dtype=ids.dtype) - helper.append_op( - type="gather_tree", - inputs={"Ids": ids, - "Parents": parents}, - outputs={"Out": out}) + helper.append_op( + type="gather_tree", + inputs={"Ids": ids, + "Parents": parents}, + outputs={"Out": out}) - return out + return out @deprecated(since="2.0.0", update_to="paddle.uniform") diff --git a/python/paddle/fluid/tests/unittests/test_gather_tree_op.py b/python/paddle/fluid/tests/unittests/test_gather_tree_op.py index f23d2c68c6..74e2cd9f74 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_tree_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_tree_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid as fluid from paddle.fluid.framework import program_guard, Program @@ -52,6 +53,7 @@ class TestGatherTreeOp(OpTest): class TestGatherTreeOpAPI(unittest.TestCase): def test_case(self): + paddle.enable_static() ids = fluid.layers.data( name='ids', shape=[5, 2, 2], dtype='int64', append_batch_size=False) parents = fluid.layers.data( @@ -60,10 +62,19 @@ class TestGatherTreeOpAPI(unittest.TestCase): dtype='int64', append_batch_size=False) final_sequences = fluid.layers.gather_tree(ids, parents) + paddle.disable_static() + + def test_case2(self): + ids = paddle.to_tensor( + [[[2, 2], [6, 1]], [[3, 9], [6, 1]], [[0, 1], [9, 0]]]) + parents = paddle.to_tensor( + [[[0, 0], [1, 1]], [[1, 0], [1, 0]], [[0, 0], [0, 1]]]) + final_sequences = paddle.nn.functional.gather_tree(ids, parents) class TestGatherTreeOpError(unittest.TestCase): def test_errors(self): + paddle.enable_static() with program_guard(Program(), Program()): ids = fluid.layers.data( name='ids', @@ -111,6 +122,7 @@ class TestGatherTreeOpError(unittest.TestCase): fluid.layers.gather_tree(ids, bad_parents) self.assertRaises(TypeError, test_type_parents) + paddle.disable_static() if __name__ == "__main__": -- GitLab