未验证 提交 0903020d 编写于 作者: J JingZhuangzhuang 提交者: GitHub

cherry pick softmax infer kernel (#45957)

上级 29c44eb2
...@@ -35,7 +35,8 @@ void IsTestPass::ApplyImpl(ir::Graph* graph) const { ...@@ -35,7 +35,8 @@ void IsTestPass::ApplyImpl(ir::Graph* graph) const {
"hard_shrink", "hard_sigmoid", "relu6", "hard_shrink", "hard_sigmoid", "relu6",
"soft_relu", "swish", "thresholded_relu", "soft_relu", "swish", "thresholded_relu",
"log", "square", "softplus", "log", "square", "softplus",
"softsign", "silu", "mish"}; "softsign", "silu", "mish",
"gumbel_softmax"};
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp()) {
auto* op = n->Op(); auto* op = n->Op();
......
...@@ -119,3 +119,10 @@ struct OneHotGenerator<CPUContext, T> { ...@@ -119,3 +119,10 @@ struct OneHotGenerator<CPUContext, T> {
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax_infer,
CPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
...@@ -170,3 +170,10 @@ struct GumbleNoiseGenerator<GPUContext, T> { ...@@ -170,3 +170,10 @@ struct GumbleNoiseGenerator<GPUContext, T> {
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax_infer,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
...@@ -25,4 +25,12 @@ void GumbelSoftmaxKernel(const Context& dev_ctx, ...@@ -25,4 +25,12 @@ void GumbelSoftmaxKernel(const Context& dev_ctx,
int axis, int axis,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& dev_ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out);
} // namespace phi } // namespace phi
...@@ -43,12 +43,13 @@ template <typename Context, typename T> ...@@ -43,12 +43,13 @@ template <typename Context, typename T>
struct OneHotGenerator; struct OneHotGenerator;
template <typename T, typename Context> template <typename T, typename Context>
void GumbelSoftmaxKernel(const Context& ctx, void GumbelSoftmaxKernelHelper(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
float temperature, float temperature,
bool hard, bool hard,
int axis, int axis,
DenseTensor* out) { DenseTensor* out,
bool is_test) {
const int rank = x.dims().size(); const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank); axis = funcs::CanonicalAxis(axis, rank);
int axis_dim = x.dims()[axis]; int axis_dim = x.dims()[axis];
...@@ -80,18 +81,39 @@ void GumbelSoftmaxKernel(const Context& ctx, ...@@ -80,18 +81,39 @@ void GumbelSoftmaxKernel(const Context& ctx,
size_to_axis, size_to_axis,
size_from_axis, size_from_axis,
temperature); temperature);
if (is_test) {
#ifdef PADDLE_ON_INFERENCE
paddle::operators::math::SoftmaxFunctor<Context, T, true>()( paddle::operators::math::SoftmaxFunctor<Context, T, true>()(
ctx, axis_dim, &x_noise_2d, &out_2d); ctx, axis_dim, &x_noise_2d, &out_2d);
#else } else {
paddle::operators::math::SoftmaxFunctor<Context, T, false>()( paddle::operators::math::SoftmaxFunctor<Context, T, false>()(
ctx, axis_dim, &x_noise_2d, &out_2d); ctx, axis_dim, &x_noise_2d, &out_2d);
#endif }
if (hard) { if (hard) {
OneHotGenerator<Context, T>::Transform(ctx, x, out, axis); OneHotGenerator<Context, T>::Transform(ctx, x, out, axis);
} }
} }
template <typename T, typename Context>
void GumbelSoftmaxKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>(
ctx, x, temperature, hard, axis, out, false);
}
template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>(
ctx, x, temperature, hard, axis, out, true);
}
} // namespace phi } // namespace phi
...@@ -16,6 +16,23 @@ limitations under the License. */ ...@@ -16,6 +16,23 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature GumbelSoftmaxOpArgumentMapping(
const ArgumentMappingContext& ctx) {
bool is_test = false;
if (ctx.HasAttr("is_test")) {
is_test = paddle::any_cast<bool>(ctx.Attr("is_test"));
}
if (is_test) {
return KernelSignature("gumbel_softmax_infer",
{"X"},
{"temperature", "hard", "axis"},
{"Out"});
} else {
return KernelSignature(
"gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"});
}
}
KernelSignature GumbelSoftmaxGradOpArgumentMapping( KernelSignature GumbelSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
...@@ -24,5 +41,6 @@ KernelSignature GumbelSoftmaxGradOpArgumentMapping( ...@@ -24,5 +41,6 @@ KernelSignature GumbelSoftmaxGradOpArgumentMapping(
} // namespace phi } // namespace phi
PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax, phi::GumbelSoftmaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax_grad, PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax_grad,
phi::GumbelSoftmaxGradOpArgumentMapping); phi::GumbelSoftmaxGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册