未验证 提交 231501fe 编写于 作者: D Double_V 提交者: GitHub

fix elugradgrad test fail & error message opt (#30171)

* fix elugradgrad test fail and error message opt

* fix unitest,test=develop

* Update prroi_pool_op.h

fix error message

* opt message,test=develop

* fix ci fail,test=develop
上级 fb49ea38
......@@ -1598,7 +1598,7 @@ struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad"));
dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() *
(x < static_cast<T>(0)).template cast<T>();
(x <= static_cast<T>(0)).template cast<T>();
}
if (ddOut) {
......
......@@ -293,19 +293,24 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> {
} else {
PADDLE_ENFORCE_EQ(rois->lod().empty(), false,
platform::errors::InvalidArgument(
"the lod of Input ROIs should not be empty when "
"The lod of Input ROIs should not be empty when "
"BatchRoINums is None!"));
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument("the rois_batch_size and input(X) "
"batch_size should be the same."));
PADDLE_ENFORCE_EQ(rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The rois_batch_size and input(X)'s "
"batch_size should be the same but received"
"rois_batch_size: %d and batch_size: %d",
rois_batch_size, batch_size));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(
rois_num_with_lod, rois_num,
platform::errors::InvalidArgument(
"the rois_num from input and lod must be the same"));
platform::errors::InvalidArgument("The rois_num from input should be "
"equal to the rois_num from lod, "
"but received rois_num from input: "
"%d and the rois_num from lod: %d.",
rois_num_with_lod, rois_num));
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
......
......@@ -78,9 +78,9 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase):
class TestELUDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 3, 6, 6]
shape = [2, 4, 4, 4]
eps = 1e-6
alpha = 1.1
alpha = 0.2
dtype = np.float64
SEED = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册