未验证 提交 7c5319ba 编写于 作者: J Jiabin Yang 提交者: GitHub

Fix/test imperative ptb rnn (#16433)

* test=develop, fix ptb rnn

* test=develop, change cdn to bj to pass ci

* test=develop, fix ci
上级 f735102e
......@@ -59,7 +59,7 @@ class SimpleLSTMRNN(fluid.imperative.Layer):
dtype="float32",
default_initializer=fluid.initializer.UniformInitializer(
low=-self._init_scale, high=self._init_scale))
self.weight_1_arr.append(weight_1)
self.weight_1_arr.append(self.add_parameter('w_%d' % i, weight_1))
bias_1 = self.create_parameter(
attr=fluid.ParamAttr(
initializer=fluid.initializer.UniformInitializer(
......@@ -67,7 +67,7 @@ class SimpleLSTMRNN(fluid.imperative.Layer):
shape=[self._hidden_size * 4],
dtype="float32",
default_initializer=fluid.initializer.Constant(0.0))
self.bias_arr.append(bias_1)
self.bias_arr.append(self.add_parameter('b_%d' % i, bias_1))
def forward(self, input_embedding, init_hidden=None, init_cell=None):
self.cell_array = []
......@@ -242,7 +242,7 @@ class TestImperativePtbRnn(unittest.TestCase):
dy_loss = None
last_hidden = None
last_cell = None
batch_num = 50
batch_num = 200
for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64')
......@@ -264,8 +264,10 @@ class TestImperativePtbRnn(unittest.TestCase):
dy_param_init[param.name] = param._numpy()
dy_loss._backward()
sgd.minimize(dy_loss)
for param in ptb_model.parameters():
dy_param_updated[param.name] = param._numpy()
ptb_model.clear_gradients()
if i == batch_num - 1:
for param in ptb_model.parameters():
dy_param_updated[param.name] = param._numpy()
with new_program_scope():
fluid.default_startup_program().random_seed = seed
......@@ -323,25 +325,28 @@ class TestImperativePtbRnn(unittest.TestCase):
},
fetch_list=fetch_list)
static_loss_value = out[0]
static_last_cell_value = out[1]
static_last_hidden_value = out[2]
for k in range(3, len(out)):
static_param_updated[static_param_name_list[k - 3]] = out[k]
static_last_hidden_value = out[1]
static_last_cell_value = out[2]
if i == batch_num - 1:
for k in range(3, len(out)):
static_param_updated[static_param_name_list[k -
3]] = out[k]
self.assertTrue(np.allclose(static_loss_value, dy_loss._numpy()))
self.assertTrue(np.allclose(static_last_cell_value, last_cell._numpy()))
self.assertTrue(
np.allclose(static_last_hidden_value, last_hidden._numpy()))
for key, value in six.iteritems(static_param_init):
# print("static_init name: {}, value {}".format(key, value))
# print("dy_init name: {}, value {}".format(key, dy_param_init[key]))
self.assertTrue(np.allclose(value, dy_param_init[key], atol=1e-5))
for key, value in six.iteritems(static_param_updated):
# print("static name: {}, value {}".format(key, value))
# print("dy name: {}, value {}".format(key, dy_param_updated[key]))
self.assertTrue(
np.allclose(static_loss_value.all(), dy_loss._numpy().all()))
self.assertTrue(
np.allclose(static_last_cell_value.all(),
last_cell._numpy().all()))
self.assertTrue(
np.allclose(static_last_hidden_value.all(),
last_hidden._numpy().all()))
for key, value in six.iteritems(static_param_init):
self.assertTrue(
np.allclose(value.all(), dy_param_init[key].all()))
for key, value in six.iteritems(static_param_updated):
self.assertTrue(
np.allclose(value.all(), dy_param_updated[key].all()))
np.allclose(
value, dy_param_updated[key], atol=1e-5))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册