From 8e3da976f4c34f086c7213739d4839cacabf3c98 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Mon, 28 Jan 2019 02:35:44 +0000 Subject: [PATCH] test=develop, polish code --- .../tests/unittests/test_imperative_ptb_rnn.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py index 1610d49d8..9c6ec331e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py @@ -226,6 +226,9 @@ class TestImperativePtbRnn(unittest.TestCase): sgd = SGDOptimizer(learning_rate=1e-3) dy_param_updated = dict() dy_param_init = dict() + dy_loss = None + last_hidden = None + last_cell = None for i in range(2): x_data = np.arange(12).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64') @@ -288,7 +291,9 @@ class TestImperativePtbRnn(unittest.TestCase): fetch_list=static_param_name_list) for i in range(len(static_param_name_list)): static_param_init[static_param_name_list[i]] = out[i] - + static_loss_value = None + static_last_cell_value = None + static_last_hidden_value = None for i in range(2): x_data = np.arange(12).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64') @@ -311,11 +316,9 @@ class TestImperativePtbRnn(unittest.TestCase): static_loss_value = out[0] static_last_cell_value = out[1] static_last_hidden_value = out[2] - # print("static_loss is {}".format(out[0])) - # print("last_hidden is {}".format(out[1])) - # print("last_cell is {}".format(out[2])) - for i in range(3, len(out)): - static_param_updated[static_param_name_list[i - 3]] = out[i] + for k in range(3, len(out)): + static_param_updated[static_param_name_list[k - 3]] = out[k] + self.assertTrue( np.allclose(static_loss_value.all(), dy_loss._numpy().all())) self.assertTrue( -- GitLab