diff --git a/paddle/fluid/framework/ir/is_test_pass.cc b/paddle/fluid/framework/ir/is_test_pass.cc index a97873e82f455461d21bf2a7cdf293e9d2bcba24..47a3e46d076c318cc709d1403c22cb4920a8648b 100644 --- a/paddle/fluid/framework/ir/is_test_pass.cc +++ b/paddle/fluid/framework/ir/is_test_pass.cc @@ -25,17 +25,18 @@ class Graph; void IsTestPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it " "for activations and pooling."; - auto op_list = {"pool2d", "sigmoid", "logsigmoid", - "softshrink", "exp", "brelu", - "pow", "leaky_relu", "stanh", - "relu", "tanh", "tanh_shrink", - "sqrt", "abs", "ceil", - "elu", "floor", "cos", - "sin", "round", "reciprocal", - "hard_shrink", "hard_sigmoid", "relu6", - "soft_relu", "swish", "thresholded_relu", - "log", "square", "softplus", - "softsign", "silu", "mish"}; + auto op_list = {"pool2d", "sigmoid", "logsigmoid", + "softshrink", "exp", "brelu", + "pow", "leaky_relu", "stanh", + "relu", "tanh", "tanh_shrink", + "sqrt", "abs", "ceil", + "elu", "floor", "cos", + "sin", "round", "reciprocal", + "hard_shrink", "hard_sigmoid", "relu6", + "soft_relu", "swish", "thresholded_relu", + "log", "square", "softplus", + "softsign", "silu", "mish", + "gumbel_softmax"}; for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); diff --git a/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc b/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc index 7638ca3aa7ee63f521d59ebabcd5d2930a2e5d0b..d9a6df2794f59125887da2de47de6e6333f620fb 100644 --- a/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc +++ b/paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc @@ -119,3 +119,10 @@ struct OneHotGenerator { PD_REGISTER_KERNEL( gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} + +PD_REGISTER_KERNEL(gumbel_softmax_infer, + CPU, + ALL_LAYOUT, + phi::GumbelSoftmaxInferKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu index 33bf0eba380e446104059f35e37fb8ba556af16e..c9ee74f0ddf34a7e2d78606cc478578d1225108d 100644 --- a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu +++ b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu @@ -170,3 +170,10 @@ struct GumbleNoiseGenerator { PD_REGISTER_KERNEL( gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} + +PD_REGISTER_KERNEL(gumbel_softmax_infer, + GPU, + ALL_LAYOUT, + phi::GumbelSoftmaxInferKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gumbel_softmax_kernel.h b/paddle/phi/kernels/gumbel_softmax_kernel.h index 46edb9750dd34832b1c908822f6e322e548db951..4ba1e56142d9bd427fee72187a771db75b69d4c2 100644 --- a/paddle/phi/kernels/gumbel_softmax_kernel.h +++ b/paddle/phi/kernels/gumbel_softmax_kernel.h @@ -25,4 +25,12 @@ void GumbelSoftmaxKernel(const Context& dev_ctx, int axis, DenseTensor* out); +template +void GumbelSoftmaxInferKernel(const Context& dev_ctx, + const DenseTensor& x, + float temperature, + bool hard, + int axis, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h index 655634e319924d3e82b54c573c1405c470e06554..ed800e70f5a36dd379e898fbe3193c20a8c10397 100644 --- a/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h +++ b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h @@ -43,12 +43,13 @@ template struct OneHotGenerator; template -void GumbelSoftmaxKernel(const Context& ctx, - const DenseTensor& x, - float temperature, - bool hard, - int axis, - DenseTensor* out) { +void GumbelSoftmaxKernelHelper(const Context& ctx, + const DenseTensor& x, + float temperature, + bool hard, + int axis, + DenseTensor* out, + bool is_test) { const int rank = x.dims().size(); axis = funcs::CanonicalAxis(axis, rank); int axis_dim = x.dims()[axis]; @@ -80,18 +81,39 @@ void GumbelSoftmaxKernel(const Context& ctx, size_to_axis, size_from_axis, temperature); - -#ifdef PADDLE_ON_INFERENCE - paddle::operators::math::SoftmaxFunctor()( - ctx, axis_dim, &x_noise_2d, &out_2d); -#else - paddle::operators::math::SoftmaxFunctor()( - ctx, axis_dim, &x_noise_2d, &out_2d); -#endif + if (is_test) { + paddle::operators::math::SoftmaxFunctor()( + ctx, axis_dim, &x_noise_2d, &out_2d); + } else { + paddle::operators::math::SoftmaxFunctor()( + ctx, axis_dim, &x_noise_2d, &out_2d); + } if (hard) { OneHotGenerator::Transform(ctx, x, out, axis); } } +template +void GumbelSoftmaxKernel(const Context& ctx, + const DenseTensor& x, + float temperature, + bool hard, + int axis, + DenseTensor* out) { + GumbelSoftmaxKernelHelper( + ctx, x, temperature, hard, axis, out, false); +} + +template +void GumbelSoftmaxInferKernel(const Context& ctx, + const DenseTensor& x, + float temperature, + bool hard, + int axis, + DenseTensor* out) { + GumbelSoftmaxKernelHelper( + ctx, x, temperature, hard, axis, out, true); +} + } // namespace phi diff --git a/paddle/phi/ops/compat/gumbel_softmax_sig.cc b/paddle/phi/ops/compat/gumbel_softmax_sig.cc index 65537f8c8948a80711566b523c1932c481b9c66d..b4afa64c1d2b39e714296dc234a978617a80e23f 100644 --- a/paddle/phi/ops/compat/gumbel_softmax_sig.cc +++ b/paddle/phi/ops/compat/gumbel_softmax_sig.cc @@ -16,6 +16,23 @@ limitations under the License. */ namespace phi { +KernelSignature GumbelSoftmaxOpArgumentMapping( + const ArgumentMappingContext& ctx) { + bool is_test = false; + if (ctx.HasAttr("is_test")) { + is_test = paddle::any_cast(ctx.Attr("is_test")); + } + if (is_test) { + return KernelSignature("gumbel_softmax_infer", + {"X"}, + {"temperature", "hard", "axis"}, + {"Out"}); + } else { + return KernelSignature( + "gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"}); + } +} + KernelSignature GumbelSoftmaxGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( @@ -24,5 +41,6 @@ KernelSignature GumbelSoftmaxGradOpArgumentMapping( } // namespace phi +PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax, phi::GumbelSoftmaxOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax_grad, phi::GumbelSoftmaxGradOpArgumentMapping);