diff --git a/scene_text_recognition/README.md b/scene_text_recognition/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9974d1d74b6d3cd6c426ae95fd6969cfc09f4610 --- /dev/null +++ b/scene_text_recognition/README.md @@ -0,0 +1,128 @@ +# 场景文字识别 (STR, Scene Text Recognition) + +## STR任务简介 + +许多场景图像中包含着丰富的文本信息,它们可以从很大程度上帮助人们去认知场景图像的内容及含义,因此场景图像中的文本识别对所在图像的信息获取具有极其重要的作用。同时,场景图像文字识别技术的发展也促进了一些新型应用的产生,例如:\[[1](#参考文献)\]通过使用深度学习模型来自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。 + +本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 。任务如下图所示,给定一张场景图片,`STR` 需要从中识别出对应的文字"keep"。 + +

+
+图 1. 输入数据示例 "keep" +

+ + +## 使用 PaddlePaddle 训练与预测 + +### 安装依赖包 +```bash +pip install -r requirements.txt +``` + +### 修改配置参数 + + `config.py` 脚本中包含了模型配置和训练相关的参数以及对应的详细解释,代码片段如下: +```python +class TrainerConfig(object): + + # Whether to use GPU in training or not. + use_gpu = True + # The number of computing threads. + trainer_count = 1 + + # The training batch size. + batch_size = 10 + + ... + + +class ModelConfig(object): + + # Number of the filters for convolution group. + filter_num = 8 + + ... +``` + +修改 `config.py` 脚本可以实现对参数的调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。 + +### 模型训练 +训练脚本 [./train.py](./train.py) 中设置了如下命令行参数: + +``` +Options: + --train_file_list_path TEXT The path of the file which contains path list + of train image files. [required] + --test_file_list_path TEXT The path of the file which contains path list + of test image files. [required] + --label_dict_path TEXT The path of label dictionary. If this parameter + is set, but the file does not exist, label + dictionay will be built from the training data + automatically. [required] + --model_save_dir TEXT The path to save the trained models (default: + 'models'). + --help Show this message and exit. + +``` + +- `train_file_list` :训练数据的列表文件,每行由图片的存储路径和对应的标记文本组成,格式为: +``` +word_1.png, "PROPER" +word_2.png, "FOOD" +``` +- `test_file_list` :测试数据的列表文件,格式同上。 +- `label_dict_path` :训练数据中标记字典的存储路径,如果指定路径中字典文件不存在,程序会使用训练数据中的标记数据自动生成标记字典。 +- `model_save_dir` :模型参数的保存目录,默认为`./models`。 + +### 具体执行的过程: + +1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: `Challenge2_Training_Task3_Images_GT.zip`、`Challenge2_Test_Task3_Images.zip` 和 `Challenge2_Test_Task3_GT.txt`。 +分别对应训练集的图片和图片对应的单词、测试集的图片、测试数据对应的单词。然后执行以下命令,对数据解压并移动至目标文件夹: + +```bash +mkdir -p data/train_data +mkdir -p data/test_data +unzip Challenge2_Training_Task3_Images_GT.zip -d data/train_data +unzip Challenge2_Test_Task3_Images.zip -d data/test_data +mv Challenge2_Test_Task3_GT.txt data/test_data +``` + +2.获取训练数据文件夹中 `gt.txt` 的路径 (data/train_data)和测试数据文件夹中`Challenge2_Test_Task3_GT.txt`的路径(data/test_data)。 + +3.执行如下命令进行训练: +```bash +python train.py \ +--train_file_list_path 'data/train_data/gt.txt' \ +--test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' \ +--label_dict_path 'label_dict.txt' +``` +4.训练过程中,模型参数会自动备份到指定目录,默认会保存在 `./models` 目录下。 + + +### 预测 +预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型保存路径、图片固定尺寸、batch_size(默认为10)、标记词典路径和图片文件的列表文件。执行如下代码: +```bash +python infer.py \ +--model_path 'models/params_pass_00000.tar.gz' \ +--image_shape '173,46' \ +--label_dict_path 'label_dict.txt' \ +--infer_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' +``` +即可进行预测。 + +### 其他数据集 + +- [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)(41G) +- [ICDAR 2003 Robust Reading Competitions](http://www.iapr-tc11.org/mediawiki/index.php?title=ICDAR_2003_Robust_Reading_Competitions) + +### 注意事项 + +- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行。 +- 本模型参数较多,占用显存比较大,实际执行时可以通过调节 `batch_size` 来控制显存占用。 +- 本例使用的数据集较小,如有需要,可以选用其他更大的数据集\[[3](#参考文献)\]来训练模型。 + +## 参考文献 + +1. [Google Now Using ReCAPTCHA To Decode Street View Addresses](https://techcrunch.com/2012/03/29/google-now-using-recaptcha-to-decode-street-view-addresses/) +2. [Focused Scene Text](http://rrc.cvc.uab.es/?ch=2&com=introduction) +3. [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) diff --git a/scene_text_recognition/config.py b/scene_text_recognition/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc563549f409d7abf044a9cf9a95919f8bd6852 --- /dev/null +++ b/scene_text_recognition/config.py @@ -0,0 +1,75 @@ +__all__ = ["TrainerConfig", "ModelConfig"] + + +class TrainerConfig(object): + + # Whether to use GPU in training or not. + use_gpu = True + + # The number of computing threads. + trainer_count = 1 + + # The training batch size. + batch_size = 10 + + # The epoch number. + num_passes = 10 + + # Parameter updates momentum. + momentum = 0 + + # The shape of images. + image_shape = (173, 46) + + # The buffer size of the data reader. + # The number of buffer size samples will be shuffled in training. + buf_size = 1000 + + # The parameter is used to control logging period. + # Training log will be printed every log_period. + log_period = 50 + + +class ModelConfig(object): + + # Number of the filters for convolution group. + filter_num = 8 + + # Use batch normalization or not in image convolution group. + with_bn = True + + # The number of channels for block expand layer. + num_channels = 128 + + # The parameter stride_x in block expand layer. + stride_x = 1 + + # The parameter stride_y in block expand layer. + stride_y = 1 + + # The parameter block_x in block expand layer. + block_x = 1 + + # The parameter block_y in block expand layer. + block_y = 11 + + # The hidden size for gru. + hidden_size = num_channels + + # Use norm_by_times or not in warp ctc layer. + norm_by_times = True + + # The list for number of filter in image convolution group layer. + filter_num_list = [16, 32, 64, 128] + + # The parameter conv_padding in image convolution group layer. + conv_padding = 1 + + # The parameter conv_filter_size in image convolution group layer. + conv_filter_size = 3 + + # The parameter pool_size in image convolution group layer. + pool_size = 2 + + # The parameter pool_stride in image convolution group layer. + pool_stride = 2 diff --git a/scene_text_recognition/decoder.py b/scene_text_recognition/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba02a453070f955acd281031f4b29608bcaf65c --- /dev/null +++ b/scene_text_recognition/decoder.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from itertools import groupby +import numpy as np + + +def ctc_greedy_decoder(probs_seq, vocabulary): + """CTC greedy (best path) decoder. + Path consisting of the most probable tokens are further post-processed to + remove consecutive repetitions and all blanks. + :param probs_seq: 2-D list of probabilities over the vocabulary for each + character. Each element is a list of float probabilities + for one character. + :type probs_seq: list + :param vocabulary: Vocabulary list. + :type vocabulary: list + :return: Decoding result string. + :rtype: baseline + """ + # dimension verification + for probs in probs_seq: + if not len(probs) == len(vocabulary) + 1: + raise ValueError("probs_seq dimension mismatchedd with vocabulary") + # argmax to get the best index for each time step + max_index_list = list(np.array(probs_seq).argmax(axis=1)) + # remove consecutive duplicate indexes + index_list = [index_group[0] for index_group in groupby(max_index_list)] + # remove blank indexes + blank_index = len(vocabulary) + index_list = [index for index in index_list if index != blank_index] + # convert index list to string + return ''.join([vocabulary[index] for index in index_list]) diff --git a/scene_text_recognition/images/503.jpg b/scene_text_recognition/images/503.jpg new file mode 100644 index 0000000000000000000000000000000000000000..87253cd25a0e0f36b8430d01054ebe0d2f068356 Binary files /dev/null and b/scene_text_recognition/images/503.jpg differ diff --git a/scene_text_recognition/images/504.jpg b/scene_text_recognition/images/504.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ba19785d45c28e35fa2de2ffea0f5bf97e1ece09 Binary files /dev/null and b/scene_text_recognition/images/504.jpg differ diff --git a/scene_text_recognition/images/505.jpg b/scene_text_recognition/images/505.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f6c2b806cd63793f706cb87cf996e4e16b5cfe97 Binary files /dev/null and b/scene_text_recognition/images/505.jpg differ diff --git a/scene_text_recognition/images/ctc.png b/scene_text_recognition/images/ctc.png new file mode 100644 index 0000000000000000000000000000000000000000..45b7df3517758ab20ff796133204b385b634e039 Binary files /dev/null and b/scene_text_recognition/images/ctc.png differ diff --git a/scene_text_recognition/images/feature_vector.png b/scene_text_recognition/images/feature_vector.png new file mode 100644 index 0000000000000000000000000000000000000000..f47473fb87462cc0b02270f6121928eae2710627 Binary files /dev/null and b/scene_text_recognition/images/feature_vector.png differ diff --git a/scene_text_recognition/images/transcription.png b/scene_text_recognition/images/transcription.png new file mode 100644 index 0000000000000000000000000000000000000000..cba1f75838d720ab8e28c2d3aa977a008cc618e1 Binary files /dev/null and b/scene_text_recognition/images/transcription.png differ diff --git a/scene_text_recognition/index.html b/scene_text_recognition/index.html new file mode 100644 index 0000000000000000000000000000000000000000..4331b2b9636c159aedc1e96f0731e44cca9889cf --- /dev/null +++ b/scene_text_recognition/index.html @@ -0,0 +1,192 @@ + + + + + + + + + + + + + + + + + +
+
+ + + + + + + diff --git a/scene_text_recognition/infer.py b/scene_text_recognition/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..64bc4ddeb38be3bb2e24da69311b49d72987f6d7 --- /dev/null +++ b/scene_text_recognition/infer.py @@ -0,0 +1,86 @@ +import click +import gzip + +import paddle.v2 as paddle +from network_conf import Model +from reader import DataGenerator +from decoder import ctc_greedy_decoder +from utils import get_file_list, load_dict, load_reverse_dict + + +def infer_batch(inferer, test_batch, labels, reversed_char_dict): + infer_results = inferer.infer(input=test_batch) + num_steps = len(infer_results) // len(test_batch) + probs_split = [ + infer_results[i * num_steps:(i + 1) * num_steps] + for i in xrange(0, len(test_batch)) + ] + results = [] + # Best path decode. + for i, probs in enumerate(probs_split): + output_transcription = ctc_greedy_decoder( + probs_seq=probs, vocabulary=reversed_char_dict) + results.append(output_transcription) + + for result, label in zip(results, labels): + print("\nOutput Transcription: %s\nTarget Transcription: %s" % + (result, label)) + + +@click.command('infer') +@click.option( + "--model_path", type=str, required=True, help=("The path of saved model.")) +@click.option( + "--image_shape", + type=str, + required=True, + help=("The fixed size for image dataset (format is like: '173,46').")) +@click.option( + "--batch_size", + type=int, + default=10, + help=("The number of examples in one batch (default: 10).")) +@click.option( + "--label_dict_path", + type=str, + required=True, + help=("The path of label dictionary. ")) +@click.option( + "--infer_file_list_path", + type=str, + required=True, + help=("The path of the file which contains " + "path list of image files for inference.")) +def infer(model_path, image_shape, batch_size, label_dict_path, + infer_file_list_path): + + image_shape = tuple(map(int, image_shape.split(','))) + infer_file_list = get_file_list(infer_file_list_path) + + char_dict = load_dict(label_dict_path) + reversed_char_dict = load_reverse_dict(label_dict_path) + dict_size = len(char_dict) + data_generator = DataGenerator(char_dict=char_dict, image_shape=image_shape) + + paddle.init(use_gpu=True, trainer_count=1) + parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path)) + model = Model(dict_size, image_shape, is_infer=True) + inferer = paddle.inference.Inference( + output_layer=model.log_probs, parameters=parameters) + + test_batch = [] + labels = [] + for i, (image, + label) in enumerate(data_generator.infer_reader(infer_file_list)()): + test_batch.append([image]) + labels.append(label) + if len(test_batch) == batch_size: + infer_batch(inferer, test_batch, labels, reversed_char_dict) + test_batch = [] + labels = [] + if test_batch: + infer_batch(inferer, test_batch, labels, reversed_char_dict) + + +if __name__ == "__main__": + infer() diff --git a/scene_text_recognition/network_conf.py b/scene_text_recognition/network_conf.py new file mode 100644 index 0000000000000000000000000000000000000000..bd92bae8737b428d7494229c7137ae2b164022ab --- /dev/null +++ b/scene_text_recognition/network_conf.py @@ -0,0 +1,128 @@ +from paddle import v2 as paddle +from paddle.v2 import layer +from paddle.v2 import evaluator +from paddle.v2.activation import Relu, Linear +from paddle.v2.networks import img_conv_group, simple_gru +from config import ModelConfig as conf + + +class Model(object): + def __init__(self, num_classes, shape, is_infer=False): + ''' + :param num_classes: The size of the character dict. + :type num_classes: int + :param shape: The size of the input images. + :type shape: tuple of 2 int + :param is_infer: The boolean parameter indicating + inferring or training. + :type shape: bool + ''' + self.num_classes = num_classes + self.shape = shape + self.is_infer = is_infer + self.image_vector_size = shape[0] * shape[1] + + self.__declare_input_layers__() + self.__build_nn__() + + def __declare_input_layers__(self): + ''' + Define the input layer. + ''' + # Image input as a float vector. + self.image = layer.data( + name='image', + type=paddle.data_type.dense_vector(self.image_vector_size), + height=self.shape[0], + width=self.shape[1]) + + # Label input as an ID list + if not self.is_infer: + self.label = layer.data( + name='label', + type=paddle.data_type.integer_value_sequence(self.num_classes)) + + def __build_nn__(self): + ''' + Build the network topology. + ''' + # Get the image features with CNN. + conv_features = self.conv_groups(self.image, conf.filter_num, + conf.with_bn) + + # Expand the output of CNN into a sequence of feature vectors. + sliced_feature = layer.block_expand( + input=conv_features, + num_channels=conf.num_channels, + stride_x=conf.stride_x, + stride_y=conf.stride_y, + block_x=conf.block_x, + block_y=conf.block_y) + + # Use RNN to capture sequence information forwards and backwards. + gru_forward = simple_gru( + input=sliced_feature, size=conf.hidden_size, act=Relu()) + gru_backward = simple_gru( + input=sliced_feature, + size=conf.hidden_size, + act=Relu(), + reverse=True) + + # Map the output of RNN to character distribution. + self.output = layer.fc( + input=[gru_forward, gru_backward], + size=self.num_classes + 1, + act=Linear()) + + self.log_probs = paddle.layer.mixed( + input=paddle.layer.identity_projection(input=self.output), + act=paddle.activation.Softmax()) + + # Use warp CTC to calculate cost for a CTC task. + if not self.is_infer: + self.cost = layer.warp_ctc( + input=self.output, + label=self.label, + size=self.num_classes + 1, + norm_by_times=conf.norm_by_times, + blank=self.num_classes) + + self.eval = evaluator.ctc_error(input=self.output, label=self.label) + + def conv_groups(self, input, num, with_bn): + ''' + Get the image features with image convolution group. + + :param input: Input layer. + :type input: LayerOutput + :param num: Number of the filters. + :type num: int + :param with_bn: Use batch normalization or not. + :type with_bn: bool + ''' + assert num % 4 == 0 + + filter_num_list = conf.filter_num_list + is_input_image = True + tmp = input + + for num_filter in filter_num_list: + + if is_input_image: + num_channels = 1 + is_input_image = False + else: + num_channels = None + + tmp = img_conv_group( + input=tmp, + num_channels=num_channels, + conv_padding=conf.conv_padding, + conv_num_filter=[num_filter] * (num / 4), + conv_filter_size=conf.conv_filter_size, + conv_act=Relu(), + conv_with_batchnorm=with_bn, + pool_size=conf.pool_size, + pool_stride=conf.pool_stride, ) + + return tmp diff --git a/scene_text_recognition/reader.py b/scene_text_recognition/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..91321e34bf6ae748dfbfcf8fff22ee890769616c --- /dev/null +++ b/scene_text_recognition/reader.py @@ -0,0 +1,64 @@ +import os +import cv2 + +from paddle.v2.image import load_image + + +class DataGenerator(object): + def __init__(self, char_dict, image_shape): + ''' + :param char_dict: The dictionary class for labels. + :type char_dict: class + :param image_shape: The fixed shape of images. + :type image_shape: tuple + ''' + self.image_shape = image_shape + self.char_dict = char_dict + + def train_reader(self, file_list): + ''' + Reader interface for training. + + :param file_list: The path list of the image file for training. + :type file_list: list + ''' + + def reader(): + UNK_ID = self.char_dict[''] + for image_path, label in file_list: + label = [self.char_dict.get(c, UNK_ID) for c in label] + yield self.load_image(image_path), label + + return reader + + def infer_reader(self, file_list): + ''' + Reader interface for inference. + + :param file_list: The path list of the image file for inference. + :type file_list: list + ''' + + def reader(): + for image_path, label in file_list: + yield self.load_image(image_path), label + + return reader + + def load_image(self, path): + ''' + Load an image and transform it to 1-dimention vector. + + :param path: The path of the image data. + :type path: str + ''' + image = load_image(path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + # Resize all images to a fixed shape. + if self.image_shape: + image = cv2.resize( + image, self.image_shape, interpolation=cv2.INTER_CUBIC) + + image = image.flatten() / 255. + return image diff --git a/scene_text_recognition/requirements.txt b/scene_text_recognition/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..eb8ed79b09459ccc1fe16e2180a555fade31c58e --- /dev/null +++ b/scene_text_recognition/requirements.txt @@ -0,0 +1,2 @@ +click +opencv-python \ No newline at end of file diff --git a/scene_text_recognition/train.py b/scene_text_recognition/train.py new file mode 100644 index 0000000000000000000000000000000000000000..79eccc0b29582ad09c5efc4d950197f9ea44baef --- /dev/null +++ b/scene_text_recognition/train.py @@ -0,0 +1,107 @@ +import gzip +import os +import click + +import paddle.v2 as paddle +from config import TrainerConfig as conf +from network_conf import Model +from reader import DataGenerator +from utils import get_file_list, build_label_dict, load_dict + + +@click.command('train') +@click.option( + "--train_file_list_path", + type=str, + required=True, + help=("The path of the file which contains " + "path list of train image files.")) +@click.option( + "--test_file_list_path", + type=str, + required=True, + help=("The path of the file which contains " + "path list of test image files.")) +@click.option( + "--label_dict_path", + type=str, + required=True, + help=("The path of label dictionary. " + "If this parameter is set, but the file does not exist, " + "label dictionay will be built from " + "the training data automatically.")) +@click.option( + "--model_save_dir", + type=str, + default="models", + help="The path to save the trained models (default: 'models').") +def train(train_file_list_path, test_file_list_path, label_dict_path, + model_save_dir): + + if not os.path.exists(model_save_dir): + os.mkdir(model_save_dir) + + train_file_list = get_file_list(train_file_list_path) + test_file_list = get_file_list(test_file_list_path) + + if not os.path.exists(label_dict_path): + print(("Label dictionary is not given, the dictionary " + "is automatically built from the training data.")) + build_label_dict(train_file_list, label_dict_path) + + char_dict = load_dict(label_dict_path) + dict_size = len(char_dict) + data_generator = DataGenerator( + char_dict=char_dict, image_shape=conf.image_shape) + + paddle.init(use_gpu=conf.use_gpu, trainer_count=conf.trainer_count) + # Create optimizer. + optimizer = paddle.optimizer.Momentum(momentum=conf.momentum) + # Define network topology. + model = Model(dict_size, conf.image_shape, is_infer=False) + # Create all the trainable parameters. + params = paddle.parameters.create(model.cost) + + trainer = paddle.trainer.SGD( + cost=model.cost, + parameters=params, + update_equation=optimizer, + extra_layers=model.eval) + # Feeding dictionary. + feeding = {'image': 0, 'label': 1} + + def event_handler(event): + if isinstance(event, paddle.event.EndIteration): + if event.batch_id % conf.log_period == 0: + print("Pass %d, batch %d, Samples %d, Cost %f, Eval %s" % + (event.pass_id, event.batch_id, event.batch_id * + conf.batch_size, event.cost, event.metrics)) + + if isinstance(event, paddle.event.EndPass): + # Here, because training and testing data share a same format, + # we still use the reader.train_reader to read the testing data. + result = trainer.test( + reader=paddle.batch( + data_generator.train_reader(test_file_list), + batch_size=conf.batch_size), + feeding=feeding) + print("Test %d, Cost %f, Eval %s" % + (event.pass_id, result.cost, result.metrics)) + with gzip.open( + os.path.join(model_save_dir, "params_pass_%05d.tar.gz" % + event.pass_id), "w") as f: + trainer.save_parameter_to_tar(f) + + trainer.train( + reader=paddle.batch( + paddle.reader.shuffle( + data_generator.train_reader(train_file_list), + buf_size=conf.buf_size), + batch_size=conf.batch_size), + feeding=feeding, + event_handler=event_handler, + num_passes=conf.num_passes) + + +if __name__ == "__main__": + train() diff --git a/scene_text_recognition/utils.py b/scene_text_recognition/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..86bd3a1f477a710c3f245ee3dd582044f37a0f8d --- /dev/null +++ b/scene_text_recognition/utils.py @@ -0,0 +1,69 @@ +import os +from collections import defaultdict + + +def get_file_list(image_file_list): + ''' + Generate the file list for training and testing data. + + :param image_file_list: The path of the file which contains + path list of image files. + :type image_file_list: str + ''' + dirname = os.path.dirname(image_file_list) + path_list = [] + with open(image_file_list) as f: + for line in f: + line_split = line.strip().split(',', 1) + filename = line_split[0].strip() + path = os.path.join(dirname, filename) + label = line_split[1][2:-1].strip() + if label: + path_list.append((path, label)) + + return path_list + + +def build_label_dict(file_list, save_path): + """ + Build label dictionary from training data. + + :param file_list: The list which contains the labels + of training data. + :type file_list: list + :params save_path: The path where the label dictionary will be saved. + :type save_path: str + """ + values = defaultdict(int) + for path, label in file_list: + for c in label: + if c: + values[c] += 1 + + values[''] = 0 + with open(save_path, "w") as f: + for v, count in sorted( + values.iteritems(), key=lambda x: x[1], reverse=True): + f.write("%s\t%d\n" % (v, count)) + + +def load_dict(dict_path): + """ + Load label dictionary from the dictionary path. + + :param dict_path: The path of word dictionary. + :type dict_path: str + """ + return dict((line.strip().split("\t")[0], idx) + for idx, line in enumerate(open(dict_path, "r").readlines())) + + +def load_reverse_dict(dict_path): + """ + Load the reversed label dictionary from dictionary path. + + :param dict_path: The path of word dictionary. + :type dict_path: str + """ + return dict((idx, line.strip().split("\t")[0]) + for idx, line in enumerate(open(dict_path, "r").readlines()))