未验证 提交 1e4ab728 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the concat Op for API 2.0 test=develop (#25307)

上级 43f3d0cc
......@@ -208,10 +208,14 @@ REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>);
......@@ -265,24 +265,26 @@ def concat(input, axis=0, name=None):
"""
:alias_main: paddle.concat
:alias: paddle.concat,paddle.tensor.concat,paddle.tensor.manipulation.concat
:old_api: paddle.fluid.layers.concat
**Concat**
This OP concatenates the input along the axis.
Args:
input(list): List of input Tensors with data type float32, float64, int32,
int64.
axis(int32|Variable, optional): A scalar with type ``int32`` or a ``Tensor`` with shape [1] and type ``int32``. Axis to compute indices along. The effective range
is [-R, R), where R is Rank(x). when axis<0, it works the same way
input(list): List of input Tensors with data type float16, float32, float64, int32,
int64. All the Tensors in ``input`` must have the same data type.
axis(int|Variable, optional): Specify the axis to operate on the input Tensors.
It's a scalar with type ``int`` or a ``Tensor`` with shape [1] and data type ``int32`` or ``int64``.
The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, it works the same way
as axis+R. Default is 0.
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Raises:
TypeError: The dtype of input must be one of float16, float32, float64, int32 and int64.
TypeError: The ``axis`` must be int or Variable. The dtype of ``axis`` must be int32 or int64 when it's a Tensor.
TypeError: All the Tensors in ``input`` must have the same data type.
Returns:
Variable: A Tensor with the same data type as input's.
Variable: A Tensor with the same data type as ``input``.
Examples:
.. code-block:: python
......@@ -300,6 +302,8 @@ def concat(input, axis=0, name=None):
x1 = fluid.dygraph.to_variable(in1)
x2 = fluid.dygraph.to_variable(in2)
x3 = fluid.dygraph.to_variable(in3)
# When the axis is negative, the real axis is (axis + Rank(x)).
# As follows, axis is -1, Rank(x) is 2, the real axis is 1
out1 = fluid.layers.concat(input=[x1,x2,x3], axis=-1)
out2 = fluid.layers.concat(input=[x1,x2], axis=0)
print(out1.numpy())
......@@ -315,8 +319,6 @@ def concat(input, axis=0, name=None):
if in_dygraph_mode():
if isinstance(axis, Variable):
axis = axis.numpy()
assert axis.shape == (
1, ), "axis of type Variable should have shape [1]"
axis = axis[0]
return core.ops.concat(input, 'axis', axis)
......@@ -329,8 +331,16 @@ def concat(input, axis=0, name=None):
check_variable_and_dtype(
x, 'input[' + str(id) + ']',
['float16', 'float32', 'float64', 'int32', 'int64'], 'concat')
if x.dtype != input[0].dtype:
raise TypeError(
"All the Tensors in the input must have the same data type.")
check_type(axis, 'axis', (int, Variable), 'concat')
if isinstance(axis, Variable):
check_dtype(
axis.dtype, 'axis', ['int32', 'int64'], 'concat',
"The data type of axis must be int32 or int64 when axis is a Tensor")
helper = LayerHelper('concat', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
......
......@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard, core
import paddle
class TestConcatOp(OpTest):
......@@ -175,8 +176,6 @@ create_test_AxisTensor(TestConcatOp6)
def create_test_fp16(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestConcatFp16(parent):
def get_dtype(self):
return np.float16
......@@ -206,12 +205,13 @@ class TestConcatOpError(unittest.TestCase):
x3 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.concat, [x2])
# The input dtype of concat_op must be float16(only support on GPU), float32, float64, int32, int64.
# The input dtype of concat_op must be float16, float32, float64, int32, int64.
x4 = fluid.layers.data(shape=[4], dtype='uint8', name='x4')
x5 = fluid.layers.data(shape=[4], dtype='uint8', name='x5')
self.assertRaises(TypeError, fluid.layers.concat, [x4, x5])
x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6')
x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7')
x8 = fluid.layers.data(shape=[4], dtype='float32', name='x8')
fluid.layers.concat([x6, x7])
# The type of axis in concat_op should be int or Variable.
......@@ -220,9 +220,14 @@ class TestConcatOpError(unittest.TestCase):
self.assertRaises(TypeError, test_axis_type)
def test_input_same_dtype():
fluid.layers.concat([x7, x8])
self.assertRaises(TypeError, test_input_same_dtype)
class TestConcatAPI(unittest.TestCase):
def test_api(self):
def test_fluid_api(self):
x_1 = fluid.data(shape=[None, 1, 4, 5], dtype='int32', name='x_1')
fluid.layers.concat([x_1, x_1], 0)
......@@ -247,6 +252,77 @@ class TestConcatAPI(unittest.TestCase):
assert np.array_equal(res_2, np.concatenate((input_2, input_3), axis=1))
assert np.array_equal(res_3, np.concatenate((input_2, input_3), axis=1))
def test_api(self):
x_1 = paddle.data(shape=[None, 1, 4, 5], dtype='int32', name='x_1')
paddle.concat([x_1, x_1], 0)
input_2 = np.random.random([2, 1, 4, 5]).astype("int32")
input_3 = np.random.random([2, 2, 4, 5]).astype("int32")
x_2 = fluid.data(shape=[2, 1, 4, 5], dtype='int32', name='x_2')
x_3 = fluid.data(shape=[2, 2, 4, 5], dtype='int32', name='x_3')
positive_1_int32 = paddle.fill_constant([1], "int32", 1)
positive_1_int64 = paddle.fill_constant([1], "int64", 1)
negative_int64 = paddle.fill_constant([1], "int64", -3)
out_1 = paddle.concat(x=[x_2, x_3], axis=1)
out_2 = paddle.concat(x=[x_2, x_3], axis=positive_1_int32)
out_3 = paddle.concat(x=[x_2, x_3], axis=positive_1_int64)
out_4 = paddle.concat(x=[x_2, x_3], axis=negative_int64)
exe = paddle.Executor(place=paddle.CPUPlace())
[res_1, res_2, res_3, res_4] = exe.run(
paddle.default_main_program(),
feed={"x_1": input_2,
"x_2": input_2,
"x_3": input_3},
fetch_list=[out_1, out_2, out_3, out_4])
assert np.array_equal(res_1, np.concatenate((input_2, input_3), axis=1))
assert np.array_equal(res_2, np.concatenate((input_2, input_3), axis=1))
assert np.array_equal(res_3, np.concatenate((input_2, input_3), axis=1))
assert np.array_equal(res_4, np.concatenate((input_2, input_3), axis=1))
def test_imperative(self):
in1 = np.array([[1, 2, 3], [4, 5, 6]])
in2 = np.array([[11, 12, 13], [14, 15, 16]])
in3 = np.array([[21, 22], [23, 24]])
with paddle.imperative.guard():
x1 = paddle.imperative.to_variable(in1)
x2 = paddle.imperative.to_variable(in2)
x3 = paddle.imperative.to_variable(in3)
out1 = fluid.layers.concat(input=[x1, x2, x3], axis=-1)
out2 = paddle.concat(x=[x1, x2], axis=0)
np_out1 = np.concatenate([in1, in2, in3], axis=-1)
np_out2 = np.concatenate([in1, in2], axis=0)
self.assertEqual((out1.numpy() == np_out1).all(), True)
self.assertEqual((out2.numpy() == np_out2).all(), True)
def test_errors(self):
with program_guard(Program(), Program()):
# The item in input must be Variable.
x2 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
x3 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, paddle.concat, [x2])
# The input dtype of concat_op must be float16, float32, float64, int32, int64.
x4 = paddle.data(shape=[4], dtype='uint8', name='x4')
x5 = paddle.data(shape=[4], dtype='uint8', name='x5')
self.assertRaises(TypeError, fluid.layers.concat, [x4, x5])
# The type of axis in concat_op should be int or Variable.
x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6')
x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7')
x8 = fluid.layers.data(shape=[4], dtype='float32', name='x8')
def test_axis_type():
paddle.concat([x6, x7], 3.2)
self.assertRaises(TypeError, test_axis_type)
def test_input_same_dtype():
paddle.concat([x7, x8])
self.assertRaises(TypeError, test_input_same_dtype)
class TestConcatAPIWithLoDTensorArray(unittest.TestCase):
"""
......
......@@ -242,7 +242,7 @@ def zeros(shape, dtype=None, name=None):
The OP creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 0.
Args:
shape(tuple|list): Shape of output tensor.
shape(tuple|list|Variable): Shape of output tensor. The data type of shape is int32 or int64.
dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of output tensor, it supports
bool, float16, float32, float64, int32 and int64. Default: if None, the date type is float32.
name(str, optional): The default value is None. Normally there is no need for user to set this
......
......@@ -23,7 +23,6 @@ from ..fluid.layers import utils
import numpy as np
# TODO: define functions to manipulate a tensor
from ..fluid.layers import cast #DEFINE_ALIAS
from ..fluid.layers import concat #DEFINE_ALIAS
from ..fluid.layers import expand #DEFINE_ALIAS
from ..fluid.layers import expand_as #DEFINE_ALIAS
from ..fluid.layers import flatten #DEFINE_ALIAS
......@@ -41,6 +40,7 @@ from ..fluid.layers import scatter_nd #DEFINE_ALIAS
from ..fluid.layers import shard_index #DEFINE_ALIAS
from ..fluid.layers import unique_with_counts #DEFINE_ALIAS
from ..fluid import layers
import paddle
__all__ = [
'cast', 'concat', 'expand', 'expand_as', 'flatten', 'gather', 'gather_nd',
......@@ -51,6 +51,65 @@ __all__ = [
]
def concat(x, axis=0, name=None):
"""
:alias_main: paddle.concat
:alias: paddle.concat,paddle.tensor.concat,paddle.tensor.manipulation.concat
This OP concatenates the input along the axis.
Args:
x(list): List of input Tensors with data type float16, float32, float64, int32, int64.
All the Tensors in ``x`` must have same data type.
axis(int|Variable, optional): Specify the axis to operate on the input Tensors.
It's a scalar with type ``int`` or a ``Tensor`` with shape [1] and data type ``int32``
or ``int64``. The effective range is [-R, R), where R is Rank(x). When ``axis < 0``,
it works the same way as axis+R. Default is 0.
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Raises:
TypeError: The dtype of ``x`` must be one of float16, float32, float64, int32 and int64.
TypeError: The ``axis`` must be int or Variable. The dtype of ``axis`` must be int32 or int64 when it's a Tensor.
TypeError: All the Tensors in ``x`` must have the same data type.
Returns:
Variable: A Tensor with the same data type as ``x``.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.enable_imperative() # Now we are in imperative mode
in1 = np.array([[1,2,3],
[4,5,6]])
in2 = np.array([[11,12,13],
[14,15,16]])
in3 = np.array([[21,22],
[23,24]])
x1 = paddle.imperative.to_variable(in1)
x2 = paddle.imperative.to_variable(in2)
x3 = paddle.imperative.to_variable(in3)
zero = paddle.full(shape=[1], dtype='int32', fill_value=0)
# When the axis is negative, the real axis is (axis + Rank(x))
# As follow, axis is -1, Rank(x) is 2, the real axis is 1
out1 = paddle.concat(x=[x1,x2,x3], axis=-1)
out2 = paddle.concat(x=[x1,x2], axis=0)
out3 = paddle.concat(x=[x1,x2], axis=zero)
# out1
# [[ 1 2 3 11 12 13 21 22]
# [ 4 5 6 14 15 16 23 24]]
# out2 out3
# [[ 1 2 3]
# [ 4 5 6]
# [11 12 13]
# [14 15 16]]
"""
return paddle.fluid.layers.concat(input=x, axis=axis, name=name)
def flip(x, axis, name=None):
"""
:alias_main: paddle.flip
......
......@@ -162,6 +162,7 @@ def index_select(x, index, axis=0, name=None):
Examples:
.. code-block:: python
import paddle
import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册