未验证 提交 d2d3b0e9 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #329 from peterzhang2029/str_dev

Add the scene text recognition example.
# 场景文字识别 (STR, Scene Text Recognition)
## STR任务简介
许多场景图像中包含着丰富的文本信息,它们可以从很大程度上帮助人们去认知场景图像的内容及含义,因此场景图像中的文本识别对所在图像的信息获取具有极其重要的作用。同时,场景图像文字识别技术的发展也促进了一些新型应用的产生,例如:\[[1](#参考文献)\]通过使用深度学习模型来自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 。任务如下图所示,给定一张场景图片,`STR` 需要从中识别出对应的文字"keep"。
<p align="center">
<img src="./images/503.jpg"/><br/>
图 1. 输入数据示例 "keep"
</p>
## 使用 PaddlePaddle 训练与预测
### 安装依赖包
```bash
pip install -r requirements.txt
```
### 修改配置参数
`config.py` 脚本中包含了模型配置和训练相关的参数以及对应的详细解释,代码片段如下:
```python
class TrainerConfig(object):
# Whether to use GPU in training or not.
use_gpu = True
# The number of computing threads.
trainer_count = 1
# The training batch size.
batch_size = 10
...
class ModelConfig(object):
# Number of the filters for convolution group.
filter_num = 8
...
```
修改 `config.py` 脚本可以实现对参数的调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。
### 模型训练
训练脚本 [./train.py](./train.py) 中设置了如下命令行参数:
```
Options:
--train_file_list_path TEXT The path of the file which contains path list
of train image files. [required]
--test_file_list_path TEXT The path of the file which contains path list
of test image files. [required]
--label_dict_path TEXT The path of label dictionary. If this parameter
is set, but the file does not exist, label
dictionay will be built from the training data
automatically. [required]
--model_save_dir TEXT The path to save the trained models (default:
'models').
--help Show this message and exit.
```
- `train_file_list` :训练数据的列表文件,每行由图片的存储路径和对应的标记文本组成,格式为:
```
word_1.png, "PROPER"
word_2.png, "FOOD"
```
- `test_file_list` :测试数据的列表文件,格式同上。
- `label_dict_path` :训练数据中标记字典的存储路径,如果指定路径中字典文件不存在,程序会使用训练数据中的标记数据自动生成标记字典。
- `model_save_dir` :模型参数的保存目录,默认为`./models`
### 具体执行的过程:
1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: `Challenge2_Training_Task3_Images_GT.zip``Challenge2_Test_Task3_Images.zip``Challenge2_Test_Task3_GT.txt`
分别对应训练集的图片和图片对应的单词、测试集的图片、测试数据对应的单词。然后执行以下命令,对数据解压并移动至目标文件夹:
```bash
mkdir -p data/train_data
mkdir -p data/test_data
unzip Challenge2_Training_Task3_Images_GT.zip -d data/train_data
unzip Challenge2_Test_Task3_Images.zip -d data/test_data
mv Challenge2_Test_Task3_GT.txt data/test_data
```
2.获取训练数据文件夹中 `gt.txt` 的路径 (data/train_data)和测试数据文件夹中`Challenge2_Test_Task3_GT.txt`的路径(data/test_data)。
3.执行如下命令进行训练:
```bash
python train.py \
--train_file_list_path 'data/train_data/gt.txt' \
--test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' \
--label_dict_path 'label_dict.txt'
```
4.训练过程中,模型参数会自动备份到指定目录,默认会保存在 `./models` 目录下。
### 预测
预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型保存路径、图片固定尺寸、batch_size(默认为10)、标记词典路径和图片文件的列表文件。执行如下代码:
```bash
python infer.py \
--model_path 'models/params_pass_00000.tar.gz' \
--image_shape '173,46' \
--label_dict_path 'label_dict.txt' \
--infer_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt'
```
即可进行预测。
### 其他数据集
- [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)(41G)
- [ICDAR 2003 Robust Reading Competitions](http://www.iapr-tc11.org/mediawiki/index.php?title=ICDAR_2003_Robust_Reading_Competitions)
### 注意事项
- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行。
- 本模型参数较多,占用显存比较大,实际执行时可以通过调节 `batch_size` 来控制显存占用。
- 本例使用的数据集较小,如有需要,可以选用其他更大的数据集\[[3](#参考文献)\]来训练模型。
## 参考文献
1. [Google Now Using ReCAPTCHA To Decode Street View Addresses](https://techcrunch.com/2012/03/29/google-now-using-recaptcha-to-decode-street-view-addresses/)
2. [Focused Scene Text](http://rrc.cvc.uab.es/?ch=2&com=introduction)
3. [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)
__all__ = ["TrainerConfig", "ModelConfig"]
class TrainerConfig(object):
# Whether to use GPU in training or not.
use_gpu = True
# The number of computing threads.
trainer_count = 1
# The training batch size.
batch_size = 10
# The epoch number.
num_passes = 10
# Parameter updates momentum.
momentum = 0
# The shape of images.
image_shape = (173, 46)
# The buffer size of the data reader.
# The number of buffer size samples will be shuffled in training.
buf_size = 1000
# The parameter is used to control logging period.
# Training log will be printed every log_period.
log_period = 50
class ModelConfig(object):
# Number of the filters for convolution group.
filter_num = 8
# Use batch normalization or not in image convolution group.
with_bn = True
# The number of channels for block expand layer.
num_channels = 128
# The parameter stride_x in block expand layer.
stride_x = 1
# The parameter stride_y in block expand layer.
stride_y = 1
# The parameter block_x in block expand layer.
block_x = 1
# The parameter block_y in block expand layer.
block_y = 11
# The hidden size for gru.
hidden_size = num_channels
# Use norm_by_times or not in warp ctc layer.
norm_by_times = True
# The list for number of filter in image convolution group layer.
filter_num_list = [16, 32, 64, 128]
# The parameter conv_padding in image convolution group layer.
conv_padding = 1
# The parameter conv_filter_size in image convolution group layer.
conv_filter_size = 3
# The parameter pool_size in image convolution group layer.
pool_size = 2
# The parameter pool_stride in image convolution group layer.
pool_stride = 2
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from itertools import groupby
import numpy as np
def ctc_greedy_decoder(probs_seq, vocabulary):
"""CTC greedy (best path) decoder.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: baseline
"""
# dimension verification
for probs in probs_seq:
if not len(probs) == len(vocabulary) + 1:
raise ValueError("probs_seq dimension mismatchedd with vocabulary")
# argmax to get the best index for each time step
max_index_list = list(np.array(probs_seq).argmax(axis=1))
# remove consecutive duplicate indexes
index_list = [index_group[0] for index_group in groupby(max_index_list)]
# remove blank indexes
blank_index = len(vocabulary)
index_list = [index for index in index_list if index != blank_index]
# convert index list to string
return ''.join([vocabulary[index] for index in index_list])
<html>
<head>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
extensions: ["tex2jax.js", "TeX/AMSsymbols.js", "TeX/AMSmath.js"],
jax: ["input/TeX", "output/HTML-CSS"],
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true
},
"HTML-CSS": { availableFonts: ["TeX"] }
});
</script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js" async></script>
<script type="text/javascript" src="../.tools/theme/marked.js">
</script>
<link href="http://cdn.bootcss.com/highlight.js/9.9.0/styles/darcula.min.css" rel="stylesheet">
<script src="http://cdn.bootcss.com/highlight.js/9.9.0/highlight.min.js"></script>
<link href="http://cdn.bootcss.com/bootstrap/4.0.0-alpha.6/css/bootstrap.min.css" rel="stylesheet">
<link href="https://cdn.jsdelivr.net/perfect-scrollbar/0.6.14/css/perfect-scrollbar.min.css" rel="stylesheet">
<link href="../.tools/theme/github-markdown.css" rel='stylesheet'>
</head>
<style type="text/css" >
.markdown-body {
box-sizing: border-box;
min-width: 200px;
max-width: 980px;
margin: 0 auto;
padding: 45px;
}
</style>
<body>
<div id="context" class="container-fluid markdown-body">
</div>
<!-- This block will be replaced by each markdown file content. Please do not change lines below.-->
<div id="markdown" style='display:none'>
# 场景文字识别 (STR, Scene Text Recognition)
## STR任务简介
许多场景图像中包含着丰富的文本信息,它们可以从很大程度上帮助人们去认知场景图像的内容及含义,因此场景图像中的文本识别对所在图像的信息获取具有极其重要的作用。同时,场景图像文字识别技术的发展也促进了一些新型应用的产生,例如:\[[1](#参考文献)\]通过使用深度学习模型来自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 。任务如下图所示,给定一张场景图片,`STR` 需要从中识别出对应的文字"keep"。
<p align="center">
<img src="./images/503.jpg"/><br/>
图 1. 输入数据示例 "keep"
</p>
## 使用 PaddlePaddle 训练与预测
### 安装依赖包
```bash
pip install -r requirements.txt
```
### 修改配置参数
`config.py` 脚本中包含了模型配置和训练相关的参数以及对应的详细解释,代码片段如下:
```python
class TrainerConfig(object):
# Whether to use GPU in training or not.
use_gpu = True
# The number of computing threads.
trainer_count = 1
# The training batch size.
batch_size = 10
...
class ModelConfig(object):
# Number of the filters for convolution group.
filter_num = 8
...
```
修改 `config.py` 脚本可以实现对参数的调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。
### 模型训练
训练脚本 [./train.py](./train.py) 中设置了如下命令行参数:
```
Options:
--train_file_list_path TEXT The path of the file which contains path list
of train image files. [required]
--test_file_list_path TEXT The path of the file which contains path list
of test image files. [required]
--label_dict_path TEXT The path of label dictionary. If this parameter
is set, but the file does not exist, label
dictionay will be built from the training data
automatically. [required]
--model_save_dir TEXT The path to save the trained models (default:
'models').
--help Show this message and exit.
```
- `train_file_list` :训练数据的列表文件,每行由图片的存储路径和对应的标记文本组成,格式为:
```
word_1.png, "PROPER"
word_2.png, "FOOD"
```
- `test_file_list` :测试数据的列表文件,格式同上。
- `label_dict_path` :训练数据中标记字典的存储路径,如果指定路径中字典文件不存在,程序会使用训练数据中的标记数据自动生成标记字典。
- `model_save_dir` :模型参数的保存目录,默认为`./models`。
### 具体执行的过程:
1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: `Challenge2_Training_Task3_Images_GT.zip`、`Challenge2_Test_Task3_Images.zip` 和 `Challenge2_Test_Task3_GT.txt`。
分别对应训练集的图片和图片对应的单词、测试集的图片、测试数据对应的单词。然后执行以下命令,对数据解压并移动至目标文件夹:
```bash
mkdir -p data/train_data
mkdir -p data/test_data
unzip Challenge2_Training_Task3_Images_GT.zip -d data/train_data
unzip Challenge2_Test_Task3_Images.zip -d data/test_data
mv Challenge2_Test_Task3_GT.txt data/test_data
```
2.获取训练数据文件夹中 `gt.txt` 的路径 (data/train_data)和测试数据文件夹中`Challenge2_Test_Task3_GT.txt`的路径(data/test_data)。
3.执行如下命令进行训练:
```bash
python train.py \
--train_file_list_path 'data/train_data/gt.txt' \
--test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' \
--label_dict_path 'label_dict.txt'
```
4.训练过程中,模型参数会自动备份到指定目录,默认会保存在 `./models` 目录下。
### 预测
预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型保存路径、图片固定尺寸、batch_size(默认为10)、标记词典路径和图片文件的列表文件。执行如下代码:
```bash
python infer.py \
--model_path 'models/params_pass_00000.tar.gz' \
--image_shape '173,46' \
--label_dict_path 'label_dict.txt' \
--infer_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt'
```
即可进行预测。
### 其他数据集
- [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)(41G)
- [ICDAR 2003 Robust Reading Competitions](http://www.iapr-tc11.org/mediawiki/index.php?title=ICDAR_2003_Robust_Reading_Competitions)
### 注意事项
- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行。
- 本模型参数较多,占用显存比较大,实际执行时可以通过调节 `batch_size` 来控制显存占用。
- 本例使用的数据集较小,如有需要,可以选用其他更大的数据集\[[3](#参考文献)\]来训练模型。
## 参考文献
1. [Google Now Using ReCAPTCHA To Decode Street View Addresses](https://techcrunch.com/2012/03/29/google-now-using-recaptcha-to-decode-street-view-addresses/)
2. [Focused Scene Text](http://rrc.cvc.uab.es/?ch=2&com=introduction)
3. [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)
</div>
<!-- You can change the lines below now. -->
<script type="text/javascript">
marked.setOptions({
renderer: new marked.Renderer(),
gfm: true,
breaks: false,
smartypants: true,
highlight: function(code, lang) {
code = code.replace(/&amp;/g, "&")
code = code.replace(/&gt;/g, ">")
code = code.replace(/&lt;/g, "<")
code = code.replace(/&nbsp;/g, " ")
return hljs.highlightAuto(code, [lang]).value;
}
});
document.getElementById("context").innerHTML = marked(
document.getElementById("markdown").innerHTML)
</script>
</body>
import click
import gzip
import paddle.v2 as paddle
from network_conf import Model
from reader import DataGenerator
from decoder import ctc_greedy_decoder
from utils import get_file_list, load_dict, load_reverse_dict
def infer_batch(inferer, test_batch, labels, reversed_char_dict):
infer_results = inferer.infer(input=test_batch)
num_steps = len(infer_results) // len(test_batch)
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(0, len(test_batch))
]
results = []
# Best path decode.
for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=reversed_char_dict)
results.append(output_transcription)
for result, label in zip(results, labels):
print("\nOutput Transcription: %s\nTarget Transcription: %s" %
(result, label))
@click.command('infer')
@click.option(
"--model_path", type=str, required=True, help=("The path of saved model."))
@click.option(
"--image_shape",
type=str,
required=True,
help=("The fixed size for image dataset (format is like: '173,46')."))
@click.option(
"--batch_size",
type=int,
default=10,
help=("The number of examples in one batch (default: 10)."))
@click.option(
"--label_dict_path",
type=str,
required=True,
help=("The path of label dictionary. "))
@click.option(
"--infer_file_list_path",
type=str,
required=True,
help=("The path of the file which contains "
"path list of image files for inference."))
def infer(model_path, image_shape, batch_size, label_dict_path,
infer_file_list_path):
image_shape = tuple(map(int, image_shape.split(',')))
infer_file_list = get_file_list(infer_file_list_path)
char_dict = load_dict(label_dict_path)
reversed_char_dict = load_reverse_dict(label_dict_path)
dict_size = len(char_dict)
data_generator = DataGenerator(char_dict=char_dict, image_shape=image_shape)
paddle.init(use_gpu=True, trainer_count=1)
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))
model = Model(dict_size, image_shape, is_infer=True)
inferer = paddle.inference.Inference(
output_layer=model.log_probs, parameters=parameters)
test_batch = []
labels = []
for i, (image,
label) in enumerate(data_generator.infer_reader(infer_file_list)()):
test_batch.append([image])
labels.append(label)
if len(test_batch) == batch_size:
infer_batch(inferer, test_batch, labels, reversed_char_dict)
test_batch = []
labels = []
if test_batch:
infer_batch(inferer, test_batch, labels, reversed_char_dict)
if __name__ == "__main__":
infer()
from paddle import v2 as paddle
from paddle.v2 import layer
from paddle.v2 import evaluator
from paddle.v2.activation import Relu, Linear
from paddle.v2.networks import img_conv_group, simple_gru
from config import ModelConfig as conf
class Model(object):
def __init__(self, num_classes, shape, is_infer=False):
'''
:param num_classes: The size of the character dict.
:type num_classes: int
:param shape: The size of the input images.
:type shape: tuple of 2 int
:param is_infer: The boolean parameter indicating
inferring or training.
:type shape: bool
'''
self.num_classes = num_classes
self.shape = shape
self.is_infer = is_infer
self.image_vector_size = shape[0] * shape[1]
self.__declare_input_layers__()
self.__build_nn__()
def __declare_input_layers__(self):
'''
Define the input layer.
'''
# Image input as a float vector.
self.image = layer.data(
name='image',
type=paddle.data_type.dense_vector(self.image_vector_size),
height=self.shape[0],
width=self.shape[1])
# Label input as an ID list
if not self.is_infer:
self.label = layer.data(
name='label',
type=paddle.data_type.integer_value_sequence(self.num_classes))
def __build_nn__(self):
'''
Build the network topology.
'''
# Get the image features with CNN.
conv_features = self.conv_groups(self.image, conf.filter_num,
conf.with_bn)
# Expand the output of CNN into a sequence of feature vectors.
sliced_feature = layer.block_expand(
input=conv_features,
num_channels=conf.num_channels,
stride_x=conf.stride_x,
stride_y=conf.stride_y,
block_x=conf.block_x,
block_y=conf.block_y)
# Use RNN to capture sequence information forwards and backwards.
gru_forward = simple_gru(
input=sliced_feature, size=conf.hidden_size, act=Relu())
gru_backward = simple_gru(
input=sliced_feature,
size=conf.hidden_size,
act=Relu(),
reverse=True)
# Map the output of RNN to character distribution.
self.output = layer.fc(
input=[gru_forward, gru_backward],
size=self.num_classes + 1,
act=Linear())
self.log_probs = paddle.layer.mixed(
input=paddle.layer.identity_projection(input=self.output),
act=paddle.activation.Softmax())
# Use warp CTC to calculate cost for a CTC task.
if not self.is_infer:
self.cost = layer.warp_ctc(
input=self.output,
label=self.label,
size=self.num_classes + 1,
norm_by_times=conf.norm_by_times,
blank=self.num_classes)
self.eval = evaluator.ctc_error(input=self.output, label=self.label)
def conv_groups(self, input, num, with_bn):
'''
Get the image features with image convolution group.
:param input: Input layer.
:type input: LayerOutput
:param num: Number of the filters.
:type num: int
:param with_bn: Use batch normalization or not.
:type with_bn: bool
'''
assert num % 4 == 0
filter_num_list = conf.filter_num_list
is_input_image = True
tmp = input
for num_filter in filter_num_list:
if is_input_image:
num_channels = 1
is_input_image = False
else:
num_channels = None
tmp = img_conv_group(
input=tmp,
num_channels=num_channels,
conv_padding=conf.conv_padding,
conv_num_filter=[num_filter] * (num / 4),
conv_filter_size=conf.conv_filter_size,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=conf.pool_size,
pool_stride=conf.pool_stride, )
return tmp
import os
import cv2
from paddle.v2.image import load_image
class DataGenerator(object):
def __init__(self, char_dict, image_shape):
'''
:param char_dict: The dictionary class for labels.
:type char_dict: class
:param image_shape: The fixed shape of images.
:type image_shape: tuple
'''
self.image_shape = image_shape
self.char_dict = char_dict
def train_reader(self, file_list):
'''
Reader interface for training.
:param file_list: The path list of the image file for training.
:type file_list: list
'''
def reader():
UNK_ID = self.char_dict['<unk>']
for image_path, label in file_list:
label = [self.char_dict.get(c, UNK_ID) for c in label]
yield self.load_image(image_path), label
return reader
def infer_reader(self, file_list):
'''
Reader interface for inference.
:param file_list: The path list of the image file for inference.
:type file_list: list
'''
def reader():
for image_path, label in file_list:
yield self.load_image(image_path), label
return reader
def load_image(self, path):
'''
Load an image and transform it to 1-dimention vector.
:param path: The path of the image data.
:type path: str
'''
image = load_image(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Resize all images to a fixed shape.
if self.image_shape:
image = cv2.resize(
image, self.image_shape, interpolation=cv2.INTER_CUBIC)
image = image.flatten() / 255.
return image
click
opencv-python
\ No newline at end of file
import gzip
import os
import click
import paddle.v2 as paddle
from config import TrainerConfig as conf
from network_conf import Model
from reader import DataGenerator
from utils import get_file_list, build_label_dict, load_dict
@click.command('train')
@click.option(
"--train_file_list_path",
type=str,
required=True,
help=("The path of the file which contains "
"path list of train image files."))
@click.option(
"--test_file_list_path",
type=str,
required=True,
help=("The path of the file which contains "
"path list of test image files."))
@click.option(
"--label_dict_path",
type=str,
required=True,
help=("The path of label dictionary. "
"If this parameter is set, but the file does not exist, "
"label dictionay will be built from "
"the training data automatically."))
@click.option(
"--model_save_dir",
type=str,
default="models",
help="The path to save the trained models (default: 'models').")
def train(train_file_list_path, test_file_list_path, label_dict_path,
model_save_dir):
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
train_file_list = get_file_list(train_file_list_path)
test_file_list = get_file_list(test_file_list_path)
if not os.path.exists(label_dict_path):
print(("Label dictionary is not given, the dictionary "
"is automatically built from the training data."))
build_label_dict(train_file_list, label_dict_path)
char_dict = load_dict(label_dict_path)
dict_size = len(char_dict)
data_generator = DataGenerator(
char_dict=char_dict, image_shape=conf.image_shape)
paddle.init(use_gpu=conf.use_gpu, trainer_count=conf.trainer_count)
# Create optimizer.
optimizer = paddle.optimizer.Momentum(momentum=conf.momentum)
# Define network topology.
model = Model(dict_size, conf.image_shape, is_infer=False)
# Create all the trainable parameters.
params = paddle.parameters.create(model.cost)
trainer = paddle.trainer.SGD(
cost=model.cost,
parameters=params,
update_equation=optimizer,
extra_layers=model.eval)
# Feeding dictionary.
feeding = {'image': 0, 'label': 1}
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % conf.log_period == 0:
print("Pass %d, batch %d, Samples %d, Cost %f, Eval %s" %
(event.pass_id, event.batch_id, event.batch_id *
conf.batch_size, event.cost, event.metrics))
if isinstance(event, paddle.event.EndPass):
# Here, because training and testing data share a same format,
# we still use the reader.train_reader to read the testing data.
result = trainer.test(
reader=paddle.batch(
data_generator.train_reader(test_file_list),
batch_size=conf.batch_size),
feeding=feeding)
print("Test %d, Cost %f, Eval %s" %
(event.pass_id, result.cost, result.metrics))
with gzip.open(
os.path.join(model_save_dir, "params_pass_%05d.tar.gz" %
event.pass_id), "w") as f:
trainer.save_parameter_to_tar(f)
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
data_generator.train_reader(train_file_list),
buf_size=conf.buf_size),
batch_size=conf.batch_size),
feeding=feeding,
event_handler=event_handler,
num_passes=conf.num_passes)
if __name__ == "__main__":
train()
import os
from collections import defaultdict
def get_file_list(image_file_list):
'''
Generate the file list for training and testing data.
:param image_file_list: The path of the file which contains
path list of image files.
:type image_file_list: str
'''
dirname = os.path.dirname(image_file_list)
path_list = []
with open(image_file_list) as f:
for line in f:
line_split = line.strip().split(',', 1)
filename = line_split[0].strip()
path = os.path.join(dirname, filename)
label = line_split[1][2:-1].strip()
if label:
path_list.append((path, label))
return path_list
def build_label_dict(file_list, save_path):
"""
Build label dictionary from training data.
:param file_list: The list which contains the labels
of training data.
:type file_list: list
:params save_path: The path where the label dictionary will be saved.
:type save_path: str
"""
values = defaultdict(int)
for path, label in file_list:
for c in label:
if c:
values[c] += 1
values['<unk>'] = 0
with open(save_path, "w") as f:
for v, count in sorted(
values.iteritems(), key=lambda x: x[1], reverse=True):
f.write("%s\t%d\n" % (v, count))
def load_dict(dict_path):
"""
Load label dictionary from the dictionary path.
:param dict_path: The path of word dictionary.
:type dict_path: str
"""
return dict((line.strip().split("\t")[0], idx)
for idx, line in enumerate(open(dict_path, "r").readlines()))
def load_reverse_dict(dict_path):
"""
Load the reversed label dictionary from dictionary path.
:param dict_path: The path of word dictionary.
:type dict_path: str
"""
return dict((idx, line.strip().split("\t")[0])
for idx, line in enumerate(open(dict_path, "r").readlines()))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册