提交 de212ae2 编写于 作者: M minqiyang

Polish code

test=develop
上级 d17bb4e6
......@@ -516,6 +516,7 @@ class GRUUnit(layers.Layer):
Args:
input (Variable): The fc transformed input value of current step.
name_scope (str): See base class.
hidden (Variable): The hidden value of gru unit from previous step.
size (integer): The input dimension value.
param_attr(ParamAttr|None): The parameter attribute for the learnable
......
......@@ -407,19 +407,26 @@ class OpTest(unittest.TestCase):
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place))
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place) + " in imperative mode")
if check_imperative:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place) + " in imperative mode")
if isinstance(expect, tuple):
self.assertListEqual(
actual.recursive_sequence_lengths(), expect[1],
"Output (" + sub_out_name +
") has different lod at " + str(place))
if check_imperative:
self.assertListEqual(
imperative_actual._ivar.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name + ") has different lod at " +
str(place) + " in imperative mode")
else:
if check_imperative:
imperative_actual = imperative_outs[out_name][0]
......@@ -436,16 +443,17 @@ class OpTest(unittest.TestCase):
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(actual_t) + " in class " + self.__class__.__name__)
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(imperative_actual_t) + " in class " +
self.__class__.__name__)
if check_imperative:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(imperative_actual_t) + " in class " +
self.__class__.__name__)
if isinstance(expect, tuple):
self.assertListEqual(actual.recursive_sequence_lengths(),
expect[1], "Output (" + out_name +
......@@ -453,8 +461,9 @@ class OpTest(unittest.TestCase):
if check_imperative:
self.assertListEqual(
imperative_actual._ivar.value().get_tensor()
.recursive_sequence_lengths(), expect[1], "Output ("
+ out_name + ") has different lod at " + str(place))
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name + ") has different lod at " +
str(place) + " in imperative mode")
def _get_places(self):
if self.dtype == np.float16:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册