提交 9af384f1 编写于 作者: P phlrain

try to fix imperative orc unitest error; test=develop

上级 95cceb2d
...@@ -29,19 +29,19 @@ class Config(object): ...@@ -29,19 +29,19 @@ class Config(object):
config for training config for training
''' '''
# encoder rnn hidden_size # encoder rnn hidden_size
encoder_size = 16 encoder_size = 8
# decoder size for decoder stage # decoder size for decoder stage
decoder_size = 16 decoder_size = 8
# size for word embedding # size for word embedding
word_vector_dim = 16 word_vector_dim = 8
# max length for label padding # max length for label padding
max_length = 5 max_length = 3
# optimizer setting # optimizer setting
LR = 1.0 LR = 1.0
learning_rate_decay = None learning_rate_decay = None
# batch size to train # batch size to train
batch_size = 8 batch_size = 2
# class number to classify # class number to classify
num_classes = 64 num_classes = 64
...@@ -55,7 +55,7 @@ class Config(object): ...@@ -55,7 +55,7 @@ class Config(object):
TRAIN_LIST_FILE_NAME = "train.list" TRAIN_LIST_FILE_NAME = "train.list"
# data shape for input image # data shape for input image
DATA_SHAPE = [1, 48, 384] DATA_SHAPE = [1, 16, 64]
class ConvBNPool(fluid.dygraph.Layer): class ConvBNPool(fluid.dygraph.Layer):
...@@ -124,13 +124,13 @@ class OCRConv(fluid.dygraph.Layer): ...@@ -124,13 +124,13 @@ class OCRConv(fluid.dygraph.Layer):
def __init__(self, is_test=False, use_cudnn=True): def __init__(self, is_test=False, use_cudnn=True):
super(OCRConv, self).__init__() super(OCRConv, self).__init__()
self.conv_bn_pool_1 = ConvBNPool( self.conv_bn_pool_1 = ConvBNPool(
2, [16, 16], [1, 16], is_test=is_test, use_cudnn=use_cudnn) 2, [8, 8], [1, 8], is_test=is_test, use_cudnn=use_cudnn)
self.conv_bn_pool_2 = ConvBNPool( self.conv_bn_pool_2 = ConvBNPool(
2, [32, 32], [16, 32], is_test=is_test, use_cudnn=use_cudnn) 2, [8, 8], [8, 8], is_test=is_test, use_cudnn=use_cudnn)
self.conv_bn_pool_3 = ConvBNPool( self.conv_bn_pool_3 = ConvBNPool(
2, [64, 64], [32, 64], is_test=is_test, use_cudnn=use_cudnn) 2, [8, 8], [8, 8], is_test=is_test, use_cudnn=use_cudnn)
self.conv_bn_pool_4 = ConvBNPool( self.conv_bn_pool_4 = ConvBNPool(
2, [128, 128], [64, 128], 2, [16, 16], [8, 16],
is_test=is_test, is_test=is_test,
pool=False, pool=False,
use_cudnn=use_cudnn) use_cudnn=use_cudnn)
...@@ -212,9 +212,9 @@ class EncoderNet(fluid.dygraph.Layer): ...@@ -212,9 +212,9 @@ class EncoderNet(fluid.dygraph.Layer):
self.ocr_convs = OCRConv(is_test=is_test, use_cudnn=use_cudnn) self.ocr_convs = OCRConv(is_test=is_test, use_cudnn=use_cudnn)
self.fc_1_layer = Linear( self.fc_1_layer = Linear(
768, rnn_hidden_size * 3, param_attr=para_attr, bias_attr=False) 32, rnn_hidden_size * 3, param_attr=para_attr, bias_attr=False)
self.fc_2_layer = Linear( self.fc_2_layer = Linear(
768, rnn_hidden_size * 3, param_attr=para_attr, bias_attr=False) 32, rnn_hidden_size * 3, param_attr=para_attr, bias_attr=False)
self.gru_forward_layer = DynamicGRU( self.gru_forward_layer = DynamicGRU(
size=rnn_hidden_size, size=rnn_hidden_size,
h_0=h_0, h_0=h_0,
...@@ -241,10 +241,9 @@ class EncoderNet(fluid.dygraph.Layer): ...@@ -241,10 +241,9 @@ class EncoderNet(fluid.dygraph.Layer):
transpose_conv_features = fluid.layers.transpose( transpose_conv_features = fluid.layers.transpose(
conv_features, perm=[0, 3, 1, 2]) conv_features, perm=[0, 3, 1, 2])
sliced_feature = fluid.layers.reshape( sliced_feature = fluid.layers.reshape(
transpose_conv_features, [ transpose_conv_features, [
-1, 48, transpose_conv_features.shape[2] * -1, 8, transpose_conv_features.shape[2] *
transpose_conv_features.shape[3] transpose_conv_features.shape[3]
], ],
inplace=False) inplace=False)
...@@ -376,9 +375,9 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -376,9 +375,9 @@ class TestDygraphOCRAttention(unittest.TestCase):
seed = 90 seed = 90
epoch_num = 1 epoch_num = 1
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
batch_num = 6 batch_num = 3
else: else:
batch_num = 4 batch_num = 2
np.random.seed = seed np.random.seed = seed
image_np = np.random.randn(Config.batch_size, Config.DATA_SHAPE[0], image_np = np.random.randn(Config.batch_size, Config.DATA_SHAPE[0],
Config.DATA_SHAPE[1], Config.DATA_SHAPE[1],
...@@ -536,8 +535,9 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -536,8 +535,9 @@ class TestDygraphOCRAttention(unittest.TestCase):
self.assertTrue(np.array_equal(value, dy_param_init_value[key])) self.assertTrue(np.array_equal(value, dy_param_init_value[key]))
for key, value in six.iteritems(static_param_value): for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key])) self.assertTrue(np.allclose(value, dy_param_value[key], rtol=1e-05))
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册