From de212ae274db34207e39181bd622751d114685a4 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 11 Mar 2019 23:00:35 +0800 Subject: [PATCH] Polish code test=develop --- python/paddle/fluid/imperative/nn.py | 1 + .../paddle/fluid/tests/unittests/op_test.py | 49 +++++++++++-------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index 6681b423415..995cbc1aa43 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -516,6 +516,7 @@ class GRUUnit(layers.Layer): Args: input (Variable): The fc transformed input value of current step. + name_scope (str): See base class. hidden (Variable): The hidden value of gru unit from previous step. size (integer): The input dimension value. param_attr(ParamAttr|None): The parameter attribute for the learnable diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 9fa62a692ee..b84ce2b3aea 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -407,19 +407,26 @@ class OpTest(unittest.TestCase): actual_t, expect_t, atol=atol, equal_nan=equal_nan), "Output (" + sub_out_name + ") has diff at " + str(place)) - self.assertTrue( - np.allclose( - imperative_actual_t, - expect_t, - atol=atol, - equal_nan=equal_nan), - "Output (" + sub_out_name + ") has diff at " + - str(place) + " in imperative mode") + if check_imperative: + self.assertTrue( + np.allclose( + imperative_actual_t, + expect_t, + atol=atol, + equal_nan=equal_nan), + "Output (" + sub_out_name + ") has diff at " + + str(place) + " in imperative mode") if isinstance(expect, tuple): self.assertListEqual( actual.recursive_sequence_lengths(), expect[1], "Output (" + sub_out_name + ") has different lod at " + str(place)) + if check_imperative: + self.assertListEqual( + imperative_actual._ivar.value().get_tensor() + .recursive_sequence_lengths(), expect[1], + "Output (" + out_name + ") has different lod at " + + str(place) + " in imperative mode") else: if check_imperative: imperative_actual = imperative_outs[out_name][0] @@ -436,16 +443,17 @@ class OpTest(unittest.TestCase): "Output (" + out_name + ") has diff at " + str(place) + "\nExpect " + str(expect_t) + "\n" + "But Got" + str(actual_t) + " in class " + self.__class__.__name__) - self.assertTrue( - np.allclose( - imperative_actual_t, - expect_t, - atol=atol, - equal_nan=equal_nan), - "Output (" + out_name + ") has diff at " + str(place) + - "\nExpect " + str(expect_t) + "\n" + "But Got" + - str(imperative_actual_t) + " in class " + - self.__class__.__name__) + if check_imperative: + self.assertTrue( + np.allclose( + imperative_actual_t, + expect_t, + atol=atol, + equal_nan=equal_nan), + "Output (" + out_name + ") has diff at " + str(place) + + "\nExpect " + str(expect_t) + "\n" + "But Got" + + str(imperative_actual_t) + " in class " + + self.__class__.__name__) if isinstance(expect, tuple): self.assertListEqual(actual.recursive_sequence_lengths(), expect[1], "Output (" + out_name + @@ -453,8 +461,9 @@ class OpTest(unittest.TestCase): if check_imperative: self.assertListEqual( imperative_actual._ivar.value().get_tensor() - .recursive_sequence_lengths(), expect[1], "Output (" - + out_name + ") has different lod at " + str(place)) + .recursive_sequence_lengths(), expect[1], + "Output (" + out_name + ") has different lod at " + + str(place) + " in imperative mode") def _get_places(self): if self.dtype == np.float16: -- GitLab