diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py index 3b602303ae9a183c7b66f5613321f58898fdfcc2..460ba65a48c863315cda4847aee1b4e2366bba96 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py @@ -59,7 +59,7 @@ class SimpleLSTMRNN(fluid.imperative.Layer): dtype="float32", default_initializer=fluid.initializer.UniformInitializer( low=-self._init_scale, high=self._init_scale)) - self.weight_1_arr.append(weight_1) + self.weight_1_arr.append(self.add_parameter('w_%d' % i, weight_1)) bias_1 = self.create_parameter( attr=fluid.ParamAttr( initializer=fluid.initializer.UniformInitializer( @@ -67,7 +67,7 @@ class SimpleLSTMRNN(fluid.imperative.Layer): shape=[self._hidden_size * 4], dtype="float32", default_initializer=fluid.initializer.Constant(0.0)) - self.bias_arr.append(bias_1) + self.bias_arr.append(self.add_parameter('b_%d' % i, bias_1)) def forward(self, input_embedding, init_hidden=None, init_cell=None): self.cell_array = [] @@ -242,7 +242,7 @@ class TestImperativePtbRnn(unittest.TestCase): dy_loss = None last_hidden = None last_cell = None - batch_num = 50 + batch_num = 200 for i in range(batch_num): x_data = np.arange(12).reshape(4, 3).astype('int64') @@ -264,8 +264,10 @@ class TestImperativePtbRnn(unittest.TestCase): dy_param_init[param.name] = param._numpy() dy_loss._backward() sgd.minimize(dy_loss) - for param in ptb_model.parameters(): - dy_param_updated[param.name] = param._numpy() + ptb_model.clear_gradients() + if i == batch_num - 1: + for param in ptb_model.parameters(): + dy_param_updated[param.name] = param._numpy() with new_program_scope(): fluid.default_startup_program().random_seed = seed @@ -323,25 +325,28 @@ class TestImperativePtbRnn(unittest.TestCase): }, fetch_list=fetch_list) static_loss_value = out[0] - static_last_cell_value = out[1] - static_last_hidden_value = out[2] - for k in range(3, len(out)): - static_param_updated[static_param_name_list[k - 3]] = out[k] + static_last_hidden_value = out[1] + static_last_cell_value = out[2] + if i == batch_num - 1: + for k in range(3, len(out)): + static_param_updated[static_param_name_list[k - + 3]] = out[k] + + self.assertTrue(np.allclose(static_loss_value, dy_loss._numpy())) + self.assertTrue(np.allclose(static_last_cell_value, last_cell._numpy())) + self.assertTrue( + np.allclose(static_last_hidden_value, last_hidden._numpy())) + for key, value in six.iteritems(static_param_init): + # print("static_init name: {}, value {}".format(key, value)) + # print("dy_init name: {}, value {}".format(key, dy_param_init[key])) + self.assertTrue(np.allclose(value, dy_param_init[key], atol=1e-5)) + for key, value in six.iteritems(static_param_updated): + # print("static name: {}, value {}".format(key, value)) + # print("dy name: {}, value {}".format(key, dy_param_updated[key])) self.assertTrue( - np.allclose(static_loss_value.all(), dy_loss._numpy().all())) - self.assertTrue( - np.allclose(static_last_cell_value.all(), - last_cell._numpy().all())) - self.assertTrue( - np.allclose(static_last_hidden_value.all(), - last_hidden._numpy().all())) - for key, value in six.iteritems(static_param_init): - self.assertTrue( - np.allclose(value.all(), dy_param_init[key].all())) - for key, value in six.iteritems(static_param_updated): - self.assertTrue( - np.allclose(value.all(), dy_param_updated[key].all())) + np.allclose( + value, dy_param_updated[key], atol=1e-5)) if __name__ == '__main__':