提交 192ef9c0 编写于 作者: W wanghaoshuang

Refine code according comments:

1. Remove 'ocr_ctc' directory to 'ocr'.
2. Init README.md
3. Fix learning rate and l2
4. Refine training log format
5. Reduce arguments of train function
6. Set filter_size of im2sequence dynamicly
7. Add fc op before GRU op
上级 c43a107d
# OCR Model
This model built with paddle fluid is still under active development and is not
the final version. We welcome feedbacks.
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
import numpy as np import numpy as np
...@@ -25,13 +24,14 @@ from utility import add_arguments, print_arguments ...@@ -25,13 +24,14 @@ from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('batch_size', int, 16, "Minibatch size.") add_arg('batch_size', int, 2, "Minibatch size.")
add_arg('pass_num', int, 16, "# of training epochs.") add_arg('pass_num', int, 16, "# of training epochs.")
add_arg('learning_rate', float, 1.0e-3, "Learning rate.") add_arg('learning_rate', float, 1.0e-3, "Learning rate.")
add_arg('l2', float, 0.0005, "L2 regularizer.") add_arg('l2', float, 0.0005, "L2 regularizer.")
add_arg('max_clip', float, 10.0, "Max clip threshold.") add_arg('max_clip', float, 10.0, "Max clip threshold.")
add_arg('min_clip', float, -10.0, "Min clip threshold.") add_arg('min_clip', float, -10.0, "Min clip threshold.")
add_arg('momentum', float, 0.9, "Momentum.") add_arg('momentum', float, 0.9, "Momentum.")
add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.")
add_arg('device', int, -1, "Device id.'-1' means running on CPU" add_arg('device', int, -1, "Device id.'-1' means running on CPU"
"while '0' means GPU-0.") "while '0' means GPU-0.")
# yapf: disable # yapf: disable
...@@ -79,15 +79,17 @@ def _ocr_conv(input, num, with_bn, param_attrs): ...@@ -79,15 +79,17 @@ def _ocr_conv(input, num, with_bn, param_attrs):
conv4 = _conv_block(conv3, 128, (num / 4), with_bn) conv4 = _conv_block(conv3, 128, (num / 4), with_bn)
return conv4 return conv4
def _ocr_ctc_net(images, num_classes, param_attrs, rnn_hidden_size=200):
def _ocr_ctc_net(images, num_classes, param_attrs):
conv_features = _ocr_conv(images, 8, True, param_attrs) conv_features = _ocr_conv(images, 8, True, param_attrs)
sliced_feature = fluid.layers.im2sequence( sliced_feature = fluid.layers.im2sequence(
input=conv_features, stride=[1, 1], filter_size=[1, 3]) input=conv_features, stride=[1, 1], filter_size=[conv_features.shape[2], 1])
hidden_size = rnn_hidden_size
fc_1 = fluid.layers.fc(input=sliced_feature, size=hidden_size * 3, param_attr=param_attrs)
fc_2 = fluid.layers.fc(input=sliced_feature, size=hidden_size * 3, param_attr=param_attrs)
gru_forward = fluid.layers.dynamic_gru( gru_forward = fluid.layers.dynamic_gru(
input=sliced_feature, size=128, param_attr=param_attrs) input=fc_1, size=hidden_size, param_attr=param_attrs)
gru_backward = fluid.layers.dynamic_gru( gru_backward = fluid.layers.dynamic_gru(
input=sliced_feature, size=128, is_reverse=True, param_attr=param_attrs) input=fc_2, size=hidden_size, is_reverse=True, param_attr=param_attrs)
fc_out = fluid.layers.fc(input=[gru_forward, gru_backward], fc_out = fluid.layers.fc(input=[gru_forward, gru_backward],
size=num_classes + 1, size=num_classes + 1,
...@@ -96,26 +98,18 @@ def _ocr_ctc_net(images, num_classes, param_attrs): ...@@ -96,26 +98,18 @@ def _ocr_ctc_net(images, num_classes, param_attrs):
def train(l2=0.0005, def train(args, data_reader=dummy_reader):
min_clip=-10,
max_clip=10,
data_reader=dummy_reader,
learning_rate=1.0e-3,
momentum=0.9,
batch_size=16,
pass_num=2,
device=0):
"""OCR CTC training""" """OCR CTC training"""
num_classes = data_reader.num_classes() num_classes = data_reader.num_classes()
# define network # define network
param_attrs = fluid.ParamAttr( param_attrs = fluid.ParamAttr(
regularizer=fluid.regularizer.L2Decay(l2 * batch_size), regularizer=fluid.regularizer.L2Decay(args.l2),
gradient_clip=fluid.clip.GradientClipByValue(max_clip, min_clip)) gradient_clip=fluid.clip.GradientClipByValue(args.max_clip, args.min_clip))
data_shape = data_reader.data_shape() data_shape = data_reader.data_shape()
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data( label = fluid.layers.data(
name='label', shape=[1], dtype='int32', lod_level=1) name='label', shape=[1], dtype='int32', lod_level=1)
fc_out = _ocr_ctc_net(images, num_classes, param_attrs) fc_out = _ocr_ctc_net(images, num_classes, param_attrs, rnn_hidden_size=args.rnn_hidden_size)
# define cost and optimizer # define cost and optimizer
cost = fluid.layers.warpctc( cost = fluid.layers.warpctc(
...@@ -126,9 +120,8 @@ def train(l2=0.0005, ...@@ -126,9 +120,8 @@ def train(l2=0.0005,
norm_by_times=True) norm_by_times=True)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate / batch_size, momentum=momentum) learning_rate=args.learning_rate, momentum=args.momentum)
opts = optimizer.minimize(cost) opts = optimizer.minimize(cost)
# decoder and evaluator # decoder and evaluator
decoded_out = fluid.layers.ctc_greedy_decoder( decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes) input=fc_out, blank=num_classes)
...@@ -136,30 +129,30 @@ def train(l2=0.0005, ...@@ -136,30 +129,30 @@ def train(l2=0.0005,
error_evaluator = fluid.evaluator.EditDistance( error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label) input=decoded_out, label=casted_label)
# data reader # data reader
train_reader = paddle.batch(data_reader.train(), batch_size=batch_size) train_reader = paddle.batch(data_reader.train(), batch_size=args.batch_size)
test_reader = paddle.batch(data_reader.test(), batch_size=batch_size) test_reader = paddle.batch(data_reader.test(), batch_size=args.batch_size)
# prepare environment # prepare environment
place = fluid.CPUPlace() place = fluid.CPUPlace()
if device >= 0: if args.device >= 0:
place = fluid.CUDAPlace(device) place = fluid.CUDAPlace(args.device)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
inference_program = fluid.io.get_inference_program(error_evaluator) inference_program = fluid.io.get_inference_program(error_evaluator)
for pass_id in range(pass_num): for pass_id in range(args.pass_num):
error_evaluator.reset(exe) error_evaluator.reset(exe)
batch_id = 0 batch_id = 0
# train a pass # train a pass
for data in train_reader(): for data in train_reader():
loss, batch_edit_distance, _, _ = exe.run( loss, batch_edit_distance = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed=_get_feeder_data(data, place), feed=_get_feeder_data(data, place),
fetch_list=[avg_cost] + error_evaluator.metrics) fetch_list=[avg_cost] + error_evaluator.metrics)
print "Pass[%d], batch[%d]; loss: %s; edit distance: %s." % ( print "Pass[%d]-batch[%d]; Loss: %s; Word error: %s." % (
pass_id, batch_id, loss[0], batch_edit_distance[0]) pass_id, batch_id, loss[0], batch_edit_distance[0])
batch_id += 1 batch_id += 1
train_edit_distance = error_evaluator.eval(exe) train_edit_distance = error_evaluator.eval(exe)
print "End pass[%d]; train data edit_distance: %s." % ( print "End pass[%d]; Train word error: %s." % (
pass_id, str(train_edit_distance[0])) pass_id, str(train_edit_distance[0]))
# evaluate model on test data # evaluate model on test data
...@@ -167,21 +160,14 @@ def train(l2=0.0005, ...@@ -167,21 +160,14 @@ def train(l2=0.0005,
for data in test_reader(): for data in test_reader():
exe.run(inference_program, feed=_get_feeder_data(data, place)) exe.run(inference_program, feed=_get_feeder_data(data, place))
test_edit_distance = error_evaluator.eval(exe) test_edit_distance = error_evaluator.eval(exe)
print "End pass[%d]; test data edit_distance: %s." % ( print "End pass[%d]; Test word error: %s." % (
pass_id, str(test_edit_distance[0])) pass_id, str(test_edit_distance[0]))
def main(): def main():
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
train(l2=args.l2, train(args)
min_clip=args.min_clip,
max_clip=args.max_clip,
learning_rate=args.learning_rate,
momentum=args.momentum,
batch_size=args.batch_size,
pass_num=args.pass_num,
device=args.device)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册