diff --git a/scene_text_recognition/README.md b/scene_text_recognition/README.md index 5e83a68eb5a025b0dbc139c60ab6c4fc35eb66c8..8ad3153e8e932783c513a10a6248154357140470 100644 --- a/scene_text_recognition/README.md +++ b/scene_text_recognition/README.md @@ -2,9 +2,9 @@ ## STR任务简介 -在现实生活中,包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息,\[[1](#参考文献)\]使用深度学习模型自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。 +在现实生活中,许多图片中的文字为图片所处场景的理解提供了丰富的语义信息(例如:路牌、菜单、街道标语等)。同时,场景图片文字识别技术的发展也促进了一些新型应用的产生,例如:\[[1](#参考文献)\]通过使用深度学习模型来自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。 -本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。以下图为例,给定一个场景图片,STR需要从图片中识别出对应的文字"keep"。 +本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。如下图所示,给定一张场景图片,`STR` 需要从中识别出对应的文字"keep"。


@@ -21,7 +21,7 @@ pip install -r requirements.txt ### 指定训练配置参数 -通过 `config.py` 脚本修改训练和模型配置参数,脚本中有对可配置参数的详细解释,示例如下: + `config.py` 脚本中包含了模型配置和训练相关的参数以及对应的详细解释,代码如下: ```python class TrainerConfig(object): @@ -43,7 +43,8 @@ class ModelConfig(object): ... ``` -修改 `config.py` 对参数进行调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。 + +修改 `config.py` 脚本可以实现对参数的调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。 ### 模型训练 训练脚本 [./train.py](./train.py) 中设置了如下命令行参数: @@ -54,24 +55,29 @@ Options: of train image files. [required] --test_file_list_path TEXT The path of the file which contains path list of test image files. [required] + --label_dict_path TEXT The path of label dictionary. If this parameter + is set, but the file does not exist, label + dictionay will be built from the training data + automatically. [required] --model_save_dir TEXT The path to save the trained models (default: 'models'). --help Show this message and exit. ``` -- `train_file_list` 训练数据的列表文件,每行一个路径加对应的text,具体格式为: +- `train_file_list` :训练数据的列表文件,每行由图片的存储路径和对应的标记文本组成,具体格式为: ``` word_1.png, "PROPER" word_2.png, "FOOD" ``` -- `test_file_list` 测试数据的列表文件,格式同上。 -- `model_save_dir` 模型参数会的保存目录目录, 默认为当前目录下的`models`目录。 +- `test_file_list` :测试数据的列表文件,格式同上。 +- `label_dict_path` :训练数据中标记字典的存储路径,如果指定路径中字典文件不存在,程序会使用训练数据中的标记数据自动生成标记字典。 +- `model_save_dir` :模型参数的保存目录,默认为`./models`。 ### 具体执行的过程: -1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: Challenge2_Training_Task3_Images_GT.zip、Challenge2_Test_Task3_Images.zip和 Challenge2_Test_Task3_GT.txt。 -分别对应训练集的图片和图片对应的单词,测试集的图片,测试数据对应的单词,然后执行以下命令,对数据解压并移动至目标文件夹: +1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: `Challenge2_Training_Task3_Images_GT.zip`、`Challenge2_Test_Task3_Images.zip` 和 `Challenge2_Test_Task3_GT.txt`。 +分别对应训练集的图片和图片对应的单词、测试集的图片、测试数据对应的单词。然后执行以下命令,对数据解压并移动至目标文件夹: ```bash mkdir -p data/train_data @@ -87,17 +93,19 @@ mv Challenge2_Test_Task3_GT.txt data/test_data ```bash python train.py \ --train_file_list_path 'data/train_data/gt.txt' \ ---test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' +--test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' \ +--label_dict_path 'label_dict.txt' ``` 4.训练过程中,模型参数会自动备份到指定目录,默认会保存在 `./models` 目录下。 ### 预测 -预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型目录、图片固定尺寸、batch_size(默认设置为10)和图片文件的列表文件。执行如下代码: +预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型保存路径、图片固定尺寸、batch_size(默认为10)、标记词典路径和图片文件的列表文件。执行如下代码: ```bash python infer.py \ --model_path 'models/params_pass_00000.tar.gz' \ --image_shape '173,46' \ +--label_dict_path 'label_dict.txt' \ --infer_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' ``` 即可进行预测。 @@ -109,9 +117,9 @@ python infer.py \ ### 注意事项 -- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行 -- 本模型参数较多,占用显存比较大,实际执行时可以调节`batch_size`控制显存占用 -- 本模型使用的数据集较小,可以选用其他更大的数据集\[[3](#参考文献)\]来训练需要的模型 +- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行。 +- 本模型参数较多,占用显存比较大,实际执行时可以通过调节 `batch_size` 来控制显存占用。 +- 本例使用的数据集较小,如有需要,可以选用其他更大的数据集\[[3](#参考文献)\]来训练模型。 ## 参考文献 diff --git a/scene_text_recognition/config.py b/scene_text_recognition/config.py index 9cc563549f409d7abf044a9cf9a95919f8bd6852..0bed3c04151bbf6b03dae4932489bb48ab134a0e 100644 --- a/scene_text_recognition/config.py +++ b/scene_text_recognition/config.py @@ -72,4 +72,4 @@ class ModelConfig(object): pool_size = 2 # The parameter pool_stride in image convolution group layer. - pool_stride = 2 + pool_stride = 2 \ No newline at end of file diff --git a/scene_text_recognition/index.html b/scene_text_recognition/index.html index 64a1160a4b09a313dbef3b93f0f40068d01592b4..dd16a2af73c08d95af69a1ccde59e77a00d478d7 100644 --- a/scene_text_recognition/index.html +++ b/scene_text_recognition/index.html @@ -44,9 +44,9 @@ ## STR任务简介 -在现实生活中,包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息,\[[1](#参考文献)\]使用深度学习模型自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。 +在现实生活中,许多图片中的文字为图片所处场景的理解提供了丰富的语义信息(例如:路牌、菜单、街道标语等)。同时,场景图片文字识别技术的发展也促进了一些新型应用的产生,例如:\[[1](#参考文献)\]通过使用深度学习模型来自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。 -本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。以下图为例,给定一个场景图片,STR需要从图片中识别出对应的文字"keep"。 +本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。如下图所示,给定一张场景图片,`STR` 需要从中识别出对应的文字"keep"。


@@ -63,7 +63,7 @@ pip install -r requirements.txt ### 指定训练配置参数 -通过 `config.py` 脚本修改训练和模型配置参数,脚本中有对可配置参数的详细解释,示例如下: + `config.py` 脚本中包含了模型配置和训练相关的参数以及对应的详细解释,代码如下: ```python class TrainerConfig(object): @@ -85,7 +85,8 @@ class ModelConfig(object): ... ``` -修改 `config.py` 对参数进行调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。 + +修改 `config.py` 脚本可以实现对参数的调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。 ### 模型训练 训练脚本 [./train.py](./train.py) 中设置了如下命令行参数: @@ -96,24 +97,29 @@ Options: of train image files. [required] --test_file_list_path TEXT The path of the file which contains path list of test image files. [required] + --label_dict_path TEXT The path of label dictionary. If this parameter + is set, but the file does not exist, label + dictionay will be built from the training data + automatically. [required] --model_save_dir TEXT The path to save the trained models (default: 'models'). --help Show this message and exit. ``` -- `train_file_list` 训练数据的列表文件,每行一个路径加对应的text,具体格式为: +- `train_file_list` :训练数据的列表文件,每行由图片的存储路径和对应的标记文本组成,具体格式为: ``` word_1.png, "PROPER" word_2.png, "FOOD" ``` -- `test_file_list` 测试数据的列表文件,格式同上。 -- `model_save_dir` 模型参数会的保存目录目录, 默认为当前目录下的`models`目录。 +- `test_file_list` :测试数据的列表文件,格式同上。 +- `label_dict_path` :训练数据中标记字典的存储路径,如果指定路径中字典文件不存在,程序会使用训练数据中的标记数据自动生成标记字典。 +- `model_save_dir` :模型参数的保存目录,默认为`./models`。 ### 具体执行的过程: -1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: Challenge2_Training_Task3_Images_GT.zip、Challenge2_Test_Task3_Images.zip和 Challenge2_Test_Task3_GT.txt。 -分别对应训练集的图片和图片对应的单词,测试集的图片,测试数据对应的单词,然后执行以下命令,对数据解压并移动至目标文件夹: +1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: `Challenge2_Training_Task3_Images_GT.zip`、`Challenge2_Test_Task3_Images.zip` 和 `Challenge2_Test_Task3_GT.txt`。 +分别对应训练集的图片和图片对应的单词、测试集的图片、测试数据对应的单词。然后执行以下命令,对数据解压并移动至目标文件夹: ```bash mkdir -p data/train_data @@ -129,17 +135,19 @@ mv Challenge2_Test_Task3_GT.txt data/test_data ```bash python train.py \ --train_file_list_path 'data/train_data/gt.txt' \ ---test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' +--test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' \ +--label_dict_path 'label_dict.txt' ``` 4.训练过程中,模型参数会自动备份到指定目录,默认会保存在 `./models` 目录下。 ### 预测 -预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型目录、图片固定尺寸、batch_size(默认设置为10)和图片文件的列表文件。执行如下代码: +预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型保存路径、图片固定尺寸、batch_size(默认为10)、标记词典路径和图片文件的列表文件。执行如下代码: ```bash python infer.py \ --model_path 'models/params_pass_00000.tar.gz' \ --image_shape '173,46' \ +--label_dict_path 'label_dict.txt' \ --infer_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' ``` 即可进行预测。 @@ -151,9 +159,9 @@ python infer.py \ ### 注意事项 -- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行 -- 本模型参数较多,占用显存比较大,实际执行时可以调节`batch_size`控制显存占用 -- 本模型使用的数据集较小,可以选用其他更大的数据集\[[3](#参考文献)\]来训练需要的模型 +- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行。 +- 本模型参数较多,占用显存比较大,实际执行时可以通过调节 `batch_size` 来控制显存占用。 +- 本例使用的数据集较小,如有需要,可以选用其他更大的数据集\[[3](#参考文献)\]来训练模型。 ## 参考文献 diff --git a/scene_text_recognition/infer.py b/scene_text_recognition/infer.py index b53c600b426d1c95c6a5e633b16eb2582c7d3a39..c572f5001adbaa15c5febdf2ae528fd647b66378 100644 --- a/scene_text_recognition/infer.py +++ b/scene_text_recognition/infer.py @@ -5,10 +5,10 @@ import paddle.v2 as paddle from model import Model from reader import DataGenerator from decoder import ctc_greedy_decoder -from utils import AsciiDic, get_file_list +from utils import get_file_list, load_dict, load_reverse_dict -def infer_batch(inferer, test_batch, labels): +def infer_batch(inferer, test_batch, labels, reversed_char_dict): infer_results = inferer.infer(input=test_batch) num_steps = len(infer_results) // len(test_batch) probs_split = [ @@ -19,7 +19,7 @@ def infer_batch(inferer, test_batch, labels): # Best path decode. for i, probs in enumerate(probs_split): output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=AsciiDic().id2word()) + probs_seq=probs, vocabulary=reversed_char_dict) results.append(output_transcription) for result, label in zip(results, labels): @@ -40,17 +40,26 @@ def infer_batch(inferer, test_batch, labels): type=int, default=10, help=("The number of examples in one batch (default: 10).")) +@click.option( + "--label_dict_path", + type=str, + required=True, + help=("The path of label dictionary. ")) @click.option( "--infer_file_list_path", type=str, required=True, help=("The path of the file which contains " "path list of image files for inference.")) -def infer(model_path, image_shape, batch_size, infer_file_list_path): +def infer(model_path, image_shape, batch_size, label_dict_path, + infer_file_list_path): + image_shape = tuple(map(int, image_shape.split(','))) infer_file_list = get_file_list(infer_file_list_path) - char_dict = AsciiDic() - dict_size = char_dict.size() + + char_dict = load_dict(label_dict_path) + reversed_char_dict = load_reverse_dict(label_dict_path) + dict_size = len(char_dict) data_generator = DataGenerator(char_dict=char_dict, image_shape=image_shape) paddle.init(use_gpu=True, trainer_count=1) @@ -66,11 +75,11 @@ def infer(model_path, image_shape, batch_size, infer_file_list_path): test_batch.append([image]) labels.append(label) if len(test_batch) == batch_size: - infer_batch(inferer, test_batch, labels) + infer_batch(inferer, test_batch, labels, reversed_char_dict) test_batch = [] labels = [] if test_batch: - infer_batch(inferer, test_batch, labels) + infer_batch(inferer, test_batch, labels, reversed_char_dict) if __name__ == "__main__": diff --git a/scene_text_recognition/model.py b/scene_text_recognition/model.py index 86dd852ceecf39eed38be609336822da5920a217..4fb297a503af74aeadb14f1654e0f6a0e094bbff 100644 --- a/scene_text_recognition/model.py +++ b/scene_text_recognition/model.py @@ -45,12 +45,11 @@ class Model(object): ''' Build the network topology. ''' - # CNN output image features. + # Get the image features with CNN. conv_features = self.conv_groups(self.image, conf.filter_num, conf.with_bn) - # Cut CNN output into a sequence of feature vectors, which are - # 1 pixel wide and 11 pixel high. + # Expand the output of CNN into a sequence of feature vectors. sliced_feature = layer.block_expand( input=conv_features, num_channels=conf.num_channels, @@ -59,7 +58,7 @@ class Model(object): block_x=conf.block_x, block_y=conf.block_y) - # RNNs to capture sequence information forwards and backwards. + # Use RNN to capture sequence information forwards and backwards. gru_forward = simple_gru( input=sliced_feature, size=conf.hidden_size, act=Relu()) gru_backward = simple_gru( @@ -68,7 +67,7 @@ class Model(object): act=Relu(), reverse=True) - # Map each step of RNN to character distribution. + # Map the output of RNN to character distribution. self.output = layer.fc( input=[gru_forward, gru_backward], size=self.num_classes + 1, diff --git a/scene_text_recognition/reader.py b/scene_text_recognition/reader.py index 013477adbbfbd8de432b40aeed6d709ec4e61f62..91321e34bf6ae748dfbfcf8fff22ee890769616c 100644 --- a/scene_text_recognition/reader.py +++ b/scene_text_recognition/reader.py @@ -18,35 +18,37 @@ class DataGenerator(object): def train_reader(self, file_list): ''' Reader interface for training. - + :param file_list: The path list of the image file for training. :type file_list: list ''' def reader(): - for i, (image, label) in enumerate(file_list): - yield self.load_image(image), self.char_dict.word2ids(label) + UNK_ID = self.char_dict[''] + for image_path, label in file_list: + label = [self.char_dict.get(c, UNK_ID) for c in label] + yield self.load_image(image_path), label return reader def infer_reader(self, file_list): ''' Reader interface for inference. - + :param file_list: The path list of the image file for inference. :type file_list: list ''' def reader(): - for i, (image, label) in enumerate(file_list): - yield self.load_image(image), label + for image_path, label in file_list: + yield self.load_image(image_path), label return reader def load_image(self, path): ''' - Load image and transform to 1-dimention vector. - + Load an image and transform it to 1-dimention vector. + :param path: The path of the image data. :type path: str ''' diff --git a/scene_text_recognition/train.py b/scene_text_recognition/train.py index 557f1ba5ee9bf0e8507b56af9d7460a20012a171..4497bee813615be9bfe3ad94a9c9eea918dcc2a4 100644 --- a/scene_text_recognition/train.py +++ b/scene_text_recognition/train.py @@ -6,7 +6,7 @@ import paddle.v2 as paddle from config import TrainerConfig as conf from model import Model from reader import DataGenerator -from utils import get_file_list, AsciiDic +from utils import get_file_list, build_label_dict, load_dict @click.command('train') @@ -22,19 +22,35 @@ from utils import get_file_list, AsciiDic required=True, help=("The path of the file which contains " "path list of test image files.")) +@click.option( + "--label_dict_path", + type=str, + required=True, + help=("The path of label dictionary. " + "If this parameter is set, but the file does not exist, " + "label dictionay will be built from " + "the training data automatically.")) @click.option( "--model_save_dir", type=str, default="models", help="The path to save the trained models (default: 'models').") -def train(train_file_list_path, test_file_list_path, model_save_dir): +def train(train_file_list_path, test_file_list_path, label_dict_path, + model_save_dir): if not os.path.exists(model_save_dir): os.mkdir(model_save_dir) + train_file_list = get_file_list(train_file_list_path) test_file_list = get_file_list(test_file_list_path) - char_dict = AsciiDic() - dict_size = char_dict.size() + + if not os.path.exists(label_dict_path): + print(("Label dictionary is not given, the dictionary " + "is automatically built from the training data.")) + build_label_dict(train_file_list, label_dict_path) + + char_dict = load_dict(label_dict_path) + dict_size = len(char_dict) data_generator = DataGenerator( char_dict=char_dict, image_shape=conf.image_shape) diff --git a/scene_text_recognition/utils.py b/scene_text_recognition/utils.py index dd43113ab045b4bdf1ad1b5a81d5dd0898b5fc6e..86bd3a1f477a710c3f245ee3dd582044f37a0f8d 100644 --- a/scene_text_recognition/utils.py +++ b/scene_text_recognition/utils.py @@ -1,41 +1,5 @@ import os - - -class AsciiDic(object): - UNK_ID = 0 - - def __init__(self): - self.dic = { - '': self.UNK_ID, - } - self.chars = [chr(i) for i in range(40, 171)] - for id, c in enumerate(self.chars): - self.dic[c] = id + 1 - - def lookup(self, w): - return self.dic.get(w, self.UNK_ID) - - def id2word(self): - ''' - Return a reversed char dict. - ''' - self.id2word = {} - for key, value in self.dic.items(): - self.id2word[value] = key - - return self.id2word - - def word2ids(self, word): - ''' - Transform a word to a list of ids. - - :param word: The word appears in image data. - :type word: str - ''' - return [self.lookup(c) for c in list(word)] - - def size(self): - return len(self.dic) +from collections import defaultdict def get_file_list(image_file_list): @@ -43,7 +7,7 @@ def get_file_list(image_file_list): Generate the file list for training and testing data. :param image_file_list: The path of the file which contains - path list of image files. + path list of image files. :type image_file_list: str ''' dirname = os.path.dirname(image_file_list) @@ -53,7 +17,53 @@ def get_file_list(image_file_list): line_split = line.strip().split(',', 1) filename = line_split[0].strip() path = os.path.join(dirname, filename) - label = line_split[1][2:-1] - path_list.append((path, label)) + label = line_split[1][2:-1].strip() + if label: + path_list.append((path, label)) return path_list + + +def build_label_dict(file_list, save_path): + """ + Build label dictionary from training data. + + :param file_list: The list which contains the labels + of training data. + :type file_list: list + :params save_path: The path where the label dictionary will be saved. + :type save_path: str + """ + values = defaultdict(int) + for path, label in file_list: + for c in label: + if c: + values[c] += 1 + + values[''] = 0 + with open(save_path, "w") as f: + for v, count in sorted( + values.iteritems(), key=lambda x: x[1], reverse=True): + f.write("%s\t%d\n" % (v, count)) + + +def load_dict(dict_path): + """ + Load label dictionary from the dictionary path. + + :param dict_path: The path of word dictionary. + :type dict_path: str + """ + return dict((line.strip().split("\t")[0], idx) + for idx, line in enumerate(open(dict_path, "r").readlines())) + + +def load_reverse_dict(dict_path): + """ + Load the reversed label dictionary from dictionary path. + + :param dict_path: The path of word dictionary. + :type dict_path: str + """ + return dict((idx, line.strip().split("\t")[0]) + for idx, line in enumerate(open(dict_path, "r").readlines()))