From 524389ee347c0371fbf95b017f16f5b3d7a00910 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Thu, 16 Dec 2021 10:33:11 +0800 Subject: [PATCH] Add the transformop parameter in TensorReduceFunctorImpl (#38135) * Add the transformop parameter in TensorReduceFunctorImpl --- paddle/fluid/operators/clip_by_norm_op.cu | 29 +---- .../elementwise/elementwise_add_op.cu | 7 +- .../elementwise/elementwise_sub_op.cu | 7 +- paddle/fluid/operators/fused/attn_gemm.h | 7 +- .../operators/margin_cross_entropy_op.cu | 22 ++-- paddle/fluid/operators/p_norm_op.cu | 73 ++----------- paddle/fluid/operators/pool_op.h | 8 +- paddle/fluid/operators/prelu_op.cu | 14 +-- .../operators/reduce_ops/reduce_all_op.cu | 3 +- .../operators/reduce_ops/reduce_any_op.cu | 3 +- .../operators/reduce_ops/reduce_max_op.cu | 10 +- .../operators/reduce_ops/reduce_mean_op.cu | 8 +- .../operators/reduce_ops/reduce_min_op.cu | 10 +- .../fluid/operators/reduce_ops/reduce_op.cu.h | 100 +++++++++--------- paddle/fluid/operators/reduce_ops/reduce_op.h | 15 ++- .../operators/reduce_ops/reduce_prod_op.cu | 10 +- .../operators/reduce_ops/reduce_sum_op.cu | 24 ++--- paddle/fluid/operators/trace_op.cu | 16 +-- paddle/fluid/operators/triangular_solve_op.cu | 4 +- 19 files changed, 138 insertions(+), 232 deletions(-) diff --git a/paddle/fluid/operators/clip_by_norm_op.cu b/paddle/fluid/operators/clip_by_norm_op.cu index 5997e467693..368fbe836c2 100644 --- a/paddle/fluid/operators/clip_by_norm_op.cu +++ b/paddle/fluid/operators/clip_by_norm_op.cu @@ -18,29 +18,6 @@ limitations under the License. */ namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -struct SquareTransformer { - HOSTDEVICE explicit inline SquareTransformer(int n) {} - - HOSTDEVICE inline Ty operator()(const Tx& x) const { - return static_cast(x) * static_cast(x); - } - - HOSTDEVICE inline Ty operator()(const Tx* x) const { - return static_cast(x[0]) * static_cast(x[0]); - } -}; - -template -struct SquareSum { - using Transformer = SquareTransformer; - - inline Ty initial() { return static_cast(0.0f); } - - __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { - return b + a; - } -}; template <> class ClipByNormKernel @@ -97,8 +74,10 @@ class ClipByNormKernel } Tensor tmp = context.AllocateTmpTensor( {1}, dev_ctx); - TensorReduceFunctorImpl( - *input, &tmp, reduce_dims, dev_ctx.stream()); + TensorReduceFunctorImpl>( + *input, &tmp, kps::SquareFunctor(), + reduce_dims, dev_ctx.stream()); auto tmp_eigen = EigenVector::Flatten(tmp); auto x_norm = tmp_eigen.sqrt(); diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 65bcb2239e4..7b153a4bce8 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -91,7 +90,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx, } std::vector reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); gpuStream_t stream = ctx.cuda_device_context().stream(); - TensorReduceFunctorImpl(*dout, dx, reduce_dims, stream); + TensorReduceFunctorImpl>( + *dout, dx, kps::IdentityFunctor(), reduce_dims, stream); } } // dy @@ -106,7 +106,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx, } else { std::vector reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); gpuStream_t stream = ctx.cuda_device_context().stream(); - TensorReduceFunctorImpl(*dout, dy, reduce_dims, stream); + TensorReduceFunctorImpl>( + *dout, dy, kps::IdentityFunctor(), reduce_dims, stream); } } } diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 2b44c81a455..cba261a3947 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -69,7 +68,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx, } std::vector reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); gpuStream_t stream = ctx.cuda_device_context().stream(); - TensorReduceFunctorImpl(*dout, dx, reduce_dims, stream); + TensorReduceFunctorImpl>( + *dout, dx, kps::IdentityFunctor(), reduce_dims, stream); } } // dy @@ -90,7 +90,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx, } else { std::vector reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); gpuStream_t stream = ctx.cuda_device_context().stream(); - TensorReduceFunctorImpl(*dout, dy, reduce_dims, stream); + TensorReduceFunctorImpl>( + *dout, dy, kps::InverseFunctor(), reduce_dims, stream); } } } diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h index 21875cc5214..e0873608fa2 100644 --- a/paddle/fluid/operators/fused/attn_gemm.h +++ b/paddle/fluid/operators/fused/attn_gemm.h @@ -16,11 +16,10 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" namespace paddle { namespace operators { - // support gemm-nt and gemm-nn, which is used in fused_attention_op. template class AttnMatMul { @@ -165,8 +164,8 @@ class AttnMatMul { (input_dims[2] == output_dims[0])); if (support_case_1 || support_case_2) { gpuStream_t stream = dev_ctx_.stream(); - TensorReduceFunctorImpl(*d_output, d_bias, {0, 1}, - stream); + TensorReduceFunctorImpl>( + *d_output, d_bias, kps::IdentityFunctor(), {0, 1}, stream); } else { PADDLE_THROW(platform::errors::InvalidArgument( "Only support reduce when the input dims are [0,1,2,3,4] and " diff --git a/paddle/fluid/operators/margin_cross_entropy_op.cu b/paddle/fluid/operators/margin_cross_entropy_op.cu index 1deaa3ef1ee..35035704b7e 100644 --- a/paddle/fluid/operators/margin_cross_entropy_op.cu +++ b/paddle/fluid/operators/margin_cross_entropy_op.cu @@ -24,7 +24,6 @@ namespace cub = hipcub; #include "paddle/fluid/operators/margin_cross_entropy_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/softmax_impl.h" -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/string/string_helper.h" @@ -128,17 +127,6 @@ __global__ void AddMarginToPositiveLogitsKernel( } } -template -struct ExpAndSum { - using Transformer = kps::ExpFunctor; - - inline Ty initial() { return static_cast(0.0f); } - - __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { - return b + a; - } -}; - template __global__ void ScaleLogitKernel(T* logits, const float scale, const int64_t N, const int64_t D) { @@ -309,8 +297,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel { logits_max = ctx.AllocateTmpTensor({N, 1}, dev_ctx); T* logits_max_buff = logits_max.mutable_data(place); - TensorReduceFunctorImpl(softmax_2d, &logits_max, {1}, - dev_ctx.stream()); + TensorReduceFunctorImpl>( + softmax_2d, &logits_max, kps::IdentityFunctor(), {1}, + dev_ctx.stream()); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) if (nranks > 1) { @@ -330,8 +319,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel { sum_exp_logits = ctx.AllocateTmpTensor({N, 1}, dev_ctx); T* sum_exp_logits_buff = sum_exp_logits.mutable_data(place); - TensorReduceFunctorImpl(softmax_2d, &sum_exp_logits, {1}, - dev_ctx.stream()); + TensorReduceFunctorImpl>( + softmax_2d, &sum_exp_logits, kps::ExpFunctor(), {1}, + dev_ctx.stream()); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) if (nranks > 1) { diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index 1a481c1cf5c..11a4928fe1c 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -59,28 +59,17 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) { return pow(base, exponent); } -struct IdentityFunctor { - HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE explicit inline IdentityFunctor(int n) {} - template - HOSTDEVICE inline T operator()(const T& x) const { - return static_cast(x); - } -}; - +template struct NonzeroFunctor { HOSTDEVICE explicit inline NonzeroFunctor() {} - HOSTDEVICE explicit inline NonzeroFunctor(int n) {} - template HOSTDEVICE inline T operator()(const T& x) const { return static_cast(static_cast(x) != 0); } }; +template struct AbsFunctor { HOSTDEVICE explicit inline AbsFunctor() {} - HOSTDEVICE explicit inline AbsFunctor(int n) {} - template HOSTDEVICE inline T operator()(const T& x) const { return static_cast(inline_abs(x)); } @@ -106,48 +95,6 @@ struct PowFunctor { float porder; }; -template -struct AbsAndMin { - using Transformer = AbsFunctor; - using MT = typename details::MPTypeTrait::Type; - inline Ty initial() { - return static_cast(std::numeric_limits::infinity()); - } - __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { - return (a < b) ? a : b; - } -}; - -template -struct AbsAndMax { - using Transformer = AbsFunctor; - using MT = typename details::MPTypeTrait::Type; - inline Ty initial() { - return static_cast(-std::numeric_limits::infinity()); - } - __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { - return (a > b) ? a : b; - } -}; - -template -struct NonzeroAndSum { - using Transformer = NonzeroFunctor; - inline Ty initial() { return static_cast(0.0f); } - __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { - return b + a; - } -}; - -template -struct IdentityAndSum { - using Transformer = IdentityFunctor; - inline Ty initial() { return static_cast(0.0f); } - __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { - return b + a; - } -}; - template class PnormCUDAKernel : public framework::OpKernel { public: @@ -167,14 +114,14 @@ class PnormCUDAKernel : public framework::OpKernel { using MT = typename details::MPTypeTrait::Type; if (porder == 0) { - TensorReduceFunctorImpl(*in_x, out_norm, reduce_axis, - stream); + TensorReduceFunctorImpl>( + *in_x, out_norm, NonzeroFunctor(), reduce_axis, stream); } else if (porder == INFINITY) { - TensorReduceFunctorImpl(*in_x, out_norm, reduce_axis, - stream); + TensorReduceFunctorImpl>( + *in_x, out_norm, AbsFunctor(), reduce_axis, stream); } else if (porder == -INFINITY) { - TensorReduceFunctorImpl(*in_x, out_norm, reduce_axis, - stream); + TensorReduceFunctorImpl>( + *in_x, out_norm, AbsFunctor(), reduce_axis, stream); } else { framework::Tensor tmp_x; tmp_x.mutable_data(xdim, ctx.GetPlace()); @@ -189,8 +136,8 @@ class PnormCUDAKernel : public framework::OpKernel { cuda_ctx, ins, &outs, func); framework::Tensor tmp_y; tmp_y.mutable_data(ndim, ctx.GetPlace()); - TensorReduceFunctorImpl(tmp_x, &tmp_y, reduce_axis, - stream); + TensorReduceFunctorImpl>( + tmp_x, &tmp_y, kps::IdentityFunctor(), reduce_axis, stream); const framework::Tensor* tmp_norm = &tmp_y; ins = {tmp_norm}; outs = {out_norm}; diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index e0242da0c5f..84c1988e29b 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -23,7 +23,6 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/pooling.h" #if defined(__HIPCC__) || defined(__NVCC__) -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #endif @@ -203,13 +202,14 @@ class PoolKernel : public framework::OpKernel { } else if (pooling_type == "avg") { std::vector reduce_dim; int reduce_num = getReduceNum(*in_x, out, data_format, &reduce_dim); - if (reduce_num > 0 && adaptive) { // for adaptive_avg_pool2d && output_size == 1 #if defined(__HIPCC__) || defined(__NVCC__) auto stream = dev_ctx.stream(); - TensorReduceFunctorImpl(*in_x, out, reduce_dim, - stream); + TensorReduceFunctorImpl>( + *in_x, out, kps::DivideFunctor(reduce_num), reduce_dim, + stream); #else // for cpu paddle::operators::math::Pool2dFunctor< DeviceContext, paddle::operators::math::AvgPool, T> diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu index 06cc9ed7a96..c6997603bb1 100644 --- a/paddle/fluid/operators/prelu_op.cu +++ b/paddle/fluid/operators/prelu_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/prelu.h" #include "paddle/fluid/operators/prelu_op.h" -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" namespace paddle { @@ -123,13 +123,6 @@ class PreluOpGradFunctor { } }; -struct IdentityFunctor { - template - HOSTDEVICE inline T operator()(const T& x) const { - return x; - } -}; - template class CUDAPReluGradKernel : public framework::OpKernel { public: @@ -192,9 +185,8 @@ class CUDAPReluGradKernel : public framework::OpKernel { reduce_dims.push_back(i); } - TensorReduce( - dalpha_tmp, dalpha, reduce_dims, static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + TensorReduceFunctorImpl>( + dalpha_tmp, dalpha, kps::IdentityFunctor(), reduce_dims, stream); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_all_op.cu b/paddle/fluid/operators/reduce_ops/reduce_all_op.cu index 674326f90c5..a1f1a228aeb 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_all_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_all_op.cu @@ -13,8 +13,7 @@ // limitations under the License. #include "paddle/fluid/operators/reduce_ops/reduce_all_op.h" -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" REGISTER_OP_CUDA_KERNEL( reduce_all, - ops::ReduceCudaKernel); + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_any_op.cu b/paddle/fluid/operators/reduce_ops/reduce_any_op.cu index b7b0eb59824..2e93e67debb 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_any_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_any_op.cu @@ -13,9 +13,8 @@ // limitations under the License. #include "paddle/fluid/operators/reduce_ops/reduce_any_op.h" -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" REGISTER_OP_CUDA_KERNEL( reduce_any, - ops::ReduceCudaKernel); + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op.cu b/paddle/fluid/operators/reduce_ops/reduce_max_op.cu index f214fcba199..8194805ddc3 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op.cu @@ -11,13 +11,13 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" // reduce_max REGISTER_OP_CUDA_KERNEL( - reduce_max, ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel); + reduce_max, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu index b5d5bb33d0a..a50b09564fd 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu @@ -13,11 +13,11 @@ // limitations under the License. #include -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" REGISTER_OP_CUDA_KERNEL( - reduce_mean, ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel); + reduce_mean, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_min_op.cu b/paddle/fluid/operators/reduce_ops/reduce_min_op.cu index 7806df284d8..44548b8d2e7 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_min_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_min_op.cu @@ -11,13 +11,13 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" // reduce_min REGISTER_OP_CUDA_KERNEL( - reduce_min, ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel); + reduce_min, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index 9c348477963..77fa5768843 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -44,11 +44,11 @@ namespace cub = hipcub; #define REDUCE_SPLIT_BOUNDARY 512 #define REDUCE_VEC_SIZE 4 +namespace kps = paddle::operators::kernel_primitives; + namespace paddle { namespace operators { -namespace kps = paddle::operators::kernel_primitives; - namespace details { static inline int GetLastPow2(int n) { @@ -722,12 +722,12 @@ __global__ void ReduceHigherDimKernel(const Tx* x, Ty* y, ReduceOp reducer, } } -template +template static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, - const ReduceOp& reducer, MPType init, + const ReduceOp& reducer, + const TransformOp& transform, MPType init, gpuStream_t stream, ReduceConfig config) { - using TransformOp = typename ReduceOp::Transformer; - if (config.reduce_type == kReduceLastDim) { int stride_reduce = 1; int stride_left = config.reduce_num; @@ -743,15 +743,15 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, #ifdef PADDLE_WITH_XPU2 ReduceAnyKernel<<<8, 128, stream>>>( - x_data, config.output_data, reducer, TransformOp(config.reduce_num), - init, config.reduce_num, config.left_num, config.reduce_last_dim, - reduce_index_calculator, left_index_calculator, dim); + x_data, config.output_data, reducer, transform, init, config.reduce_num, + config.left_num, config.reduce_last_dim, reduce_index_calculator, + left_index_calculator, dim); #else ReduceAnyKernel<<>>( - x_data, config.output_data, reducer, TransformOp(config.reduce_num), - init, config.reduce_num, config.left_num, config.reduce_last_dim, - reduce_index_calculator, left_index_calculator, dim); + x_data, config.output_data, reducer, transform, init, config.reduce_num, + config.left_num, config.reduce_last_dim, reduce_index_calculator, + left_index_calculator, dim); #endif } else { @@ -771,15 +771,15 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, #ifdef PADDLE_WITH_XPU2 ReduceAnyKernel<<<8, 128, stream>>>( - x_data, config.output_data, reducer, TransformOp(config.reduce_num), - init, config.reduce_num, config.left_num, config.reduce_last_dim, - reduce_index_calculator, left_index_calculator, dim); + x_data, config.output_data, reducer, transform, init, config.reduce_num, + config.left_num, config.reduce_last_dim, reduce_index_calculator, + left_index_calculator, dim); #else ReduceAnyKernel<<>>( - x_data, config.output_data, reducer, TransformOp(config.reduce_num), - init, config.reduce_num, config.left_num, config.reduce_last_dim, - reduce_index_calculator, left_index_calculator, dim); + x_data, config.output_data, reducer, transform, init, config.reduce_num, + config.left_num, config.reduce_last_dim, reduce_index_calculator, + left_index_calculator, dim); #endif } @@ -802,23 +802,22 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, #ifdef PADDLE_WITH_XPU2 ReduceHigherDimKernel><<<8, 128, stream>>>( - config.output_data, y_data, reducer, - kps::IdentityFunctor(config.grid.y), init, config.grid.y, - config.left_num, config.grid.y, dim); + config.output_data, y_data, reducer, kps::IdentityFunctor(), + init, config.grid.y, config.left_num, config.grid.y, dim); #else ReduceHigherDimKernel< Ty, Ty, MPType, ReduceOp, kps::IdentityFunctor><<>>( - config.output_data, y_data, reducer, - kps::IdentityFunctor(config.grid.y), init, config.grid.y, - config.left_num, config.grid.y, dim); + config.output_data, y_data, reducer, kps::IdentityFunctor(), + init, config.grid.y, config.left_num, config.grid.y, dim); #endif } } -template class ReduceOp> +template class ReduceOp, + typename TransformOp> void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, + const TransformOp& transform, std::vector origin_reduce_dims, gpuStream_t stream) { auto x_dim = framework::vectorize(x.dims()); @@ -853,10 +852,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, (!std::is_same::value); if (use_cub_reduce) { // launch CUB::Reduce - using TransformOp = typename ReduceOp::Transformer; - auto reducer = ReduceOp(); - cub::TransformInputIterator trans_x( - x_data, TransformOp(config.reduce_num)); + auto reducer = ReduceOp(); + cub::TransformInputIterator trans_x(x_data, + transform); size_t temp_storage_bytes = 0; cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, config.reduce_num, reducer, reducer.initial(), @@ -873,7 +871,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, } using MPType = typename details::MPTypeTrait::Type; - auto reducer = ReduceOp(); + auto reducer = ReduceOp(); // launch ReduceHigherDimKernel // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this // function will be used @@ -882,7 +880,6 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, // 32 // else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32 if (config.reduce_type == ReduceType::kReduceHigherDim) { - using TransformOp = typename ReduceOp::Transformer; kps::DimConfig dim = kps::DimConfig(config.grid.x, config.grid.y, config.grid.z, config.block.x, config.blocking_size, 0); @@ -890,18 +887,16 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, config.reduce_num % config.blocking_size, 0); #ifdef PADDLE_WITH_XPU2 - ReduceHigherDimKernel, + ReduceHigherDimKernel, TransformOp><<<8, 128, stream>>>( - x_data, config.output_data, reducer, TransformOp(config.reduce_num), - reducer.initial(), config.reduce_num, config.left_num, - config.blocking_size, dim); + x_data, config.output_data, reducer, transform, reducer.initial(), + config.reduce_num, config.left_num, config.blocking_size, dim); #else ReduceHigherDimKernel< - Tx, Ty, MPType, ReduceOp, + Tx, Ty, MPType, ReduceOp, TransformOp><<>>( - x_data, config.output_data, reducer, TransformOp(config.reduce_num), - reducer.initial(), config.reduce_num, config.left_num, - config.blocking_size, dim); + x_data, config.output_data, reducer, transform, reducer.initial(), + config.reduce_num, config.left_num, config.blocking_size, dim); #endif if (config.should_reduce_again) { @@ -913,14 +908,14 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, #ifdef PADDLE_WITH_XPU2 ReduceHigherDimKernel< - Ty, Ty, MPType, ReduceOp, + Ty, Ty, MPType, ReduceOp, kps::IdentityFunctor><<<8, 128, stream>>>( config.output_data, y_data, reducer, kps::IdentityFunctor(config.grid.y), reducer.initial(), config.grid.y, config.left_num, config.grid.y, dim2); #else ReduceHigherDimKernel< - Ty, Ty, MPType, ReduceOp, + Ty, Ty, MPType, ReduceOp, kps::IdentityFunctor><<>>( config.output_data, y_data, reducer, kps::IdentityFunctor(config.grid.y), reducer.initial(), @@ -933,23 +928,32 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this // function will be used - LaunchReduceKernel>( - x_data, y_data, reducer, reducer.initial(), stream, config); + LaunchReduceKernel, TransformOp>( + x_data, y_data, reducer, transform, reducer.initial(), stream, config); } -template class ReduceOp> +template class ReduceOp, + template class TransformOp> struct TensorReduceFunc { const framework::Tensor& x; framework::Tensor* y; std::vector origin_reduce_dims; gpuStream_t stream; + int reduce_num; TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y, - std::vector origin_reduce_dims, gpuStream_t stream) - : x(x), y(y), origin_reduce_dims(origin_reduce_dims), stream(stream) {} + std::vector origin_reduce_dims, int num_reduce, + gpuStream_t stream) + : x(x), + y(y), + origin_reduce_dims(origin_reduce_dims), + reduce_num(num_reduce), + stream(stream) {} template void apply() const { - TensorReduceFunctorImpl(x, y, origin_reduce_dims, stream); + using MPType = typename details::MPTypeTrait::Type; + TensorReduceFunctorImpl>( + x, y, TransformOp(reduce_num), origin_reduce_dims, stream); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 9259ca0e6f4..ea9b272878c 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -670,7 +670,8 @@ If reduce_all is true, just reduce along all dimensions and output a scalar. }; #if defined(__HIPCC__) || defined(__NVCC__) -template class ReduceOp> +template class ReduceOp, + template class TransformOp> class ReduceCudaKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -682,15 +683,19 @@ class ReduceCudaKernel : public framework::OpKernel { std::vector reduce_dims = GetReduceDim(dims, input->dims().size(), reduce_all); - + int reduce_num = 1; + for (int i = 0; i < input->dims().size(); i++) { + reduce_num *= (input->dims())[i]; + } gpuStream_t stream = context.cuda_device_context().stream(); if (out_dtype >= 0) { framework::VisitDataTypeSmall( static_cast(out_dtype), - TensorReduceFunc(*input, output, reduce_dims, stream)); + TensorReduceFunc( + *input, output, reduce_dims, reduce_num, stream)); } else { - TensorReduceFunctorImpl(*input, output, reduce_dims, - stream); + TensorReduceFunctorImpl>( + *input, output, TransformOp(reduce_num), reduce_dims, stream); } } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu index 317a6e1d93c..2de647df8b1 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h" REGISTER_OP_CUDA_KERNEL( - reduce_prod, ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel); + reduce_prod, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu index 27a29a5b095..ea9a89bea97 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu @@ -11,18 +11,18 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" REGISTER_OP_CUDA_KERNEL( - reduce_sum, ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - paddle::operators::CustomSum>, - ops::ReduceCudaKernel, - paddle::operators::CustomSum>); + reduce_sum, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, kps::AddFunctor, + kps::IdentityFunctor>, + ops::ReduceCudaKernel, kps::AddFunctor, + kps::IdentityFunctor>); diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu index f3fe32e10a5..98a77637f92 100644 --- a/paddle/fluid/operators/trace_op.cu +++ b/paddle/fluid/operators/trace_op.cu @@ -15,21 +15,12 @@ #include #include #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/trace_op.h" namespace paddle { namespace operators { -struct IdentityFunctor { - HOSTDEVICE explicit inline IdentityFunctor() {} - - template - HOSTDEVICE inline U operator()(const U& x) const { - return x; - } -}; - template class TraceCUDAKernel : public framework::OpKernel { public: @@ -48,9 +39,8 @@ class TraceCUDAKernel : public framework::OpKernel { auto stream = context.cuda_device_context().stream(); std::vector reduce_dims; reduce_dims.push_back(out->dims().size()); - TensorReduce( - diag, out, reduce_dims, static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + TensorReduceFunctorImpl>( + diag, out, kps::IdentityFunctor(), reduce_dims, stream); } else { math::SetConstant functor; functor(context.device_context(), out, static_cast(0)); diff --git a/paddle/fluid/operators/triangular_solve_op.cu b/paddle/fluid/operators/triangular_solve_op.cu index c5218aec03e..dfd48fb47e5 100644 --- a/paddle/fluid/operators/triangular_solve_op.cu +++ b/paddle/fluid/operators/triangular_solve_op.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/triangular_solve_op.h" @@ -44,7 +43,8 @@ struct MatrixReduceSumFunctor { } } gpuStream_t stream = ctx.cuda_device_context().stream(); - TensorReduceFunctorImpl(in, out, out_reduce_dims, stream); + TensorReduceFunctorImpl>( + in, out, kps::IdentityFunctor(), out_reduce_dims, stream); } }; -- GitLab