Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
330b1a0a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
330b1a0a
编写于
9月 23, 2022
作者:
Z
Zhang Zheng
提交者:
GitHub
9月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize performance of depthwise_conv_fwd (#46287)
* Optimize performance of depthwise_conv_fwd * fix
上级
22fe4f03
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
55 addition
and
51 deletion
+55
-51
paddle/phi/kernels/gpu/depthwise_conv.h
paddle/phi/kernels/gpu/depthwise_conv.h
+55
-51
未找到文件。
paddle/phi/kernels/gpu/depthwise_conv.h
浏览文件 @
330b1a0a
...
...
@@ -139,56 +139,53 @@ __forceinline__ __device__ T BlockReduceSum(T val) {
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
template
<
typename
T
,
bool
fuse_relu_before_conv
>
template
<
typename
T
,
int
c_filter
,
bool
fuse_relu_before_conv
>
__device__
__inline__
void
KernelDepthwiseConvNCHW
(
ARG_DEFINE_KernelDepthwiseConv
)
{
const
int
fw_size
=
c_filter
!=
-
1
?
c_filter
:
filter_width
;
const
int
fh_size
=
c_filter
!=
-
1
?
c_filter
:
filter_height
;
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
>=
(
output_channels
*
batch_size
*
output_height
*
output_width
))
return
;
const
int
w_out
=
idx
%
output_width
;
const
int
h_out
=
(
idx
/
output_width
)
%
output_height
;
const
int
c_out
=
(
idx
/
output_width
/
output_height
)
%
output_channels
;
const
int
batch
=
idx
/
output_width
/
output_height
/
output_channels
;
int
tmp_1
=
idx
/
output_width
;
const
int
w_out
=
idx
-
tmp_1
*
output_width
;
int
tmp_2
=
tmp_1
/
output_height
;
const
int
h_out
=
tmp_1
-
tmp_2
*
output_height
;
tmp_1
=
tmp_2
;
tmp_2
=
tmp_1
/
output_channels
;
const
int
c_out
=
tmp_1
-
tmp_2
*
output_channels
;
const
int
batch
=
tmp_2
;
const
int
c_in
=
c_out
/
filter_multiplier
;
const
T
*
weight
=
filter_data
+
c_out
*
filter_height
*
filter_width
;
T
value
(
0
);
const
int
h_in_start
=
-
padding_height
+
h_out
*
stride_height
;
const
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
const
int
h_in_end
=
h_in_start
+
filter_height
*
dilate_height
;
const
int
w_in_end
=
w_in_start
+
filter_width
*
dilate_width
;
int
in_offset
=
((
batch
*
input_channels
+
c_in
)
*
input_height
)
*
input_width
;
const
int
h_end
=
h_in_end
<
input_height
?
h_in_end
:
input_height
;
const
int
w_end
=
w_in_end
<
input_width
?
w_in_end
:
input_width
;
const
int
h_start
=
h_in_start
>
0
?
h_in_start
:
0
;
const
int
w_start
=
w_in_start
>
0
?
w_in_start
:
0
;
int
weight_offset
=
0
;
int
weight_offset
=
c_out
*
filter_height
*
filter_width
;
int
h_in_start
=
-
padding_height
+
h_out
*
stride_height
;
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
#pragma unroll
for
(
int
h_in
=
h_in_start
;
h_in
<
h_in_end
;
h_in
+=
dilate_height
)
{
for
(
int
fh
=
0
,
h_in
=
h_in_start
;
fh
<
fh_size
;
fh
++
,
h_in
+=
dilate_height
)
{
#pragma unroll
for
(
int
w_in
=
w_in_start
;
w_in
<
w_in_end
;
w_in
+=
dilate_width
)
{
if
(
h_in
>=
h_start
&&
h_in
<
h_end
&&
w_in
>=
w_start
&&
w_in
<
w_end
)
{
for
(
int
fw
=
0
,
w_in
=
w_in_start
;
fw
<
fw_size
;
fw
++
,
w_in
+=
dilate_width
)
{
if
(
h_in
>=
0
&&
h_in
<
input_height
&&
w_in
>=
0
&&
w_in
<
input_width
)
{
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
T
in_data
=
input_data
[
offset
];
if
(
fuse_relu_before_conv
)
{
value
+=
weight
[
weight_offset
]
*
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
in_data
)));
value
+=
filter_data
[
weight_offset
]
*
static_cast
<
T
>
(
max
(
0.0
f
,
static_cast
<
double
>
(
in_data
)));
}
else
{
value
+=
weight
[
weight_offset
]
*
in_data
;
value
+=
filter_data
[
weight_offset
]
*
in_data
;
}
}
weight_offset
++
;
}
}
int
index
=
batch
*
output_channels
*
output_height
*
output_width
+
c_out
*
output_height
*
output_width
+
h_out
*
output_width
+
w_out
;
output_data
[
index
]
=
value
;
output_data
[
idx
]
=
value
;
}
// A Cuda kernel to compute the depthwise convolution forward pass
...
...
@@ -229,7 +226,8 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
T
in_data
=
input_data
[
offset
];
const
T
*
weight
=
filter_data
+
weight_offset
*
output_channels
+
c_out
;
if
(
fuse_relu_before_conv
)
{
value
+=
weight
[
0
]
*
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
in_data
)));
value
+=
weight
[
0
]
*
static_cast
<
T
>
(
max
(
0.0
f
,
static_cast
<
double
>
(
in_data
)));
}
else
{
value
+=
weight
[
0
]
*
in_data
;
}
...
...
@@ -282,7 +280,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
if
(
fuse_relu_before_conv
)
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
offset
])));
static_cast
<
T
>
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
offset
])));
}
else
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
input_data
[
offset
];
}
...
...
@@ -338,7 +337,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
in_offset
+
(
h_in
*
input_width
+
w_in
)
*
input_channels
+
c_in
;
if
(
fuse_relu_before_conv
)
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
T
(
max
(
0.0
,
static_cast
<
double
>
(
input_data
[
offset
])));
static_cast
<
T
>
(
max
(
0.0
,
static_cast
<
double
>
(
input_data
[
offset
])));
}
else
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
input_data
[
offset
];
}
...
...
@@ -368,25 +368,26 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
}
if
(
c_filter
==
-
1
)
{
if
(
data_layout
!=
DataLayout
::
kNHWC
)
{
KernelDepthwiseConvNCHW
<
T
,
fuse_relu_before_conv
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
final_filter_multiplier
,
filter_height
,
filter_width
,
h_stride
,
w_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
);
KernelDepthwiseConvNCHW
<
T
,
c_filter
,
fuse_relu_before_conv
>
(
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
final_filter_multiplier
,
filter_height
,
filter_width
,
h_stride
,
w_stride
,
padding_height
,
padding_width
,
dilate_height
,
dilate_width
,
output_data
);
}
else
{
KernelDepthwiseConvNHWC
<
T
,
fuse_relu_before_conv
>
(
input_data
,
filter_data
,
...
...
@@ -881,7 +882,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
image_wk
;
if
(
fuse_relu_before_conv
)
{
s
+=
output_grad_data
[
gaid
(
bid
,
kernel_id
,
image_h
,
image_w
)]
*
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
input_id
])));
static_cast
<
T
>
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
input_id
])));
}
else
{
s
+=
output_grad_data
[
gaid
(
bid
,
kernel_id
,
image_h
,
image_w
)]
*
input_data
[
input_id
];
...
...
@@ -942,7 +944,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
kernel_id
/
filter_multiplier
;
if
(
fuse_relu_before_conv
)
{
s
+=
output_grad_data
[
gaid
(
bid
,
image_h
,
image_w
,
kernel_id
)]
*
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
input_id
])));
static_cast
<
T
>
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
input_id
])));
}
else
{
s
+=
output_grad_data
[
gaid
(
bid
,
image_h
,
image_w
,
kernel_id
)]
*
input_data
[
input_id
];
...
...
@@ -1014,7 +1017,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
T
s
(
0
);
if
(
fuse_relu_before_conv
)
{
s
=
output_grad_data
[
output_id
]
*
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
input_id
])));
static_cast
<
T
>
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
input_id
])));
}
else
{
s
=
output_grad_data
[
output_id
]
*
input_data
[
input_id
];
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录