diff --git a/scene_text_recognition/README.md b/scene_text_recognition/README.md
index 62afdc35cb62ab63fce0231e73e0a131765b6aa5..de1418dd524b7961aea53ac8859f1d265dea6b33 100644
--- a/scene_text_recognition/README.md
+++ b/scene_text_recognition/README.md
@@ -4,18 +4,18 @@
在现实生活中,包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息,\[[1](#参考文献)\]使用深度学习模型自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
-本文将针对 **场景文字识别 (STR, Scene Text Recognition)** 任务,演示如何用 PaddlePaddle 实现 一个端对端 CTC 的模型 **CRNN(Convolutional Recurrent Neural Network)**
-\[[2](#参考文献)\],具体的,本文使用如下图片进行训练,需要识别文字对应的文字 "keep"。
+本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。以下图为例,给定一个场景图片,STR需要从图片中识别出对应的文字"keep":
+# 场景文字识别 (STR, Scene Text Recognition)
+
+## STR任务简介
+
+在现实生活中,包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息,\[[1](#参考文献)\]使用深度学习模型自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
+
+本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。以下图为例,给定一个场景图片,STR需要从图片中识别出对应的文字"keep":
+
+
+
+图 1. 数据示例 "keep"
+
+
+
+## 使用 PaddlePaddle 训练与预测
+
+### 模型训练
+训练脚本 [./train.py](./train.py) 中设置了如下命令行参数:
+
+```
+usage: train.py [-h] --image_shape IMAGE_SHAPE --train_file_list
+ TRAIN_FILE_LIST --test_file_list TEST_FILE_LIST
+ [--batch_size BATCH_SIZE]
+ [--model_output_prefix MODEL_OUTPUT_PREFIX]
+ [--trainer_count TRAINER_COUNT]
+ [--save_period_by_batch SAVE_PERIOD_BY_BATCH]
+ [--num_passes NUM_PASSES]
+
+PaddlePaddle CTC example
+
+optional arguments:
+ -h, --help show this help message and exit
+ --image_shape IMAGE_SHAPE
+ image's shape, format is like '173,46'
+ --train_file_list TRAIN_FILE_LIST
+ path of the file which contains path list of train
+ image files
+ --test_file_list TEST_FILE_LIST
+ path of the file which contains path list of test
+ image files
+ --batch_size BATCH_SIZE
+ size of a mini-batch
+ --model_output_prefix MODEL_OUTPUT_PREFIX
+ prefix of path for model to store (default:
+ ./model.ctc)
+ --trainer_count TRAINER_COUNT
+ number of training threads
+ --save_period_by_batch SAVE_PERIOD_BY_BATCH
+ save model to disk every N batches
+ --num_passes NUM_PASSES
+ number of passes to train (default: 1)
+```
+
+重要的几个参数包括:
+
+- `image_shape` 图片的尺寸
+- `train_file_list` 训练数据的列表文件,每行一个路径加对应的text,具体格式为:
+```
+word_1.png, "PROPER"
+word_2.png, "FOOD"
+```
+- `test_file_list` 测试数据的列表文件,格式同上
+
+### 预测
+预测部分由infer.py完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在infer.py中指定具体的模型目录、图片固定尺寸、batch_size和图片文件的列表文件。例如:
+```python
+model_path = "model.ctc-pass-9-batch-150-test.tar.gz"
+image_shape = "173,46"
+batch_size = 50
+infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt'
+```
+然后运行```python infer.py```
+
+
+### 具体执行的过程:
+
+1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: Challenge2_Training_Task3_Images_GT.zip、Challenge2_Test_Task3_Images.zip和 Challenge2_Test_Task3_GT.txt。
+分别对应训练集的图片和图片对应的单词,测试集的图片,测试数据对应的单词,然后执行以下命令,对数据解压并移动至目标文件夹:
+
+```
+mkdir -p data/train_data
+mkdir -p data/test_data
+unzip Challenge2_Training_Task3_Images_GT.zip -d data/train_data
+unzip Challenge2_Test_Task3_Images.zip -d data/test_data
+mv Challenge2_Test_Task3_GT.txt data/test_data
+```
+
+2.获取训练数据文件夹中 `gt.txt` 的路径 (data/train_data)和测试数据文件夹中`Challenge2_Test_Task3_GT.txt`的路径(data/test_data)
+
+3.执行命令
+```
+python train.py --train_file_list data/train_data/gt.txt --test_file_list data/test_data/Challenge2_Test_Task3_GT.txt --image_shape '173,46'
+```
+4.训练过程中,模型参数会自动备份到指定目录,默认为 ./model.ctc
+
+5.设置infer.py中的相关参数(模型所在路径),运行```python infer.py``` 进行预测
+
+
+### 其他数据集
+
+- [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)(41G)
+- [ICDAR 2003 Robust Reading Competitions](http://www.iapr-tc11.org/mediawiki/index.php?title=ICDAR_2003_Robust_Reading_Competitions)
+
+### 注意事项
+
+- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行
+- 本模型参数较多,占用显存比较大,实际执行时可以调节batch_size 控制显存占用
+- 本模型使用的数据集较小,可以选用其他更大的数据集\[[3](#参考文献)\]来训练需要的模型
+
+## 参考文献
+
+1. [Google Now Using ReCAPTCHA To Decode Street View Addresses](https://techcrunch.com/2012/03/29/google-now-using-recaptcha-to-decode-street-view-addresses/)
+2. [Focused Scene Text](http://rrc.cvc.uab.es/?ch=2&com=introduction)
+3. [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)
+
+
+
+
+
+
diff --git a/scene_text_recognition/infer.py b/scene_text_recognition/infer.py
index 2c8f675e247265fb4c2229f6a2f2f616a7d7cb3e..ff1f43be56f3b108e5a940628b7eab2bd20017a8 100644
--- a/scene_text_recognition/infer.py
+++ b/scene_text_recognition/infer.py
@@ -1,13 +1,14 @@
import logging
import argparse
-import paddle.v2 as paddle
import gzip
+
+import paddle.v2 as paddle
from model import Model
from data_provider import get_file_list, AsciiDic, ImageDataset
from decoder import ctc_greedy_decoder
-def infer(inferer, test_batch, labels):
+def infer_batch(inferer, test_batch, labels):
infer_results = inferer.infer(input=test_batch)
num_steps = len(infer_results) // len(test_batch)
probs_split = [
@@ -23,15 +24,11 @@ def infer(inferer, test_batch, labels):
results.append(output_transcription)
for result, label in zip(results, labels):
- print("\nOutput Transcription: %s\nTarget Transcription: %s" % (result,
- label))
+ print("\nOutput Transcription: %s\nTarget Transcription: %s" %
+ (result, label))
-if __name__ == "__main__":
- model_path = "model.ctc-pass-1-batch-150-test-10.2607016472.tar.gz"
- image_shape = "173,46"
- batch_size = 50
- infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt'
+def infer(model_path, image_shape, batch_size, infer_file_list):
image_shape = tuple(map(int, image_shape.split(',')))
infer_generator = get_file_list(infer_file_list)
@@ -49,8 +46,17 @@ if __name__ == "__main__":
test_batch.append([image])
labels.append(label)
if len(test_batch) == batch_size:
- infer(inferer, test_batch, labels)
+ infer_batch(inferer, test_batch, labels)
test_batch = []
labels = []
if test_batch:
- infer(inferer, test_batch, labels)
+ infer_batch(inferer, test_batch, labels)
+
+
+if __name__ == "__main__":
+ model_path = "model.ctc-pass-9-batch-150-test.tar.gz"
+ image_shape = "173,46"
+ batch_size = 50
+ infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt'
+
+ infer(model_path, image_shape, batch_size, infer_file_list)
diff --git a/scene_text_recognition/model.py b/scene_text_recognition/model.py
index d5ce3d8c4641f5ad6507b5011225fb61f1cd67cc..2ea1240d4f42a82ce2b2a853cb32dc16a7eb42f7 100644
--- a/scene_text_recognition/model.py
+++ b/scene_text_recognition/model.py
@@ -5,66 +5,15 @@ from paddle.v2.activation import Relu, Linear
from paddle.v2.networks import img_conv_group, simple_gru
-def conv_groups(input_image, num, with_bn):
- '''
- a deep CNN.
- @input_image: input image
- @num: number of CONV filters
- @with_bn: whether with batch normal
- '''
- assert num % 4 == 0
-
- tmp = img_conv_group(
- input=input_image,
- num_channels=1,
- conv_padding=1,
- conv_num_filter=[16] * (num / 4),
- conv_filter_size=3,
- conv_act=Relu(),
- conv_with_batchnorm=with_bn,
- pool_size=2,
- pool_stride=2, )
-
- tmp = img_conv_group(
- input=tmp,
- conv_padding=1,
- conv_num_filter=[32] * (num / 4),
- conv_filter_size=3,
- conv_act=Relu(),
- conv_with_batchnorm=with_bn,
- pool_size=2,
- pool_stride=2, )
-
- tmp = img_conv_group(
- input=tmp,
- conv_padding=1,
- conv_num_filter=[64] * (num / 4),
- conv_filter_size=3,
- conv_act=Relu(),
- conv_with_batchnorm=with_bn,
- pool_size=2,
- pool_stride=2, )
-
- tmp = img_conv_group(
- input=tmp,
- conv_padding=1,
- conv_num_filter=[128] * (num / 4),
- conv_filter_size=3,
- conv_act=Relu(),
- conv_with_batchnorm=with_bn,
- pool_size=2,
- pool_stride=2, )
-
- return tmp
-
-
class Model(object):
def __init__(self, num_classes, shape, is_infer=False):
'''
- @num_classes: int
- size of the character dict
- @shape: tuple of 2 int
- size of the input images
+ :param num_classes: size of the character dict.
+ :type num_classes: int
+ :param shape: size of the input images.
+ :type shape: tuple of 2 int
+ :param is_infer: infer mode or not
+ :type shape: bool
'''
self.num_classes = num_classes
self.shape = shape
@@ -90,7 +39,7 @@ class Model(object):
def __build_nn__(self):
# CNN output image features, 128 float matrixes
- conv_features = conv_groups(self.image, 8, True)
+ conv_features = self.conv_groups(self.image, 8, True)
# cutting CNN output into a sequence of feature vectors, which are
# 1 pixel wide and 11 pixel high.
@@ -125,3 +74,41 @@ class Model(object):
size=self.num_classes + 1,
norm_by_times=True,
blank=self.num_classes)
+
+ self.eval = evaluator.ctc_error(input=self.output, label=self.label)
+
+ def conv_groups(self, input_image, num, with_bn):
+ '''
+ :param input_image: input image.
+ :type input_image: LayerOutput
+ :param num: number of CONV filters.
+ :type num: int
+ :param with_bn: whether with batch normal.
+ :type with_bn: bool
+ '''
+ assert num % 4 == 0
+
+ filter_num_list = [16, 32, 64, 128]
+ is_input_image = True
+ tmp = input_image
+
+ for num_filter in filter_num_list:
+
+ if is_input_image:
+ num_channels = 1
+ is_input_image = False
+ else:
+ num_channels = None
+
+ tmp = img_conv_group(
+ input=tmp,
+ num_channels=num_channels,
+ conv_padding=1,
+ conv_num_filter=[num_filter] * (num / 4),
+ conv_filter_size=3,
+ conv_act=Relu(),
+ conv_with_batchnorm=with_bn,
+ pool_size=2,
+ pool_stride=2, )
+
+ return tmp
diff --git a/scene_text_recognition/train.py b/scene_text_recognition/train.py
index c8bcfd51c9407d13388fa1a33272b0c53ab4b852..212102c532e06b2d74aea5a95251c20af9d77747 100644
--- a/scene_text_recognition/train.py
+++ b/scene_text_recognition/train.py
@@ -1,7 +1,8 @@
import logging
import argparse
-import paddle.v2 as paddle
import gzip
+
+import paddle.v2 as paddle
from model import Model
from data_provider import get_file_list, AsciiDic, ImageDataset
@@ -33,70 +34,76 @@ parser.add_argument(
parser.add_argument(
'--save_period_by_batch',
type=int,
- default=50,
+ default=150,
help='save model to disk every N batches')
parser.add_argument(
'--num_passes',
type=int,
- default=1,
+ default=10,
help='number of passes to train (default: 1)')
args = parser.parse_args()
-image_shape = tuple(map(int, args.image_shape.split(',')))
-
-print 'image_shape', image_shape
-print 'batch_size', args.batch_size
-print 'train_file_list', args.train_file_list
-print 'test_file_list', args.test_file_list
-
-train_generator = get_file_list(args.train_file_list)
-test_generator = get_file_list(args.test_file_list)
-infer_generator = None
-
-dataset = ImageDataset(
- train_generator,
- test_generator,
- infer_generator,
- fixed_shape=image_shape,
- is_infer=False)
-
-paddle.init(use_gpu=True, trainer_count=args.trainer_count)
-
-model = Model(AsciiDic().size(), image_shape, is_infer=False)
-params = paddle.parameters.create(model.cost)
-optimizer = paddle.optimizer.Momentum(momentum=0)
-trainer = paddle.trainer.SGD(
- cost=model.cost, parameters=params, update_equation=optimizer)
-
-
-def event_handler(event):
- if isinstance(event, paddle.event.EndIteration):
- if event.batch_id % 100 == 0:
- print "Pass %d, batch %d, Samples %d, Cost %f" % (
- event.pass_id, event.batch_id, event.batch_id * args.batch_size,
- event.cost)
-
- if event.batch_id > 0 and event.batch_id % args.save_period_by_batch == 0:
- result = trainer.test(
- reader=paddle.batch(dataset.test, batch_size=10),
- feeding={'image': 0,
- 'label': 1})
- print "Test %d-%d, Cost %f " % (event.pass_id, event.batch_id,
- result.cost)
-
- path = "{}-pass-{}-batch-{}-test-{}.tar.gz".format(
- args.model_output_prefix, event.pass_id, event.batch_id,
- result.cost)
- with gzip.open(path, 'w') as f:
- params.to_tar(f)
-
-
-trainer.train(
- reader=paddle.batch(
- paddle.reader.shuffle(dataset.train, buf_size=500),
- batch_size=args.batch_size),
- feeding={'image': 0,
- 'label': 1},
- event_handler=event_handler,
- num_passes=args.num_passes)
+
+def main():
+ image_shape = tuple(map(int, args.image_shape.split(',')))
+
+ print 'image_shape', image_shape
+ print 'batch_size', args.batch_size
+ print 'train_file_list', args.train_file_list
+ print 'test_file_list', args.test_file_list
+
+ train_generator = get_file_list(args.train_file_list)
+ test_generator = get_file_list(args.test_file_list)
+ infer_generator = None
+
+ dataset = ImageDataset(
+ train_generator,
+ test_generator,
+ infer_generator,
+ fixed_shape=image_shape,
+ is_infer=False)
+
+ paddle.init(use_gpu=True, trainer_count=args.trainer_count)
+
+ model = Model(AsciiDic().size(), image_shape, is_infer=False)
+ params = paddle.parameters.create(model.cost)
+ optimizer = paddle.optimizer.Momentum(momentum=0)
+ trainer = paddle.trainer.SGD(
+ cost=model.cost,
+ parameters=params,
+ update_equation=optimizer,
+ extra_layers=model.eval)
+
+ def event_handler(event):
+ if isinstance(event, paddle.event.EndIteration):
+ if event.batch_id % 100 == 0:
+ print "Pass %d, batch %d, Samples %d, Cost %f, Eval %s" % (
+ event.pass_id, event.batch_id,
+ event.batch_id * args.batch_size, event.cost, event.metrics)
+
+ if event.batch_id > 0 and event.batch_id % args.save_period_by_batch == 0:
+ result = trainer.test(
+ reader=paddle.batch(dataset.test, batch_size=10),
+ feeding={'image': 0,
+ 'label': 1})
+ print "Test %d-%d, Cost %f, Eval %s" % (
+ event.pass_id, event.batch_id, result.cost, result.metrics)
+
+ path = "{}-pass-{}-batch-{}-test.tar.gz".format(
+ args.model_output_prefix, event.pass_id, event.batch_id)
+ with gzip.open(path, 'w') as f:
+ params.to_tar(f)
+
+ trainer.train(
+ reader=paddle.batch(
+ paddle.reader.shuffle(dataset.train, buf_size=500),
+ batch_size=args.batch_size),
+ feeding={'image': 0,
+ 'label': 1},
+ event_handler=event_handler,
+ num_passes=args.num_passes)
+
+
+if __name__ == "__main__":
+ main()