Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
18650db3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
18650db3
编写于
9月 21, 2022
作者:
5
5u13
提交者:
GitHub
9月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimization of depthwise_conv2d grad (#46332)
上级
4839aca2
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
9 addition
and
8 deletion
+9
-8
paddle/phi/kernels/gpu/depthwise_conv.h
paddle/phi/kernels/gpu/depthwise_conv.h
+9
-8
未找到文件。
paddle/phi/kernels/gpu/depthwise_conv.h
浏览文件 @
18650db3
...
@@ -176,7 +176,8 @@ __device__ __inline__ void KernelDepthwiseConvNCHW(
...
@@ -176,7 +176,8 @@ __device__ __inline__ void KernelDepthwiseConvNCHW(
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
T
in_data
=
input_data
[
offset
];
T
in_data
=
input_data
[
offset
];
if
(
fuse_relu_before_conv
)
{
if
(
fuse_relu_before_conv
)
{
value
+=
weight
[
weight_offset
]
*
T
(
max
(
0.0
f
,
double
(
in_data
)));
value
+=
weight
[
weight_offset
]
*
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
in_data
)));
}
else
{
}
else
{
value
+=
weight
[
weight_offset
]
*
in_data
;
value
+=
weight
[
weight_offset
]
*
in_data
;
}
}
...
@@ -228,7 +229,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
...
@@ -228,7 +229,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
T
in_data
=
input_data
[
offset
];
T
in_data
=
input_data
[
offset
];
const
T
*
weight
=
filter_data
+
weight_offset
*
output_channels
+
c_out
;
const
T
*
weight
=
filter_data
+
weight_offset
*
output_channels
+
c_out
;
if
(
fuse_relu_before_conv
)
{
if
(
fuse_relu_before_conv
)
{
value
+=
weight
[
0
]
*
T
(
max
(
0.0
f
,
double
(
in_data
)));
value
+=
weight
[
0
]
*
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
in_data
)));
}
else
{
}
else
{
value
+=
weight
[
0
]
*
in_data
;
value
+=
weight
[
0
]
*
in_data
;
}
}
...
@@ -281,7 +282,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
...
@@ -281,7 +282,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
if
(
fuse_relu_before_conv
)
{
if
(
fuse_relu_before_conv
)
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
T
(
max
(
0.0
f
,
double
(
input_data
[
offset
])));
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
offset
])));
}
else
{
}
else
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
input_data
[
offset
];
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
input_data
[
offset
];
}
}
...
@@ -337,7 +338,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
...
@@ -337,7 +338,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
in_offset
+
(
h_in
*
input_width
+
w_in
)
*
input_channels
+
c_in
;
in_offset
+
(
h_in
*
input_width
+
w_in
)
*
input_channels
+
c_in
;
if
(
fuse_relu_before_conv
)
{
if
(
fuse_relu_before_conv
)
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
T
(
max
(
0.0
,
double
(
input_data
[
offset
])));
T
(
max
(
0.0
,
static_cast
<
double
>
(
input_data
[
offset
])));
}
else
{
}
else
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
input_data
[
offset
];
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
input_data
[
offset
];
}
}
...
@@ -880,7 +881,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
...
@@ -880,7 +881,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
image_wk
;
image_wk
;
if
(
fuse_relu_before_conv
)
{
if
(
fuse_relu_before_conv
)
{
s
+=
output_grad_data
[
gaid
(
bid
,
kernel_id
,
image_h
,
image_w
)]
*
s
+=
output_grad_data
[
gaid
(
bid
,
kernel_id
,
image_h
,
image_w
)]
*
T
(
max
(
0.0
f
,
double
(
input_data
[
input_id
])));
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
input_id
])));
}
else
{
}
else
{
s
+=
output_grad_data
[
gaid
(
bid
,
kernel_id
,
image_h
,
image_w
)]
*
s
+=
output_grad_data
[
gaid
(
bid
,
kernel_id
,
image_h
,
image_w
)]
*
input_data
[
input_id
];
input_data
[
input_id
];
...
@@ -891,7 +892,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
...
@@ -891,7 +892,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
}
}
T
val
=
BlockReduceSum
(
s
);
T
val
=
BlockReduceSum
(
s
);
platform
::
CudaAtomicAdd
(
&
filter_grad_data
[
gbid
],
val
)
;
if
(
threadIdx
.
y
==
0
&&
threadIdx
.
x
==
0
)
filter_grad_data
[
gbid
]
=
val
;
}
}
template
<
typename
T
,
bool
fuse_relu_before_conv
>
template
<
typename
T
,
bool
fuse_relu_before_conv
>
...
@@ -941,7 +942,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
...
@@ -941,7 +942,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
kernel_id
/
filter_multiplier
;
kernel_id
/
filter_multiplier
;
if
(
fuse_relu_before_conv
)
{
if
(
fuse_relu_before_conv
)
{
s
+=
output_grad_data
[
gaid
(
bid
,
image_h
,
image_w
,
kernel_id
)]
*
s
+=
output_grad_data
[
gaid
(
bid
,
image_h
,
image_w
,
kernel_id
)]
*
T
(
max
(
0.0
f
,
double
(
input_data
[
input_id
])));
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
input_id
])));
}
else
{
}
else
{
s
+=
output_grad_data
[
gaid
(
bid
,
image_h
,
image_w
,
kernel_id
)]
*
s
+=
output_grad_data
[
gaid
(
bid
,
image_h
,
image_w
,
kernel_id
)]
*
input_data
[
input_id
];
input_data
[
input_id
];
...
@@ -1013,7 +1014,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
...
@@ -1013,7 +1014,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
T
s
(
0
);
T
s
(
0
);
if
(
fuse_relu_before_conv
)
{
if
(
fuse_relu_before_conv
)
{
s
=
output_grad_data
[
output_id
]
*
s
=
output_grad_data
[
output_id
]
*
T
(
max
(
0.0
f
,
double
(
input_data
[
input_id
])));
T
(
max
(
0.0
f
,
static_cast
<
double
>
(
input_data
[
input_id
])));
}
else
{
}
else
{
s
=
output_grad_data
[
output_id
]
*
input_data
[
input_id
];
s
=
output_grad_data
[
output_id
]
*
input_data
[
input_id
];
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录