diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index 9ee8eab1a7922b2cc00b1164507460e5643a8bda..e0242da0c5fabb5d12088c60628f39e0bd7f3d07 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -23,23 +23,12 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/pooling.h" #if defined(__HIPCC__) || defined(__NVCC__) -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #endif namespace paddle { namespace operators { -template -struct DivideFunctor { - HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {} - - template - HOSTDEVICE inline U operator()(const U& x) const { - return x * static_cast(n_inv); - } - - private: - T n_inv; -}; using Tensor = framework::Tensor; @@ -219,9 +208,8 @@ class PoolKernel : public framework::OpKernel { adaptive) { // for adaptive_avg_pool2d && output_size == 1 #if defined(__HIPCC__) || defined(__NVCC__) auto stream = dev_ctx.stream(); - TensorReduce>( - *in_x, out, reduce_dim, static_cast(0), cub::Sum(), - DivideFunctor(reduce_num), stream); + TensorReduceFunctorImpl(*in_x, out, reduce_dim, + stream); #else // for cpu paddle::operators::math::Pool2dFunctor< DeviceContext, paddle::operators::math::AvgPool, T>