From f469f176d66b3e65df9756ff3dc8d98b301a0f63 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Tue, 20 Jun 2023 16:33:51 +0800 Subject: [PATCH] Remove reduntant definition of MPTypeTrait. (#54756) --- paddle/phi/kernels/funcs/dropout_impl.cu.h | 10 +++++----- paddle/phi/kernels/funcs/reduce_function.h | 2 +- .../gpu/fused_dropout_add_grad_kernel.cu | 4 ++-- .../fusion/gpu/fused_dropout_add_kernel.cu | 6 +++--- paddle/phi/kernels/gpu/exponential_kernel.cu | 2 +- .../phi/kernels/gpu/group_norm_grad_kernel.cu | 4 ++-- paddle/phi/kernels/gpu/multinomial_kernel.cu | 2 +- paddle/phi/kernels/gpu/reduce.h | 6 +++--- .../phi/kernels/gpu/reduce_amin_amax_common.h | 2 +- .../kernels/gpu/reduce_mean_grad_kernel.cu | 2 +- .../phi/kernels/gpu/reduce_sum_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/rrelu_kernel.cu | 2 +- .../phi/kernels/gpu/uniform_inplace_kernel.cu | 2 +- paddle/phi/kernels/gpu/uniform_kernel.cu | 2 +- .../phi/kernels/legacy/gpu/uniform_kernel.cu | 2 +- .../kernels/primitive/compute_primitives.h | 20 +------------------ 16 files changed, 26 insertions(+), 44 deletions(-) diff --git a/paddle/phi/kernels/funcs/dropout_impl.cu.h b/paddle/phi/kernels/funcs/dropout_impl.cu.h index 48a7008463c..a1fc2c225ec 100644 --- a/paddle/phi/kernels/funcs/dropout_impl.cu.h +++ b/paddle/phi/kernels/funcs/dropout_impl.cu.h @@ -40,7 +40,7 @@ namespace funcs { template struct DstFunctor { - using MT = typename phi::kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; HOSTDEVICE inline DstFunctor(const float retain_prob, const bool is_upscale_in_train, @@ -90,7 +90,7 @@ struct MaskFunctor { template struct DstMaskFunctor { - using MT = typename phi::kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; HOSTDEVICE inline DstMaskFunctor(const float retain_prob, const bool is_upscale_in_train) : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) { @@ -386,7 +386,7 @@ void DropoutFwGPUKernelDriver( // y = x phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, y); } else { - using MT = typename phi::kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; MT factor = static_cast(1.0f - dropout_prob); // y = factor * x ScaleByDropoutFactor(dev_ctx, x, y, factor); @@ -396,7 +396,7 @@ void DropoutFwGPUKernelDriver( template struct CudaDropoutGradFunctor { - using MT = typename phi::kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {} @@ -419,7 +419,7 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, const phi::DenseTensor& mask, phi::DenseTensor* grad_x, bool is_dropout_nd = false) { - using MT = typename phi::kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; auto stream = dev_ctx.stream(); if (is_test) { diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index bf110fcdd9e..5e738d431df 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -1047,7 +1047,7 @@ void ReduceKernel(const KPDevice& dev_ctx, } #endif - using MPType = typename kps::details::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; auto reducer = ReduceOp(); // launch ReduceHigherDimKernel // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu index 6ea21f3bd48..dce2f8e5247 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu @@ -62,7 +62,7 @@ __global__ void FuseScaleAddGradRateZero(const T* grad, template struct NoMaskBwFunctor { const float retain_prob_; - using MT = typename phi::kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; MT factor_; HOSTDEVICE inline NoMaskBwFunctor(const float retain_prob) : retain_prob_(retain_prob) { @@ -171,7 +171,7 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx, auto* y_grad_data = dev_ctx.template Alloc(y_grad); const auto* out_grad_data = out_grad.data(); - using MT = typename phi::kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; int blocks = NumBlocks(numel); int threads = kNumCUDAThreads; diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu index afdef3f4b58..3cb1a674254 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu @@ -29,7 +29,7 @@ template struct NoMaskFwFunctor { const float retain_prob_; const bool is_upscale_in_train_; - using MT = typename phi::kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; MT factor; HOSTDEVICE inline NoMaskFwFunctor(const float retain_prob, const bool is_upscale_in_train) @@ -59,7 +59,7 @@ struct NoMaskFwFunctor { template struct ScaleAddFuctor { - using MT = typename phi::kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; explicit ScaleAddFuctor(const MT factor, bool upscale_in_train) : factor_(factor), upscale_in_train_(upscale_in_train) {} @@ -206,7 +206,7 @@ void FusedDropoutAddKernel(const Context& dev_ctx, dst_functor); #undef PD_DROPOUT_KERNEL_NAME } else { - using MT = typename phi::kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; MT factor = static_cast(1.0f - dropout_rate); std::vector outs = {out}; std::vector ins = {&x, &y}; diff --git a/paddle/phi/kernels/gpu/exponential_kernel.cu b/paddle/phi/kernels/gpu/exponential_kernel.cu index 7d6e1d54d1e..3a29e1dd4a2 100644 --- a/paddle/phi/kernels/gpu/exponential_kernel.cu +++ b/paddle/phi/kernels/gpu/exponential_kernel.cu @@ -25,7 +25,7 @@ void ExponentialKernel(const Context &dev_ctx, const DenseTensor &x, float lambda, DenseTensor *out) { - using MT = typename kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; phi::funcs::uniform_distribution dist; phi::funcs::exponential_transform trans(lambda); phi::funcs::distribution_and_transform(dev_ctx, out, dist, trans); diff --git a/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu index 3cbd1d8191c..a9980f805f8 100644 --- a/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu @@ -107,7 +107,7 @@ __global__ void GroupNormBackward(const T* x, int group_size, float epsilon, T* d_x) { - // using AccT = typename kps::details::MPTypeTrait::Type; + // using AccT = typename phi::dtype::MPTypeTrait::Type; int gid = blockIdx.y; int cid = blockIdx.x; @@ -279,7 +279,7 @@ void GroupNormGradKernel(const Context& dev_ctx, DenseTensor* d_x, DenseTensor* d_scale, DenseTensor* d_bias) { - using AccT = typename kps::details::MPTypeTrait::Type; + using AccT = typename phi::dtype::MPTypeTrait::Type; const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); const auto scale_ptr = scale.get_ptr(); const auto bias_ptr = bias.get_ptr(); diff --git a/paddle/phi/kernels/gpu/multinomial_kernel.cu b/paddle/phi/kernels/gpu/multinomial_kernel.cu index 039a5e2c8b9..effc963cd0a 100644 --- a/paddle/phi/kernels/gpu/multinomial_kernel.cu +++ b/paddle/phi/kernels/gpu/multinomial_kernel.cu @@ -132,7 +132,7 @@ void MultinomialKernel(const Context& dev_ctx, const Scalar& num_samples, bool replacement, DenseTensor* out) { - using MT = typename kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; auto int_num_samples = num_samples.to(); auto* in_data = x.data(); diff --git a/paddle/phi/kernels/gpu/reduce.h b/paddle/phi/kernels/gpu/reduce.h index 5ceb81eabd8..cc3cad38f46 100644 --- a/paddle/phi/kernels/gpu/reduce.h +++ b/paddle/phi/kernels/gpu/reduce.h @@ -55,7 +55,7 @@ void Reduce(const KPDevice& dev_ctx, out_dtype, "ReduceKernel", ([&] { - using MPType = typename kps::details::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; phi::funcs::ReduceKernel::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; phi::funcs::ReduceKernel>( dev_ctx, x, @@ -78,7 +78,7 @@ void Reduce(const KPDevice& dev_ctx, is_mean); } #else - using MPType = typename kps::details::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; phi::funcs::ReduceKernel>( dev_ctx, x, diff --git a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h index 04befb29b2d..fb0eace755e 100644 --- a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h +++ b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h @@ -81,7 +81,7 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, funcs::BroadcastKernel( dev_ctx, equal_inputs, &equal_outputs, funcs::EqualFunctor(), 0); // 2. equal_count = reduceSum(equal_out) - using MPType = typename kps::details::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; phi::funcs:: ReduceKernel>( dev_ctx, diff --git a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu index 13683af9cb9..ccf95042b40 100644 --- a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu @@ -52,7 +52,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx, std::vector inputs = {&new_out_grad}; std::vector outputs = {x_grad}; - using MPType = typename kps::details::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; funcs::BroadcastKernel( dev_ctx, inputs, &outputs, kps::DivideFunctor(reduce_num), 0); } diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu index 3e88506f723..8083fb1ab2d 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -47,7 +47,7 @@ void ReduceSumGradKernel(const Context& dev_ctx, // call ReduceGrad dev_ctx.Alloc(x_grad, x.dtype()); - using MPType = typename kps::details::MPTypeTrait::Type; + using MPType = typename phi::dtype::MPTypeTrait::Type; phi::ReduceGrad>( dev_ctx, &new_out_grad, diff --git a/paddle/phi/kernels/gpu/rrelu_kernel.cu b/paddle/phi/kernels/gpu/rrelu_kernel.cu index b15e525a3bc..78b8696bd10 100644 --- a/paddle/phi/kernels/gpu/rrelu_kernel.cu +++ b/paddle/phi/kernels/gpu/rrelu_kernel.cu @@ -93,7 +93,7 @@ void RReluKernel(const Context& ctx, RReluTestCudaFunctor functor(x_data, out_data, noise_data, mid_val); for_range(functor); } else { - using MT = typename kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; funcs::uniform_distribution dist; funcs::uniform_real_transform trans(lower, upper); funcs::distribution_and_transform(ctx, noise, dist, trans); diff --git a/paddle/phi/kernels/gpu/uniform_inplace_kernel.cu b/paddle/phi/kernels/gpu/uniform_inplace_kernel.cu index 5c3a886ad87..653a64b127a 100644 --- a/paddle/phi/kernels/gpu/uniform_inplace_kernel.cu +++ b/paddle/phi/kernels/gpu/uniform_inplace_kernel.cu @@ -67,7 +67,7 @@ void UniformInplaceKernel(const Context& ctx, ctx.template Alloc(out); if (seed == 0) { // Use global Generator seed - using MT = typename kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; funcs::uniform_distribution dist; funcs::uniform_real_transform trans(min, max); funcs::distribution_and_transform(ctx, out, dist, trans); diff --git a/paddle/phi/kernels/gpu/uniform_kernel.cu b/paddle/phi/kernels/gpu/uniform_kernel.cu index 1ba5847fa29..04217db0a74 100644 --- a/paddle/phi/kernels/gpu/uniform_kernel.cu +++ b/paddle/phi/kernels/gpu/uniform_kernel.cu @@ -65,7 +65,7 @@ void UniformKernel(const Context& dev_ctx, dev_ctx.template Alloc(out); if (seed == 0) { // Use global Generator seed - using MT = typename kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; funcs::uniform_distribution dist; funcs::uniform_real_transform trans(min.to(), max.to()); funcs::distribution_and_transform(dev_ctx, out, dist, trans); diff --git a/paddle/phi/kernels/legacy/gpu/uniform_kernel.cu b/paddle/phi/kernels/legacy/gpu/uniform_kernel.cu index 211c7accf6f..609238435c9 100644 --- a/paddle/phi/kernels/legacy/gpu/uniform_kernel.cu +++ b/paddle/phi/kernels/legacy/gpu/uniform_kernel.cu @@ -68,7 +68,7 @@ void UniformRawKernel(const Context& dev_ctx, dev_ctx.template Alloc(out); if (seed == 0) { // Use global Generator seed - using MT = typename kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; funcs::uniform_distribution dist; funcs::uniform_real_transform trans(min.to(), max.to()); funcs::distribution_and_transform(dev_ctx, out, dist, trans); diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index 24b961abb9b..30c2636a2bd 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -22,7 +22,7 @@ #endif #include "paddle/phi/backends/gpu/gpu_device_function.h" -#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/amp_type_traits.h" namespace phi { namespace kps { @@ -40,24 +40,6 @@ constexpr int kWarpSize = 32; // kLocalMode: thread reduce, each thread gets an output; enum ReduceMode { kGlobalMode, kLocalMode }; -template -class MPTypeTrait { - public: - using Type = T; -}; - -template <> -class MPTypeTrait { - public: - using Type = float; -}; - -template <> -class MPTypeTrait { - public: - using Type = float; -}; - /** * @brief Will be used in BlockYReduce, get the index of reduce_num in shared * memory. -- GitLab