未验证 提交 2fd999d9 编写于 作者: N niuliling123 提交者: GitHub

Optimized the adaptive_avg_pool2d op when output_size == 1 (#31197)

* Optimized the adaptive_avg_pool2d op when output_size == 1
上级 aebf2234
...@@ -22,8 +22,20 @@ limitations under the License. */ ...@@ -22,8 +22,20 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/operators/math/pooling.h"
#ifdef __NVCC__
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
private:
T n_inv;
};
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
...@@ -124,6 +136,26 @@ inline void UpdateKsize(std::vector<T>* ksize, ...@@ -124,6 +136,26 @@ inline void UpdateKsize(std::vector<T>* ksize,
} }
} }
inline int getReduceNum(const framework::Tensor& input,
const framework::Tensor* output,
const std::string data_format,
std::vector<int>* reduce_dim) {
// data_format only can be NCHW
bool channel_last = (data_format == "NHWC");
if (channel_last) {
return 0;
}
int reduce_num = 0;
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
if ((output_height == 1) && (output_width == 1)) {
reduce_dim->push_back(2);
reduce_dim->push_back(3);
reduce_num = input.dims()[2] * input.dims()[3];
}
return reduce_num;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class PoolKernel : public framework::OpKernel<T> { class PoolKernel : public framework::OpKernel<T> {
public: public:
...@@ -164,7 +196,6 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -164,7 +196,6 @@ class PoolKernel : public framework::OpKernel<T> {
if (global_pooling) { if (global_pooling) {
UpdateKsize(&ksize, data_dims); UpdateKsize(&ksize, data_dims);
} }
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
switch (ksize.size()) { switch (ksize.size()) {
case 2: { case 2: {
...@@ -177,12 +208,32 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -177,12 +208,32 @@ class PoolKernel : public framework::OpKernel<T> {
pool_process, true, false, out); pool_process, true, false, out);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor< std::vector<int> reduce_dim;
DeviceContext, paddle::operators::math::AvgPool<T>, T> int reduce_num = getReduceNum(*in_x, out, data_format, &reduce_dim);
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process; if (reduce_num > 0 &&
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format, adaptive) { // for adaptive_avg_pool2d && output_size == 1
pool_process, exclusive, adaptive, out); #ifdef __NVCC__
auto stream = dev_ctx.stream();
TensorReduce<T, T, cub::Sum, DivideFunctor<T>>(
*in_x, out, reduce_dim, static_cast<T>(0), cub::Sum(),
DivideFunctor<T>(reduce_num), stream);
#else // for cpu
paddle::operators::math::Pool2dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings,
data_format, pool_process, exclusive, adaptive, out);
#endif
} else { // avgpool_2d or adaptive_avg_pool2d && output_size != 1
paddle::operators::math::Pool2dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings,
data_format, pool_process, exclusive, adaptive, out);
}
} }
} break; } break;
case 3: { case 3: {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册