Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3c21f26b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3c21f26b
编写于
9月 01, 2021
作者:
W
wangguanzhong
提交者:
GitHub
9月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Stablize depthwise conv (#35161)
* stablize depthwise conv * clean commend
上级
7ca28bb6
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
38 addition
and
9 deletion
+38
-9
paddle/fluid/operators/math/depthwise_conv.cu
paddle/fluid/operators/math/depthwise_conv.cu
+38
-9
未找到文件。
paddle/fluid/operators/math/depthwise_conv.cu
浏览文件 @
3c21f26b
...
...
@@ -31,18 +31,43 @@ namespace operators {
namespace
math
{
template
<
typename
T
>
__device__
__inline__
void
CudaAtomicAddWithWarp
(
T
*
sum
,
T
valu
e
)
{
static
__forceinline__
__device__
T
WarpReduceSum
(
T
val
,
int
warp_siz
e
)
{
typedef
cub
::
WarpReduce
<
T
>
WarpReduce
;
typename
WarpReduce
::
TempStorage
temp_storage
;
val
=
WarpReduce
(
temp_storage
).
Sum
(
val
,
warp_size
);
return
val
;
}
#ifdef __HIPCC__
int
block_size
=
min
(
blockDim
.
x
*
blockDim
.
y
*
blockDim
.
z
,
warpSize
);
value
=
WarpReduce
(
temp_storage
).
Sum
(
value
,
block_size
);
#else
value
=
WarpReduce
(
temp_storage
).
Sum
(
value
);
#endif
template
<
typename
T
>
__forceinline__
__device__
T
BlockReduceSum
(
T
val
)
{
static
__shared__
T
shared
[
32
];
int
thread_id
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
z
*
blockDim
.
x
*
blockDim
.
y
;
int
warp_size
=
min
(
blockDim
.
x
*
blockDim
.
y
*
blockDim
.
z
,
warpSize
);
int
lane
=
thread_id
%
warp_size
;
int
wid
=
thread_id
/
warp_size
;
val
=
WarpReduceSum
(
val
,
warp_size
);
// Each warp performs partial reduction
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
// Write reduced value to shared memory
__syncthreads
();
// Wait for all partial reductions
// read from shared memory only if that warp existed
int
block_size
=
blockDim
.
x
*
blockDim
.
y
*
blockDim
.
z
;
if
(
thread_id
<
(
block_size
-
1
)
/
warp_size
+
1
)
{
val
=
shared
[
lane
];
}
else
{
val
=
static_cast
<
T
>
(
0
);
}
if
(
cub
::
LaneId
()
==
0
)
platform
::
CudaAtomicAdd
(
sum
,
value
);
if
(
wid
==
0
)
{
val
=
WarpReduceSum
(
val
,
warp_size
);
// Final reduce within first warp
}
__syncthreads
();
if
(
thread_id
!=
0
)
{
val
=
static_cast
<
T
>
(
0
);
}
return
val
;
}
#define ARG_DEFINE_KernelDepthwiseConv \
...
...
@@ -665,7 +690,9 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
}
}
}
CudaAtomicAddWithWarp
(
&
filter_grad_data
[
gbid
],
s
);
T
val
=
BlockReduceSum
(
s
);
platform
::
CudaAtomicAdd
(
&
filter_grad_data
[
gbid
],
val
);
}
template
<
typename
T
,
bool
fuse_relu_before_conv
>
...
...
@@ -892,6 +919,7 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
int
blocks
;
dim3
threads
;
dim3
grid
;
if
(
data_layout
!=
DataLayout
::
kNHWC
)
{
if
(
output_width
>
1024
&&
output_width
<=
2048
)
thread
=
(
output_width
-
1
)
/
2
+
1
;
...
...
@@ -1034,6 +1062,7 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
int
blocks
;
dim3
threads
;
dim3
grid
;
if
(
data_layout
!=
DataLayout
::
kNHWC
)
{
if
(
input_width
>
1024
&&
input_width
<=
2048
)
{
thread
=
(
input_width
-
1
)
/
2
+
1
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录