未验证 提交 cabb9501 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix leaky_relu op when alpha is zero, test=develop (#19833)

上级 9cbc1eff
......@@ -1073,8 +1073,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 =
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>();
auto temp2 = (out >= static_cast<T>(0)).template cast<T>();
static_cast<T>(alpha) * (out <= static_cast<T>(0)).template cast<T>();
auto temp2 = (out > static_cast<T>(0)).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
......@@ -1418,11 +1418,11 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) =
ddx *
((out >= static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>())
.template cast<T>();
ddout.device(*d) = ddx *
((out > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) *
(out <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
......
......@@ -22,5 +22,10 @@ TEST(leaky_relu_grad_grad, test_cpu) {
TestLeakyReluGradGradMain<float>({32, 64}, platform::CPUPlace(), 0.02));
}
TEST(leaky_relu_grad_grad, test_cpu_zero_alpha) {
ASSERT_TRUE(
TestLeakyReluGradGradMain<float>({32, 64}, platform::CPUPlace(), 0.0));
}
} // namespace operators
} // namespace paddle
......@@ -22,5 +22,10 @@ TEST(leaky_relu_grad_grad, test_gpu) {
TestLeakyReluGradGradMain<float>({32, 64}, platform::CUDAPlace(0), 0.15));
}
TEST(leaky_relu_grad_grad, test_gpu_zero_alpha) {
ASSERT_TRUE(
TestLeakyReluGradGradMain<float>({32, 64}, platform::CUDAPlace(0), 0.0));
}
} // namespace operators
} // namespace paddle
......@@ -46,7 +46,7 @@ struct LeakyReluGradGradEachElementFunctor {
: ddx_(ddx), out_(out), alpha_(alpha), ddout_(ddout) {}
HOSTDEVICE void operator()(int idx) {
if (out_[idx] >= 0) {
if (out_[idx] > 0) {
ddout_[idx] = ddx_[idx];
} else {
ddout_[idx] = ddx_[idx] * alpha_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册