提交 4e37cccb 编写于 作者: W wanghaoshuang

Fix issues

上级 a7d6b1af
......@@ -15,6 +15,16 @@ import sys
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import numpy as np
def random_reader(num_class):
def reader():
sequence_len = np.random.randint(5, 10)
yield np.random.uniform(0.1, 1, [1, 512, 512]), np.random.randint(
0, num_class + 1, [sequence_len])
return reader
def ocr_conv(input, num, with_bn):
......@@ -23,7 +33,7 @@ def ocr_conv(input, num, with_bn):
def conv_block(input, filter_size, group_size, with_bn):
return fluid.nets.img_conv_group(
input=input,
conv_num_filter=[num_filter] * groups,
conv_num_filter=[filter_size] * group_size,
pool_size=2,
pool_stride=2,
conv_padding=1,
......@@ -46,23 +56,21 @@ images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# encoder part
conv_features = ocr_convs(imges, 8, True)
conv_features = ocr_conv(images, 8, True)
sliced_feature = fluid.layers.im2sequence(
input=conv_features,
stride_x=1,
stride_y=1,
block_x=1,
block_y=3, )
input=conv_features, stride=[1, 1], filter_size=[1, 3])
# TODO(wanghaoshuang): repaced by GRU
gru_forward = fluid.layers.lstm(input=sliced_feature, size=200, act="relu")
gru_backward = fluid.layers.lstm(
input=sliced_feature, size=200, reverse=True, act="relu")
gru_forward, _ = fluid.layers.dynamic_lstm(input=sliced_feature, size=3 * 128)
gru_backward, _ = fluid.layers.dynamic_lstm(
input=sliced_feature, size=3 * 128, is_reverse=True)
fc_out = fluid.layers.fc(input=[gru_forward, gru_backward],
size=num_classes + 1)
out = fluid.layers.fc(input=[gru_forward, gru_backward], size=num_classes + 1)
cost = fluid.layers.warpctc(
input=out,
input=fc_out,
label=label,
size=num_classes + 1,
blank=num_classes,
......@@ -74,7 +82,7 @@ optimizer = fluid.optimizer.Momentum(
learning_rate=((1.0e-3) / 16), momentum=0.9)
opts = optimizer.minimize(cost)
decoded_out = fluid.layers.ctc_greedy_decoder(input=output, blank=class_num)
decoded_out = fluid.layers.ctc_greedy_decoder(input=fc_out, blank=num_classes)
error_evaluator = fluid.evaluator.EditDistance(input=decoded_out, label=label)
BATCH_SIZE = 16
......@@ -83,7 +91,7 @@ PASS_NUM = 1
# TODO(wanghaoshuang): replaced by correct data reader
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=128 * 10),
random_reader(num_classes), buf_size=128 * 10),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
......@@ -96,7 +104,7 @@ for pass_id in range(PASS_NUM):
for data in train_reader():
loss, error = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost] + error.metrics)
fetch_list=[avg_cost] + error_evaluator.metrics)
pass_error = error_evaluator.eval(exe)
print "loss: %s; distance error: %s; pass_dis_error: %s;" % (
str(loss), str(error), str(pass_error))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册