Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
fef2faa7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fef2faa7
编写于
11月 05, 2018
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
limit CUDA kernel parallel threads max number to 4096. test=develop
上级
34bfae24
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
34 addition
and
19 deletion
+34
-19
paddle/fluid/operators/interpolate_op.cu
paddle/fluid/operators/interpolate_op.cu
+18
-12
python/paddle/fluid/tests/unittests/test_interpolate_op.py
python/paddle/fluid/tests/unittests/test_interpolate_op.py
+16
-7
未找到文件。
paddle/fluid/operators/interpolate_op.cu
浏览文件 @
fef2faa7
...
...
@@ -26,7 +26,8 @@ __global__ void KeNearestNeighborInterpFw(
const
size_t
num_channels
,
const
float
ratio_h
,
const
float
ratio_w
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
int
in_img_size
=
input_w
/
num_channels
;
...
...
@@ -52,7 +53,8 @@ __global__ void KeNearestNeighborInterpBw(
const
size_t
num_channels
,
const
float
ratio_h
,
const
float
ratio_w
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
int
in_img_size
=
input_w
/
num_channels
;
...
...
@@ -80,7 +82,8 @@ __global__ void KeBilinearInterpFw(
const
size_t
num_channels
,
const
float
ratio_h
,
const
float
ratio_w
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
int
in_img_size
=
input_w
/
num_channels
;
...
...
@@ -118,7 +121,8 @@ __global__ void KeBilinearInterpBw(
const
size_t
num_channels
,
const
T
ratio_h
,
const
T
ratio_w
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
int
in_img_size
=
input_w
/
num_channels
;
...
...
@@ -194,17 +198,18 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
return
;
}
int
threadNum
=
n
*
out_chw
;
int
blocks
=
(
threadNum
+
1024
-
1
)
/
1024
;
int
pixelNum
=
n
*
out_chw
;
int
grid_dim
=
(
pixelNum
+
512
-
1
)
/
512
;
grid_dim
=
grid_dim
>
8
?
8
:
grid_dim
;
if
(
"nearest"
==
interp_method
)
{
KeNearestNeighborInterpFw
<
T
><<<
blocks
,
1024
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
);
}
else
if
(
"bilinear"
==
interp_method
)
{
KeBilinearInterpFw
<
T
><<<
blocks
,
1024
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
);
}
...
...
@@ -257,17 +262,18 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
return
;
}
int
threadNum
=
n
*
out_chw
;
int
blocks
=
(
threadNum
+
1024
-
1
)
/
1024
;
int
pixelNum
=
n
*
out_chw
;
int
grid_dim
=
(
pixelNum
+
512
-
1
)
/
512
;
grid_dim
=
grid_dim
>
8
?
8
:
grid_dim
;
if
(
"nearest"
==
interp_method
)
{
KeNearestNeighborInterpBw
<
T
><<<
blocks
,
1024
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_grad_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_grad_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
);
}
else
if
(
"bilinear"
==
interp_method
)
{
KeBilinearInterpBw
<
T
><<<
blocks
,
1024
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_grad_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_grad_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
);
}
...
...
python/paddle/fluid/tests/unittests/test_interpolate_op.py
浏览文件 @
fef2faa7
...
...
@@ -167,13 +167,13 @@ class TestBilinearInterpCase6(TestInterpolateOp):
self
.
out_size
=
np
.
array
([
65
,
129
]).
astype
(
"int32"
)
#
class TestBilinearInterpBigScale(TestInterpolateOp):
#
def init_test_case(self):
#
self.interp_method = 'bilinear'
# self.input_shape = [32, 16, 128, 64
]
# self.out_h = 2
00
# self.out_w = 10
0
# self.out_size = np.array([201, 10
1]).astype('int32')
class
TestBilinearInterpBigScale
(
TestInterpolateOp
):
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
4
,
4
,
64
,
32
]
self
.
out_h
=
1
00
self
.
out_w
=
5
0
self
.
out_size
=
np
.
array
([
101
,
5
1
]).
astype
(
'int32'
)
class
TestInterpolateOpUint8
(
OpTest
):
...
...
@@ -273,6 +273,15 @@ class TestNearestNeighborInterpCase6(TestInterpolateOp):
self
.
out_size
=
np
.
array
([
65
,
129
]).
astype
(
"int32"
)
class
TestNearestNeighborInterpBigScale
(
TestInterpolateOp
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
self
.
input_shape
=
[
4
,
4
,
64
,
32
]
self
.
out_h
=
100
self
.
out_w
=
50
self
.
out_size
=
np
.
array
([
101
,
51
]).
astype
(
'int32'
)
class
TestNearestNeighborInterpCase1Uint8
(
TestInterpolateOpUint8
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录