未验证 提交 5023482a 编写于 作者: H hong 提交者: GitHub

Fix imperative orc attention unitest (#25797)

* reduce hidden size and loop number; test=develop

* change loop number; remove useless code; test=develop
上级 dc42e3c4
...@@ -28,11 +28,11 @@ class Config(object): ...@@ -28,11 +28,11 @@ class Config(object):
config for training config for training
''' '''
# encoder rnn hidden_size # encoder rnn hidden_size
encoder_size = 200 encoder_size = 64
# decoder size for decoder stage # decoder size for decoder stage
decoder_size = 128 decoder_size = 64
# size for word embedding # size for word embedding
word_vector_dim = 128 word_vector_dim = 64
# max length for label padding # max length for label padding
max_length = 5 max_length = 5
# optimizer setting # optimizer setting
...@@ -373,12 +373,11 @@ class OCRAttention(fluid.dygraph.Layer): ...@@ -373,12 +373,11 @@ class OCRAttention(fluid.dygraph.Layer):
class TestDygraphOCRAttention(unittest.TestCase): class TestDygraphOCRAttention(unittest.TestCase):
def test_while_op(self): def test_while_op(self):
seed = 90 seed = 90
epoch_num = 2 epoch_num = 1
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
batch_num = 20 batch_num = 10
else: else:
print("in CPU") batch_num = 4
batch_num = 2
np.random.seed = seed np.random.seed = seed
image_np = np.random.randn(Config.batch_size, Config.DATA_SHAPE[0], image_np = np.random.randn(Config.batch_size, Config.DATA_SHAPE[0],
Config.DATA_SHAPE[1], Config.DATA_SHAPE[1],
...@@ -457,7 +456,6 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -457,7 +456,6 @@ class TestDygraphOCRAttention(unittest.TestCase):
with new_program_scope(): with new_program_scope():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
# print("static start")
exe = fluid.Executor(fluid.CPUPlace( exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
ocr_attention = OCRAttention() ocr_attention = OCRAttention()
...@@ -523,7 +521,6 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -523,7 +521,6 @@ class TestDygraphOCRAttention(unittest.TestCase):
static_param_value = {} static_param_value = {}
static_grad_value = {} static_grad_value = {}
static_out = out[0] static_out = out[0]
# static_test_grad = out[1]
for i in range(1, len(static_param_name_list) + 1): for i in range(1, len(static_param_name_list) + 1):
static_param_value[static_param_name_list[i - 1]] = out[ static_param_value[static_param_name_list[i - 1]] = out[
i] i]
...@@ -533,13 +530,13 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -533,13 +530,13 @@ class TestDygraphOCRAttention(unittest.TestCase):
static_grad_value[static_grad_name_list[ static_grad_value[static_grad_name_list[
i - grad_start_pos]] = out[i] i - grad_start_pos]] = out[i]
self.assertTrue(np.array_equal(static_out, dy_out)) self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_init_value): for key, value in six.iteritems(static_param_init_value):
self.assertTrue(np.array_equal(value, dy_param_init_value[key])) self.assertTrue(np.array_equal(value, dy_param_init_value[key]))
for key, value in six.iteritems(static_param_value): for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-20)) self.assertTrue(np.allclose(value, dy_param_value[key]))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册