提交 192ef9c0 编写于 作者: W wanghaoshuang

Refine code according comments:

1. Remove 'ocr_ctc' directory to 'ocr'.
2. Init README.md
3. Fix learning rate and l2
4. Refine training log format
5. Reduce arguments of train function
6. Set filter_size of im2sequence dynamicly
7. Add fc op before GRU op
上级 c43a107d
# OCR Model
This model built with paddle fluid is still under active development and is not
the final version. We welcome feedbacks.
......@@ -12,7 +12,6 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import numpy as np
......@@ -25,13 +24,14 @@ from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 16, "Minibatch size.")
add_arg('batch_size', int, 2, "Minibatch size.")
add_arg('pass_num', int, 16, "# of training epochs.")
add_arg('learning_rate', float, 1.0e-3, "Learning rate.")
add_arg('l2', float, 0.0005, "L2 regularizer.")
add_arg('max_clip', float, 10.0, "Max clip threshold.")
add_arg('min_clip', float, -10.0, "Min clip threshold.")
add_arg('momentum', float, 0.9, "Momentum.")
add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.")
add_arg('device', int, -1, "Device id.'-1' means running on CPU"
"while '0' means GPU-0.")
# yapf: disable
......@@ -79,15 +79,17 @@ def _ocr_conv(input, num, with_bn, param_attrs):
conv4 = _conv_block(conv3, 128, (num / 4), with_bn)
return conv4
def _ocr_ctc_net(images, num_classes, param_attrs):
def _ocr_ctc_net(images, num_classes, param_attrs, rnn_hidden_size=200):
conv_features = _ocr_conv(images, 8, True, param_attrs)
sliced_feature = fluid.layers.im2sequence(
input=conv_features, stride=[1, 1], filter_size=[1, 3])
input=conv_features, stride=[1, 1], filter_size=[conv_features.shape[2], 1])
hidden_size = rnn_hidden_size
fc_1 = fluid.layers.fc(input=sliced_feature, size=hidden_size * 3, param_attr=param_attrs)
fc_2 = fluid.layers.fc(input=sliced_feature, size=hidden_size * 3, param_attr=param_attrs)
gru_forward = fluid.layers.dynamic_gru(
input=sliced_feature, size=128, param_attr=param_attrs)
input=fc_1, size=hidden_size, param_attr=param_attrs)
gru_backward = fluid.layers.dynamic_gru(
input=sliced_feature, size=128, is_reverse=True, param_attr=param_attrs)
input=fc_2, size=hidden_size, is_reverse=True, param_attr=param_attrs)
fc_out = fluid.layers.fc(input=[gru_forward, gru_backward],
size=num_classes + 1,
......@@ -96,26 +98,18 @@ def _ocr_ctc_net(images, num_classes, param_attrs):
def train(l2=0.0005,
min_clip=-10,
max_clip=10,
data_reader=dummy_reader,
learning_rate=1.0e-3,
momentum=0.9,
batch_size=16,
pass_num=2,
device=0):
def train(args, data_reader=dummy_reader):
"""OCR CTC training"""
num_classes = data_reader.num_classes()
# define network
param_attrs = fluid.ParamAttr(
regularizer=fluid.regularizer.L2Decay(l2 * batch_size),
gradient_clip=fluid.clip.GradientClipByValue(max_clip, min_clip))
regularizer=fluid.regularizer.L2Decay(args.l2),
gradient_clip=fluid.clip.GradientClipByValue(args.max_clip, args.min_clip))
data_shape = data_reader.data_shape()
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int32', lod_level=1)
fc_out = _ocr_ctc_net(images, num_classes, param_attrs)
fc_out = _ocr_ctc_net(images, num_classes, param_attrs, rnn_hidden_size=args.rnn_hidden_size)
# define cost and optimizer
cost = fluid.layers.warpctc(
......@@ -126,9 +120,8 @@ def train(l2=0.0005,
norm_by_times=True)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate / batch_size, momentum=momentum)
learning_rate=args.learning_rate, momentum=args.momentum)
opts = optimizer.minimize(cost)
# decoder and evaluator
decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes)
......@@ -136,30 +129,30 @@ def train(l2=0.0005,
error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label)
# data reader
train_reader = paddle.batch(data_reader.train(), batch_size=batch_size)
test_reader = paddle.batch(data_reader.test(), batch_size=batch_size)
train_reader = paddle.batch(data_reader.train(), batch_size=args.batch_size)
test_reader = paddle.batch(data_reader.test(), batch_size=args.batch_size)
# prepare environment
place = fluid.CPUPlace()
if device >= 0:
place = fluid.CUDAPlace(device)
if args.device >= 0:
place = fluid.CUDAPlace(args.device)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
inference_program = fluid.io.get_inference_program(error_evaluator)
for pass_id in range(pass_num):
for pass_id in range(args.pass_num):
error_evaluator.reset(exe)
batch_id = 0
# train a pass
for data in train_reader():
loss, batch_edit_distance, _, _ = exe.run(
loss, batch_edit_distance = exe.run(
fluid.default_main_program(),
feed=_get_feeder_data(data, place),
fetch_list=[avg_cost] + error_evaluator.metrics)
print "Pass[%d], batch[%d]; loss: %s; edit distance: %s." % (
print "Pass[%d]-batch[%d]; Loss: %s; Word error: %s." % (
pass_id, batch_id, loss[0], batch_edit_distance[0])
batch_id += 1
train_edit_distance = error_evaluator.eval(exe)
print "End pass[%d]; train data edit_distance: %s." % (
print "End pass[%d]; Train word error: %s." % (
pass_id, str(train_edit_distance[0]))
# evaluate model on test data
......@@ -167,21 +160,14 @@ def train(l2=0.0005,
for data in test_reader():
exe.run(inference_program, feed=_get_feeder_data(data, place))
test_edit_distance = error_evaluator.eval(exe)
print "End pass[%d]; test data edit_distance: %s." % (
print "End pass[%d]; Test word error: %s." % (
pass_id, str(test_edit_distance[0]))
def main():
args = parser.parse_args()
print_arguments(args)
train(l2=args.l2,
min_clip=args.min_clip,
max_clip=args.max_clip,
learning_rate=args.learning_rate,
momentum=args.momentum,
batch_size=args.batch_size,
pass_num=args.pass_num,
device=args.device)
train(args)
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册