From 018884829ecaa35754deb5e70b7d334aed3f8f9f Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Tue, 13 Sep 2022 20:01:51 +0800 Subject: [PATCH] add softmax infer kernel (#45955) * add softmax infer kernel --- paddle/fluid/framework/ir/is_test_pass.cc | 23 +++++---- .../phi/kernels/cpu/gumbel_softmax_kernel.cc | 7 +++ .../phi/kernels/gpu/gumbel_softmax_kernel.cu | 7 +++ paddle/phi/kernels/gumbel_softmax_kernel.h | 8 +++ .../kernels/impl/gumbel_softmax_kernel_impl.h | 50 +++++++++++++------ paddle/phi/ops/compat/gumbel_softmax_sig.cc | 18 +++++++ 6 files changed, 88 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/framework/ir/is_test_pass.cc b/paddle/fluid/framework/ir/is_test_pass.cc index a97873e82f..47a3e46d07 100644 --- a/paddle/fluid/framework/ir/is_test_pass.cc +++ b/paddle/fluid/framework/ir/is_test_pass.cc @@ -25,17 +25,18 @@ class Graph; void IsTestPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it " "for activations and pooling."; - auto op_list = {"pool2d", "sigmoid", "logsigmoid", - "softshrink", "exp", "brelu", - "pow", "leaky_relu", "stanh", - "relu", "tanh", "tanh_shrink", - "sqrt", "abs", "ceil", - "elu", "floor", "cos", - "sin", "round", "reciprocal", - "hard_shrink", "hard_sigmoid", "relu6", - "soft_relu", "swish", "thresholded_relu", - "log", "square", "softplus", - "softsign", "silu", "mish"}; + auto op_list = {"pool2d", "sigmoid", "logsigmoid", + "softshrink", "exp", "brelu", + "pow", "leaky_relu", "stanh", + "relu", "tanh", "tanh_shrink", + "sqrt", "abs", "ceil", + "elu", "floor", "cos", + "sin", "round", "reciprocal", + "hard_shrink", "hard_sigmoid", "relu6", + "soft_relu", "swish", "thresholded_relu", + "log", "square", "softplus", + "softsign", "silu", "mish", + "gumbel_softmax"}; for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); diff --git a/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc b/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc index 7638ca3aa7..d9a6df2794 100644 --- a/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc +++ b/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc @@ -119,3 +119,10 @@ struct OneHotGenerator { PD_REGISTER_KERNEL( gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} + +PD_REGISTER_KERNEL(gumbel_softmax_infer, + CPU, + ALL_LAYOUT, + phi::GumbelSoftmaxInferKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu index 33bf0eba38..c9ee74f0dd 100644 --- a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu +++ b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu @@ -170,3 +170,10 @@ struct GumbleNoiseGenerator { PD_REGISTER_KERNEL( gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} + +PD_REGISTER_KERNEL(gumbel_softmax_infer, + GPU, + ALL_LAYOUT, + phi::GumbelSoftmaxInferKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gumbel_softmax_kernel.h b/paddle/phi/kernels/gumbel_softmax_kernel.h index 46edb9750d..4ba1e56142 100644 --- a/paddle/phi/kernels/gumbel_softmax_kernel.h +++ b/paddle/phi/kernels/gumbel_softmax_kernel.h @@ -25,4 +25,12 @@ void GumbelSoftmaxKernel(const Context& dev_ctx, int axis, DenseTensor* out); +template +void GumbelSoftmaxInferKernel(const Context& dev_ctx, + const DenseTensor& x, + float temperature, + bool hard, + int axis, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h index 655634e319..ed800e70f5 100644 --- a/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h +++ b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h @@ -43,12 +43,13 @@ template struct OneHotGenerator; template -void GumbelSoftmaxKernel(const Context& ctx, - const DenseTensor& x, - float temperature, - bool hard, - int axis, - DenseTensor* out) { +void GumbelSoftmaxKernelHelper(const Context& ctx, + const DenseTensor& x, + float temperature, + bool hard, + int axis, + DenseTensor* out, + bool is_test) { const int rank = x.dims().size(); axis = funcs::CanonicalAxis(axis, rank); int axis_dim = x.dims()[axis]; @@ -80,18 +81,39 @@ void GumbelSoftmaxKernel(const Context& ctx, size_to_axis, size_from_axis, temperature); - -#ifdef PADDLE_ON_INFERENCE - paddle::operators::math::SoftmaxFunctor()( - ctx, axis_dim, &x_noise_2d, &out_2d); -#else - paddle::operators::math::SoftmaxFunctor()( - ctx, axis_dim, &x_noise_2d, &out_2d); -#endif + if (is_test) { + paddle::operators::math::SoftmaxFunctor()( + ctx, axis_dim, &x_noise_2d, &out_2d); + } else { + paddle::operators::math::SoftmaxFunctor()( + ctx, axis_dim, &x_noise_2d, &out_2d); + } if (hard) { OneHotGenerator::Transform(ctx, x, out, axis); } } +template +void GumbelSoftmaxKernel(const Context& ctx, + const DenseTensor& x, + float temperature, + bool hard, + int axis, + DenseTensor* out) { + GumbelSoftmaxKernelHelper( + ctx, x, temperature, hard, axis, out, false); +} + +template +void GumbelSoftmaxInferKernel(const Context& ctx, + const DenseTensor& x, + float temperature, + bool hard, + int axis, + DenseTensor* out) { + GumbelSoftmaxKernelHelper( + ctx, x, temperature, hard, axis, out, true); +} + } // namespace phi diff --git a/paddle/phi/ops/compat/gumbel_softmax_sig.cc b/paddle/phi/ops/compat/gumbel_softmax_sig.cc index 65537f8c89..b4afa64c1d 100644 --- a/paddle/phi/ops/compat/gumbel_softmax_sig.cc +++ b/paddle/phi/ops/compat/gumbel_softmax_sig.cc @@ -16,6 +16,23 @@ limitations under the License. */ namespace phi { +KernelSignature GumbelSoftmaxOpArgumentMapping( + const ArgumentMappingContext& ctx) { + bool is_test = false; + if (ctx.HasAttr("is_test")) { + is_test = paddle::any_cast(ctx.Attr("is_test")); + } + if (is_test) { + return KernelSignature("gumbel_softmax_infer", + {"X"}, + {"temperature", "hard", "axis"}, + {"Out"}); + } else { + return KernelSignature( + "gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"}); + } +} + KernelSignature GumbelSoftmaxGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( @@ -24,5 +41,6 @@ KernelSignature GumbelSoftmaxGradOpArgumentMapping( } // namespace phi +PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax, phi::GumbelSoftmaxOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax_grad, phi::GumbelSoftmaxGradOpArgumentMapping); -- GitLab