提交 8e3da976 编写于 作者: J JiabinYang

test=develop, polish code

上级 f364b722
...@@ -226,6 +226,9 @@ class TestImperativePtbRnn(unittest.TestCase): ...@@ -226,6 +226,9 @@ class TestImperativePtbRnn(unittest.TestCase):
sgd = SGDOptimizer(learning_rate=1e-3) sgd = SGDOptimizer(learning_rate=1e-3)
dy_param_updated = dict() dy_param_updated = dict()
dy_param_init = dict() dy_param_init = dict()
dy_loss = None
last_hidden = None
last_cell = None
for i in range(2): for i in range(2):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
...@@ -288,7 +291,9 @@ class TestImperativePtbRnn(unittest.TestCase): ...@@ -288,7 +291,9 @@ class TestImperativePtbRnn(unittest.TestCase):
fetch_list=static_param_name_list) fetch_list=static_param_name_list)
for i in range(len(static_param_name_list)): for i in range(len(static_param_name_list)):
static_param_init[static_param_name_list[i]] = out[i] static_param_init[static_param_name_list[i]] = out[i]
static_loss_value = None
static_last_cell_value = None
static_last_hidden_value = None
for i in range(2): for i in range(2):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
...@@ -311,11 +316,9 @@ class TestImperativePtbRnn(unittest.TestCase): ...@@ -311,11 +316,9 @@ class TestImperativePtbRnn(unittest.TestCase):
static_loss_value = out[0] static_loss_value = out[0]
static_last_cell_value = out[1] static_last_cell_value = out[1]
static_last_hidden_value = out[2] static_last_hidden_value = out[2]
# print("static_loss is {}".format(out[0])) for k in range(3, len(out)):
# print("last_hidden is {}".format(out[1])) static_param_updated[static_param_name_list[k - 3]] = out[k]
# print("last_cell is {}".format(out[2]))
for i in range(3, len(out)):
static_param_updated[static_param_name_list[i - 3]] = out[i]
self.assertTrue( self.assertTrue(
np.allclose(static_loss_value.all(), dy_loss._numpy().all())) np.allclose(static_loss_value.all(), dy_loss._numpy().all()))
self.assertTrue( self.assertTrue(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册