未验证 提交 2fd999d9 编写于 作者: N niuliling123 提交者: GitHub

Optimized the adaptive_avg_pool2d op when output_size == 1 (#31197)

* Optimized the adaptive_avg_pool2d op when output_size == 1
上级 aebf2234
......@@ -22,8 +22,20 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
#ifdef __NVCC__
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
private:
T n_inv;
};
using Tensor = framework::Tensor;
......@@ -124,6 +136,26 @@ inline void UpdateKsize(std::vector<T>* ksize,
}
}
inline int getReduceNum(const framework::Tensor& input,
const framework::Tensor* output,
const std::string data_format,
std::vector<int>* reduce_dim) {
// data_format only can be NCHW
bool channel_last = (data_format == "NHWC");
if (channel_last) {
return 0;
}
int reduce_num = 0;
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
if ((output_height == 1) && (output_width == 1)) {
reduce_dim->push_back(2);
reduce_dim->push_back(3);
reduce_num = input.dims()[2] * input.dims()[3];
}
return reduce_num;
}
template <typename DeviceContext, typename T>
class PoolKernel : public framework::OpKernel<T> {
public:
......@@ -164,7 +196,6 @@ class PoolKernel : public framework::OpKernel<T> {
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
auto& dev_ctx = context.template device_context<DeviceContext>();
switch (ksize.size()) {
case 2: {
......@@ -177,12 +208,32 @@ class PoolKernel : public framework::OpKernel<T> {
pool_process, true, false, out);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
pool_process, exclusive, adaptive, out);
std::vector<int> reduce_dim;
int reduce_num = getReduceNum(*in_x, out, data_format, &reduce_dim);
if (reduce_num > 0 &&
adaptive) { // for adaptive_avg_pool2d && output_size == 1
#ifdef __NVCC__
auto stream = dev_ctx.stream();
TensorReduce<T, T, cub::Sum, DivideFunctor<T>>(
*in_x, out, reduce_dim, static_cast<T>(0), cub::Sum(),
DivideFunctor<T>(reduce_num), stream);
#else // for cpu
paddle::operators::math::Pool2dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings,
data_format, pool_process, exclusive, adaptive, out);
#endif
} else { // avgpool_2d or adaptive_avg_pool2d && output_size != 1
paddle::operators::math::Pool2dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings,
data_format, pool_process, exclusive, adaptive, out);
}
}
} break;
case 3: {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册