Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2632d77d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2632d77d
编写于
9月 09, 2022
作者:
5
5u13
提交者:
GitHub
9月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimization of max_pool3d forward (#45820)
上级
a001f263
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
91 addition
and
58 deletion
+91
-58
paddle/phi/kernels/funcs/pooling.cu
paddle/phi/kernels/funcs/pooling.cu
+91
-58
未找到文件。
paddle/phi/kernels/funcs/pooling.cu
浏览文件 @
2632d77d
...
...
@@ -38,6 +38,24 @@ struct FastDivModForPooling {
}
};
struct
FastDivModForPooling3D
{
public:
paddle
::
platform
::
FastDivMod
channel
;
paddle
::
platform
::
FastDivMod
width
;
paddle
::
platform
::
FastDivMod
height
;
paddle
::
platform
::
FastDivMod
depth
;
explicit
HOSTDEVICE
FastDivModForPooling3D
(
const
int
channels
,
const
int
output_width
,
const
int
output_height
,
const
int
output_depth
)
{
channel
=
paddle
::
platform
::
FastDivMod
(
channels
);
width
=
paddle
::
platform
::
FastDivMod
(
output_width
);
height
=
paddle
::
platform
::
FastDivMod
(
output_height
);
depth
=
paddle
::
platform
::
FastDivMod
(
output_depth
);
}
};
struct
FastDivModForPoolingWithMoreStaff
{
public:
paddle
::
platform
::
FastDivMod
channel
;
...
...
@@ -2003,7 +2021,7 @@ template class MaxPool2dWithIndexFunctor<phi::GPUContext, double, int>;
template
class
MaxPool2dWithIndexGradFunctor
<
phi
::
GPUContext
,
double
,
int
>;
template
<
typename
T1
,
typename
T2
>
__global__
void
KernelMaxPool3DWithIdx
(
const
int
n
threads
,
__global__
void
KernelMaxPool3DWithIdx
(
const
int
n
cd
,
const
T1
*
input_data
,
const
int
channels
,
const
int
input_depth
,
...
...
@@ -2023,57 +2041,65 @@ __global__ void KernelMaxPool3DWithIdx(const int nthreads,
const
int
padding_width
,
bool
adaptive
,
T1
*
output_data
,
T2
*
mask_data
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
int
pd
=
(
index
/
output_width
/
output_height
)
%
output_depth
;
int
c
=
(
index
/
output_width
/
output_height
/
output_depth
)
%
channels
;
int
batch_idx
=
index
/
output_width
/
output_height
/
output_depth
/
channels
;
int
dstart
,
dend
;
int
hstart
,
hend
;
int
wstart
,
wend
;
if
(
adaptive
)
{
dstart
=
AdaptStartIndex
(
pd
,
input_depth
,
output_depth
);
dend
=
AdaptEndIndex
(
pd
,
input_depth
,
output_depth
);
hstart
=
AdaptStartIndex
(
ph
,
input_height
,
output_height
);
hend
=
AdaptEndIndex
(
ph
,
input_height
,
output_height
);
wstart
=
AdaptStartIndex
(
pw
,
input_width
,
output_width
);
wend
=
AdaptEndIndex
(
pw
,
input_width
,
output_width
);
}
else
{
dstart
=
pd
*
stride_depth
-
padding_depth
;
hstart
=
ph
*
stride_height
-
padding_height
;
wstart
=
pw
*
stride_width
-
padding_width
;
dend
=
min
(
dstart
+
ksize_depth
,
input_depth
);
hend
=
min
(
hstart
+
ksize_height
,
input_height
);
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
dstart
=
max
(
dstart
,
0
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
}
T1
ele
=
-
FLT_MAX
;
int
max_index
=
-
1
;
input_data
+=
(
batch_idx
*
channels
+
c
)
*
input_depth
*
input_height
*
input_width
;
T2
*
mask_data
,
FastDivModForPooling3D
divmods_output
)
{
int
w_offset
,
h_offset
,
d_offset
,
nc_offset
;
int
dstart
,
dend
,
hstart
,
hend
,
wstart
,
wend
;
const
T1
*
input_data_cur
;
w_offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
h_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
w_offset
<
output_width
&&
h_offset
<
output_height
)
{
for
(
int
index_z
=
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
;
index_z
<
ncd
;
index_z
+=
gridDim
.
z
*
blockDim
.
z
)
{
auto
output_depth_divmod
=
divmods_output
.
depth
.
Divmod
(
index_z
);
d_offset
=
output_depth_divmod
.
val
[
1
];
nc_offset
=
output_depth_divmod
.
val
[
0
];
int
output_index
=
nc_offset
*
output_depth
*
output_height
*
output_width
+
d_offset
*
output_height
*
output_width
+
h_offset
*
output_width
+
w_offset
;
int
input_offset
=
nc_offset
*
input_depth
*
input_height
*
input_width
;
input_data_cur
=
input_data
+
input_offset
;
if
(
adaptive
)
{
dstart
=
AdaptStartIndex
(
d_offset
,
input_depth
,
output_depth
);
dend
=
AdaptEndIndex
(
d_offset
,
input_depth
,
output_depth
);
hstart
=
AdaptStartIndex
(
h_offset
,
input_height
,
output_height
);
hend
=
AdaptEndIndex
(
h_offset
,
input_height
,
output_height
);
wstart
=
AdaptStartIndex
(
w_offset
,
input_width
,
output_width
);
wend
=
AdaptEndIndex
(
w_offset
,
input_width
,
output_width
);
}
else
{
dstart
=
d_offset
*
stride_depth
-
padding_depth
;
hstart
=
h_offset
*
stride_height
-
padding_height
;
wstart
=
w_offset
*
stride_width
-
padding_width
;
dend
=
min
(
dstart
+
ksize_depth
,
input_depth
);
hend
=
min
(
hstart
+
ksize_height
,
input_height
);
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
dstart
=
max
(
dstart
,
0
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
}
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
if
(
ele
<
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
])
{
max_index
=
(
d
*
input_height
+
h
)
*
input_width
+
w
;
ele
=
input_data
[
max_index
];
T1
ele
=
-
FLT_MAX
;
int
max_index
=
-
1
;
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
if
(
ele
<
input_data_cur
[(
d
*
input_height
+
h
)
*
input_width
+
w
])
{
max_index
=
(
d
*
input_height
+
h
)
*
input_width
+
w
;
ele
=
input_data_cur
[
max_index
];
}
}
}
}
output_data
[
output_index
]
=
ele
;
mask_data
[
output_index
]
=
max_index
;
}
output_data
[
index
]
=
ele
;
mask_data
[
index
]
=
max_index
;
}
}
...
...
@@ -2201,19 +2227,25 @@ class MaxPool3dWithIndexFunctor<phi::GPUContext, T1, T2> {
T1
*
output_data
=
context
.
template
Alloc
<
T1
>(
output
);
T2
*
mask_data
=
context
.
template
Alloc
<
T2
>(
mask
);
int
nthreads
=
batch_size
*
output_channels
*
output_depth
*
output_height
*
output_width
;
int
thread_num
=
1024
;
#ifdef WITH_NV_JETSON
backends
::
gpu
::
ChangeThreadNum
(
context
,
&
thread_num
);
#endif
int
ncd
=
batch_size
*
input_channels
*
output_depth
;
int
blocks
=
(
nthreads
+
thread_num
-
1
)
/
thread_num
;
dim3
threads
(
thread_num
,
1
);
dim3
grid
(
blocks
,
1
);
int
thread_x
=
32
;
int
thread_y
=
8
;
int
thread_z
=
1
;
dim3
threads
(
thread_x
,
thread_y
,
thread_z
);
std
::
array
<
int
,
3
>
max_grid_dim
=
context
.
GetCUDAMaxGridDimSize
();
int
block_x
=
(
output_width
+
threads
.
x
-
1
)
/
threads
.
x
;
int
block_y
=
(
output_height
+
threads
.
y
-
1
)
/
threads
.
y
;
int
block_z
=
(
ncd
>
max_grid_dim
[
2
]
*
threads
.
z
)
?
max_grid_dim
[
2
]
:
(
ncd
+
threads
.
z
-
1
)
/
threads
.
z
;
dim3
grid
(
block_x
,
block_y
,
block_z
);
auto
pool_divmods_output
=
FastDivModForPooling3D
(
input_channels
,
output_width
,
output_height
,
output_depth
);
KernelMaxPool3DWithIdx
<
T1
,
T2
>
<<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
n
threads
,
<<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
n
cd
,
input_data
,
input_channels
,
input_depth
,
...
...
@@ -2233,7 +2265,8 @@ class MaxPool3dWithIndexFunctor<phi::GPUContext, T1, T2> {
padding_width
,
adaptive
,
output_data
,
mask_data
);
mask_data
,
pool_divmods_output
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录