未验证 提交 f1fe2ad4 编写于 作者: X xiongkun 提交者: GitHub

add support for concat and variadic tensor list (#40229)

上级 47d1d5af
...@@ -722,13 +722,17 @@ class OpTest(unittest.TestCase): ...@@ -722,13 +722,17 @@ class OpTest(unittest.TestCase):
def assumption_assert_and_transform(args, argvs): def assumption_assert_and_transform(args, argvs):
""" """
currently only support "X" is [Tensor], don't support multi-tensor in "X" transform by the following rules:
1. [Tensor] -> Tensor
2. [Tensor, Tensor, ...] -> list of Tensors
only support "X" is list of Tensor, currently don't support other structure like dict.
""" """
for inp in args: for inp in args:
assert isinstance(inp, list) and len( assert isinstance(
inp inp, list
) == 1, "currently only support `X` is [Tensor], don't support multi-tensor in `X`" ), "currently only support `X` is [Tensor], don't support other structure."
args = [inp[0] for inp in args] args = [inp[0] if len(inp) == 1 else inp for inp in args]
return args, argvs return args, argvs
def cal_python_api(python_api, args, argvs, kernel_sig): def cal_python_api(python_api, args, argvs, kernel_sig):
...@@ -1239,15 +1243,16 @@ class OpTest(unittest.TestCase): ...@@ -1239,15 +1243,16 @@ class OpTest(unittest.TestCase):
dygraph_outs = self._calc_dygraph_output( dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set) place, no_check_set=no_check_set)
if check_eager:
with _test_eager_guard():
eager_dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
# we only check end2end api when check_eager=True
if hasattr(self, "python_api"): if hasattr(self, "python_api"):
api_outs = self._calc_python_api_output(place) api_outs = self._calc_python_api_output(place)
self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs, self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs,
place) place)
if check_eager:
with _test_eager_guard():
eager_dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set) outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
for out_name, out_dup in Operator.get_op_outputs(self.op_type): for out_name, out_dup in Operator.get_op_outputs(self.op_type):
......
...@@ -25,6 +25,7 @@ import paddle ...@@ -25,6 +25,7 @@ import paddle
class TestConcatOp(OpTest): class TestConcatOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "concat" self.op_type = "concat"
self.python_api = paddle.concat
self.dtype = self.get_dtype() self.dtype = self.get_dtype()
self.init_test_data() self.init_test_data()
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册