From 7f346a76bc4e5fabdba3e54613f711acdeb74045 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Sun, 18 Sep 2022 11:51:48 +0800 Subject: [PATCH] Delete redundant param in SoftmaxFunctor (#46003) * perfect softmax functor * fix compile bugs * fix ci bugs --- paddle/fluid/operators/math/softmax.cc | 6 +-- paddle/fluid/operators/math/softmax.cu | 12 ++--- paddle/fluid/operators/math/softmax.h | 5 +- paddle/fluid/operators/math/softmax_impl.h | 46 +++++-------------- .../phi/kernels/cpu/gumbel_softmax_kernel.cc | 7 --- .../phi/kernels/gpu/gumbel_softmax_kernel.cu | 7 --- paddle/phi/kernels/gumbel_softmax_kernel.h | 8 ---- .../kernels/impl/gumbel_softmax_kernel_impl.h | 26 ++--------- paddle/phi/kernels/impl/softmax_kernel_impl.h | 2 +- paddle/phi/ops/compat/gumbel_softmax_sig.cc | 15 +----- 10 files changed, 25 insertions(+), 109 deletions(-) diff --git a/paddle/fluid/operators/math/softmax.cc b/paddle/fluid/operators/math/softmax.cc index 730dcbf59a..216658b3d7 100644 --- a/paddle/fluid/operators/math/softmax.cc +++ b/paddle/fluid/operators/math/softmax.cc @@ -21,10 +21,8 @@ namespace paddle { namespace operators { namespace math { -template class SoftmaxFunctor; -template class SoftmaxFunctor; -template class SoftmaxFunctor; -template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 47621883fd..6729b962f2 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -156,14 +156,10 @@ template class SoftmaxCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; #endif -template class SoftmaxFunctor; -template class SoftmaxFunctor; -template class SoftmaxFunctor; -template class SoftmaxFunctor; -template class SoftmaxFunctor; -template class SoftmaxFunctor; -template class SoftmaxFunctor; -template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; diff --git a/paddle/fluid/operators/math/softmax.h b/paddle/fluid/operators/math/softmax.h index 0ed3116a55..958244bdbb 100644 --- a/paddle/fluid/operators/math/softmax.h +++ b/paddle/fluid/operators/math/softmax.h @@ -19,10 +19,7 @@ namespace paddle { namespace operators { namespace math { -template +template class SoftmaxFunctor { public: void operator()(const DeviceContext& context, diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index 7cf7b25233..8a0eb2ad7a 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -42,7 +42,7 @@ struct ValueClip { } }; -template +template class SoftmaxEigen { public: void operator()(const DeviceContext& context, @@ -103,8 +103,8 @@ class SoftmaxEigen { } }; -template -class SoftmaxEigen { +template +class SoftmaxEigen { public: void operator()(const DeviceContext& context, const int axis_dim, @@ -161,8 +161,8 @@ class SoftmaxEigen { } }; -template -class SoftmaxEigen { +template +class SoftmaxEigen { public: void operator()(const DeviceContext& context, const int axis_dim, @@ -219,21 +219,21 @@ class SoftmaxEigen { } }; -template -void SoftmaxFunctor::operator()( +template +void SoftmaxFunctor::operator()( const DeviceContext& context, const int axis_dim, const framework::Tensor* X, framework::Tensor* Y) { - SoftmaxEigen()(context, axis_dim, X, Y); + SoftmaxEigen()(context, axis_dim, X, Y); } template using enable_if_CPU = typename std::enable_if< std::is_same::value>::type; -template -class SoftmaxFunctor> { +template +class SoftmaxFunctor> { public: void operator()(const DeviceContext& context, const int axis_dim, @@ -267,35 +267,11 @@ class SoftmaxFunctor> { out_data += num_classes; } } else { - SoftmaxEigen()(context, axis_dim, X, Y); + SoftmaxEigen()(context, axis_dim, X, Y); } } }; -template -class SoftmaxFunctor> { - public: - void operator()(const DeviceContext& context, - const int axis_dim, - const framework::Tensor* X, - framework::Tensor* Y) { - const auto& in_dims = X->dims(); - const float* in_data = X->data(); - float* out_data = Y->data(); - const int kBatchDim = 0; - const int kClassDim = 1; - // 2D data. Batch x C - auto compute_softmax = - jit::KernelFuncs, platform::CPUPlace>::Cache() - .At(in_dims[kClassDim]); - compute_softmax(in_data, - out_data, - in_dims[kClassDim], - in_dims[kBatchDim], - in_dims[kClassDim] / axis_dim); - } -}; - template class SoftmaxGradEigen { public: diff --git a/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc b/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc index d9a6df2794..7638ca3aa7 100644 --- a/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc +++ b/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc @@ -119,10 +119,3 @@ struct OneHotGenerator { PD_REGISTER_KERNEL( gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} - -PD_REGISTER_KERNEL(gumbel_softmax_infer, - CPU, - ALL_LAYOUT, - phi::GumbelSoftmaxInferKernel, - float, - double) {} diff --git a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu index c9ee74f0dd..33bf0eba38 100644 --- a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu +++ b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu @@ -170,10 +170,3 @@ struct GumbleNoiseGenerator { PD_REGISTER_KERNEL( gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} - -PD_REGISTER_KERNEL(gumbel_softmax_infer, - GPU, - ALL_LAYOUT, - phi::GumbelSoftmaxInferKernel, - float, - double) {} diff --git a/paddle/phi/kernels/gumbel_softmax_kernel.h b/paddle/phi/kernels/gumbel_softmax_kernel.h index 4ba1e56142..46edb9750d 100644 --- a/paddle/phi/kernels/gumbel_softmax_kernel.h +++ b/paddle/phi/kernels/gumbel_softmax_kernel.h @@ -25,12 +25,4 @@ void GumbelSoftmaxKernel(const Context& dev_ctx, int axis, DenseTensor* out); -template -void GumbelSoftmaxInferKernel(const Context& dev_ctx, - const DenseTensor& x, - float temperature, - bool hard, - int axis, - DenseTensor* out); - } // namespace phi diff --git a/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h index ed800e70f5..e310d4a167 100644 --- a/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h +++ b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h @@ -48,8 +48,7 @@ void GumbelSoftmaxKernelHelper(const Context& ctx, float temperature, bool hard, int axis, - DenseTensor* out, - bool is_test) { + DenseTensor* out) { const int rank = x.dims().size(); axis = funcs::CanonicalAxis(axis, rank); int axis_dim = x.dims()[axis]; @@ -81,13 +80,8 @@ void GumbelSoftmaxKernelHelper(const Context& ctx, size_to_axis, size_from_axis, temperature); - if (is_test) { - paddle::operators::math::SoftmaxFunctor()( - ctx, axis_dim, &x_noise_2d, &out_2d); - } else { - paddle::operators::math::SoftmaxFunctor()( - ctx, axis_dim, &x_noise_2d, &out_2d); - } + paddle::operators::math::SoftmaxFunctor()( + ctx, axis_dim, &x_noise_2d, &out_2d); if (hard) { OneHotGenerator::Transform(ctx, x, out, axis); @@ -101,19 +95,7 @@ void GumbelSoftmaxKernel(const Context& ctx, bool hard, int axis, DenseTensor* out) { - GumbelSoftmaxKernelHelper( - ctx, x, temperature, hard, axis, out, false); -} - -template -void GumbelSoftmaxInferKernel(const Context& ctx, - const DenseTensor& x, - float temperature, - bool hard, - int axis, - DenseTensor* out) { - GumbelSoftmaxKernelHelper( - ctx, x, temperature, hard, axis, out, true); + GumbelSoftmaxKernelHelper(ctx, x, temperature, hard, axis, out); } } // namespace phi diff --git a/paddle/phi/kernels/impl/softmax_kernel_impl.h b/paddle/phi/kernels/impl/softmax_kernel_impl.h index aa0ebf2570..5f7d097242 100644 --- a/paddle/phi/kernels/impl/softmax_kernel_impl.h +++ b/paddle/phi/kernels/impl/softmax_kernel_impl.h @@ -40,7 +40,7 @@ void SoftmaxKernel(const Context& dev_ctx, DenseTensor X_2d, Out_2d; X_2d.ShareDataWith(x).Resize({n, d}); Out_2d.ShareDataWith(*out).Resize({n, d}); - paddle::operators::math::SoftmaxFunctor()( + paddle::operators::math::SoftmaxFunctor()( dev_ctx, axis_dim, &X_2d, &Out_2d); } diff --git a/paddle/phi/ops/compat/gumbel_softmax_sig.cc b/paddle/phi/ops/compat/gumbel_softmax_sig.cc index b4afa64c1d..54d3d55bf5 100644 --- a/paddle/phi/ops/compat/gumbel_softmax_sig.cc +++ b/paddle/phi/ops/compat/gumbel_softmax_sig.cc @@ -18,19 +18,8 @@ namespace phi { KernelSignature GumbelSoftmaxOpArgumentMapping( const ArgumentMappingContext& ctx) { - bool is_test = false; - if (ctx.HasAttr("is_test")) { - is_test = paddle::any_cast(ctx.Attr("is_test")); - } - if (is_test) { - return KernelSignature("gumbel_softmax_infer", - {"X"}, - {"temperature", "hard", "axis"}, - {"Out"}); - } else { - return KernelSignature( - "gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"}); - } + return KernelSignature( + "gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"}); } KernelSignature GumbelSoftmaxGradOpArgumentMapping( -- GitLab