未验证 提交 ff717d51 编写于 作者: W wangchaochaohu 提交者: GitHub

Add support for tuple of concat Op test=develop (#25800)

上级 e5514935
...@@ -207,6 +207,7 @@ REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad, ...@@ -207,6 +207,7 @@ REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>, concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ConcatKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
...@@ -215,6 +216,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -215,6 +216,7 @@ REGISTER_OP_CPU_KERNEL(
concat_grad, concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
......
...@@ -20,6 +20,7 @@ namespace plat = paddle::platform; ...@@ -20,6 +20,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CUDADeviceContext, double>, concat, ops::ConcatKernel<paddle::platform::CUDADeviceContext, double>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, float>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, float>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>); ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>);
...@@ -27,6 +28,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -27,6 +28,7 @@ REGISTER_OP_CUDA_KERNEL(
concat_grad, concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>); ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>);
...@@ -1958,7 +1958,7 @@ class Operator(object): ...@@ -1958,7 +1958,7 @@ class Operator(object):
in_proto.name) in_proto.name)
if found: if found:
in_args = inputs[in_proto.name] in_args = inputs[in_proto.name]
if not isinstance(in_args, list): if not isinstance(in_args, (list, tuple)):
in_args = [in_args] in_args = [in_args]
if not in_proto.duplicable and len(in_args) > 1: if not in_proto.duplicable and len(in_args) > 1:
raise ValueError( raise ValueError(
......
...@@ -266,8 +266,8 @@ def concat(input, axis=0, name=None): ...@@ -266,8 +266,8 @@ def concat(input, axis=0, name=None):
This OP concatenates the input along the axis. This OP concatenates the input along the axis.
Args: Args:
input(list): List of input Tensors with data type float16, float32, float64, int32, input(list|tuple|Tensor): ``input`` can be Tensor, Tensor list or Tensor tuple which is with data type
int64. All the Tensors in ``input`` must have the same data type. bool, float16, float32, float64, int32, int64. All the Tensors in ``input`` must have the same data type.
axis(int|Tensor, optional): Specify the axis to operate on the input Tensors. axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
It's a scalar with data type int or a Tensor with shape [1] and data type int32 or int64. It's a scalar with data type int or a Tensor with shape [1] and data type int32 or int64.
The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, it works the same way The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, it works the same way
...@@ -276,7 +276,8 @@ def concat(input, axis=0, name=None): ...@@ -276,7 +276,8 @@ def concat(input, axis=0, name=None):
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
Raises: Raises:
TypeError: The dtype of ``input`` must be one of float16, float32, float64, int32 and int64. TypeError: ``input`` must be one of list, tuple or Tensor.
TypeError: The data type of ``input`` must be one of bool, float16, float32, float64, int32 and int64.
TypeError: The ``axis`` must be int or Tensor. The dtype of ``axis`` must be int32 or int64 when it's a Tensor. TypeError: The ``axis`` must be int or Tensor. The dtype of ``axis`` must be int32 or int64 when it's a Tensor.
TypeError: All the Tensors in ``input`` must have the same data type. TypeError: All the Tensors in ``input`` must have the same data type.
...@@ -289,20 +290,20 @@ def concat(input, axis=0, name=None): ...@@ -289,20 +290,20 @@ def concat(input, axis=0, name=None):
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
in1 = np.array([[1,2,3], in1 = np.array([[1, 2, 3],
[4,5,6]]) [4, 5, 6]])
in2 = np.array([[11,12,13], in2 = np.array([[11, 12, 13],
[14,15,16]]) [14, 15, 16]])
in3 = np.array([[21,22], in3 = np.array([[21, 22],
[23,24]]) [23, 24]])
with fluid.dygraph.guard(): with fluid.dygraph.guard():
x1 = fluid.dygraph.to_variable(in1) x1 = fluid.dygraph.to_variable(in1)
x2 = fluid.dygraph.to_variable(in2) x2 = fluid.dygraph.to_variable(in2)
x3 = fluid.dygraph.to_variable(in3) x3 = fluid.dygraph.to_variable(in3)
# When the axis is negative, the real axis is (axis + Rank(x)). # When the axis is negative, the real axis is (axis + Rank(x)).
# As follows, axis is -1, Rank(x) is 2, the real axis is 1 # As follows, axis is -1, Rank(x) is 2, the real axis is 1
out1 = fluid.layers.concat(input=[x1,x2,x3], axis=-1) out1 = fluid.layers.concat(input=[x1, x2, x3], axis=-1)
out2 = fluid.layers.concat(input=[x1,x2], axis=0) out2 = fluid.layers.concat(input=[x1, x2], axis=0)
print(out1.numpy()) print(out1.numpy())
# [[ 1 2 3 11 12 13 21 22] # [[ 1 2 3 11 12 13 21 22]
# [ 4 5 6 14 15 16 23 24]] # [ 4 5 6 14 15 16 23 24]]
...@@ -319,18 +320,18 @@ def concat(input, axis=0, name=None): ...@@ -319,18 +320,18 @@ def concat(input, axis=0, name=None):
axis = axis[0] axis = axis[0]
return core.ops.concat(input, 'axis', axis) return core.ops.concat(input, 'axis', axis)
if not isinstance(input, list): check_type(input, 'input', (list, tuple, Variable), 'concat')
warnings.warn( if not isinstance(input, Variable):
"The type of input in concat should be list, but received %s." %
(type(input)))
input = [input]
for id, x in enumerate(input): for id, x in enumerate(input):
check_variable_and_dtype( check_variable_and_dtype(
x, 'input[' + str(id) + ']', x, 'input[' + str(id) + ']',
['float16', 'float32', 'float64', 'int32', 'int64'], 'concat') ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'concat')
if x.dtype != input[0].dtype: if x.dtype != input[0].dtype:
raise TypeError( raise TypeError(
"All the Tensors in the input must have the same data type.") "All the Tensors in the input must have the same data type.")
else:
input = [input]
check_type(axis, 'axis', (int, Variable), 'concat') check_type(axis, 'axis', (int, Variable), 'concat')
if isinstance(axis, Variable): if isinstance(axis, Variable):
...@@ -343,7 +344,7 @@ def concat(input, axis=0, name=None): ...@@ -343,7 +344,7 @@ def concat(input, axis=0, name=None):
if input[0].desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY: if input[0].desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
assert len(input) == 1, "If the elements of 'input' in concat are Variable(LoDTensorArray), " \ assert len(input) == 1, "If the elements of 'input' in concat are Variable(LoDTensorArray), " \
"number of the elements must be 1, but received %s." % len(x) "number of the elements must be 1, but received %s." % len(input)
out_index = helper.create_variable_for_type_inference(dtype="int32") out_index = helper.create_variable_for_type_inference(dtype="int32")
helper.append_op( helper.append_op(
type='tensor_array_to_tensor', type='tensor_array_to_tensor',
...@@ -1045,8 +1046,7 @@ def ones(shape, dtype, force_cpu=False): ...@@ -1045,8 +1046,7 @@ def ones(shape, dtype, force_cpu=False):
Returns: Returns:
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1. Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Raises: Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64.
and the data type of out Tensor must be the same as the dtype.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor. be int32 or int64 when it's a Tensor.
...@@ -1082,8 +1082,7 @@ def zeros(shape, dtype, force_cpu=False, name=None): ...@@ -1082,8 +1082,7 @@ def zeros(shape, dtype, force_cpu=False, name=None):
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0. Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0.
Raises: Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64.
and the data type of out Tensor must be the same as the dtype.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor. be int32 or int64 when it's a Tensor.
Examples: Examples:
......
...@@ -136,8 +136,7 @@ def ones(shape, dtype=None, name=None): ...@@ -136,8 +136,7 @@ def ones(shape, dtype=None, name=None):
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1. Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Raises: Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None.
and the data type of out Tensor must be the same as the dtype.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor. be int32 or int64 when it's a Tensor.
...@@ -229,8 +228,7 @@ def zeros(shape, dtype=None, name=None): ...@@ -229,8 +228,7 @@ def zeros(shape, dtype=None, name=None):
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0. Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0.
Raises: Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None.
and the data type of out Tensor must be the same as the dtype.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor. be int32 or int64 when it's a Tensor.
......
...@@ -59,8 +59,8 @@ def concat(x, axis=0, name=None): ...@@ -59,8 +59,8 @@ def concat(x, axis=0, name=None):
This OP concatenates the input along the axis. This OP concatenates the input along the axis.
Args: Args:
x(list): List of input Tensors with data type float16, float32, float64, int32, int64. x(list|tuple): ``x`` is a Tensor list or Tensor tuple which is with data type bool, float16,
All the Tensors in ``x`` must have same data type. float32, float64, int32, int64. All the Tensors in ``x`` must have same data type.
axis(int|Tensor, optional): Specify the axis to operate on the input Tensors. axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
It's a scalar with data type int or a Tensor with shape [1] and data type int32 It's a scalar with data type int or a Tensor with shape [1] and data type int32
or int64. The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, or int64. The effective range is [-R, R), where R is Rank(x). When ``axis < 0``,
...@@ -69,7 +69,8 @@ def concat(x, axis=0, name=None): ...@@ -69,7 +69,8 @@ def concat(x, axis=0, name=None):
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
Raises: Raises:
TypeError: The dtype of ``x`` must be one of float16, float32, float64, int32 and int64. TypeError: ``x`` must be list or tuple.
TypeError: The data type of ``x`` must be one of bool, float16, float32, float64, int32 and int64.
TypeError: The ``axis`` must be int or Tensor. The dtype of ``axis`` must be int32 or int64 when it's a Tensor. TypeError: The ``axis`` must be int or Tensor. The dtype of ``axis`` must be int32 or int64 when it's a Tensor.
TypeError: All the Tensors in ``x`` must have the same data type. TypeError: All the Tensors in ``x`` must have the same data type.
...@@ -83,21 +84,21 @@ def concat(x, axis=0, name=None): ...@@ -83,21 +84,21 @@ def concat(x, axis=0, name=None):
import numpy as np import numpy as np
paddle.enable_imperative() # Now we are in imperative mode paddle.enable_imperative() # Now we are in imperative mode
in1 = np.array([[1,2,3], in1 = np.array([[1, 2, 3],
[4,5,6]]) [4, 5, 6]])
in2 = np.array([[11,12,13], in2 = np.array([[11, 12, 13],
[14,15,16]]) [14, 15, 16]])
in3 = np.array([[21,22], in3 = np.array([[21, 22],
[23,24]]) [23, 24]])
x1 = paddle.imperative.to_variable(in1) x1 = paddle.imperative.to_variable(in1)
x2 = paddle.imperative.to_variable(in2) x2 = paddle.imperative.to_variable(in2)
x3 = paddle.imperative.to_variable(in3) x3 = paddle.imperative.to_variable(in3)
zero = paddle.full(shape=[1], dtype='int32', fill_value=0) zero = paddle.full(shape=[1], dtype='int32', fill_value=0)
# When the axis is negative, the real axis is (axis + Rank(x)) # When the axis is negative, the real axis is (axis + Rank(x))
# As follow, axis is -1, Rank(x) is 2, the real axis is 1 # As follow, axis is -1, Rank(x) is 2, the real axis is 1
out1 = paddle.concat(x=[x1,x2,x3], axis=-1) out1 = paddle.concat(x=[x1, x2, x3], axis=-1)
out2 = paddle.concat(x=[x1,x2], axis=0) out2 = paddle.concat(x=[x1, x2], axis=0)
out3 = paddle.concat(x=[x1,x2], axis=zero) out3 = paddle.concat(x=[x1, x2], axis=zero)
# out1 # out1
# [[ 1 2 3 11 12 13 21 22] # [[ 1 2 3 11 12 13 21 22]
# [ 4 5 6 14 15 16 23 24]] # [ 4 5 6 14 15 16 23 24]]
...@@ -107,6 +108,7 @@ def concat(x, axis=0, name=None): ...@@ -107,6 +108,7 @@ def concat(x, axis=0, name=None):
# [11 12 13] # [11 12 13]
# [14 15 16]] # [14 15 16]]
""" """
check_type(x, 'x', (list, tuple), 'concat')
return paddle.fluid.layers.concat(input=x, axis=axis, name=name) return paddle.fluid.layers.concat(input=x, axis=axis, name=name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册