提交 9af384f1 编写于 作者: P phlrain

try to fix imperative orc unitest error; test=develop

上级 95cceb2d
......@@ -29,19 +29,19 @@ class Config(object):
config for training
'''
# encoder rnn hidden_size
encoder_size = 16
encoder_size = 8
# decoder size for decoder stage
decoder_size = 16
decoder_size = 8
# size for word embedding
word_vector_dim = 16
word_vector_dim = 8
# max length for label padding
max_length = 5
max_length = 3
# optimizer setting
LR = 1.0
learning_rate_decay = None
# batch size to train
batch_size = 8
batch_size = 2
# class number to classify
num_classes = 64
......@@ -55,7 +55,7 @@ class Config(object):
TRAIN_LIST_FILE_NAME = "train.list"
# data shape for input image
DATA_SHAPE = [1, 48, 384]
DATA_SHAPE = [1, 16, 64]
class ConvBNPool(fluid.dygraph.Layer):
......@@ -124,13 +124,13 @@ class OCRConv(fluid.dygraph.Layer):
def __init__(self, is_test=False, use_cudnn=True):
super(OCRConv, self).__init__()
self.conv_bn_pool_1 = ConvBNPool(
2, [16, 16], [1, 16], is_test=is_test, use_cudnn=use_cudnn)
2, [8, 8], [1, 8], is_test=is_test, use_cudnn=use_cudnn)
self.conv_bn_pool_2 = ConvBNPool(
2, [32, 32], [16, 32], is_test=is_test, use_cudnn=use_cudnn)
2, [8, 8], [8, 8], is_test=is_test, use_cudnn=use_cudnn)
self.conv_bn_pool_3 = ConvBNPool(
2, [64, 64], [32, 64], is_test=is_test, use_cudnn=use_cudnn)
2, [8, 8], [8, 8], is_test=is_test, use_cudnn=use_cudnn)
self.conv_bn_pool_4 = ConvBNPool(
2, [128, 128], [64, 128],
2, [16, 16], [8, 16],
is_test=is_test,
pool=False,
use_cudnn=use_cudnn)
......@@ -212,9 +212,9 @@ class EncoderNet(fluid.dygraph.Layer):
self.ocr_convs = OCRConv(is_test=is_test, use_cudnn=use_cudnn)
self.fc_1_layer = Linear(
768, rnn_hidden_size * 3, param_attr=para_attr, bias_attr=False)
32, rnn_hidden_size * 3, param_attr=para_attr, bias_attr=False)
self.fc_2_layer = Linear(
768, rnn_hidden_size * 3, param_attr=para_attr, bias_attr=False)
32, rnn_hidden_size * 3, param_attr=para_attr, bias_attr=False)
self.gru_forward_layer = DynamicGRU(
size=rnn_hidden_size,
h_0=h_0,
......@@ -241,10 +241,9 @@ class EncoderNet(fluid.dygraph.Layer):
transpose_conv_features = fluid.layers.transpose(
conv_features, perm=[0, 3, 1, 2])
sliced_feature = fluid.layers.reshape(
transpose_conv_features, [
-1, 48, transpose_conv_features.shape[2] *
-1, 8, transpose_conv_features.shape[2] *
transpose_conv_features.shape[3]
],
inplace=False)
......@@ -376,9 +375,9 @@ class TestDygraphOCRAttention(unittest.TestCase):
seed = 90
epoch_num = 1
if core.is_compiled_with_cuda():
batch_num = 6
batch_num = 3
else:
batch_num = 4
batch_num = 2
np.random.seed = seed
image_np = np.random.randn(Config.batch_size, Config.DATA_SHAPE[0],
Config.DATA_SHAPE[1],
......@@ -536,8 +535,9 @@ class TestDygraphOCRAttention(unittest.TestCase):
self.assertTrue(np.array_equal(value, dy_param_init_value[key]))
for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key]))
self.assertTrue(np.allclose(value, dy_param_value[key], rtol=1e-05))
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册