未验证 提交 231501fe 编写于 作者: D Double_V 提交者: GitHub

fix elugradgrad test fail & error message opt (#30171)

* fix elugradgrad test fail and error message opt

* fix unitest,test=develop

* Update prroi_pool_op.h

fix error message

* opt message,test=develop

* fix ci fail,test=develop
上级 fb49ea38
...@@ -1598,7 +1598,7 @@ struct ELUGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1598,7 +1598,7 @@ struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
auto dout = framework::EigenVector<T>::Flatten( auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad")); GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad"));
dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() * dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() *
(x < static_cast<T>(0)).template cast<T>(); (x <= static_cast<T>(0)).template cast<T>();
} }
if (ddOut) { if (ddOut) {
......
...@@ -293,19 +293,24 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -293,19 +293,24 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_ENFORCE_EQ(rois->lod().empty(), false, PADDLE_ENFORCE_EQ(rois->lod().empty(), false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"the lod of Input ROIs should not be empty when " "The lod of Input ROIs should not be empty when "
"BatchRoINums is None!")); "BatchRoINums is None!"));
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1; int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(rois_batch_size, batch_size,
rois_batch_size, batch_size, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("the rois_batch_size and input(X) " "The rois_batch_size and input(X)'s "
"batch_size should be the same.")); "batch_size should be the same but received"
"rois_batch_size: %d and batch_size: %d",
rois_batch_size, batch_size));
int rois_num_with_lod = rois_lod[rois_batch_size]; int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_num_with_lod, rois_num, rois_num_with_lod, rois_num,
platform::errors::InvalidArgument( platform::errors::InvalidArgument("The rois_num from input should be "
"the rois_num from input and lod must be the same")); "equal to the rois_num from lod, "
"but received rois_num from input: "
"%d and the rois_num from lod: %d.",
rois_num_with_lod, rois_num));
// calculate batch id index for each roi according to LoD // calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) { for (int n = 0; n < rois_batch_size; ++n) {
......
...@@ -78,9 +78,9 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase): ...@@ -78,9 +78,9 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase):
class TestELUDoubleGradCheck(unittest.TestCase): class TestELUDoubleGradCheck(unittest.TestCase):
@prog_scope() @prog_scope()
def func(self, place): def func(self, place):
shape = [2, 3, 6, 6] shape = [2, 4, 4, 4]
eps = 1e-6 eps = 1e-6
alpha = 1.1 alpha = 0.2
dtype = np.float64 dtype = np.float64
SEED = 0 SEED = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册