未验证 提交 5023482a 编写于 作者: H hong 提交者: GitHub

Fix imperative orc attention unitest (#25797)

* reduce hidden size and loop number; test=develop

* change loop number; remove useless code; test=develop
上级 dc42e3c4
......@@ -28,11 +28,11 @@ class Config(object):
config for training
'''
# encoder rnn hidden_size
encoder_size = 200
encoder_size = 64
# decoder size for decoder stage
decoder_size = 128
decoder_size = 64
# size for word embedding
word_vector_dim = 128
word_vector_dim = 64
# max length for label padding
max_length = 5
# optimizer setting
......@@ -373,12 +373,11 @@ class OCRAttention(fluid.dygraph.Layer):
class TestDygraphOCRAttention(unittest.TestCase):
def test_while_op(self):
seed = 90
epoch_num = 2
epoch_num = 1
if core.is_compiled_with_cuda():
batch_num = 20
batch_num = 10
else:
print("in CPU")
batch_num = 2
batch_num = 4
np.random.seed = seed
image_np = np.random.randn(Config.batch_size, Config.DATA_SHAPE[0],
Config.DATA_SHAPE[1],
......@@ -457,7 +456,6 @@ class TestDygraphOCRAttention(unittest.TestCase):
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
# print("static start")
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
ocr_attention = OCRAttention()
......@@ -523,7 +521,6 @@ class TestDygraphOCRAttention(unittest.TestCase):
static_param_value = {}
static_grad_value = {}
static_out = out[0]
# static_test_grad = out[1]
for i in range(1, len(static_param_name_list) + 1):
static_param_value[static_param_name_list[i - 1]] = out[
i]
......@@ -533,13 +530,13 @@ class TestDygraphOCRAttention(unittest.TestCase):
static_grad_value[static_grad_name_list[
i - grad_start_pos]] = out[i]
self.assertTrue(np.array_equal(static_out, dy_out))
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_init_value):
self.assertTrue(np.array_equal(value, dy_param_init_value[key]))
for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-20))
self.assertTrue(np.allclose(value, dy_param_value[key]))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册