Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
82bd82c1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
82bd82c1
编写于
3月 05, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments and refine code
上级
00e596ed
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
88 addition
and
75 deletion
+88
-75
paddle/fluid/operators/concat_op.h
paddle/fluid/operators/concat_op.h
+2
-0
paddle/fluid/operators/math/concat.cc
paddle/fluid/operators/math/concat.cc
+8
-11
paddle/fluid/operators/math/concat.cu
paddle/fluid/operators/math/concat.cu
+57
-64
paddle/fluid/operators/math/concat.h
paddle/fluid/operators/math/concat.h
+21
-0
未找到文件。
paddle/fluid/operators/concat_op.h
浏览文件 @
82bd82c1
...
...
@@ -33,6 +33,7 @@ class ConcatKernel : public framework::OpKernel<T> {
auto
place
=
ctx
.
GetPlace
();
out
->
mutable_data
<
T
>
(
place
);
// TODO(zcd): Sometimes direct copies will be faster
std
::
vector
<
framework
::
Tensor
>
inputs
(
ins
.
size
());
for
(
size_t
j
=
0
;
j
<
ins
.
size
();
++
j
)
{
inputs
[
j
]
=
*
ins
[
j
];
...
...
@@ -51,6 +52,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto
outs
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
int64_t
axis
=
static_cast
<
int64_t
>
(
ctx
.
Attr
<
int
>
(
"axis"
));
// TODO(zcd): Sometimes direct copies will be faster
std
::
vector
<
framework
::
Tensor
>
outputs
(
outs
.
size
());
for
(
size_t
j
=
0
;
j
<
outs
.
size
();
++
j
)
{
outs
[
j
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
paddle/fluid/operators/math/concat.cc
浏览文件 @
82bd82c1
...
...
@@ -19,7 +19,8 @@ namespace operators {
namespace
math
{
/*
* All tensors' dimension should be the same.
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template
<
typename
T
>
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
...
...
@@ -27,12 +28,9 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
)
{
// assume the the max size of input is less than 8 and see the performance
// save origin dim
// TODO(zcd): Add input data validity checking
int
num
=
input
.
size
();
std
::
vector
<
paddle
::
framework
::
DDim
>
origin_dim
(
num
);
// get the matrix size
int
rows
=
1
;
auto
dim_0
=
input
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
...
...
@@ -40,7 +38,6 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
}
int
out_rows
=
rows
,
out_cols
=
0
;
// get input's cols
std
::
vector
<
int64_t
>
input_cols
(
input
.
size
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
input
[
i
].
numel
()
/
rows
;
...
...
@@ -64,18 +61,19 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
}
};
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template
<
typename
T
>
class
ConcatGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
int
axis
,
std
::
vector
<
framework
::
Tensor
>&
outputs
)
{
// assume the the max size of input is less than 8 and see the performance
// save origin dim
// TODO(zcd): Add input data validity checking
int
num
=
outputs
.
size
();
std
::
vector
<
paddle
::
framework
::
DDim
>
origin_dim
(
num
);
// get the matrix size
int
input_rows
=
1
;
auto
dim_0
=
outputs
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
...
...
@@ -83,7 +81,6 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
}
int
input_cols
=
0
;
// get outputs' cols
std
::
vector
<
int64_t
>
output_cols
(
outputs
.
size
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
outputs
[
i
].
numel
()
/
input_rows
;
...
...
paddle/fluid/operators/math/concat.cu
浏览文件 @
82bd82c1
...
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_helper.h"
...
...
@@ -19,16 +20,6 @@ namespace paddle {
namespace
operators
{
namespace
math
{
// TODO(zcd): This can be replaced by tensor,
// if that, maybe we should add int8 to VarType::Type.
// Or replaced by tensorArray.
static
constexpr
int
MaxSize
=
8
;
template
<
typename
T
>
struct
CUDADeviceArray
{
T
data
[
MaxSize
];
int
size
;
};
template
<
typename
T
>
__device__
T
upper_bound
(
const
T
*
first
,
T
count
,
T
val
)
{
const
T
*
orig
=
first
;
...
...
@@ -49,25 +40,24 @@ __device__ T upper_bound(const T* first, T count, T val) {
}
template
<
typename
T
>
__global__
void
KernelConcat
(
const
CUDADeviceArray
<
const
T
*>
inputs
,
const
CUDADeviceArray
<
int
>
input_cols
,
__global__
void
KernelConcat
(
T
**
inputs
,
const
int
*
input_cols
,
int
col_size
,
const
int
output_rows
,
const
int
output_cols
,
T
*
output
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
segment
=
upper_bound
<
int
>
(
input_cols
.
data
,
input_cols
.
size
,
tid_x
)
-
1
;
int
segment
=
upper_bound
<
int
>
(
input_cols
,
col_
size
,
tid_x
)
-
1
;
int
curr_offset
=
input_cols
.
data
[
segment
];
int
curr_offset
=
input_cols
[
segment
];
int
curr_segment
=
segment
;
for
(;
tid_x
<
output_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
curr_col_offset
;
while
((
curr_col_offset
=
input_cols
.
data
[
curr_segment
+
1
])
<=
tid_x
)
{
while
((
curr_col_offset
=
input_cols
[
curr_segment
+
1
])
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
}
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
const
T
*
input_ptr
=
inputs
.
data
[
curr_segment
];
T
*
input_ptr
=
inputs
[
curr_segment
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output
[
tid_y
*
output_cols
+
tid_x
]
=
...
...
@@ -76,41 +66,41 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
}
template
<
typename
T
>
__global__
void
KernelConcat
(
const
CUDADeviceArray
<
const
T
*>
inputs
,
const
int
input_col
,
const
int
output_row
s
,
const
int
output_cols
,
T
*
output
)
{
__global__
void
KernelConcat
(
T
**
inputs
,
const
int
input_col
,
const
int
output_rows
,
const
int
output_col
s
,
T
*
output
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
float
inv_input_col
=
1.0
/
input_col
;
for
(;
tid_x
<
output_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
*
inv_input_col
;
int
in_offset
=
tid_x
-
split
*
input_col
;
const
T
*
input_ptr
=
inputs
.
data
[
split
];
T
*
input_ptr
=
inputs
[
split
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
{
output
[
tid_y
*
output_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
input_col
+
in_offset
];
}
}
}
template
<
typename
T
>
__global__
void
KernelConcatGrad
(
const
T
*
input
,
const
int
input_row
,
const
int
input_col
,
CUDADeviceArray
<
int
>
output_cols
,
CUDADeviceArray
<
T
*>
outputs
)
{
const
int
input_col
,
const
int
*
output_cols
,
int
col_size
,
T
**
outputs
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
segment
=
upper_bound
<
int
>
(
output_cols
.
data
,
output_cols
.
size
,
tid_x
)
-
1
;
int
curr_offset
=
output_cols
.
data
[
segment
];
int
segment
=
upper_bound
<
int
>
(
output_cols
,
col_
size
,
tid_x
)
-
1
;
int
curr_offset
=
output_cols
[
segment
];
int
curr_segment
=
segment
;
for
(;
tid_x
<
input_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
curr_col_offset
;
while
((
curr_col_offset
=
output_cols
.
data
[
curr_segment
+
1
])
<=
tid_x
)
{
while
((
curr_col_offset
=
output_cols
[
curr_segment
+
1
])
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
}
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
T
*
output_ptr
=
outputs
.
data
[
curr_segment
];
T
*
output_ptr
=
outputs
[
curr_segment
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
input_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
segment_width
+
local_col
]
=
...
...
@@ -121,13 +111,13 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
template
<
typename
T
>
__global__
void
KernelConcatGrad
(
const
T
*
input
,
const
int
input_row
,
const
int
input_col
,
const
int
output_cols
,
CUDADeviceArray
<
T
*>
outputs
)
{
T
**
outputs
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
float
inv_input_col
=
1.0
/
input_col
;
for
(;
tid_x
<
input_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
*
inv_input_col
;
int
in_offset
=
tid_x
-
split
*
input_col
;
T
*
output_ptr
=
outputs
.
data
[
split
];
T
*
output_ptr
=
outputs
[
split
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
input_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
output_cols
+
in_offset
]
=
...
...
@@ -136,7 +126,8 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
}
/*
* All tensors' dimension should be the same.
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template
<
typename
T
>
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
...
...
@@ -144,12 +135,8 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
)
{
// assume the the max size of input is less than 8 and see the performance
// save origin dim
// TODO(zcd): Add input data validity checking
int
num
=
input
.
size
();
PADDLE_ENFORCE_LT
(
num
,
MaxSize
,
"input number should be less than %d"
,
MaxSize
);
// get the matrix size
int
rows
=
1
;
auto
dim_0
=
input
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
...
...
@@ -157,25 +144,27 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
}
int
cols
=
input
[
0
].
numel
()
/
rows
;
int
out_rows
=
rows
,
out_cols
=
0
;
bool
sameShape
=
true
;
CUDADeviceArray
<
const
T
*>
inputs_data
;
CUDADeviceArray
<
int
>
inputs_cols
;
inputs_data
.
size
=
num
;
inputs_cols
.
size
=
num
+
1
;
inputs_cols
.
data
[
0
]
=
0
;
// reshape to matrix
// check input shape is valid
paddle
::
framework
::
Vector
<
int16_t
>
inputs_data
(
num
*
sizeof
(
T
*
)
/
2
);
paddle
::
framework
::
Vector
<
int
>
inputs_cols
(
num
+
1
);
inputs_cols
[
0
]
=
0
;
T
**
inputs_ptr
=
reinterpret_cast
<
T
**>
(
inputs_data
.
data
());
bool
sameShape
=
true
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
input
[
i
].
numel
()
/
rows
;
if
(
sameShape
)
{
if
(
t_cols
!=
cols
)
sameShape
=
false
;
}
out_cols
+=
t_cols
;
inputs_cols
.
data
[
i
+
1
]
=
out_cols
;
inputs_
data
.
data
[
i
]
=
input
[
i
].
data
<
T
>
(
);
inputs_cols
[
i
+
1
]
=
out_cols
;
inputs_
ptr
[
i
]
=
const_cast
<
T
*>
(
input
[
i
].
data
<
T
>
()
);
}
T
**
ins_gpu
=
reinterpret_cast
<
T
**>
(
inputs_data
.
CUDAMutableData
(
context
.
GetPlace
()));
const
int
*
ins_col_gpu
=
inputs_cols
.
CUDAData
(
context
.
GetPlace
());
// computation
// set the thread block and grid according to CurrentDeviceId
const
int
kThreadsPerBlock
=
1024
;
...
...
@@ -198,27 +187,27 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
if
(
sameShape
)
{
KernelConcat
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
in
puts_data
,
cols
,
out_rows
,
out_cols
,
output
->
data
<
T
>
());
in
s_gpu
,
cols
,
out_rows
,
out_cols
,
output
->
data
<
T
>
());
}
else
{
KernelConcat
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
inputs_data
,
inputs_cols
,
out_rows
,
out_cols
,
output
->
data
<
T
>
());
ins_gpu
,
ins_col_gpu
,
static_cast
<
int
>
(
inputs_cols
.
size
()),
out_rows
,
out_cols
,
output
->
data
<
T
>
());
}
}
};
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template
<
typename
T
>
class
ConcatGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
int
axis
,
std
::
vector
<
framework
::
Tensor
>&
outputs
)
{
// assume the the max size of input is less than 8 and see the performance
// save origin dim
// TODO(zcd): Add input data validity checking
int
num
=
outputs
.
size
();
PADDLE_ENFORCE_LT
(
num
,
MaxSize
,
"input number should be less than %d"
,
MaxSize
);
// get the matrix size
int
input_row
=
1
;
auto
dim_0
=
outputs
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
...
...
@@ -229,11 +218,10 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
int
input_col
=
0
;
bool
sameShape
=
true
;
CUDADeviceArray
<
T
*>
outputs_data
;
CUDADeviceArray
<
int
>
outputs_cols
;
outputs_data
.
size
=
num
;
outputs_cols
.
size
=
num
+
1
;
outputs_cols
.
data
[
0
]
=
0
;
paddle
::
framework
::
Vector
<
int16_t
>
outputs_data
(
num
*
sizeof
(
T
*
)
/
2
);
paddle
::
framework
::
Vector
<
int
>
outputs_cols
(
num
+
1
);
outputs_cols
[
0
]
=
0
;
T
**
outputs_ptr
=
reinterpret_cast
<
T
**>
(
outputs_data
.
data
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_col
=
outputs
[
i
].
numel
()
/
input_row
;
...
...
@@ -241,12 +229,16 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
if
(
t_col
!=
output_col_0
)
sameShape
=
false
;
}
input_col
+=
t_col
;
outputs_cols
.
data
[
i
+
1
]
=
input_col
;
outputs_
data
.
data
[
i
]
=
outputs
[
i
].
data
<
T
>
();
outputs_cols
[
i
+
1
]
=
input_col
;
outputs_
ptr
[
i
]
=
outputs
[
i
].
data
<
T
>
();
}
T
**
outs_gpu
=
reinterpret_cast
<
T
**>
(
outputs_data
.
CUDAMutableData
(
context
.
GetPlace
()));
const
int
*
outs_col_gpu
=
outputs_cols
.
CUDAData
(
context
.
GetPlace
());
// computation
const
int
kThreadsPerBlock
=
256
;
const
int
kThreadsPerBlock
=
1024
;
int
block_cols
=
std
::
min
(
input_col
,
kThreadsPerBlock
);
int
block_rows
=
std
::
max
(
kThreadsPerBlock
/
block_cols
,
1
);
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
...
...
@@ -257,10 +249,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
if
(
sameShape
)
{
KernelConcatGrad
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
input_row
,
input_col
,
output_col_0
,
out
puts_data
);
input
.
data
<
T
>
(),
input_row
,
input_col
,
output_col_0
,
out
s_gpu
);
}
else
{
KernelConcatGrad
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
input_row
,
input_col
,
outputs_cols
,
outputs_data
);
input
.
data
<
T
>
(),
input_row
,
input_col
,
outs_col_gpu
,
static_cast
<
int
>
(
outputs_cols
.
size
()),
outs_gpu
);
}
}
};
...
...
paddle/fluid/operators/math/concat.h
浏览文件 @
82bd82c1
...
...
@@ -20,7 +20,16 @@ namespace operators {
namespace
math
{
/*
* \brief Concatenate the input tensors along the dimension axis.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input[0] = [[1,2],[3,4]]
* Input[1] = [[5,6]]
* axis = 0
*
* Output = [[1,2],
* [3,4],
* [5,6]]
*/
template
<
typename
DeviceContext
,
typename
T
>
class
ConcatFunctor
{
...
...
@@ -30,6 +39,18 @@ class ConcatFunctor {
framework
::
Tensor
*
output
);
};
/*
* \brief Split the input tensors along the dimension axis into outputs.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input = [[1,2],
* [3,4],
* [5,6]]
* axis = 0
*
* Output[0] = [[1,2],[3,4]]
* Output[1] = [[5,6]]
*/
template
<
typename
DeviceContext
,
typename
T
>
class
ConcatGradFunctor
{
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录