diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 0c7f269a087b8fd0b4ed4749d7b9a433be134e8f..6455da924757b182cc0d47f0307e2bf184af2d7b 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -722,13 +722,17 @@ class OpTest(unittest.TestCase): def assumption_assert_and_transform(args, argvs): """ - currently only support "X" is [Tensor], don't support multi-tensor in "X" + transform by the following rules: + 1. [Tensor] -> Tensor + 2. [Tensor, Tensor, ...] -> list of Tensors + + only support "X" is list of Tensor, currently don't support other structure like dict. """ for inp in args: - assert isinstance(inp, list) and len( - inp - ) == 1, "currently only support `X` is [Tensor], don't support multi-tensor in `X`" - args = [inp[0] for inp in args] + assert isinstance( + inp, list + ), "currently only support `X` is [Tensor], don't support other structure." + args = [inp[0] if len(inp) == 1 else inp for inp in args] return args, argvs def cal_python_api(python_api, args, argvs, kernel_sig): @@ -1239,15 +1243,16 @@ class OpTest(unittest.TestCase): dygraph_outs = self._calc_dygraph_output( place, no_check_set=no_check_set) + if check_eager: + with _test_eager_guard(): + eager_dygraph_outs = self._calc_dygraph_output( + place, no_check_set=no_check_set) + # we only check end2end api when check_eager=True if hasattr(self, "python_api"): api_outs = self._calc_python_api_output(place) self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs, place) - if check_eager: - with _test_eager_guard(): - eager_dygraph_outs = self._calc_dygraph_output( - place, no_check_set=no_check_set) outs, fetch_list = self._calc_output(place, no_check_set=no_check_set) for out_name, out_dup in Operator.get_op_outputs(self.op_type): diff --git a/python/paddle/fluid/tests/unittests/test_concat_op.py b/python/paddle/fluid/tests/unittests/test_concat_op.py index 10b7e13dcc334dbc6b2f7b4c614cf888168c34ab..4feca1b92505b60713df245578896a9880b7cf06 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_op.py @@ -25,6 +25,7 @@ import paddle class TestConcatOp(OpTest): def setUp(self): self.op_type = "concat" + self.python_api = paddle.concat self.dtype = self.get_dtype() self.init_test_data() self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}