Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0e563da6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0e563da6
编写于
9月 20, 2022
作者:
5
5u13
提交者:
GitHub
9月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimization of max_pool3d grad (#45934)
上级
6d067860
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
61 addition
and
84 deletion
+61
-84
paddle/phi/kernels/funcs/pooling.cu
paddle/phi/kernels/funcs/pooling.cu
+61
-84
未找到文件。
paddle/phi/kernels/funcs/pooling.cu
浏览文件 @
0e563da6
...
...
@@ -2319,7 +2319,8 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd,
}
template
<
typename
T1
,
typename
T2
>
__global__
void
KernelMaxPool3DWithIdxGrad
(
const
int
nthreads
,
__global__
void
KernelMaxPool3DWithIdxGrad
(
const
int
ncd
,
const
T1
*
output_grad
,
const
T2
*
mask
,
const
int
channels
,
...
...
@@ -2339,67 +2340,31 @@ __global__ void KernelMaxPool3DWithIdxGrad(const int nthreads,
const
int
padding_height
,
const
int
padding_width
,
bool
adaptive
,
T1
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
w_offset
=
index
%
input_width
;
int
h_offset
=
(
index
/
input_width
)
%
input_height
;
int
d_offset
=
(
index
/
input_width
/
input_height
)
%
input_depth
;
int
c_offset
=
(
index
/
input_width
/
input_height
/
input_depth
)
%
channels
;
int
batch_idx
=
index
/
input_width
/
input_height
/
input_depth
/
channels
;
int
pdstart
,
pdend
;
int
phstart
,
phend
;
int
pwstart
,
pwend
;
if
(
adaptive
)
{
pdstart
=
d_offset
*
output_depth
/
input_depth
;
pdend
=
min
((
d_offset
+
1
)
*
output_depth
/
input_depth
+
1
,
output_depth
);
phstart
=
h_offset
*
output_height
/
input_height
;
phend
=
min
((
h_offset
+
1
)
*
output_height
/
input_height
+
1
,
output_height
);
pwstart
=
w_offset
*
output_width
/
input_width
;
pwend
=
min
((
w_offset
+
1
)
*
output_width
/
input_width
+
1
,
output_width
);
}
else
{
pdstart
=
(
d_offset
+
padding_depth
<
ksize_depth
)
?
0
:
(
d_offset
+
padding_depth
-
ksize_depth
)
/
stride_depth
+
1
;
phstart
=
(
h_offset
+
padding_height
<
ksize_height
)
?
0
:
(
h_offset
+
padding_height
-
ksize_height
)
/
stride_height
+
1
;
pwstart
=
(
w_offset
+
padding_width
<
ksize_width
)
?
0
:
(
w_offset
+
padding_width
-
ksize_width
)
/
stride_width
+
1
;
pdend
=
min
((
d_offset
+
padding_depth
)
/
stride_depth
+
1
,
output_depth
);
phend
=
min
((
h_offset
+
padding_height
)
/
stride_height
+
1
,
output_height
);
pwend
=
min
((
w_offset
+
padding_width
)
/
stride_width
+
1
,
output_width
);
}
T1
*
input_grad
,
FastDivModForPooling3D
divmods_output
)
{
int
w_offset
,
h_offset
,
d_offset
,
nc_offset
;
T1
input_grad_data
=
0
;
int
input_current_feature_map_idx
=
(
d_offset
*
input_height
+
h_offset
)
*
input_width
+
w_offset
;
int
output_idx
=
(
batch_idx
*
channels
+
c_offset
)
*
output_depth
*
output_height
*
output_width
;
mask
+=
output_idx
;
output_grad
+=
output_idx
;
w_offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
h_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(
int
pd
=
pdstart
;
pd
<
pdend
;
++
pd
)
{
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
if
(
mask
[(
pd
*
output_height
+
ph
)
*
output_width
+
pw
]
==
input_current_feature_map_idx
)
input_grad_data
+=
output_grad
[(
pd
*
output_height
+
ph
)
*
output_width
+
pw
];
}
if
(
w_offset
<
output_width
&&
h_offset
<
output_height
)
{
for
(
int
index_z
=
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
;
index_z
<
ncd
;
index_z
+=
gridDim
.
z
*
blockDim
.
z
)
{
auto
output_depth_divmod
=
divmods_output
.
depth
.
Divmod
(
index_z
);
d_offset
=
output_depth_divmod
.
val
[
1
];
nc_offset
=
output_depth_divmod
.
val
[
0
];
int
output_index
=
nc_offset
*
output_depth
*
output_height
*
output_width
+
d_offset
*
output_height
*
output_width
+
h_offset
*
output_width
+
w_offset
;
int
max_index
=
mask
[
output_index
];
if
(
max_index
!=
-
1
)
{
paddle
::
platform
::
CudaAtomicAdd
(
&
input_grad
[
nc_offset
*
input_depth
*
input_height
*
input_width
+
max_index
],
output_grad
[
output_index
]);
}
}
input_grad
[
index
]
=
input_grad_data
;
}
}
...
...
@@ -2523,14 +2488,25 @@ class MaxPool3dWithIndexGradFunctor<phi::GPUContext, T1, T2> {
const
T2
*
mask_data
=
mask
.
data
<
T2
>
();
T1
*
input_grad_data
=
context
.
template
Alloc
<
T1
>(
input_grad
);
int
nthreads
=
batch_size
*
input_channels
*
input_depth
*
input_height
*
input_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
int
ncd
=
batch_size
*
input_channels
*
output_depth
;
int
thread_x
=
32
;
int
thread_y
=
8
;
int
thread_z
=
1
;
dim3
threads
(
thread_x
,
thread_y
,
thread_z
);
std
::
array
<
int
,
3
>
max_grid_dim
=
context
.
GetCUDAMaxGridDimSize
();
int
block_x
=
(
output_width
+
threads
.
x
-
1
)
/
threads
.
x
;
int
block_y
=
(
output_height
+
threads
.
y
-
1
)
/
threads
.
y
;
int
block_z
=
(
ncd
>
max_grid_dim
[
2
]
*
threads
.
z
)
?
max_grid_dim
[
2
]
:
(
ncd
+
threads
.
z
-
1
)
/
threads
.
z
;
dim3
grid
(
block_x
,
block_y
,
block_z
);
auto
pool_divmods_output
=
FastDivModForPooling3D
(
input_channels
,
output_width
,
output_height
,
output_depth
);
KernelMaxPool3DWithIdxGrad
<
T1
,
T2
>
<<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
n
threads
,
<<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
n
cd
,
output_grad_data
,
mask_data
,
input_channels
,
...
...
@@ -2550,7 +2526,8 @@ class MaxPool3dWithIndexGradFunctor<phi::GPUContext, T1, T2> {
padding_height
,
padding_width
,
adaptive
,
input_grad_data
);
input_grad_data
,
pool_divmods_output
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录