未验证 提交 0903020d 编写于 作者: J JingZhuangzhuang 提交者: GitHub

cherry pick softmax infer kernel (#45957)

上级 29c44eb2
......@@ -25,17 +25,18 @@ class Graph;
void IsTestPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it "
"for activations and pooling.";
auto op_list = {"pool2d", "sigmoid", "logsigmoid",
"softshrink", "exp", "brelu",
"pow", "leaky_relu", "stanh",
"relu", "tanh", "tanh_shrink",
"sqrt", "abs", "ceil",
"elu", "floor", "cos",
"sin", "round", "reciprocal",
"hard_shrink", "hard_sigmoid", "relu6",
"soft_relu", "swish", "thresholded_relu",
"log", "square", "softplus",
"softsign", "silu", "mish"};
auto op_list = {"pool2d", "sigmoid", "logsigmoid",
"softshrink", "exp", "brelu",
"pow", "leaky_relu", "stanh",
"relu", "tanh", "tanh_shrink",
"sqrt", "abs", "ceil",
"elu", "floor", "cos",
"sin", "round", "reciprocal",
"hard_shrink", "hard_sigmoid", "relu6",
"soft_relu", "swish", "thresholded_relu",
"log", "square", "softplus",
"softsign", "silu", "mish",
"gumbel_softmax"};
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
......
......@@ -119,3 +119,10 @@ struct OneHotGenerator<CPUContext, T> {
PD_REGISTER_KERNEL(
gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax_infer,
CPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
......@@ -170,3 +170,10 @@ struct GumbleNoiseGenerator<GPUContext, T> {
PD_REGISTER_KERNEL(
gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax_infer,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
......@@ -25,4 +25,12 @@ void GumbelSoftmaxKernel(const Context& dev_ctx,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& dev_ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out);
} // namespace phi
......@@ -43,12 +43,13 @@ template <typename Context, typename T>
struct OneHotGenerator;
template <typename T, typename Context>
void GumbelSoftmaxKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
void GumbelSoftmaxKernelHelper(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out,
bool is_test) {
const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank);
int axis_dim = x.dims()[axis];
......@@ -80,18 +81,39 @@ void GumbelSoftmaxKernel(const Context& ctx,
size_to_axis,
size_from_axis,
temperature);
#ifdef PADDLE_ON_INFERENCE
paddle::operators::math::SoftmaxFunctor<Context, T, true>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
#else
paddle::operators::math::SoftmaxFunctor<Context, T, false>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
#endif
if (is_test) {
paddle::operators::math::SoftmaxFunctor<Context, T, true>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
} else {
paddle::operators::math::SoftmaxFunctor<Context, T, false>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
}
if (hard) {
OneHotGenerator<Context, T>::Transform(ctx, x, out, axis);
}
}
template <typename T, typename Context>
void GumbelSoftmaxKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>(
ctx, x, temperature, hard, axis, out, false);
}
template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>(
ctx, x, temperature, hard, axis, out, true);
}
} // namespace phi
......@@ -16,6 +16,23 @@ limitations under the License. */
namespace phi {
KernelSignature GumbelSoftmaxOpArgumentMapping(
const ArgumentMappingContext& ctx) {
bool is_test = false;
if (ctx.HasAttr("is_test")) {
is_test = paddle::any_cast<bool>(ctx.Attr("is_test"));
}
if (is_test) {
return KernelSignature("gumbel_softmax_infer",
{"X"},
{"temperature", "hard", "axis"},
{"Out"});
} else {
return KernelSignature(
"gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"});
}
}
KernelSignature GumbelSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
......@@ -24,5 +41,6 @@ KernelSignature GumbelSoftmaxGradOpArgumentMapping(
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax, phi::GumbelSoftmaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax_grad,
phi::GumbelSoftmaxGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册