未验证 提交 f1fe2ad4 编写于 作者: X xiongkun 提交者: GitHub

add support for concat and variadic tensor list (#40229)

上级 47d1d5af
......@@ -722,13 +722,17 @@ class OpTest(unittest.TestCase):
def assumption_assert_and_transform(args, argvs):
"""
currently only support "X" is [Tensor], don't support multi-tensor in "X"
transform by the following rules:
1. [Tensor] -> Tensor
2. [Tensor, Tensor, ...] -> list of Tensors
only support "X" is list of Tensor, currently don't support other structure like dict.
"""
for inp in args:
assert isinstance(inp, list) and len(
inp
) == 1, "currently only support `X` is [Tensor], don't support multi-tensor in `X`"
args = [inp[0] for inp in args]
assert isinstance(
inp, list
), "currently only support `X` is [Tensor], don't support other structure."
args = [inp[0] if len(inp) == 1 else inp for inp in args]
return args, argvs
def cal_python_api(python_api, args, argvs, kernel_sig):
......@@ -1239,15 +1243,16 @@ class OpTest(unittest.TestCase):
dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
if check_eager:
with _test_eager_guard():
eager_dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
# we only check end2end api when check_eager=True
if hasattr(self, "python_api"):
api_outs = self._calc_python_api_output(place)
self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs,
place)
if check_eager:
with _test_eager_guard():
eager_dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
......
......@@ -25,6 +25,7 @@ import paddle
class TestConcatOp(OpTest):
def setUp(self):
self.op_type = "concat"
self.python_api = paddle.concat
self.dtype = self.get_dtype()
self.init_test_data()
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册