提交 bff7fbe3 编写于 作者: W wanghaoshuang

Restruct code.

1. Split data reader and train script.
2. Wrapper some function
上级 4e37cccb
import numpy as np
DATA_SHAPE = [1, 512, 512]
def _read_creater(num_sample=1024, num_class=20, min_seq_len=1, max_seq_len=10):
def reader():
for i in range(num_sample):
sequence_len = np.random.randint(min_seq_len, max_seq_len)
x = np.random.uniform(0.1, 1, DATA_SHAPE).astype("float32")
y = np.random.randint(0, num_class + 1,
[sequence_len]).astype("int32")
yield x, y
return reader
def train(num_sample=16):
return _read_creater(num_sample=num_sample)
def test(num_sample=16):
return _read_creater(num_sample=num_sample)
def data_shape():
return DATA_SHAPE
...@@ -12,22 +12,29 @@ ...@@ -12,22 +12,29 @@
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
import sys import sys
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
from paddle.v2.fluid import core
import numpy as np import numpy as np
import dummy_reader
def random_reader(num_class): def to_lodtensor(data, place):
def reader(): seq_lens = [len(seq) for seq in data]
sequence_len = np.random.randint(5, 10) cur_len = 0
yield np.random.uniform(0.1, 1, [1, 512, 512]), np.random.randint( lod = [cur_len]
0, num_class + 1, [sequence_len]) for l in seq_lens:
cur_len += l
return reader lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int32")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = core.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def ocr_conv(input, num, with_bn): def ocr_conv(input, num, with_bn, param_attrs):
assert (num % 4 == 0) assert (num % 4 == 0)
def conv_block(input, filter_size, group_size, with_bn): def conv_block(input, filter_size, group_size, with_bn):
...@@ -40,7 +47,8 @@ def ocr_conv(input, num, with_bn): ...@@ -40,7 +47,8 @@ def ocr_conv(input, num, with_bn):
conv_filter_size=3, conv_filter_size=3,
conv_act='relu', conv_act='relu',
conv_with_batchnorm=with_bn, conv_with_batchnorm=with_bn,
pool_type='max') pool_type='max',
param_attr=param_attrs)
conv1 = conv_block(input, 16, (num / 4), with_bn) conv1 = conv_block(input, 16, (num / 4), with_bn)
conv2 = conv_block(conv1, 32, (num / 4), with_bn) conv2 = conv_block(conv1, 32, (num / 4), with_bn)
...@@ -49,62 +57,101 @@ def ocr_conv(input, num, with_bn): ...@@ -49,62 +57,101 @@ def ocr_conv(input, num, with_bn):
return conv4 return conv4
num_classes = 9054 def ocr_ctc_net(images, num_classes, param_attrs):
data_shape = [1, 512, 512] conv_features = ocr_conv(images, 8, True, param_attrs)
sliced_feature = fluid.layers.im2sequence(
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') input=conv_features, stride=[1, 1], filter_size=[1, 3])
label = fluid.layers.data(name='label', shape=[1], dtype='int64') gru_forward = fluid.layers.dynamic_gru(
input=sliced_feature, size=128, param_attr=param_attrs)
# encoder part gru_backward = fluid.layers.dynamic_gru(
conv_features = ocr_conv(images, 8, True) input=sliced_feature, size=128, is_reverse=True, param_attr=param_attrs)
sliced_feature = fluid.layers.im2sequence( fc_out = fluid.layers.fc(input=[gru_forward, gru_backward],
input=conv_features, stride=[1, 1], filter_size=[1, 3]) size=num_classes + 1,
param_attr=param_attrs)
# TODO(wanghaoshuang): repaced by GRU return fc_out
gru_forward, _ = fluid.layers.dynamic_lstm(input=sliced_feature, size=3 * 128)
gru_backward, _ = fluid.layers.dynamic_lstm(
input=sliced_feature, size=3 * 128, is_reverse=True) def get_feeder_data(data, place):
pixel_tensor = core.LoDTensor()
fc_out = fluid.layers.fc(input=[gru_forward, gru_backward], pixel_data = np.concatenate(
size=num_classes + 1) map(lambda x: x[0][np.newaxis, :], data), axis=0).astype("float32")
pixel_tensor.set(pixel_data, place)
cost = fluid.layers.warpctc( label_tensor = to_lodtensor(map(lambda x: x[1], data), place)
input=fc_out, return {"pixel": pixel_tensor, "label": label_tensor}
label=label,
size=num_classes + 1,
blank=num_classes, def train(num_classes=20,
norm_by_times=True) l2=0.0005 * 16,
avg_cost = fluid.layers.mean(x=cost) clip_threshold=10,
data_reader=dummy_reader,
# TODO(wanghaoshuang): set clipping learning_rate=((1.0e-3) / 16),
optimizer = fluid.optimizer.Momentum( momentum=0.9,
learning_rate=((1.0e-3) / 16), momentum=0.9) batch_size=4,
opts = optimizer.minimize(cost) pass_num=2):
decoded_out = fluid.layers.ctc_greedy_decoder(input=fc_out, blank=num_classes) param_attrs = fluid.ParamAttr(
error_evaluator = fluid.evaluator.EditDistance(input=decoded_out, label=label) regularizer=fluid.regularizer.L2Decay(l2),
gradient_clip=fluid.clip.GradientClipByValue(clip_threshold))
BATCH_SIZE = 16 data_shape = data_reader.data_shape()
PASS_NUM = 1 images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(
# TODO(wanghaoshuang): replaced by correct data reader name='label', shape=[1], dtype='int32', lod_level=1)
train_reader = paddle.batch(
paddle.reader.shuffle( fc_out = ocr_ctc_net(images, num_classes, param_attrs)
random_reader(num_classes), buf_size=128 * 10),
batch_size=BATCH_SIZE) cost = fluid.layers.warpctc(
input=fc_out,
place = fluid.CPUPlace() label=label,
exe = fluid.Executor(place) size=num_classes + 1,
feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) blank=num_classes,
exe.run(fluid.default_startup_program()) norm_by_times=True)
avg_cost = fluid.layers.mean(x=cost)
for pass_id in range(PASS_NUM):
error_evaluator.reset(exe) optimizer = fluid.optimizer.Momentum(
for data in train_reader(): learning_rate=learning_rate, momentum=momentum)
loss, error = exe.run(fluid.default_main_program(), opts = optimizer.minimize(cost)
feed=feeder.feed(data),
fetch_list=[avg_cost] + error_evaluator.metrics) decoded_out = fluid.layers.ctc_greedy_decoder(
pass_error = error_evaluator.eval(exe) input=fc_out, blank=num_classes)
print "loss: %s; distance error: %s; pass_dis_error: %s;" % ( casted_label = fluid.layers.cast(x=label, dtype='int64')
str(loss), str(error), str(pass_error)) error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label)
train_reader = paddle.batch(data_reader.train(), batch_size=batch_size)
test_reader = paddle.batch(data_reader.test(), batch_size=batch_size)
#place = fluid.CPUPlace()
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
exe.run(fluid.default_startup_program())
inference_program = fluid.io.get_inference_program(error_evaluator)
for pass_id in range(pass_num):
error_evaluator.reset(exe)
batch_id = 0
for data in train_reader():
loss, batch_edit_distance, _, _ = exe.run(
fluid.default_main_program(),
feed=get_feeder_data(data, place),
fetch_list=[avg_cost] + error_evaluator.metrics)
print "Pass[%d], batch[%d]; loss: %s; edit distance: %s" % (
pass_id, batch_id, loss[0], batch_edit_distance[0])
batch_id += 1
train_edit_distance = error_evaluator.eval(exe)
print "End pass[%d]; train data edit_distance: %s" % (
pass_id, str(train_edit_distance))
# test
error_evaluator.reset(exe)
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
test_edit_distance = error_evaluator.eval(exe)
print "End pass[%d]; test data edit_distance: %s" % (
pass_id, str(test_edit_distance))
if __name__ == "__main__":
train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册