未验证 提交 0903020d 编写于 作者: J JingZhuangzhuang 提交者: GitHub

cherry pick softmax infer kernel (#45957)

上级 29c44eb2
......@@ -35,7 +35,8 @@ void IsTestPass::ApplyImpl(ir::Graph* graph) const {
"hard_shrink", "hard_sigmoid", "relu6",
"soft_relu", "swish", "thresholded_relu",
"log", "square", "softplus",
"softsign", "silu", "mish"};
"softsign", "silu", "mish",
"gumbel_softmax"};
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
......
......@@ -119,3 +119,10 @@ struct OneHotGenerator<CPUContext, T> {
PD_REGISTER_KERNEL(
gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax_infer,
CPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
......@@ -170,3 +170,10 @@ struct GumbleNoiseGenerator<GPUContext, T> {
PD_REGISTER_KERNEL(
gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax_infer,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
......@@ -25,4 +25,12 @@ void GumbelSoftmaxKernel(const Context& dev_ctx,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& dev_ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out);
} // namespace phi
......@@ -43,12 +43,13 @@ template <typename Context, typename T>
struct OneHotGenerator;
template <typename T, typename Context>
void GumbelSoftmaxKernel(const Context& ctx,
void GumbelSoftmaxKernelHelper(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
DenseTensor* out,
bool is_test) {
const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank);
int axis_dim = x.dims()[axis];
......@@ -80,18 +81,39 @@ void GumbelSoftmaxKernel(const Context& ctx,
size_to_axis,
size_from_axis,
temperature);
#ifdef PADDLE_ON_INFERENCE
if (is_test) {
paddle::operators::math::SoftmaxFunctor<Context, T, true>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
#else
} else {
paddle::operators::math::SoftmaxFunctor<Context, T, false>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
#endif
}
if (hard) {
OneHotGenerator<Context, T>::Transform(ctx, x, out, axis);
}
}
template <typename T, typename Context>
void GumbelSoftmaxKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>(
ctx, x, temperature, hard, axis, out, false);
}
template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>(
ctx, x, temperature, hard, axis, out, true);
}
} // namespace phi
......@@ -16,6 +16,23 @@ limitations under the License. */
namespace phi {
KernelSignature GumbelSoftmaxOpArgumentMapping(
const ArgumentMappingContext& ctx) {
bool is_test = false;
if (ctx.HasAttr("is_test")) {
is_test = paddle::any_cast<bool>(ctx.Attr("is_test"));
}
if (is_test) {
return KernelSignature("gumbel_softmax_infer",
{"X"},
{"temperature", "hard", "axis"},
{"Out"});
} else {
return KernelSignature(
"gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"});
}
}
KernelSignature GumbelSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
......@@ -24,5 +41,6 @@ KernelSignature GumbelSoftmaxGradOpArgumentMapping(
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax, phi::GumbelSoftmaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax_grad,
phi::GumbelSoftmaxGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册