提交 a87e0568 编写于 作者: W wanghaoshuang

Add arguments parser.

上级 bff7fbe3
"""Trainer for OCR CTC model."""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); #Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,15 +12,30 @@ ...@@ -11,15 +12,30 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
import sys
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
from paddle.v2.fluid import core
import numpy as np import numpy as np
import dummy_reader import dummy_reader
import argparse
import functools
def to_lodtensor(data, place): from paddle.v2.fluid import core
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 16, "Minibatch size.")
add_arg('pass_num', int, 16, "# of training epochs.")
add_arg('learning_rate', float, 1.0e-3, "Learning rate.")
add_arg('l2', float, 0.0005, "L2 regularizer.")
add_arg('max_clip', float, 10.0, "Max clip threshold.")
add_arg('min_clip', float, -10.0, "Min clip threshold.")
add_arg('momentum', float, 0.9, "Momentum.")
add_arg('device', int, -1, "Device id.'-1' means running on CPU"
"while '0' means GPU-0.")
# yapf: disable
def _to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data] seq_lens = [len(seq) for seq in data]
cur_len = 0 cur_len = 0
lod = [cur_len] lod = [cur_len]
...@@ -33,11 +49,18 @@ def to_lodtensor(data, place): ...@@ -33,11 +49,18 @@ def to_lodtensor(data, place):
res.set_lod([lod]) res.set_lod([lod])
return res return res
def _get_feeder_data(data, place):
pixel_tensor = core.LoDTensor()
pixel_data = np.concatenate(
map(lambda x: x[0][np.newaxis, :], data), axis=0).astype("float32")
pixel_tensor.set(pixel_data, place)
label_tensor = _to_lodtensor(map(lambda x: x[1], data), place)
return {"pixel": pixel_tensor, "label": label_tensor}
def ocr_conv(input, num, with_bn, param_attrs): def _ocr_conv(input, num, with_bn, param_attrs):
assert (num % 4 == 0) assert (num % 4 == 0)
def conv_block(input, filter_size, group_size, with_bn): def _conv_block(input, filter_size, group_size, with_bn):
return fluid.nets.img_conv_group( return fluid.nets.img_conv_group(
input=input, input=input,
conv_num_filter=[filter_size] * group_size, conv_num_filter=[filter_size] * group_size,
...@@ -50,15 +73,15 @@ def ocr_conv(input, num, with_bn, param_attrs): ...@@ -50,15 +73,15 @@ def ocr_conv(input, num, with_bn, param_attrs):
pool_type='max', pool_type='max',
param_attr=param_attrs) param_attr=param_attrs)
conv1 = conv_block(input, 16, (num / 4), with_bn) conv1 = _conv_block(input, 16, (num / 4), with_bn)
conv2 = conv_block(conv1, 32, (num / 4), with_bn) conv2 = _conv_block(conv1, 32, (num / 4), with_bn)
conv3 = conv_block(conv2, 64, (num / 4), with_bn) conv3 = _conv_block(conv2, 64, (num / 4), with_bn)
conv4 = conv_block(conv3, 128, (num / 4), with_bn) conv4 = _conv_block(conv3, 128, (num / 4), with_bn)
return conv4 return conv4
def ocr_ctc_net(images, num_classes, param_attrs): def _ocr_ctc_net(images, num_classes, param_attrs):
conv_features = ocr_conv(images, 8, True, param_attrs) conv_features = _ocr_conv(images, 8, True, param_attrs)
sliced_feature = fluid.layers.im2sequence( sliced_feature = fluid.layers.im2sequence(
input=conv_features, stride=[1, 1], filter_size=[1, 3]) input=conv_features, stride=[1, 1], filter_size=[1, 3])
gru_forward = fluid.layers.dynamic_gru( gru_forward = fluid.layers.dynamic_gru(
...@@ -72,34 +95,29 @@ def ocr_ctc_net(images, num_classes, param_attrs): ...@@ -72,34 +95,29 @@ def ocr_ctc_net(images, num_classes, param_attrs):
return fc_out return fc_out
def get_feeder_data(data, place):
pixel_tensor = core.LoDTensor()
pixel_data = np.concatenate(
map(lambda x: x[0][np.newaxis, :], data), axis=0).astype("float32")
pixel_tensor.set(pixel_data, place)
label_tensor = to_lodtensor(map(lambda x: x[1], data), place)
return {"pixel": pixel_tensor, "label": label_tensor}
def train(l2=0.0005,
def train(num_classes=20, min_clip=-10,
l2=0.0005 * 16, max_clip=10,
clip_threshold=10,
data_reader=dummy_reader, data_reader=dummy_reader,
learning_rate=((1.0e-3) / 16), learning_rate=1.0e-3,
momentum=0.9, momentum=0.9,
batch_size=4, batch_size=16,
pass_num=2): pass_num=2,
device=0):
"""OCR CTC training"""
num_classes = data_reader.num_classes()
# define network
param_attrs = fluid.ParamAttr( param_attrs = fluid.ParamAttr(
regularizer=fluid.regularizer.L2Decay(l2), regularizer=fluid.regularizer.L2Decay(l2 * batch_size),
gradient_clip=fluid.clip.GradientClipByValue(clip_threshold)) gradient_clip=fluid.clip.GradientClipByValue(max_clip, min_clip))
data_shape = data_reader.data_shape() data_shape = data_reader.data_shape()
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data( label = fluid.layers.data(
name='label', shape=[1], dtype='int32', lod_level=1) name='label', shape=[1], dtype='int32', lod_level=1)
fc_out = _ocr_ctc_net(images, num_classes, param_attrs)
fc_out = ocr_ctc_net(images, num_classes, param_attrs) # define cost and optimizer
cost = fluid.layers.warpctc( cost = fluid.layers.warpctc(
input=fc_out, input=fc_out,
label=label, label=label,
...@@ -107,51 +125,63 @@ def train(num_classes=20, ...@@ -107,51 +125,63 @@ def train(num_classes=20,
blank=num_classes, blank=num_classes,
norm_by_times=True) norm_by_times=True)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate, momentum=momentum) learning_rate=learning_rate / batch_size, momentum=momentum)
opts = optimizer.minimize(cost) opts = optimizer.minimize(cost)
# decoder and evaluator
decoded_out = fluid.layers.ctc_greedy_decoder( decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes) input=fc_out, blank=num_classes)
casted_label = fluid.layers.cast(x=label, dtype='int64') casted_label = fluid.layers.cast(x=label, dtype='int64')
error_evaluator = fluid.evaluator.EditDistance( error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label) input=decoded_out, label=casted_label)
# data reader
train_reader = paddle.batch(data_reader.train(), batch_size=batch_size) train_reader = paddle.batch(data_reader.train(), batch_size=batch_size)
test_reader = paddle.batch(data_reader.test(), batch_size=batch_size) test_reader = paddle.batch(data_reader.test(), batch_size=batch_size)
# prepare environment
#place = fluid.CPUPlace() place = fluid.CPUPlace()
place = fluid.CUDAPlace(0) if device >= 0:
place = fluid.CUDAPlace(device)
exe = fluid.Executor(place) exe = fluid.Executor(place)
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
inference_program = fluid.io.get_inference_program(error_evaluator) inference_program = fluid.io.get_inference_program(error_evaluator)
for pass_id in range(pass_num): for pass_id in range(pass_num):
error_evaluator.reset(exe) error_evaluator.reset(exe)
batch_id = 0 batch_id = 0
# train a pass
for data in train_reader(): for data in train_reader():
loss, batch_edit_distance, _, _ = exe.run( loss, batch_edit_distance, _, _ = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed=get_feeder_data(data, place), feed=_get_feeder_data(data, place),
fetch_list=[avg_cost] + error_evaluator.metrics) fetch_list=[avg_cost] + error_evaluator.metrics)
print "Pass[%d], batch[%d]; loss: %s; edit distance: %s" % ( print "Pass[%d], batch[%d]; loss: %s; edit distance: %s." % (
pass_id, batch_id, loss[0], batch_edit_distance[0]) pass_id, batch_id, loss[0], batch_edit_distance[0])
batch_id += 1 batch_id += 1
train_edit_distance = error_evaluator.eval(exe) train_edit_distance = error_evaluator.eval(exe)
print "End pass[%d]; train data edit_distance: %s" % ( print "End pass[%d]; train data edit_distance: %s." % (
pass_id, str(train_edit_distance)) pass_id, str(train_edit_distance[0]))
# test # evaluate model on test data
error_evaluator.reset(exe) error_evaluator.reset(exe)
for data in test_reader(): for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place)) exe.run(inference_program, feed=_get_feeder_data(data, place))
test_edit_distance = error_evaluator.eval(exe) test_edit_distance = error_evaluator.eval(exe)
print "End pass[%d]; test data edit_distance: %s" % ( print "End pass[%d]; test data edit_distance: %s." % (
pass_id, str(test_edit_distance)) pass_id, str(test_edit_distance[0]))
def main():
args = parser.parse_args()
print_arguments(args)
train(l2=args.l2,
min_clip=args.min_clip,
max_clip=args.max_clip,
learning_rate=args.learning_rate,
momentum=args.momentum,
batch_size=args.batch_size,
pass_num=args.pass_num,
device=args.device)
if __name__ == "__main__": if __name__ == "__main__":
train() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册