未验证 提交 5f84d0b3 编写于 作者: L liym27 提交者: GitHub

Fix bug: delete wrong check_type of paddle.concat and support LoDTensorArray (#29306)

上级 f7cdcefa
......@@ -327,6 +327,10 @@ def concat(input, axis=0, name=None):
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if input[0].desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
# NOTE(liym27): Don't remove this if branch!
# This feature is supported for Dynamic-to-Static, because after transformed, the type of inputs[0]
# is LOD_TENSOR_ARRAY in some scenarios. And this feature can be used in static mode.
assert len(input) == 1, "If the elements of 'input' in concat are Variable(LoDTensorArray), " \
"number of the elements must be 1, but received %s." % len(input)
out_index = helper.create_variable_for_type_inference(dtype="int32")
......
......@@ -228,6 +228,7 @@ class TestConcatOpError(unittest.TestCase):
class TestConcatAPI(unittest.TestCase):
def test_fluid_api(self):
paddle.enable_static()
x_1 = fluid.data(shape=[None, 1, 4, 5], dtype='int32', name='x_1')
fluid.layers.concat([x_1, x_1], 0)
......@@ -253,6 +254,7 @@ class TestConcatAPI(unittest.TestCase):
assert np.array_equal(res_3, np.concatenate((input_2, input_3), axis=1))
def test_api(self):
paddle.enable_static()
x_1 = paddle.fluid.data(
shape=[None, 1, 4, 5], dtype='int32', name='x_1')
paddle.concat([x_1, x_1], 0)
......@@ -338,21 +340,44 @@ class TestConcatAPIWithLoDTensorArray(unittest.TestCase):
self.x = np.random.random(self.input_shape).astype("float32")
self.place = fluid.CUDAPlace(0) \
if fluid.is_compiled_with_cuda() else fluid.CPUPlace()
self.set_program()
def set_program(self):
self.program = fluid.Program()
with fluid.program_guard(self.program):
input = fluid.layers.assign(self.x)
tensor_array = fluid.layers.create_array(dtype='float32')
zero = fluid.layers.fill_constant(shape=[1], value=0, dtype="int64")
def set_program(self, use_fluid_api):
paddle.enable_static()
if use_fluid_api:
self.program = fluid.Program()
with fluid.program_guard(self.program):
input = fluid.layers.assign(self.x)
tensor_array = fluid.layers.create_array(dtype='float32')
zero = fluid.layers.fill_constant(
shape=[1], value=0, dtype="int64")
for i in range(self.iter_num):
fluid.layers.array_write(input, zero + i, tensor_array)
self.out_var = fluid.layers.concat(tensor_array, axis=self.axis)
else:
self.program = paddle.static.Program()
with paddle.static.program_guard(self.program):
input = paddle.assign(self.x)
tensor_array = fluid.layers.create_array(
dtype='float32'
) # Api create_array is not supported in paddle 2.0 yet.
zero = paddle.zeros(shape=[1], dtype="int64")
for i in range(self.iter_num):
fluid.layers.array_write(input, zero + i, tensor_array)
for i in range(self.iter_num):
# Api array_write is not supported in paddle 2.0 yet.
fluid.layers.array_write(input, zero + i, tensor_array)
self.out_var = paddle.concat(tensor_array, axis=self.axis)
def test_fluid_api(self):
self._run_static_mode(use_fluid_api=True)
self.out_var = fluid.layers.concat(tensor_array, axis=self.axis)
def test_paddle_api(self):
self._run_static_mode(use_fluid_api=False)
def test_case(self):
def _run_static_mode(self, use_fluid_api):
self.set_program(use_fluid_api)
self.assertTrue(self.out_var.shape[self.axis] == -1)
exe = fluid.Executor(self.place)
res = exe.run(self.program, fetch_list=self.out_var)
......
......@@ -71,7 +71,7 @@ def concat(x, axis=0, name=None):
This OP concatenates the input along the axis.
Args:
x(list|tuple): ``x`` is a Tensor list or Tensor tuple which is with data type bool, float16,
x(list|tuple): ``x`` is a Tensor list or Tensor tuple which is with data type bool, float16,
float32, float64, int32, int64. All the Tensors in ``x`` must have same data type.
axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
It's a scalar with data type int or a Tensor with shape [1] and data type int32
......@@ -110,7 +110,6 @@ def concat(x, axis=0, name=None):
# [11 12 13]
# [14 15 16]]
"""
check_type(x, 'x', (list, tuple), 'concat')
return paddle.fluid.layers.concat(input=x, axis=axis, name=name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册