提交 672e8565 编写于 作者: P peterzhang2029

update the dictionary module

上级 0fa990bb
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
## STR任务简介 ## STR任务简介
在现实生活中,包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息,\[[1](#参考文献)\]使用深度学习模型自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。 在现实生活中,许多图片中的文字为图片所处场景的理解提供了丰富的语义信息(例如:路牌、菜单、街道标语等)。同时,场景图片文字识别技术的发展也促进了一些新型应用的产生,例如:\[[1](#参考文献)\]通过使用深度学习模型来自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。以下图为例,给定一个场景图片,STR需要从图片中识别出对应的文字"keep"。 本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。如下图所示,给定一张场景图片,`STR` 需要从中识别出对应的文字"keep"。
<p align="center"> <p align="center">
<img src="./images/503.jpg"/><br/> <img src="./images/503.jpg"/><br/>
...@@ -21,7 +21,7 @@ pip install -r requirements.txt ...@@ -21,7 +21,7 @@ pip install -r requirements.txt
### 指定训练配置参数 ### 指定训练配置参数
通过 `config.py` 脚本修改训练和模型配置参数,脚本中有对可配置参数的详细解释,示例如下: `config.py` 脚本中包含了模型配置和训练相关的参数以及对应的详细解释,代码如下:
```python ```python
class TrainerConfig(object): class TrainerConfig(object):
...@@ -43,7 +43,8 @@ class ModelConfig(object): ...@@ -43,7 +43,8 @@ class ModelConfig(object):
... ...
``` ```
修改 `config.py` 对参数进行调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。
修改 `config.py` 脚本可以实现对参数的调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。
### 模型训练 ### 模型训练
训练脚本 [./train.py](./train.py) 中设置了如下命令行参数: 训练脚本 [./train.py](./train.py) 中设置了如下命令行参数:
...@@ -54,24 +55,29 @@ Options: ...@@ -54,24 +55,29 @@ Options:
of train image files. [required] of train image files. [required]
--test_file_list_path TEXT The path of the file which contains path list --test_file_list_path TEXT The path of the file which contains path list
of test image files. [required] of test image files. [required]
--label_dict_path TEXT The path of label dictionary. If this parameter
is set, but the file does not exist, label
dictionay will be built from the training data
automatically. [required]
--model_save_dir TEXT The path to save the trained models (default: --model_save_dir TEXT The path to save the trained models (default:
'models'). 'models').
--help Show this message and exit. --help Show this message and exit.
``` ```
- `train_file_list` 训练数据的列表文件,每行一个路径加对应的text,具体格式为: - `train_file_list` :训练数据的列表文件,每行由图片的存储路径和对应的标记文本组成,具体格式为:
``` ```
word_1.png, "PROPER" word_1.png, "PROPER"
word_2.png, "FOOD" word_2.png, "FOOD"
``` ```
- `test_file_list` 测试数据的列表文件,格式同上。 - `test_file_list` :测试数据的列表文件,格式同上。
- `model_save_dir` 模型参数会的保存目录目录, 默认为当前目录下的`models`目录。 - `label_dict_path` :训练数据中标记字典的存储路径,如果指定路径中字典文件不存在,程序会使用训练数据中的标记数据自动生成标记字典。
- `model_save_dir` :模型参数的保存目录,默认为`./models`
### 具体执行的过程: ### 具体执行的过程:
1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: Challenge2_Training_Task3_Images_GT.zip、Challenge2_Test_Task3_Images.zip和 Challenge2_Test_Task3_GT.txt 1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: `Challenge2_Training_Task3_Images_GT.zip``Challenge2_Test_Task3_Images.zip``Challenge2_Test_Task3_GT.txt`
分别对应训练集的图片和图片对应的单词,测试集的图片,测试数据对应的单词,然后执行以下命令,对数据解压并移动至目标文件夹: 分别对应训练集的图片和图片对应的单词、测试集的图片、测试数据对应的单词。然后执行以下命令,对数据解压并移动至目标文件夹:
```bash ```bash
mkdir -p data/train_data mkdir -p data/train_data
...@@ -87,17 +93,19 @@ mv Challenge2_Test_Task3_GT.txt data/test_data ...@@ -87,17 +93,19 @@ mv Challenge2_Test_Task3_GT.txt data/test_data
```bash ```bash
python train.py \ python train.py \
--train_file_list_path 'data/train_data/gt.txt' \ --train_file_list_path 'data/train_data/gt.txt' \
--test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' --test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' \
--label_dict_path 'label_dict.txt'
``` ```
4.训练过程中,模型参数会自动备份到指定目录,默认会保存在 `./models` 目录下。 4.训练过程中,模型参数会自动备份到指定目录,默认会保存在 `./models` 目录下。
### 预测 ### 预测
预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型目录、图片固定尺寸、batch_size(默认设置为10)和图片文件的列表文件。执行如下代码: 预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型保存路径、图片固定尺寸、batch_size(默认为10)、标记词典路径和图片文件的列表文件。执行如下代码:
```bash ```bash
python infer.py \ python infer.py \
--model_path 'models/params_pass_00000.tar.gz' \ --model_path 'models/params_pass_00000.tar.gz' \
--image_shape '173,46' \ --image_shape '173,46' \
--label_dict_path 'label_dict.txt' \
--infer_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' --infer_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt'
``` ```
即可进行预测。 即可进行预测。
...@@ -109,9 +117,9 @@ python infer.py \ ...@@ -109,9 +117,9 @@ python infer.py \
### 注意事项 ### 注意事项
- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行 - 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行
- 本模型参数较多,占用显存比较大,实际执行时可以调节`batch_size`控制显存占用 - 本模型参数较多,占用显存比较大,实际执行时可以通过调节 `batch_size` 来控制显存占用。
-模型使用的数据集较小,可以选用其他更大的数据集\[[3](#参考文献)\]来训练需要的模型 -例使用的数据集较小,如有需要,可以选用其他更大的数据集\[[3](#参考文献)\]来训练模型。
## 参考文献 ## 参考文献
......
...@@ -44,9 +44,9 @@ ...@@ -44,9 +44,9 @@
## STR任务简介 ## STR任务简介
在现实生活中,包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息,\[[1](#参考文献)\]使用深度学习模型自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。 在现实生活中,许多图片中的文字为图片所处场景的理解提供了丰富的语义信息(例如:路牌、菜单、街道标语等)。同时,场景图片文字识别技术的发展也促进了一些新型应用的产生,例如:\[[1](#参考文献)\]通过使用深度学习模型来自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。以下图为例,给定一个场景图片,STR需要从图片中识别出对应的文字"keep"。 本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。如下图所示,给定一张场景图片,`STR` 需要从中识别出对应的文字"keep"。
<p align="center"> <p align="center">
<img src="./images/503.jpg"/><br/> <img src="./images/503.jpg"/><br/>
...@@ -63,7 +63,7 @@ pip install -r requirements.txt ...@@ -63,7 +63,7 @@ pip install -r requirements.txt
### 指定训练配置参数 ### 指定训练配置参数
通过 `config.py` 脚本修改训练和模型配置参数,脚本中有对可配置参数的详细解释,示例如下: `config.py` 脚本中包含了模型配置和训练相关的参数以及对应的详细解释,代码如下:
```python ```python
class TrainerConfig(object): class TrainerConfig(object):
...@@ -85,7 +85,8 @@ class ModelConfig(object): ...@@ -85,7 +85,8 @@ class ModelConfig(object):
... ...
``` ```
修改 `config.py` 对参数进行调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。
修改 `config.py` 脚本可以实现对参数的调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。
### 模型训练 ### 模型训练
训练脚本 [./train.py](./train.py) 中设置了如下命令行参数: 训练脚本 [./train.py](./train.py) 中设置了如下命令行参数:
...@@ -96,24 +97,29 @@ Options: ...@@ -96,24 +97,29 @@ Options:
of train image files. [required] of train image files. [required]
--test_file_list_path TEXT The path of the file which contains path list --test_file_list_path TEXT The path of the file which contains path list
of test image files. [required] of test image files. [required]
--label_dict_path TEXT The path of label dictionary. If this parameter
is set, but the file does not exist, label
dictionay will be built from the training data
automatically. [required]
--model_save_dir TEXT The path to save the trained models (default: --model_save_dir TEXT The path to save the trained models (default:
'models'). 'models').
--help Show this message and exit. --help Show this message and exit.
``` ```
- `train_file_list` 训练数据的列表文件,每行一个路径加对应的text,具体格式为: - `train_file_list` :训练数据的列表文件,每行由图片的存储路径和对应的标记文本组成,具体格式为:
``` ```
word_1.png, "PROPER" word_1.png, "PROPER"
word_2.png, "FOOD" word_2.png, "FOOD"
``` ```
- `test_file_list` 测试数据的列表文件,格式同上。 - `test_file_list` :测试数据的列表文件,格式同上。
- `model_save_dir` 模型参数会的保存目录目录, 默认为当前目录下的`models`目录。 - `label_dict_path` :训练数据中标记字典的存储路径,如果指定路径中字典文件不存在,程序会使用训练数据中的标记数据自动生成标记字典。
- `model_save_dir` :模型参数的保存目录,默认为`./models`。
### 具体执行的过程: ### 具体执行的过程:
1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: Challenge2_Training_Task3_Images_GT.zip、Challenge2_Test_Task3_Images.zip和 Challenge2_Test_Task3_GT.txt 1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: `Challenge2_Training_Task3_Images_GT.zip`、`Challenge2_Test_Task3_Images.zip` 和 `Challenge2_Test_Task3_GT.txt`
分别对应训练集的图片和图片对应的单词,测试集的图片,测试数据对应的单词,然后执行以下命令,对数据解压并移动至目标文件夹: 分别对应训练集的图片和图片对应的单词、测试集的图片、测试数据对应的单词。然后执行以下命令,对数据解压并移动至目标文件夹:
```bash ```bash
mkdir -p data/train_data mkdir -p data/train_data
...@@ -129,17 +135,19 @@ mv Challenge2_Test_Task3_GT.txt data/test_data ...@@ -129,17 +135,19 @@ mv Challenge2_Test_Task3_GT.txt data/test_data
```bash ```bash
python train.py \ python train.py \
--train_file_list_path 'data/train_data/gt.txt' \ --train_file_list_path 'data/train_data/gt.txt' \
--test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' --test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' \
--label_dict_path 'label_dict.txt'
``` ```
4.训练过程中,模型参数会自动备份到指定目录,默认会保存在 `./models` 目录下。 4.训练过程中,模型参数会自动备份到指定目录,默认会保存在 `./models` 目录下。
### 预测 ### 预测
预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型目录、图片固定尺寸、batch_size(默认设置为10)和图片文件的列表文件。执行如下代码: 预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型保存路径、图片固定尺寸、batch_size(默认为10)、标记词典路径和图片文件的列表文件。执行如下代码:
```bash ```bash
python infer.py \ python infer.py \
--model_path 'models/params_pass_00000.tar.gz' \ --model_path 'models/params_pass_00000.tar.gz' \
--image_shape '173,46' \ --image_shape '173,46' \
--label_dict_path 'label_dict.txt' \
--infer_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' --infer_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt'
``` ```
即可进行预测。 即可进行预测。
...@@ -151,9 +159,9 @@ python infer.py \ ...@@ -151,9 +159,9 @@ python infer.py \
### 注意事项 ### 注意事项
- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行 - 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行
- 本模型参数较多,占用显存比较大,实际执行时可以调节`batch_size`控制显存占用 - 本模型参数较多,占用显存比较大,实际执行时可以通过调节 `batch_size` 来控制显存占用。
- 本模型使用的数据集较小,可以选用其他更大的数据集\[[3](#参考文献)\]来训练需要的模型 - 本例使用的数据集较小,如有需要,可以选用其他更大的数据集\[[3](#参考文献)\]来训练模型。
## 参考文献 ## 参考文献
......
...@@ -5,10 +5,10 @@ import paddle.v2 as paddle ...@@ -5,10 +5,10 @@ import paddle.v2 as paddle
from model import Model from model import Model
from reader import DataGenerator from reader import DataGenerator
from decoder import ctc_greedy_decoder from decoder import ctc_greedy_decoder
from utils import AsciiDic, get_file_list from utils import get_file_list, load_dict, load_reverse_dict
def infer_batch(inferer, test_batch, labels): def infer_batch(inferer, test_batch, labels, reversed_char_dict):
infer_results = inferer.infer(input=test_batch) infer_results = inferer.infer(input=test_batch)
num_steps = len(infer_results) // len(test_batch) num_steps = len(infer_results) // len(test_batch)
probs_split = [ probs_split = [
...@@ -19,7 +19,7 @@ def infer_batch(inferer, test_batch, labels): ...@@ -19,7 +19,7 @@ def infer_batch(inferer, test_batch, labels):
# Best path decode. # Best path decode.
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder( output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=AsciiDic().id2word()) probs_seq=probs, vocabulary=reversed_char_dict)
results.append(output_transcription) results.append(output_transcription)
for result, label in zip(results, labels): for result, label in zip(results, labels):
...@@ -40,17 +40,26 @@ def infer_batch(inferer, test_batch, labels): ...@@ -40,17 +40,26 @@ def infer_batch(inferer, test_batch, labels):
type=int, type=int,
default=10, default=10,
help=("The number of examples in one batch (default: 10).")) help=("The number of examples in one batch (default: 10)."))
@click.option(
"--label_dict_path",
type=str,
required=True,
help=("The path of label dictionary. "))
@click.option( @click.option(
"--infer_file_list_path", "--infer_file_list_path",
type=str, type=str,
required=True, required=True,
help=("The path of the file which contains " help=("The path of the file which contains "
"path list of image files for inference.")) "path list of image files for inference."))
def infer(model_path, image_shape, batch_size, infer_file_list_path): def infer(model_path, image_shape, batch_size, label_dict_path,
infer_file_list_path):
image_shape = tuple(map(int, image_shape.split(','))) image_shape = tuple(map(int, image_shape.split(',')))
infer_file_list = get_file_list(infer_file_list_path) infer_file_list = get_file_list(infer_file_list_path)
char_dict = AsciiDic()
dict_size = char_dict.size() char_dict = load_dict(label_dict_path)
reversed_char_dict = load_reverse_dict(label_dict_path)
dict_size = len(char_dict)
data_generator = DataGenerator(char_dict=char_dict, image_shape=image_shape) data_generator = DataGenerator(char_dict=char_dict, image_shape=image_shape)
paddle.init(use_gpu=True, trainer_count=1) paddle.init(use_gpu=True, trainer_count=1)
...@@ -66,11 +75,11 @@ def infer(model_path, image_shape, batch_size, infer_file_list_path): ...@@ -66,11 +75,11 @@ def infer(model_path, image_shape, batch_size, infer_file_list_path):
test_batch.append([image]) test_batch.append([image])
labels.append(label) labels.append(label)
if len(test_batch) == batch_size: if len(test_batch) == batch_size:
infer_batch(inferer, test_batch, labels) infer_batch(inferer, test_batch, labels, reversed_char_dict)
test_batch = [] test_batch = []
labels = [] labels = []
if test_batch: if test_batch:
infer_batch(inferer, test_batch, labels) infer_batch(inferer, test_batch, labels, reversed_char_dict)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -45,12 +45,11 @@ class Model(object): ...@@ -45,12 +45,11 @@ class Model(object):
''' '''
Build the network topology. Build the network topology.
''' '''
# CNN output image features. # Get the image features with CNN.
conv_features = self.conv_groups(self.image, conf.filter_num, conv_features = self.conv_groups(self.image, conf.filter_num,
conf.with_bn) conf.with_bn)
# Cut CNN output into a sequence of feature vectors, which are # Expand the output of CNN into a sequence of feature vectors.
# 1 pixel wide and 11 pixel high.
sliced_feature = layer.block_expand( sliced_feature = layer.block_expand(
input=conv_features, input=conv_features,
num_channels=conf.num_channels, num_channels=conf.num_channels,
...@@ -59,7 +58,7 @@ class Model(object): ...@@ -59,7 +58,7 @@ class Model(object):
block_x=conf.block_x, block_x=conf.block_x,
block_y=conf.block_y) block_y=conf.block_y)
# RNNs to capture sequence information forwards and backwards. # Use RNN to capture sequence information forwards and backwards.
gru_forward = simple_gru( gru_forward = simple_gru(
input=sliced_feature, size=conf.hidden_size, act=Relu()) input=sliced_feature, size=conf.hidden_size, act=Relu())
gru_backward = simple_gru( gru_backward = simple_gru(
...@@ -68,7 +67,7 @@ class Model(object): ...@@ -68,7 +67,7 @@ class Model(object):
act=Relu(), act=Relu(),
reverse=True) reverse=True)
# Map each step of RNN to character distribution. # Map the output of RNN to character distribution.
self.output = layer.fc( self.output = layer.fc(
input=[gru_forward, gru_backward], input=[gru_forward, gru_backward],
size=self.num_classes + 1, size=self.num_classes + 1,
......
...@@ -24,8 +24,10 @@ class DataGenerator(object): ...@@ -24,8 +24,10 @@ class DataGenerator(object):
''' '''
def reader(): def reader():
for i, (image, label) in enumerate(file_list): UNK_ID = self.char_dict['<unk>']
yield self.load_image(image), self.char_dict.word2ids(label) for image_path, label in file_list:
label = [self.char_dict.get(c, UNK_ID) for c in label]
yield self.load_image(image_path), label
return reader return reader
...@@ -38,14 +40,14 @@ class DataGenerator(object): ...@@ -38,14 +40,14 @@ class DataGenerator(object):
''' '''
def reader(): def reader():
for i, (image, label) in enumerate(file_list): for image_path, label in file_list:
yield self.load_image(image), label yield self.load_image(image_path), label
return reader return reader
def load_image(self, path): def load_image(self, path):
''' '''
Load image and transform to 1-dimention vector. Load an image and transform it to 1-dimention vector.
:param path: The path of the image data. :param path: The path of the image data.
:type path: str :type path: str
......
...@@ -6,7 +6,7 @@ import paddle.v2 as paddle ...@@ -6,7 +6,7 @@ import paddle.v2 as paddle
from config import TrainerConfig as conf from config import TrainerConfig as conf
from model import Model from model import Model
from reader import DataGenerator from reader import DataGenerator
from utils import get_file_list, AsciiDic from utils import get_file_list, build_label_dict, load_dict
@click.command('train') @click.command('train')
...@@ -22,19 +22,35 @@ from utils import get_file_list, AsciiDic ...@@ -22,19 +22,35 @@ from utils import get_file_list, AsciiDic
required=True, required=True,
help=("The path of the file which contains " help=("The path of the file which contains "
"path list of test image files.")) "path list of test image files."))
@click.option(
"--label_dict_path",
type=str,
required=True,
help=("The path of label dictionary. "
"If this parameter is set, but the file does not exist, "
"label dictionay will be built from "
"the training data automatically."))
@click.option( @click.option(
"--model_save_dir", "--model_save_dir",
type=str, type=str,
default="models", default="models",
help="The path to save the trained models (default: 'models').") help="The path to save the trained models (default: 'models').")
def train(train_file_list_path, test_file_list_path, model_save_dir): def train(train_file_list_path, test_file_list_path, label_dict_path,
model_save_dir):
if not os.path.exists(model_save_dir): if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir) os.mkdir(model_save_dir)
train_file_list = get_file_list(train_file_list_path) train_file_list = get_file_list(train_file_list_path)
test_file_list = get_file_list(test_file_list_path) test_file_list = get_file_list(test_file_list_path)
char_dict = AsciiDic()
dict_size = char_dict.size() if not os.path.exists(label_dict_path):
print(("Label dictionary is not given, the dictionary "
"is automatically built from the training data."))
build_label_dict(train_file_list, label_dict_path)
char_dict = load_dict(label_dict_path)
dict_size = len(char_dict)
data_generator = DataGenerator( data_generator = DataGenerator(
char_dict=char_dict, image_shape=conf.image_shape) char_dict=char_dict, image_shape=conf.image_shape)
......
import os import os
from collections import defaultdict
class AsciiDic(object):
UNK_ID = 0
def __init__(self):
self.dic = {
'<unk>': self.UNK_ID,
}
self.chars = [chr(i) for i in range(40, 171)]
for id, c in enumerate(self.chars):
self.dic[c] = id + 1
def lookup(self, w):
return self.dic.get(w, self.UNK_ID)
def id2word(self):
'''
Return a reversed char dict.
'''
self.id2word = {}
for key, value in self.dic.items():
self.id2word[value] = key
return self.id2word
def word2ids(self, word):
'''
Transform a word to a list of ids.
:param word: The word appears in image data.
:type word: str
'''
return [self.lookup(c) for c in list(word)]
def size(self):
return len(self.dic)
def get_file_list(image_file_list): def get_file_list(image_file_list):
...@@ -53,7 +17,53 @@ def get_file_list(image_file_list): ...@@ -53,7 +17,53 @@ def get_file_list(image_file_list):
line_split = line.strip().split(',', 1) line_split = line.strip().split(',', 1)
filename = line_split[0].strip() filename = line_split[0].strip()
path = os.path.join(dirname, filename) path = os.path.join(dirname, filename)
label = line_split[1][2:-1] label = line_split[1][2:-1].strip()
if label:
path_list.append((path, label)) path_list.append((path, label))
return path_list return path_list
def build_label_dict(file_list, save_path):
"""
Build label dictionary from training data.
:param file_list: The list which contains the labels
of training data.
:type file_list: list
:params save_path: The path where the label dictionary will be saved.
:type save_path: str
"""
values = defaultdict(int)
for path, label in file_list:
for c in label:
if c:
values[c] += 1
values['<unk>'] = 0
with open(save_path, "w") as f:
for v, count in sorted(
values.iteritems(), key=lambda x: x[1], reverse=True):
f.write("%s\t%d\n" % (v, count))
def load_dict(dict_path):
"""
Load label dictionary from the dictionary path.
:param dict_path: The path of word dictionary.
:type dict_path: str
"""
return dict((line.strip().split("\t")[0], idx)
for idx, line in enumerate(open(dict_path, "r").readlines()))
def load_reverse_dict(dict_path):
"""
Load the reversed label dictionary from dictionary path.
:param dict_path: The path of word dictionary.
:type dict_path: str
"""
return dict((idx, line.strip().split("\t")[0])
for idx, line in enumerate(open(dict_path, "r").readlines()))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册