Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2fd999d9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2fd999d9
编写于
3月 01, 2021
作者:
N
niuliling123
提交者:
GitHub
3月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimized the adaptive_avg_pool2d op when output_size == 1 (#31197)
* Optimized the adaptive_avg_pool2d op when output_size == 1
上级
aebf2234
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
58 addition
and
7 deletion
+58
-7
paddle/fluid/operators/pool_op.cu
paddle/fluid/operators/pool_op.cu
+0
-0
paddle/fluid/operators/pool_op.h
paddle/fluid/operators/pool_op.h
+58
-7
未找到文件。
paddle/fluid/operators/pool_op.cu
.cc
→
paddle/fluid/operators/pool_op.cu
浏览文件 @
2fd999d9
文件已移动
paddle/fluid/operators/pool_op.h
浏览文件 @
2fd999d9
...
...
@@ -22,8 +22,20 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
#ifdef __NVCC__
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
DivideFunctor
{
HOSTDEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
((
T
)(
1.0
/
n
))
{}
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
*
n_inv
;
}
private:
T
n_inv
;
};
using
Tensor
=
framework
::
Tensor
;
...
...
@@ -124,6 +136,26 @@ inline void UpdateKsize(std::vector<T>* ksize,
}
}
inline
int
getReduceNum
(
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
*
output
,
const
std
::
string
data_format
,
std
::
vector
<
int
>*
reduce_dim
)
{
// data_format only can be NCHW
bool
channel_last
=
(
data_format
==
"NHWC"
);
if
(
channel_last
)
{
return
0
;
}
int
reduce_num
=
0
;
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_width
=
output
->
dims
()[
3
];
if
((
output_height
==
1
)
&&
(
output_width
==
1
))
{
reduce_dim
->
push_back
(
2
);
reduce_dim
->
push_back
(
3
);
reduce_num
=
input
.
dims
()[
2
]
*
input
.
dims
()[
3
];
}
return
reduce_num
;
}
template
<
typename
DeviceContext
,
typename
T
>
class
PoolKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -164,7 +196,6 @@ class PoolKernel : public framework::OpKernel<T> {
if
(
global_pooling
)
{
UpdateKsize
(
&
ksize
,
data_dims
);
}
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
switch
(
ksize
.
size
())
{
case
2
:
{
...
...
@@ -177,12 +208,32 @@ class PoolKernel : public framework::OpKernel<T> {
pool_process
,
true
,
false
,
out
);
}
else
if
(
pooling_type
==
"avg"
)
{
std
::
vector
<
int
>
reduce_dim
;
int
reduce_num
=
getReduceNum
(
*
in_x
,
out
,
data_format
,
&
reduce_dim
);
if
(
reduce_num
>
0
&&
adaptive
)
{
// for adaptive_avg_pool2d && output_size == 1
#ifdef __NVCC__
auto
stream
=
dev_ctx
.
stream
();
TensorReduce
<
T
,
T
,
cub
::
Sum
,
DivideFunctor
<
T
>>
(
*
in_x
,
out
,
reduce_dim
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
DivideFunctor
<
T
>
(
reduce_num
),
stream
);
#else // for cpu
paddle
::
operators
::
math
::
Pool2dFunctor
<
DeviceContext
,
paddle
::
operators
::
math
::
AvgPool
<
T
>
,
T
>
pool2d_forward
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
pool2d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
data_format
,
pool_process
,
exclusive
,
adaptive
,
out
);
pool2d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
data_format
,
pool_process
,
exclusive
,
adaptive
,
out
);
#endif
}
else
{
// avgpool_2d or adaptive_avg_pool2d && output_size != 1
paddle
::
operators
::
math
::
Pool2dFunctor
<
DeviceContext
,
paddle
::
operators
::
math
::
AvgPool
<
T
>
,
T
>
pool2d_forward
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
pool2d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
data_format
,
pool_process
,
exclusive
,
adaptive
,
out
);
}
}
}
break
;
case
3
:
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录