提交 4e37cccb 编写于 作者: W wanghaoshuang

Fix issues

上级 a7d6b1af
...@@ -15,6 +15,16 @@ import sys ...@@ -15,6 +15,16 @@ import sys
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
import numpy as np
def random_reader(num_class):
def reader():
sequence_len = np.random.randint(5, 10)
yield np.random.uniform(0.1, 1, [1, 512, 512]), np.random.randint(
0, num_class + 1, [sequence_len])
return reader
def ocr_conv(input, num, with_bn): def ocr_conv(input, num, with_bn):
...@@ -23,7 +33,7 @@ def ocr_conv(input, num, with_bn): ...@@ -23,7 +33,7 @@ def ocr_conv(input, num, with_bn):
def conv_block(input, filter_size, group_size, with_bn): def conv_block(input, filter_size, group_size, with_bn):
return fluid.nets.img_conv_group( return fluid.nets.img_conv_group(
input=input, input=input,
conv_num_filter=[num_filter] * groups, conv_num_filter=[filter_size] * group_size,
pool_size=2, pool_size=2,
pool_stride=2, pool_stride=2,
conv_padding=1, conv_padding=1,
...@@ -46,23 +56,21 @@ images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') ...@@ -46,23 +56,21 @@ images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# encoder part # encoder part
conv_features = ocr_convs(imges, 8, True) conv_features = ocr_conv(images, 8, True)
sliced_feature = fluid.layers.im2sequence( sliced_feature = fluid.layers.im2sequence(
input=conv_features, input=conv_features, stride=[1, 1], filter_size=[1, 3])
stride_x=1,
stride_y=1,
block_x=1,
block_y=3, )
# TODO(wanghaoshuang): repaced by GRU # TODO(wanghaoshuang): repaced by GRU
gru_forward = fluid.layers.lstm(input=sliced_feature, size=200, act="relu") gru_forward, _ = fluid.layers.dynamic_lstm(input=sliced_feature, size=3 * 128)
gru_backward = fluid.layers.lstm( gru_backward, _ = fluid.layers.dynamic_lstm(
input=sliced_feature, size=200, reverse=True, act="relu") input=sliced_feature, size=3 * 128, is_reverse=True)
fc_out = fluid.layers.fc(input=[gru_forward, gru_backward],
size=num_classes + 1)
out = fluid.layers.fc(input=[gru_forward, gru_backward], size=num_classes + 1)
cost = fluid.layers.warpctc( cost = fluid.layers.warpctc(
input=out, input=fc_out,
label=label, label=label,
size=num_classes + 1, size=num_classes + 1,
blank=num_classes, blank=num_classes,
...@@ -74,7 +82,7 @@ optimizer = fluid.optimizer.Momentum( ...@@ -74,7 +82,7 @@ optimizer = fluid.optimizer.Momentum(
learning_rate=((1.0e-3) / 16), momentum=0.9) learning_rate=((1.0e-3) / 16), momentum=0.9)
opts = optimizer.minimize(cost) opts = optimizer.minimize(cost)
decoded_out = fluid.layers.ctc_greedy_decoder(input=output, blank=class_num) decoded_out = fluid.layers.ctc_greedy_decoder(input=fc_out, blank=num_classes)
error_evaluator = fluid.evaluator.EditDistance(input=decoded_out, label=label) error_evaluator = fluid.evaluator.EditDistance(input=decoded_out, label=label)
BATCH_SIZE = 16 BATCH_SIZE = 16
...@@ -83,7 +91,7 @@ PASS_NUM = 1 ...@@ -83,7 +91,7 @@ PASS_NUM = 1
# TODO(wanghaoshuang): replaced by correct data reader # TODO(wanghaoshuang): replaced by correct data reader
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=128 * 10), random_reader(num_classes), buf_size=128 * 10),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -96,7 +104,7 @@ for pass_id in range(PASS_NUM): ...@@ -96,7 +104,7 @@ for pass_id in range(PASS_NUM):
for data in train_reader(): for data in train_reader():
loss, error = exe.run(fluid.default_main_program(), loss, error = exe.run(fluid.default_main_program(),
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost] + error.metrics) fetch_list=[avg_cost] + error_evaluator.metrics)
pass_error = error_evaluator.eval(exe) pass_error = error_evaluator.eval(exe)
print "loss: %s; distance error: %s; pass_dis_error: %s;" % ( print "loss: %s; distance error: %s; pass_dis_error: %s;" % (
str(loss), str(error), str(pass_error)) str(loss), str(error), str(pass_error))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册