Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
9c90dc97
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9c90dc97
编写于
6月 19, 2018
作者:
Q
qingqing01
提交者:
GitHub
6月 19, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make the CUDA kernel of concat correct and fix unit tests. (#11541)
* Make the CUDA kernel of concat correct and fix unit tests.
上级
9988f8ec
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
23 addition
and
31 deletion
+23
-31
paddle/fluid/operators/math/concat.cu
paddle/fluid/operators/math/concat.cu
+11
-30
python/paddle/fluid/tests/unittests/test_concat_op.py
python/paddle/fluid/tests/unittests/test_concat_op.py
+12
-1
未找到文件。
paddle/fluid/operators/math/concat.cu
浏览文件 @
9c90dc97
...
...
@@ -22,43 +22,24 @@ namespace paddle {
namespace
operators
{
namespace
math
{
template
<
typename
T
>
__device__
T
upper_bound
(
const
T
*
first
,
T
count
,
T
val
)
{
const
T
*
orig
=
first
;
const
T
*
it
=
nullptr
;
T
step
=
0
;
while
(
count
>
0
)
{
it
=
first
;
step
=
count
/
2
;
it
+=
step
;
if
(
!
(
val
<
*
it
))
{
first
=
++
it
;
count
-=
step
+
1
;
}
else
{
count
=
step
;
}
}
return
first
-
orig
;
}
template
<
typename
T
>
__global__
void
KernelConcat
(
T
**
inputs
,
const
int
*
input_cols
,
int
col_size
,
const
int
output_rows
,
const
int
output_cols
,
T
*
output
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
segment
=
upper_bound
<
int
>
(
input_cols
,
col_size
,
tid_x
)
-
1
;
int
curr_offset
=
input_cols
[
segment
];
int
curr_segment
=
segment
;
int
curr_segment
=
0
;
int
curr_offset
=
input_cols
[
0
];
for
(;
tid_x
<
output_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
curr_col_offset
;
while
(
(
curr_col_offset
=
input_cols
[
curr_segment
+
1
])
<=
tid_x
)
{
int
curr_col_offset
=
input_cols
[
curr_segment
+
1
]
;
while
(
curr_col_offset
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
curr_col_offset
=
input_cols
[
curr_segment
+
1
];
}
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
T
*
input_ptr
=
inputs
[
curr_segment
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
...
...
@@ -89,14 +70,14 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
const
int
in_col
,
const
int
*
out_cols
,
int
out_cols_size
,
T
**
outputs_data
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
segment
=
upper_bound
<
int
>
(
out_cols
,
out_cols_size
,
tid_x
)
-
1
;
int
curr_offset
=
out_cols
[
segment
];
int
curr_segment
=
segment
;
int
curr_segment
=
0
;
int
curr_offset
=
out_cols
[
0
];
for
(;
tid_x
<
in_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
curr_col_offset
;
while
(
(
curr_col_offset
=
out_cols
[
curr_segment
+
1
])
<=
tid_x
)
{
int
curr_col_offset
=
out_cols
[
curr_segment
+
1
]
;
while
(
curr_col_offset
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
curr_col_offset
=
out_cols
[
curr_segment
+
1
];
}
int
local_col
=
tid_x
-
curr_offset
;
...
...
python/paddle/fluid/tests/unittests/test_concat_op.py
浏览文件 @
9c90dc97
...
...
@@ -43,7 +43,7 @@ class TestConcatOp(OpTest):
self
.
axis
=
1
class
TestConcatOp2
(
OpTest
):
class
TestConcatOp2
(
TestConcatOp
):
def
init_test_data
(
self
):
self
.
x0
=
np
.
random
.
random
((
2
,
3
,
4
,
5
)).
astype
(
'float32'
)
self
.
x1
=
np
.
random
.
random
((
2
,
3
,
4
,
5
)).
astype
(
'float32'
)
...
...
@@ -51,5 +51,16 @@ class TestConcatOp2(OpTest):
self
.
axis
=
1
class
TestConcatOp3
(
TestConcatOp
):
def
init_test_data
(
self
):
self
.
x0
=
np
.
random
.
random
((
1
,
256
,
170
,
256
)).
astype
(
'float32'
)
self
.
x1
=
np
.
random
.
random
((
1
,
128
,
170
,
256
)).
astype
(
'float32'
)
self
.
x2
=
np
.
random
.
random
((
1
,
128
,
170
,
256
)).
astype
(
'float32'
)
self
.
axis
=
1
def
test_check_grad
(
self
):
pass
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录