From 6f69fbc8eaf37455ef4759fb954c7db0a8798f25 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Tue, 25 Aug 2020 09:39:09 +0800 Subject: [PATCH] fix elu grad whne alpha less then zero, test=develop (#26543) --- paddle/fluid/operators/activation_op.h | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index b411f0f21da..00a7c063c91 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1134,9 +1134,20 @@ struct ELUGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * (x > static_cast(0)).template cast() + - dout * static_cast(alpha) * x.exp() * - (x <= static_cast(0)).template cast(); + auto temp_a_pos = static_cast(alpha > 0); + auto temp_a_neg = static_cast(alpha <= 0); + auto temp_x_pos = (x > static_cast(0)).template cast(); + auto temp_x_neg = (x <= static_cast(0)).template cast(); + + // dx = dout, if alpha > 0 and x > 0 + // dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0 + // dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0 + // dx = 0, if alpha <= 0 and x <=0 + dx.device(d) = + dout * temp_a_pos * temp_x_pos + + dout * static_cast(alpha) * x.exp() * temp_a_pos * temp_x_neg + + dout * (static_cast(1) + static_cast(alpha) * x.exp()) * + temp_a_neg * temp_x_pos; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } -- GitLab