Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0d8b222b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0d8b222b
编写于
1月 19, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
1月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize the depthwise op test=develop (#22265)
上级
325f0722
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
129 addition
and
79 deletion
+129
-79
paddle/fluid/operators/math/depthwise_conv.cu
paddle/fluid/operators/math/depthwise_conv.cu
+129
-79
未找到文件。
paddle/fluid/operators/math/depthwise_conv.cu
浏览文件 @
0d8b222b
...
...
@@ -45,67 +45,106 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
template
<
typename
T
,
bool
fuse_relu_before_conv
>
__device__
__inline__
void
KernelDepthwiseConv
(
ARG_DEFINE_KernelDepthwiseConv
)
{
for
(
int
w_out
=
threadIdx
.
x
;
w_out
<
output_width
;
w_out
+=
blockDim
.
x
)
{
for
(
int
h_out
=
threadIdx
.
y
;
h_out
<
output_height
;
h_out
+=
blockDim
.
y
)
{
const
int
batch
=
blockIdx
.
y
;
const
int
c_out
=
blockIdx
.
x
;
const
int
c_in
=
c_out
/
filter_multiplier
;
const
T
*
weight
=
filter_data
+
c_out
*
filter_height
*
filter_width
;
T
value
=
0
;
const
int
h_in_start
=
-
padding_height
+
h_out
*
stride_height
;
const
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
const
int
h_in_end
=
h_in_start
+
filter_height
*
dilate_height
;
const
int
w_in_end
=
w_in_start
+
filter_width
*
dilate_width
;
int
in_offset
;
if
(
data_layout
!=
DataLayout
::
kNHWC
)
{
in_offset
=
((
batch
*
input_channels
+
c_in
)
*
input_height
)
*
input_width
;
}
else
{
in_offset
=
batch
*
input_height
*
input_width
*
input_channels
;
__device__
__inline__
void
KernelDepthwiseConvNCHW
(
ARG_DEFINE_KernelDepthwiseConv
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
>=
(
output_channels
*
batch_size
*
output_height
*
output_width
))
return
;
const
int
w_out
=
idx
%
output_width
;
const
int
h_out
=
(
idx
/
output_width
)
%
output_height
;
const
int
c_out
=
(
idx
/
output_width
/
output_height
)
%
output_channels
;
const
int
batch
=
idx
/
output_width
/
output_height
/
output_channels
;
const
int
c_in
=
c_out
/
filter_multiplier
;
const
T
*
weight
=
filter_data
+
c_out
*
filter_height
*
filter_width
;
T
value
=
0
;
const
int
h_in_start
=
-
padding_height
+
h_out
*
stride_height
;
const
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
const
int
h_in_end
=
h_in_start
+
filter_height
*
dilate_height
;
const
int
w_in_end
=
w_in_start
+
filter_width
*
dilate_width
;
int
in_offset
=
((
batch
*
input_channels
+
c_in
)
*
input_height
)
*
input_width
;
const
int
h_end
=
h_in_end
<
input_height
?
h_in_end
:
input_height
;
const
int
w_end
=
w_in_end
<
input_width
?
w_in_end
:
input_width
;
const
int
h_start
=
h_in_start
>
0
?
h_in_start
:
0
;
const
int
w_start
=
w_in_start
>
0
?
w_in_start
:
0
;
int
weight_offset
=
0
;
#pragma unroll
for
(
int
h_in
=
h_in_start
;
h_in
<
h_in_end
;
h_in
+=
dilate_height
)
{
#pragma unroll
for
(
int
w_in
=
w_in_start
;
w_in
<
w_in_end
;
w_in
+=
dilate_width
)
{
if
(
h_in
>=
h_start
&&
h_in
<
h_end
&&
w_in
>=
w_start
&&
w_in
<
w_end
)
{
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
T
in_data
=
input_data
[
offset
];
if
(
fuse_relu_before_conv
)
{
value
+=
weight
[
weight_offset
]
*
max
(
0.0
f
,
in_data
);
}
else
{
value
+=
weight
[
weight_offset
]
*
in_data
;
}
}
weight_offset
++
;
}
}
int
index
=
batch
*
output_channels
*
output_height
*
output_width
+
c_out
*
output_height
*
output_width
+
h_out
*
output_width
+
w_out
;
output_data
[
index
]
=
value
;
}
const
int
h_end
=
h_in_end
<
input_height
?
h_in_end
:
input_height
;
const
int
w_end
=
w_in_end
<
input_width
?
w_in_end
:
input_width
;
const
int
h_start
=
h_in_start
>
0
?
h_in_start
:
0
;
const
int
w_start
=
w_in_start
>
0
?
w_in_start
:
0
;
int
weight_offset
=
0
;
for
(
int
h_in
=
h_in_start
;
h_in
<
h_in_end
;
h_in
+=
dilate_height
)
{
for
(
int
w_in
=
w_in_start
;
w_in
<
w_in_end
;
w_in
+=
dilate_width
)
{
if
(
h_in
>=
h_start
&&
h_in
<
h_end
&&
w_in
>=
w_start
&&
w_in
<
w_end
)
{
int
offset
;
if
(
data_layout
!=
DataLayout
::
kNHWC
)
{
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
}
else
{
offset
=
in_offset
+
(
h_in
*
input_width
+
w_in
)
*
input_channels
+
c_in
;
}
if
(
fuse_relu_before_conv
)
{
value
+=
weight
[
weight_offset
]
*
max
(
0.0
f
,
input_data
[
offset
]);
}
else
{
value
+=
weight
[
weight_offset
]
*
input_data
[
offset
];
}
}
weight_offset
++
;
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template
<
typename
T
,
bool
fuse_relu_before_conv
>
__device__
__inline__
void
KernelDepthwiseConvNHWC
(
ARG_DEFINE_KernelDepthwiseConv
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
>=
(
output_channels
*
batch_size
*
output_height
*
output_width
))
return
;
const
int
c_out
=
idx
%
output_channels
;
const
int
w_out
=
(
idx
/
output_channels
)
%
output_width
;
const
int
h_out
=
(
idx
/
output_channels
/
output_width
)
%
output_height
;
const
int
batch
=
idx
/
output_width
/
output_height
/
output_channels
;
const
int
c_in
=
c_out
/
filter_multiplier
;
const
T
*
weight
=
filter_data
+
c_out
*
filter_height
*
filter_width
;
T
value
=
0
;
const
int
h_in_start
=
-
padding_height
+
h_out
*
stride_height
;
const
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
const
int
h_in_end
=
h_in_start
+
filter_height
*
dilate_height
;
const
int
w_in_end
=
w_in_start
+
filter_width
*
dilate_width
;
const
int
h_end
=
h_in_end
<
input_height
?
h_in_end
:
input_height
;
const
int
w_end
=
w_in_end
<
input_width
?
w_in_end
:
input_width
;
const
int
h_start
=
h_in_start
>
0
?
h_in_start
:
0
;
const
int
w_start
=
w_in_start
>
0
?
w_in_start
:
0
;
int
weight_offset
=
0
;
#pragma unroll
for
(
int
h_in
=
h_in_start
;
h_in
<
h_in_end
;
h_in
+=
dilate_height
)
{
#pragma unroll
for
(
int
w_in
=
w_in_start
;
w_in
<
w_in_end
;
w_in
+=
dilate_width
)
{
if
(
h_in
>=
h_start
&&
h_in
<
h_end
&&
w_in
>=
w_start
&&
w_in
<
w_end
)
{
int
offset
=
((
batch
*
input_height
+
h_in
)
*
input_width
+
w_in
)
*
output_channels
+
c_in
;
T
in_data
=
input_data
[
offset
];
if
(
fuse_relu_before_conv
)
{
value
+=
weight
[
weight_offset
]
*
max
(
0.0
f
,
in_data
);
}
else
{
value
+=
weight
[
weight_offset
]
*
in_data
;
}
}
int
index
;
if
(
data_layout
!=
DataLayout
::
kNHWC
)
{
index
=
((
batch
*
gridDim
.
x
+
c_out
)
*
output_height
+
h_out
)
*
output_width
+
w_out
;
}
else
{
index
=
((
batch
*
output_height
+
h_out
)
*
output_width
+
w_out
)
*
gridDim
.
x
+
c_out
;
}
output_data
[
index
]
=
value
;
weight_offset
++
;
}
}
int
index
=
batch
*
output_channels
*
output_height
*
output_width
+
h_out
*
output_width
*
output_channels
+
w_out
*
output_channels
+
c_out
;
output_data
[
index
]
=
value
;
}
template
<
typename
T
,
int
c_filter
,
bool
fuse_relu_before_conv
>
...
...
@@ -183,36 +222,37 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
template
<
typename
T
,
int
c_filter_multiplier
,
int
c_stride
,
int
c_filter
,
bool
fuse_relu_before_conv
>
__global__
void
KernelDepthwiseConvSp
(
ARG_DEFINE_KernelDepthwiseConv
)
{
if
(
c_filter_multiplier
==
0
)
{
if
(
c_filter
==
-
1
)
KernelDepthwiseConv
<
T
,
fuse_relu_before_conv
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
filter_multiplier
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
,
data_layout
);
else
KernelDepthwiseConvCFilter
<
T
,
c_filter
,
fuse_relu_before_conv
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
filter_multiplier
,
filter_height
,
filter_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
,
data_layout
);
}
else
{
if
(
c_filter
==
-
1
)
KernelDepthwiseConv
<
T
,
fuse_relu_before_conv
>
(
int
final_filter_multiplier
=
filter_multiplier
;
int
h_stride
=
stride_height
;
int
w_stride
=
stride_width
;
if
(
c_filter_multiplier
!=
0
)
{
final_filter_multiplier
=
c_filter_multiplier
;
h_stride
=
c_stride
;
w_stride
=
c_stride
;
}
if
(
c_filter
==
-
1
)
{
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
KernelDepthwiseConvNCHW
<
T
,
fuse_relu_before_conv
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
c_filter_multiplier
,
filter_height
,
filter_height
,
c_stride
,
c
_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
final_filter_multiplier
,
filter_height
,
filter_width
,
h
_stride
,
w_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
,
data_layout
);
else
KernelDepthwiseConv
CFilter
<
T
,
c_filter
,
fuse_relu_before_conv
>
(
}
else
{
KernelDepthwiseConv
NHWC
<
T
,
fuse_relu_before_conv
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
c_filter_multiplier
,
filter_height
,
filter_height
,
c_stride
,
c
_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
final_filter_multiplier
,
filter_height
,
filter_width
,
h
_stride
,
w_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
,
data_layout
);
}
}
else
{
KernelDepthwiseConvCFilter
<
T
,
c_filter
,
fuse_relu_before_conv
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
final_filter_multiplier
,
filter_height
,
filter_width
,
h_stride
,
w_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
,
data_layout
);
}
}
...
...
@@ -564,12 +604,22 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
dim3
threads
(
std
::
min
(
output_width
,
thread
),
blocks
,
1
);
dim3
grid
(
output_channels
,
batch_size
,
1
);
int
filter_multiplier
=
output_channels
/
input_channels
;
int
nums_output
=
batch_size
*
output_channels
*
output_height
*
output_width
;
int
block_size
=
512
;
#define check_case(c_filter_multiplier, c_stride, c_filter) \
if (c_filter_multiplier == 0 || \
filter_multiplier == c_filter_multiplier && \
stride_height == stride_width && stride_height == c_stride && \
(ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \
if (c_filter == -1) { \
threads.x = block_size; \
grid.x = (nums_output + block_size - 1) / block_size; \
threads.y = threads.z = grid.y = grid.z = 1; \
} \
KernelDepthwiseConvSp< \
T, c_filter_multiplier, c_stride, c_filter, \
fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录