Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
750aff10
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
750aff10
编写于
3月 23, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
code refine
上级
9075049a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
74 addition
and
74 deletion
+74
-74
paddle/fluid/operators/math/concat.cu
paddle/fluid/operators/math/concat.cu
+74
-74
未找到文件。
paddle/fluid/operators/math/concat.cu
浏览文件 @
750aff10
...
...
@@ -66,60 +66,60 @@ __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
}
template
<
typename
T
>
__global__
void
KernelConcat
(
T
**
inputs
,
const
int
input
_col
,
const
int
out
put_rows
,
const
int
outp
ut_cols
,
T
*
output
)
{
__global__
void
KernelConcat
(
T
**
inputs
_data
,
const
int
fixed_in
_col
,
const
int
out
_rows
,
const
int
o
ut_cols
,
T
*
output
_data
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
tid_x
<
out
put
_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
*
1.0
/
input
_col
;
int
in_offset
=
tid_x
-
split
*
input
_col
;
T
*
input_ptr
=
inputs
[
split
];
for
(;
tid_x
<
out_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
*
1.0
/
fixed_in
_col
;
int
in_offset
=
tid_x
-
split
*
fixed_in
_col
;
T
*
input_ptr
=
inputs
_data
[
split
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
out
put
_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
{
output
[
tid_y
*
outp
ut_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
input
_col
+
in_offset
];
for
(;
tid_y
<
out_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
{
output
_data
[
tid_y
*
o
ut_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
fixed_in
_col
+
in_offset
];
}
}
}
template
<
typename
T
>
__global__
void
KernelConcatGrad
(
const
T
*
input
,
const
int
input
_row
,
const
int
in
put_col
,
const
int
*
outp
ut_cols
,
int
col_size
,
T
**
outputs
)
{
__global__
void
KernelConcatGrad
(
const
T
*
input
_data
,
const
int
in
_row
,
const
int
in
_col
,
const
int
*
o
ut_cols
,
int
out_cols_size
,
T
**
outputs_data
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
segment
=
upper_bound
<
int
>
(
out
put_cols
,
col
_size
,
tid_x
)
-
1
;
int
curr_offset
=
out
put
_cols
[
segment
];
int
segment
=
upper_bound
<
int
>
(
out
_cols
,
out_cols
_size
,
tid_x
)
-
1
;
int
curr_offset
=
out_cols
[
segment
];
int
curr_segment
=
segment
;
for
(;
tid_x
<
in
put
_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(;
tid_x
<
in_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
curr_col_offset
;
while
((
curr_col_offset
=
out
put
_cols
[
curr_segment
+
1
])
<=
tid_x
)
{
while
((
curr_col_offset
=
out_cols
[
curr_segment
+
1
])
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
}
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
T
*
output_ptr
=
outputs
[
curr_segment
];
T
*
output_ptr
=
outputs
_data
[
curr_segment
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
in
put
_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
for
(;
tid_y
<
in_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
segment_width
+
local_col
]
=
input
[
tid_y
*
input
_col
+
tid_x
];
input
_data
[
tid_y
*
in
_col
+
tid_x
];
}
}
template
<
typename
T
>
__global__
void
KernelConcatGrad
(
const
T
*
input
,
const
int
input
_row
,
const
int
in
put_col
,
const
int
outp
ut_col
,
T
**
outputs
)
{
__global__
void
KernelConcatGrad
(
const
T
*
input
_data
,
const
int
in
_row
,
const
int
in
_col
,
const
int
fixed_o
ut_col
,
T
**
outputs
_data
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
tid_x
<
in
put
_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
/
outp
ut_col
;
int
in_offset
=
tid_x
-
split
*
outp
ut_col
;
T
*
output_ptr
=
outputs
[
split
];
for
(;
tid_x
<
in_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
/
fixed_o
ut_col
;
int
in_offset
=
tid_x
-
split
*
fixed_o
ut_col
;
T
*
output_ptr
=
outputs
_data
[
split
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
in
put
_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
outp
ut_col
+
in_offset
]
=
input
[
tid_y
*
input
_col
+
tid_x
];
for
(;
tid_y
<
in_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
fixed_o
ut_col
+
in_offset
]
=
input
_data
[
tid_y
*
in
_col
+
tid_x
];
}
}
...
...
@@ -134,41 +134,40 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
const
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
)
{
// TODO(zcd): Add input data validity checking
int
num
=
input
.
size
();
int
rows
=
1
;
int
in_
num
=
input
.
size
();
int
in_row
=
1
;
auto
dim_0
=
input
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
rows
*=
dim_0
[
i
];
in_row
*=
dim_0
[
i
];
}
int
cols
=
input
[
0
].
numel
()
/
rows
;
int
out_row
s
=
rows
,
out_cols
=
0
;
int
in_col
=
input
[
0
].
numel
()
/
in_row
;
int
out_row
=
in_row
,
out_col
=
0
;
framework
::
Vector
<
int16_t
>
inputs_data
(
num
*
sizeof
(
T
*
)
/
2
);
framework
::
Vector
<
int
>
inputs_cols
(
num
+
1
);
inputs_cols
[
0
]
=
0
;
framework
::
Vector
<
int16_t
>
inputs_data
(
in_num
*
sizeof
(
T
*
)
/
2
);
framework
::
Vector
<
int
>
inputs_col
(
in_num
+
1
);
T
**
inputs_ptr
=
reinterpret_cast
<
T
**>
(
inputs_data
.
data
());
inputs_col
[
0
]
=
0
;
bool
sameShape
=
true
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
input
[
i
].
numel
()
/
rows
;
for
(
int
i
=
0
;
i
<
in_
num
;
++
i
)
{
int
t_cols
=
input
[
i
].
numel
()
/
in_row
;
if
(
sameShape
)
{
if
(
t_cols
!=
cols
)
sameShape
=
false
;
if
(
t_cols
!=
in_col
)
sameShape
=
false
;
}
out_col
s
+=
t_cols
;
inputs_col
s
[
i
+
1
]
=
out_cols
;
out_col
+=
t_cols
;
inputs_col
[
i
+
1
]
=
out_col
;
inputs_ptr
[
i
]
=
const_cast
<
T
*>
(
input
[
i
].
data
<
T
>
());
}
T
**
ins_gpu
=
T
**
dev_ins_data
=
reinterpret_cast
<
T
**>
(
inputs_data
.
CUDAMutableData
(
context
.
GetPlace
()));
const
int
*
ins_col_gpu
=
inputs_cols
.
CUDAData
(
context
.
GetPlace
());
// computation
// set the thread block and grid according to CurrentDeviceId
const
int
kThreadsPerBlock
=
1024
;
int
block_cols
=
kThreadsPerBlock
;
if
(
out_col
s
<
kThreadsPerBlock
)
{
// block_cols is aligned by 32.
block_cols
=
((
out_col
s
+
31
)
>>
5
)
<<
5
;
if
(
out_col
<
kThreadsPerBlock
)
{
// block_cols is aligned by 32.
block_cols
=
((
out_col
+
31
)
>>
5
)
<<
5
;
}
int
block_rows
=
kThreadsPerBlock
/
block_cols
;
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
...
...
@@ -177,18 +176,19 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
int
grid_cols
=
std
::
min
((
out_col
s
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
std
::
min
((
out_col
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
int
grid_rows
=
std
::
min
(
max_blocks
/
grid_cols
,
std
::
max
(
out_row
s
/
block_rows
,
1
));
std
::
min
(
max_blocks
/
grid_cols
,
std
::
max
(
out_row
/
block_rows
,
1
));
dim3
grid_size
=
dim3
(
grid_cols
,
grid_rows
,
1
);
if
(
sameShape
)
{
KernelConcat
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
ins_gpu
,
cols
,
out_rows
,
out_cols
,
output
->
data
<
T
>
());
dev_ins_data
,
in_col
,
out_row
,
out_col
,
output
->
data
<
T
>
());
}
else
{
const
int
*
dev_ins_col_data
=
inputs_col
.
CUDAData
(
context
.
GetPlace
());
KernelConcat
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
ins_gpu
,
ins_col_gpu
,
static_cast
<
int
>
(
inputs_cols
.
size
()),
out_rows
,
out_
cols
,
output
->
data
<
T
>
());
dev_ins_data
,
dev_ins_col_data
,
static_cast
<
int
>
(
inputs_col
.
size
())
,
out_
row
,
out_col
,
output
->
data
<
T
>
());
}
}
};
...
...
@@ -204,41 +204,40 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
const
framework
::
Tensor
&
input
,
const
int
axis
,
std
::
vector
<
framework
::
Tensor
>&
outputs
)
{
// TODO(zcd): Add input data validity checking
int
num
=
outputs
.
size
();
int
inp
ut_row
=
1
;
int
o_
num
=
outputs
.
size
();
int
o
ut_row
=
1
;
auto
dim_0
=
outputs
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
inp
ut_row
*=
dim_0
[
i
];
o
ut_row
*=
dim_0
[
i
];
}
int
out
put_col_0
=
outputs
[
0
].
numel
()
/
inp
ut_row
;
int
in
put_col
=
0
;
int
out
_col
=
outputs
[
0
].
numel
()
/
o
ut_row
;
int
in
_col
=
0
,
in_row
=
out_row
;
bool
sameShape
=
true
;
framework
::
Vector
<
int16_t
>
outputs_data
(
num
*
sizeof
(
T
*
)
/
2
);
framework
::
Vector
<
int
>
outputs_cols
(
num
+
1
);
outputs_cols
[
0
]
=
0
;
framework
::
Vector
<
int16_t
>
outputs_data
(
o_num
*
sizeof
(
T
*
)
/
2
);
framework
::
Vector
<
int
>
outputs_cols
(
o_num
+
1
);
T
**
outputs_ptr
=
reinterpret_cast
<
T
**>
(
outputs_data
.
data
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_col
=
outputs
[
i
].
numel
()
/
input_row
;
outputs_cols
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
o_num
;
++
i
)
{
int
t_col
=
outputs
[
i
].
numel
()
/
out_row
;
if
(
sameShape
)
{
if
(
t_col
!=
out
put_col_0
)
sameShape
=
false
;
if
(
t_col
!=
out
_col
)
sameShape
=
false
;
}
in
put
_col
+=
t_col
;
outputs_cols
[
i
+
1
]
=
in
put
_col
;
in_col
+=
t_col
;
outputs_cols
[
i
+
1
]
=
in_col
;
outputs_ptr
[
i
]
=
outputs
[
i
].
data
<
T
>
();
}
T
**
outs_gpu
=
T
**
dev_out_gpu_data
=
reinterpret_cast
<
T
**>
(
outputs_data
.
CUDAMutableData
(
context
.
GetPlace
()));
const
int
*
outs_col_gpu
=
outputs_cols
.
CUDAData
(
context
.
GetPlace
());
// computation
const
int
kThreadsPerBlock
=
1024
;
int
block_cols
=
kThreadsPerBlock
;
if
(
in
put
_col
<
kThreadsPerBlock
)
{
// block_cols is aligned by 32.
block_cols
=
((
in
put
_col
+
31
)
>>
5
)
<<
5
;
if
(
in_col
<
kThreadsPerBlock
)
{
// block_cols is aligned by 32.
block_cols
=
((
in_col
+
31
)
>>
5
)
<<
5
;
}
int
block_rows
=
kThreadsPerBlock
/
block_cols
;
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
...
...
@@ -247,18 +246,19 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
int
grid_cols
=
std
::
min
((
in
put
_col
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
std
::
min
((
in_col
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
int
grid_rows
=
std
::
min
(
max_blocks
/
grid_cols
,
std
::
max
(
inp
ut_row
/
block_rows
,
1
));
std
::
min
(
max_blocks
/
grid_cols
,
std
::
max
(
o
ut_row
/
block_rows
,
1
));
dim3
grid_size
=
dim3
(
grid_cols
,
grid_rows
,
1
);
if
(
sameShape
)
{
KernelConcatGrad
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
in
put_row
,
input_col
,
output_col_0
,
outs_gpu
);
input
.
data
<
T
>
(),
in
_row
,
in_col
,
out_col
,
dev_out_gpu_data
);
}
else
{
const
int
*
dev_outs_col_data
=
outputs_cols
.
CUDAData
(
context
.
GetPlace
());
KernelConcatGrad
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
in
put_row
,
input_col
,
outs_col_gpu
,
static_cast
<
int
>
(
outputs_cols
.
size
()),
outs_gpu
);
input
.
data
<
T
>
(),
in
_row
,
in_col
,
dev_outs_col_data
,
static_cast
<
int
>
(
outputs_cols
.
size
()),
dev_out_gpu_data
);
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录