Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ff717d51
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ff717d51
编写于
8月 04, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
8月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add support for tuple of concat Op test=develop (#25800)
上级
e5514935
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
47 addition
and
44 deletion
+47
-44
paddle/fluid/operators/concat_op.cc
paddle/fluid/operators/concat_op.cc
+2
-0
paddle/fluid/operators/concat_op.cu.cc
paddle/fluid/operators/concat_op.cu.cc
+2
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+1
-1
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+26
-27
python/paddle/tensor/creation.py
python/paddle/tensor/creation.py
+2
-4
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+14
-12
未找到文件。
paddle/fluid/operators/concat_op.cc
浏览文件 @
ff717d51
...
@@ -207,6 +207,7 @@ REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
...
@@ -207,6 +207,7 @@ REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
concat
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
concat
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUDeviceContext
,
bool
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
float16
>
,
paddle
::
platform
::
float16
>
,
...
@@ -215,6 +216,7 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -215,6 +216,7 @@ REGISTER_OP_CPU_KERNEL(
concat_grad
,
concat_grad
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
bool
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
float16
>
,
paddle
::
platform
::
float16
>
,
...
...
paddle/fluid/operators/concat_op.cu.cc
浏览文件 @
ff717d51
...
@@ -20,6 +20,7 @@ namespace plat = paddle::platform;
...
@@ -20,6 +20,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
concat
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
concat
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
);
ops
::
ConcatKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
);
...
@@ -27,6 +28,7 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -27,6 +28,7 @@ REGISTER_OP_CUDA_KERNEL(
concat_grad
,
concat_grad
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
);
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
);
python/paddle/fluid/framework.py
浏览文件 @
ff717d51
...
@@ -1958,7 +1958,7 @@ class Operator(object):
...
@@ -1958,7 +1958,7 @@ class Operator(object):
in_proto
.
name
)
in_proto
.
name
)
if
found
:
if
found
:
in_args
=
inputs
[
in_proto
.
name
]
in_args
=
inputs
[
in_proto
.
name
]
if
not
isinstance
(
in_args
,
list
):
if
not
isinstance
(
in_args
,
(
list
,
tuple
)
):
in_args
=
[
in_args
]
in_args
=
[
in_args
]
if
not
in_proto
.
duplicable
and
len
(
in_args
)
>
1
:
if
not
in_proto
.
duplicable
and
len
(
in_args
)
>
1
:
raise
ValueError
(
raise
ValueError
(
...
...
python/paddle/fluid/layers/tensor.py
浏览文件 @
ff717d51
...
@@ -266,8 +266,8 @@ def concat(input, axis=0, name=None):
...
@@ -266,8 +266,8 @@ def concat(input, axis=0, name=None):
This OP concatenates the input along the axis.
This OP concatenates the input along the axis.
Args:
Args:
input(list
): List of input Tensors with data type float16, float32, float64, int32,
input(list
|tuple|Tensor): ``input`` can be Tensor, Tensor list or Tensor tuple which is with data type
int64. All the Tensors in ``input`` must have the same data type.
bool, float16, float32, float64, int32, int64. All the Tensors in ``input`` must have the same data type.
axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
It's a scalar with data type int or a Tensor with shape [1] and data type int32 or int64.
It's a scalar with data type int or a Tensor with shape [1] and data type int32 or int64.
The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, it works the same way
The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, it works the same way
...
@@ -276,7 +276,8 @@ def concat(input, axis=0, name=None):
...
@@ -276,7 +276,8 @@ def concat(input, axis=0, name=None):
need for user to set this property. For more information, please
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
refer to :ref:`api_guide_Name`.
Raises:
Raises:
TypeError: The dtype of ``input`` must be one of float16, float32, float64, int32 and int64.
TypeError: ``input`` must be one of list, tuple or Tensor.
TypeError: The data type of ``input`` must be one of bool, float16, float32, float64, int32 and int64.
TypeError: The ``axis`` must be int or Tensor. The dtype of ``axis`` must be int32 or int64 when it's a Tensor.
TypeError: The ``axis`` must be int or Tensor. The dtype of ``axis`` must be int32 or int64 when it's a Tensor.
TypeError: All the Tensors in ``input`` must have the same data type.
TypeError: All the Tensors in ``input`` must have the same data type.
...
@@ -289,20 +290,20 @@ def concat(input, axis=0, name=None):
...
@@ -289,20 +290,20 @@ def concat(input, axis=0, name=None):
import paddle.fluid as fluid
import paddle.fluid as fluid
import numpy as np
import numpy as np
in1 = np.array([[1,
2,
3],
in1 = np.array([[1,
2,
3],
[4,
5,
6]])
[4,
5,
6]])
in2 = np.array([[11,
12,
13],
in2 = np.array([[11,
12,
13],
[14,
15,
16]])
[14,
15,
16]])
in3 = np.array([[21,22],
in3 = np.array([[21,
22],
[23,24]])
[23,
24]])
with fluid.dygraph.guard():
with fluid.dygraph.guard():
x1 = fluid.dygraph.to_variable(in1)
x1 = fluid.dygraph.to_variable(in1)
x2 = fluid.dygraph.to_variable(in2)
x2 = fluid.dygraph.to_variable(in2)
x3 = fluid.dygraph.to_variable(in3)
x3 = fluid.dygraph.to_variable(in3)
# When the axis is negative, the real axis is (axis + Rank(x)).
# When the axis is negative, the real axis is (axis + Rank(x)).
# As follows, axis is -1, Rank(x) is 2, the real axis is 1
# As follows, axis is -1, Rank(x) is 2, the real axis is 1
out1 = fluid.layers.concat(input=[x1,
x2,
x3], axis=-1)
out1 = fluid.layers.concat(input=[x1,
x2,
x3], axis=-1)
out2 = fluid.layers.concat(input=[x1,x2], axis=0)
out2 = fluid.layers.concat(input=[x1,
x2], axis=0)
print(out1.numpy())
print(out1.numpy())
# [[ 1 2 3 11 12 13 21 22]
# [[ 1 2 3 11 12 13 21 22]
# [ 4 5 6 14 15 16 23 24]]
# [ 4 5 6 14 15 16 23 24]]
...
@@ -319,18 +320,18 @@ def concat(input, axis=0, name=None):
...
@@ -319,18 +320,18 @@ def concat(input, axis=0, name=None):
axis
=
axis
[
0
]
axis
=
axis
[
0
]
return
core
.
ops
.
concat
(
input
,
'axis'
,
axis
)
return
core
.
ops
.
concat
(
input
,
'axis'
,
axis
)
if
not
isinstance
(
input
,
list
):
check_type
(
input
,
'input'
,
(
list
,
tuple
,
Variable
),
'concat'
)
warnings
.
warn
(
if
not
isinstance
(
input
,
Variable
):
"The type of input in concat should be list, but received %s."
%
for
id
,
x
in
enumerate
(
input
):
(
type
(
input
)))
check_variable_and_dtype
(
x
,
'input['
+
str
(
id
)
+
']'
,
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'concat'
)
if
x
.
dtype
!=
input
[
0
].
dtype
:
raise
TypeError
(
"All the Tensors in the input must have the same data type."
)
else
:
input
=
[
input
]
input
=
[
input
]
for
id
,
x
in
enumerate
(
input
):
check_variable_and_dtype
(
x
,
'input['
+
str
(
id
)
+
']'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'concat'
)
if
x
.
dtype
!=
input
[
0
].
dtype
:
raise
TypeError
(
"All the Tensors in the input must have the same data type."
)
check_type
(
axis
,
'axis'
,
(
int
,
Variable
),
'concat'
)
check_type
(
axis
,
'axis'
,
(
int
,
Variable
),
'concat'
)
if
isinstance
(
axis
,
Variable
):
if
isinstance
(
axis
,
Variable
):
...
@@ -343,7 +344,7 @@ def concat(input, axis=0, name=None):
...
@@ -343,7 +344,7 @@ def concat(input, axis=0, name=None):
if
input
[
0
].
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
:
if
input
[
0
].
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
:
assert
len
(
input
)
==
1
,
"If the elements of 'input' in concat are Variable(LoDTensorArray), "
\
assert
len
(
input
)
==
1
,
"If the elements of 'input' in concat are Variable(LoDTensorArray), "
\
"number of the elements must be 1, but received %s."
%
len
(
x
)
"number of the elements must be 1, but received %s."
%
len
(
input
)
out_index
=
helper
.
create_variable_for_type_inference
(
dtype
=
"int32"
)
out_index
=
helper
.
create_variable_for_type_inference
(
dtype
=
"int32"
)
helper
.
append_op
(
helper
.
append_op
(
type
=
'tensor_array_to_tensor'
,
type
=
'tensor_array_to_tensor'
,
...
@@ -1045,8 +1046,7 @@ def ones(shape, dtype, force_cpu=False):
...
@@ -1045,8 +1046,7 @@ def ones(shape, dtype, force_cpu=False):
Returns:
Returns:
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Raises:
Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64.
and the data type of out Tensor must be the same as the dtype.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor.
be int32 or int64 when it's a Tensor.
...
@@ -1082,8 +1082,7 @@ def zeros(shape, dtype, force_cpu=False, name=None):
...
@@ -1082,8 +1082,7 @@ def zeros(shape, dtype, force_cpu=False, name=None):
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0.
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0.
Raises:
Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64.
and the data type of out Tensor must be the same as the dtype.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor.
be int32 or int64 when it's a Tensor.
Examples:
Examples:
...
...
python/paddle/tensor/creation.py
浏览文件 @
ff717d51
...
@@ -136,8 +136,7 @@ def ones(shape, dtype=None, name=None):
...
@@ -136,8 +136,7 @@ def ones(shape, dtype=None, name=None):
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Raises:
Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None.
and the data type of out Tensor must be the same as the dtype.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor.
be int32 or int64 when it's a Tensor.
...
@@ -229,8 +228,7 @@ def zeros(shape, dtype=None, name=None):
...
@@ -229,8 +228,7 @@ def zeros(shape, dtype=None, name=None):
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0.
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0.
Raises:
Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64 and None.
and the data type of out Tensor must be the same as the dtype.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor.
be int32 or int64 when it's a Tensor.
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
ff717d51
...
@@ -59,8 +59,8 @@ def concat(x, axis=0, name=None):
...
@@ -59,8 +59,8 @@ def concat(x, axis=0, name=None):
This OP concatenates the input along the axis.
This OP concatenates the input along the axis.
Args:
Args:
x(list
): List of input Tensors with data type float16, float32, float64, int32, int64.
x(list
|tuple): ``x`` is a Tensor list or Tensor tuple which is with data type bool, float16,
All the Tensors in ``x`` must have same data type.
float32, float64, int32, int64.
All the Tensors in ``x`` must have same data type.
axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
It's a scalar with data type int or a Tensor with shape [1] and data type int32
It's a scalar with data type int or a Tensor with shape [1] and data type int32
or int64. The effective range is [-R, R), where R is Rank(x). When ``axis < 0``,
or int64. The effective range is [-R, R), where R is Rank(x). When ``axis < 0``,
...
@@ -69,7 +69,8 @@ def concat(x, axis=0, name=None):
...
@@ -69,7 +69,8 @@ def concat(x, axis=0, name=None):
need for user to set this property. For more information, please
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
refer to :ref:`api_guide_Name`.
Raises:
Raises:
TypeError: The dtype of ``x`` must be one of float16, float32, float64, int32 and int64.
TypeError: ``x`` must be list or tuple.
TypeError: The data type of ``x`` must be one of bool, float16, float32, float64, int32 and int64.
TypeError: The ``axis`` must be int or Tensor. The dtype of ``axis`` must be int32 or int64 when it's a Tensor.
TypeError: The ``axis`` must be int or Tensor. The dtype of ``axis`` must be int32 or int64 when it's a Tensor.
TypeError: All the Tensors in ``x`` must have the same data type.
TypeError: All the Tensors in ``x`` must have the same data type.
...
@@ -83,21 +84,21 @@ def concat(x, axis=0, name=None):
...
@@ -83,21 +84,21 @@ def concat(x, axis=0, name=None):
import numpy as np
import numpy as np
paddle.enable_imperative() # Now we are in imperative mode
paddle.enable_imperative() # Now we are in imperative mode
in1 = np.array([[1,
2,
3],
in1 = np.array([[1,
2,
3],
[4,
5,
6]])
[4,
5,
6]])
in2 = np.array([[11,
12,
13],
in2 = np.array([[11,
12,
13],
[14,
15,
16]])
[14,
15,
16]])
in3 = np.array([[21,22],
in3 = np.array([[21,
22],
[23,24]])
[23,
24]])
x1 = paddle.imperative.to_variable(in1)
x1 = paddle.imperative.to_variable(in1)
x2 = paddle.imperative.to_variable(in2)
x2 = paddle.imperative.to_variable(in2)
x3 = paddle.imperative.to_variable(in3)
x3 = paddle.imperative.to_variable(in3)
zero = paddle.full(shape=[1], dtype='int32', fill_value=0)
zero = paddle.full(shape=[1], dtype='int32', fill_value=0)
# When the axis is negative, the real axis is (axis + Rank(x))
# When the axis is negative, the real axis is (axis + Rank(x))
# As follow, axis is -1, Rank(x) is 2, the real axis is 1
# As follow, axis is -1, Rank(x) is 2, the real axis is 1
out1 = paddle.concat(x=[x1,
x2,
x3], axis=-1)
out1 = paddle.concat(x=[x1,
x2,
x3], axis=-1)
out2 = paddle.concat(x=[x1,x2], axis=0)
out2 = paddle.concat(x=[x1,
x2], axis=0)
out3 = paddle.concat(x=[x1,x2], axis=zero)
out3 = paddle.concat(x=[x1,
x2], axis=zero)
# out1
# out1
# [[ 1 2 3 11 12 13 21 22]
# [[ 1 2 3 11 12 13 21 22]
# [ 4 5 6 14 15 16 23 24]]
# [ 4 5 6 14 15 16 23 24]]
...
@@ -107,6 +108,7 @@ def concat(x, axis=0, name=None):
...
@@ -107,6 +108,7 @@ def concat(x, axis=0, name=None):
# [11 12 13]
# [11 12 13]
# [14 15 16]]
# [14 15 16]]
"""
"""
check_type
(
x
,
'x'
,
(
list
,
tuple
),
'concat'
)
return
paddle
.
fluid
.
layers
.
concat
(
input
=
x
,
axis
=
axis
,
name
=
name
)
return
paddle
.
fluid
.
layers
.
concat
(
input
=
x
,
axis
=
axis
,
name
=
name
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录