From f1fe2ad45d2b4cd013ce83194192b1fb7bc72957 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 8 Mar 2022 14:33:28 +0800 Subject: [PATCH] add support for concat and variadic tensor list (#40229) --- .../paddle/fluid/tests/unittests/op_test.py | 23 +++++++++++-------- .../fluid/tests/unittests/test_concat_op.py | 1 + 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 0c7f269a087..6455da92475 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -722,13 +722,17 @@ class OpTest(unittest.TestCase): def assumption_assert_and_transform(args, argvs): """ - currently only support "X" is [Tensor], don't support multi-tensor in "X" + transform by the following rules: + 1. [Tensor] -> Tensor + 2. [Tensor, Tensor, ...] -> list of Tensors + + only support "X" is list of Tensor, currently don't support other structure like dict. """ for inp in args: - assert isinstance(inp, list) and len( - inp - ) == 1, "currently only support `X` is [Tensor], don't support multi-tensor in `X`" - args = [inp[0] for inp in args] + assert isinstance( + inp, list + ), "currently only support `X` is [Tensor], don't support other structure." + args = [inp[0] if len(inp) == 1 else inp for inp in args] return args, argvs def cal_python_api(python_api, args, argvs, kernel_sig): @@ -1239,15 +1243,16 @@ class OpTest(unittest.TestCase): dygraph_outs = self._calc_dygraph_output( place, no_check_set=no_check_set) + if check_eager: + with _test_eager_guard(): + eager_dygraph_outs = self._calc_dygraph_output( + place, no_check_set=no_check_set) + # we only check end2end api when check_eager=True if hasattr(self, "python_api"): api_outs = self._calc_python_api_output(place) self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs, place) - if check_eager: - with _test_eager_guard(): - eager_dygraph_outs = self._calc_dygraph_output( - place, no_check_set=no_check_set) outs, fetch_list = self._calc_output(place, no_check_set=no_check_set) for out_name, out_dup in Operator.get_op_outputs(self.op_type): diff --git a/python/paddle/fluid/tests/unittests/test_concat_op.py b/python/paddle/fluid/tests/unittests/test_concat_op.py index 10b7e13dcc3..4feca1b9250 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_op.py @@ -25,6 +25,7 @@ import paddle class TestConcatOp(OpTest): def setUp(self): self.op_type = "concat" + self.python_api = paddle.concat self.dtype = self.get_dtype() self.init_test_data() self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} -- GitLab