diff --git a/fluid/ocr_recognition/README.md b/fluid/ocr_recognition/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e71386a8e9a5c94633d31ce9bf40e26dd483fa87 --- /dev/null +++ b/fluid/ocr_recognition/README.md @@ -0,0 +1,4 @@ +# OCR Model + +This model built with paddle fluid is still under active development and is not +the final version. We welcome feedbacks. diff --git a/fluid/ocr_recognition/crnn_ctc_model.py b/fluid/ocr_recognition/crnn_ctc_model.py new file mode 100644 index 0000000000000000000000000000000000000000..719c0158ec0e28c46a2915e42bd81533f848673c --- /dev/null +++ b/fluid/ocr_recognition/crnn_ctc_model.py @@ -0,0 +1,178 @@ +import paddle.fluid as fluid + + +def conv_bn_pool(input, + group, + out_ch, + act="relu", + param=None, + bias=None, + param_0=None, + is_test=False): + tmp = input + for i in xrange(group): + tmp = fluid.layers.conv2d( + input=tmp, + num_filters=out_ch[i], + filter_size=3, + padding=1, + param_attr=param if param_0 is None else param_0, + act=None, # LinearActivation + use_cudnn=True) + tmp = fluid.layers.batch_norm( + input=tmp, + act=act, + param_attr=param, + bias_attr=bias, + is_test=is_test) + tmp = fluid.layers.pool2d( + input=tmp, pool_size=2, pool_type='max', pool_stride=2, use_cudnn=True) + + return tmp + + +def ocr_convs(input, + num, + with_bn, + regularizer=None, + gradient_clip=None, + is_test=False): + assert (num % 4 == 0) + + b = fluid.ParamAttr( + regularizer=regularizer, + gradient_clip=gradient_clip, + initializer=fluid.initializer.Normal(0.0, 0.0)) + w0 = fluid.ParamAttr( + regularizer=regularizer, + gradient_clip=gradient_clip, + initializer=fluid.initializer.Normal(0.0, 0.0005)) + w1 = fluid.ParamAttr( + regularizer=regularizer, + gradient_clip=gradient_clip, + initializer=fluid.initializer.Normal(0.0, 0.01)) + tmp = input + tmp = conv_bn_pool( + tmp, 2, [16, 16], param=w1, bias=b, param_0=w0, is_test=is_test) + + tmp = conv_bn_pool(tmp, 2, [32, 32], param=w1, bias=b, is_test=is_test) + tmp = conv_bn_pool(tmp, 2, [64, 64], param=w1, bias=b, is_test=is_test) + tmp = conv_bn_pool(tmp, 2, [128, 128], param=w1, bias=b, is_test=is_test) + return tmp + + +def encoder_net(images, + num_classes, + rnn_hidden_size=200, + regularizer=None, + gradient_clip=None, + is_test=False): + conv_features = ocr_convs( + images, + 8, + True, + regularizer=regularizer, + gradient_clip=gradient_clip, + is_test=is_test) + sliced_feature = fluid.layers.im2sequence( + input=conv_features, + stride=[1, 1], + filter_size=[conv_features.shape[2], 1]) + + para_attr = fluid.ParamAttr( + regularizer=regularizer, + gradient_clip=gradient_clip, + initializer=fluid.initializer.Normal(0.0, 0.02)) + bias_attr = fluid.ParamAttr( + regularizer=regularizer, + gradient_clip=gradient_clip, + initializer=fluid.initializer.Normal(0.0, 0.02), + learning_rate=2.0) + bias_attr_nobias = fluid.ParamAttr( + regularizer=regularizer, + gradient_clip=gradient_clip, + initializer=fluid.initializer.Normal(0.0, 0.02)) + + fc_1 = fluid.layers.fc(input=sliced_feature, + size=rnn_hidden_size * 3, + param_attr=para_attr, + bias_attr=bias_attr_nobias) + fc_2 = fluid.layers.fc(input=sliced_feature, + size=rnn_hidden_size * 3, + param_attr=para_attr, + bias_attr=bias_attr_nobias) + + gru_forward = fluid.layers.dynamic_gru( + input=fc_1, + size=rnn_hidden_size, + param_attr=para_attr, + bias_attr=bias_attr, + candidate_activation='relu') + gru_backward = fluid.layers.dynamic_gru( + input=fc_2, + size=rnn_hidden_size, + is_reverse=True, + param_attr=para_attr, + bias_attr=bias_attr, + candidate_activation='relu') + + w_attr = fluid.ParamAttr( + regularizer=regularizer, + gradient_clip=gradient_clip, + initializer=fluid.initializer.Normal(0.0, 0.02)) + b_attr = fluid.ParamAttr( + regularizer=regularizer, + gradient_clip=gradient_clip, + initializer=fluid.initializer.Normal(0.0, 0.0)) + + fc_out = fluid.layers.fc(input=[gru_forward, gru_backward], + size=num_classes + 1, + param_attr=w_attr, + bias_attr=b_attr) + + return fc_out + + +def ctc_train_net(images, label, args, num_classes): + regularizer = fluid.regularizer.L2Decay(args.l2) + gradient_clip = None + fc_out = encoder_net( + images, + num_classes, + regularizer=regularizer, + gradient_clip=gradient_clip) + + cost = fluid.layers.warpctc( + input=fc_out, label=label, blank=num_classes, norm_by_times=True) + sum_cost = fluid.layers.reduce_sum(cost) + + optimizer = fluid.optimizer.Momentum( + learning_rate=args.learning_rate, momentum=args.momentum) + optimizer.minimize(sum_cost) + + decoded_out = fluid.layers.ctc_greedy_decoder( + input=fc_out, blank=num_classes) + casted_label = fluid.layers.cast(x=label, dtype='int64') + error_evaluator = fluid.evaluator.EditDistance( + input=decoded_out, label=casted_label) + return sum_cost, error_evaluator + + +def ctc_infer(images, num_classes): + fc_out = encoder_net(images, num_classes, is_test=True) + return fluid.layers.ctc_greedy_decoder(input=fc_out, blank=num_classes) + + +def ctc_eval(images, label, num_classes): + fc_out = encoder_net(images, num_classes, is_test=True) + decoded_out = fluid.layers.ctc_greedy_decoder( + input=fc_out, blank=num_classes) + + casted_label = fluid.layers.cast(x=label, dtype='int64') + error_evaluator = fluid.evaluator.EditDistance( + input=decoded_out, label=casted_label) + + cost = fluid.layers.warpctc( + input=fc_out, label=label, blank=num_classes, norm_by_times=True) + + return error_evaluator, cost diff --git a/fluid/ocr_recognition/ctc_reader.py b/fluid/ocr_recognition/ctc_reader.py index e5264c33de526846161c1e3ada2555addba53e0d..f095c9b3bb7cdf36c247cca1c93ea7d636b91d24 100644 --- a/fluid/ocr_recognition/ctc_reader.py +++ b/fluid/ocr_recognition/ctc_reader.py @@ -4,6 +4,10 @@ import numpy as np from PIL import Image from paddle.v2.image import load_image +import paddle.v2 as paddle + +NUM_CLASSES = 10784 +DATA_SHAPE = [1, 48, 512] class DataGenerator(object): @@ -15,10 +19,10 @@ class DataGenerator(object): Reader interface for training. :param img_root_dir: The root path of the image for training. - :type file_list: str + :type file_list: str :param img_label_list: The path of the file for training. - :type file_list: str + :type file_list: str ''' @@ -76,7 +80,7 @@ class DataGenerator(object): Reader interface for inference. :param img_root_dir: The root path of the images for training. - :type file_list: str + :type file_list: str :param img_label_list: The path of the file for testing. :type file_list: list @@ -95,3 +99,28 @@ class DataGenerator(object): yield img, label return reader + + +def num_classes(): + return NUM_CLASSES + + +def data_shape(): + return DATA_SHAPE + + +def train(batch_size): + generator = DataGenerator() + return generator.train_reader( + "/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/train_images/", + "/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/train.list", + batch_size) + + +def test(batch_size=1): + generator = DataGenerator() + return paddle.batch( + generator.test_reader( + "/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/test_images/", + "/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/test.list" + ), batch_size) diff --git a/fluid/ocr_recognition/ctc_train.py b/fluid/ocr_recognition/ctc_train.py new file mode 100644 index 0000000000000000000000000000000000000000..85b1d2e708f73d7ac049af276626a38e76d19399 --- /dev/null +++ b/fluid/ocr_recognition/ctc_train.py @@ -0,0 +1,95 @@ +"""Trainer for OCR CTC model.""" +import paddle.v2 as paddle +import paddle.fluid as fluid +import dummy_reader +import ctc_reader +import argparse +from load_model import load_param +import functools +import sys +from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data +from crnn_ctc_model import ctc_train_net + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 32, "Minibatch size.") +add_arg('pass_num', int, 100, "# of training epochs.") +add_arg('log_period', int, 1000, "Log period.") +add_arg('learning_rate', float, 1.0e-3, "Learning rate.") +add_arg('l2', float, 0.0004, "L2 regularizer.") +add_arg('max_clip', float, 10.0, "Max clip threshold.") +add_arg('min_clip', float, -10.0, "Min clip threshold.") +add_arg('momentum', float, 0.9, "Momentum.") +add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.") +add_arg('device', int, 0, "Device id.'-1' means running on CPU" + "while '0' means GPU-0.") +# yapf: disable + +def load_parameter(place): + params = load_param('./name.map', './data/model/results_without_avg_window/pass-00000/') + for name in params: + # print "param: %s" % name + t = fluid.global_scope().find_var(name).get_tensor() + t.set(params[name], place) + + +def train(args, data_reader=dummy_reader): + """OCR CTC training""" + num_classes = data_reader.num_classes() + data_shape = data_reader.data_shape() + # define network + images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int32', lod_level=1) + sum_cost, error_evaluator = ctc_train_net(images, label, args, num_classes) + # data reader + train_reader = data_reader.train(args.batch_size) + test_reader = data_reader.test() + # prepare environment + place = fluid.CPUPlace() + if args.device >= 0: + place = fluid.CUDAPlace(args.device) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + #load_parameter(place) + + inference_program = fluid.io.get_inference_program(error_evaluator) + + for pass_id in range(args.pass_num): + error_evaluator.reset(exe) + batch_id = 1 + total_loss = 0.0 + total_seq_error = 0.0 + # train a pass + for data in train_reader(): + batch_loss, _, batch_seq_error = exe.run( + fluid.default_main_program(), + feed=get_feeder_data(data, place), + fetch_list=[sum_cost] + error_evaluator.metrics) + total_loss += batch_loss[0] + total_seq_error += batch_seq_error[0] + if batch_id % 10 == 1: + print '.', + sys.stdout.flush() + if batch_id % args.log_period == 1: + print "\nPass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq error: %s." % ( + pass_id, batch_id, total_loss / (batch_id * args.batch_size), total_seq_error / (batch_id * args.batch_size)) + sys.stdout.flush() + batch_id += 1 + + # evaluate model on test data + error_evaluator.reset(exe) + for data in test_reader(): + exe.run(inference_program, feed=get_feeder_data(data, place)) + _, test_seq_error = error_evaluator.eval(exe) + print "\nEnd pass[%d]; Test seq error: %s.\n" % ( + pass_id, str(test_seq_error[0])) + +def main(): + args = parser.parse_args() + print_arguments(args) + train(args, data_reader=ctc_reader) + +if __name__ == "__main__": + main() diff --git a/fluid/ocr_recognition/dummy_reader.py b/fluid/ocr_recognition/dummy_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..def91b1dd95857e7df740271cac486001da5f24b --- /dev/null +++ b/fluid/ocr_recognition/dummy_reader.py @@ -0,0 +1,52 @@ +"""A dummy reader for test.""" +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import numpy as np +import paddle.v2 as paddle + +DATA_SHAPE = [1, 512, 512] +NUM_CLASSES = 20 + + +def _read_creater(num_sample=1024, min_seq_len=1, max_seq_len=10): + def reader(): + for i in range(num_sample): + sequence_len = np.random.randint(min_seq_len, max_seq_len) + x = np.random.uniform(0.1, 1, DATA_SHAPE).astype("float32") + y = np.random.randint(0, NUM_CLASSES + 1, + [sequence_len]).astype("int32") + yield x, y + + return reader + + +def train(batch_size, num_sample=128): + """Get train dataset reader.""" + return paddle.batch(_read_creater(num_sample=num_sample), batch_size) + + +def test(batch_size=1, num_sample=16): + """Get test dataset reader.""" + return paddle.batch(_read_creater(num_sample=num_sample), batch_size) + + +def data_shape(): + """Get image shape in CHW order.""" + return DATA_SHAPE + + +def num_classes(): + """Get number of total classes.""" + return NUM_CLASSES diff --git a/fluid/ocr_recognition/eval.py b/fluid/ocr_recognition/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..342d0f16cd5f321d56988273cd6f47759e31bef0 --- /dev/null +++ b/fluid/ocr_recognition/eval.py @@ -0,0 +1,55 @@ +import paddle.v2 as paddle +import paddle.fluid as fluid +from load_model import load_param +from utility import get_feeder_data +from crnn_ctc_model import ctc_eval +import ctc_reader +import dummy_reader + + +def load_parameter(place): + params = load_param('./name.map', './data/model/results/pass-00062/') + for name in params: + print "param: %s" % name + t = fluid.global_scope().find_var(name).get_tensor() + t.set(params[name], place) + + +def evaluate(eval=ctc_eval, data_reader=dummy_reader): + """OCR inference""" + num_classes = data_reader.num_classes() + data_shape = data_reader.data_shape() + # define network + images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') + label = fluid.layers.data( + name='label', shape=[1], dtype='int32', lod_level=1) + evaluator, cost = eval(images, label, num_classes) + + # data reader + test_reader = data_reader.test() + # prepare environment + place = fluid.CUDAPlace(0) + #place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + print fluid.default_main_program() + load_parameter(place) + evaluator.reset(exe) + count = 0 + for data in test_reader(): + count += 1 + print 'Process samples: %d\r' % (count, ), + result, avg_distance, avg_seq_error = exe.run( + fluid.default_main_program(), + feed=get_feeder_data(data, place), + fetch_list=[cost] + evaluator.metrics) + avg_distance, avg_seq_error = evaluator.eval(exe) + print "avg_distance: %s; avg_seq_error: %s" % (avg_distance, avg_seq_error) + + +def main(): + evaluate(data_reader=ctc_reader) + + +if __name__ == "__main__": + main() diff --git a/fluid/ocr_recognition/inference.py b/fluid/ocr_recognition/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..32bc59e9b04dd91e2060b55adbb6264e7797fbe5 --- /dev/null +++ b/fluid/ocr_recognition/inference.py @@ -0,0 +1,48 @@ +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +from load_model import load_param +from utility import get_feeder_data +from crnn_ctc_model import ctc_infer +import ctc_reader +import dummy_reader + + +def load_parameter(place): + params = load_param('./name.map', './data/model/results/pass-00062/') + for name in params: + print "param: %s" % name + t = fluid.global_scope().find_var(name).get_tensor() + t.set(params[name], place) + + +def inference(infer=ctc_infer, data_reader=dummy_reader): + """OCR inference""" + num_classes = data_reader.num_classes() + data_shape = data_reader.data_shape() + # define network + images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') + sequence, tmp = infer(images, num_classes) + fluid.layers.Print(tmp) + # data reader + test_reader = data_reader.test() + # prepare environment + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + load_parameter(place) + + for data in test_reader(): + result = exe.run(fluid.default_main_program(), + feed=get_feeder_data( + data, place, need_label=False), + fetch_list=[tmp]) + print "result: %s" % (list(result[0].flatten()), ) + + +def main(): + inference(data_reader=ctc_reader) + + +if __name__ == "__main__": + main() diff --git a/fluid/ocr_recognition/load_model.py b/fluid/ocr_recognition/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fea9398866f3f3c276f6e998a18c6bdd0a2a488a --- /dev/null +++ b/fluid/ocr_recognition/load_model.py @@ -0,0 +1,33 @@ +import sys +import numpy as np +import ast + + +def load_parameter(file_name): + with open(file_name, 'rb') as f: + f.read(16) # skip header. + return np.fromfile(f, dtype=np.float32) + + +def load_param(name_map_file, old_param_dir): + result = {} + name_map = {} + shape_map = {} + with open(name_map_file, 'r') as map_file: + for param in map_file: + old_name, new_name, shape = param.strip().split('=') + name_map[new_name] = old_name + shape_map[new_name] = ast.literal_eval(shape) + + for new_name in name_map: + result[new_name] = load_parameter("/".join( + [old_param_dir, name_map[new_name]])).reshape(shape_map[new_name]) + return result + + +if __name__ == "__main__": + name_map_file = "./name.map" + old_param_dir = "./data/model/results/pass-00062/" + result = load_param(name_map_file, old_param_dir) + for p in result: + print "name: %s; param.shape: %s" % (p, result[p].shape) diff --git a/fluid/ocr_recognition/utility.py b/fluid/ocr_recognition/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..67a5bfa018bad5a4d69ba9d0d3cb63ff59214775 --- /dev/null +++ b/fluid/ocr_recognition/utility.py @@ -0,0 +1,90 @@ +"""Contains common utility functions.""" +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import distutils.util +import numpy as np +from paddle.fluid import core + + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("----------- Configuration Arguments -----------") + for arg, value in sorted(vars(args).iteritems()): + print("%s: %s" % (arg, value)) + print("------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) + + +def to_lodtensor(data, place): + seq_lens = [len(seq) for seq in data] + cur_len = 0 + lod = [cur_len] + for l in seq_lens: + cur_len += l + lod.append(cur_len) + flattened_data = np.concatenate(data, axis=0).astype("int32") + flattened_data = flattened_data.reshape([len(flattened_data), 1]) + res = core.LoDTensor() + res.set(flattened_data, place) + res.set_lod([lod]) + return res + + +def get_feeder_data(data, place, need_label=True): + pixel_tensor = core.LoDTensor() + pixel_data = None + pixel_data = np.concatenate( + map(lambda x: x[0][np.newaxis, :], data), axis=0).astype("float32") + pixel_tensor.set(pixel_data, place) + label_tensor = to_lodtensor(map(lambda x: x[1], data), place) + if need_label: + return {"pixel": pixel_tensor, "label": label_tensor} + else: + return {"pixel": pixel_tensor}