未验证 提交 cabb9501 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix leaky_relu op when alpha is zero, test=develop (#19833)

上级 9cbc1eff
...@@ -1073,8 +1073,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -1073,8 +1073,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
typename dX> typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = auto temp1 =
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>(); static_cast<T>(alpha) * (out <= static_cast<T>(0)).template cast<T>();
auto temp2 = (out >= static_cast<T>(0)).template cast<T>(); auto temp2 = (out > static_cast<T>(0)).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>(); dx.device(d) = dout * (temp1 + temp2).template cast<T>();
} }
...@@ -1418,11 +1418,11 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1418,11 +1418,11 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddout.device(*d) = ddx *
ddx * ((out > static_cast<T>(0)).template cast<T>() +
((out >= static_cast<T>(0)).template cast<T>() + static_cast<T>(alpha) *
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>()) (out <= static_cast<T>(0)).template cast<T>())
.template cast<T>(); .template cast<T>();
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
......
...@@ -22,5 +22,10 @@ TEST(leaky_relu_grad_grad, test_cpu) { ...@@ -22,5 +22,10 @@ TEST(leaky_relu_grad_grad, test_cpu) {
TestLeakyReluGradGradMain<float>({32, 64}, platform::CPUPlace(), 0.02)); TestLeakyReluGradGradMain<float>({32, 64}, platform::CPUPlace(), 0.02));
} }
TEST(leaky_relu_grad_grad, test_cpu_zero_alpha) {
ASSERT_TRUE(
TestLeakyReluGradGradMain<float>({32, 64}, platform::CPUPlace(), 0.0));
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -22,5 +22,10 @@ TEST(leaky_relu_grad_grad, test_gpu) { ...@@ -22,5 +22,10 @@ TEST(leaky_relu_grad_grad, test_gpu) {
TestLeakyReluGradGradMain<float>({32, 64}, platform::CUDAPlace(0), 0.15)); TestLeakyReluGradGradMain<float>({32, 64}, platform::CUDAPlace(0), 0.15));
} }
TEST(leaky_relu_grad_grad, test_gpu_zero_alpha) {
ASSERT_TRUE(
TestLeakyReluGradGradMain<float>({32, 64}, platform::CUDAPlace(0), 0.0));
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -46,7 +46,7 @@ struct LeakyReluGradGradEachElementFunctor { ...@@ -46,7 +46,7 @@ struct LeakyReluGradGradEachElementFunctor {
: ddx_(ddx), out_(out), alpha_(alpha), ddout_(ddout) {} : ddx_(ddx), out_(out), alpha_(alpha), ddout_(ddout) {}
HOSTDEVICE void operator()(int idx) { HOSTDEVICE void operator()(int idx) {
if (out_[idx] >= 0) { if (out_[idx] > 0) {
ddout_[idx] = ddx_[idx]; ddout_[idx] = ddx_[idx];
} else { } else {
ddout_[idx] = ddx_[idx] * alpha_; ddout_[idx] = ddx_[idx] * alpha_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册