提交 5f902e0c 编写于 作者: P peterzhang2029

refine doc and add evals for ctc

上级 6170035a
...@@ -4,18 +4,18 @@ ...@@ -4,18 +4,18 @@
在现实生活中,包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息,\[[1](#参考文献)\]使用深度学习模型自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。 在现实生活中,包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息,\[[1](#参考文献)\]使用深度学习模型自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
本文将针对 **场景文字识别 (STR, Scene Text Recognition)** 任务,演示如何用 PaddlePaddle 实现 一个端对端 CTC 的模型 **CRNN(Convolutional Recurrent Neural Network)** 本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。以下图为例,给定一个场景图片,STR需要从图片中识别出对应的文字"keep":
\[[2](#参考文献)\],具体的,本文使用如下图片进行训练,需要识别文字对应的文字 "keep"。
<p align="center"> <p align="center">
<img src="./images/503.jpg"/><br/> <img src="./images/503.jpg"/><br/>
图 1. 数据示例 "keep" 图 1. 数据示例 "keep"
</p> </p>
## 使用 PaddlePaddle 训练与预测 ## 使用 PaddlePaddle 训练与预测
### 模型训练 ### 模型训练
训练脚本参照 [./train.py](./train.py)设置了如下命令行参数: 训练脚本 [./train.py](./train.py)设置了如下命令行参数:
``` ```
usage: train.py [-h] --image_shape IMAGE_SHAPE --train_file_list usage: train.py [-h] --image_shape IMAGE_SHAPE --train_file_list
...@@ -51,19 +51,20 @@ optional arguments: ...@@ -51,19 +51,20 @@ optional arguments:
number of passes to train (default: 1) number of passes to train (default: 1)
``` ```
其中最重要的几个参数包括: 重要的几个参数包括:
- `image_shape` 图片的尺寸 - `image_shape` 图片的尺寸
- `train_file_list` 训练数据的列表文件,每行一个路径加对应的text,格式类似 - `train_file_list` 训练数据的列表文件,每行一个路径加对应的text,具体格式为
``` ```
word_1.png, "PROPER" word_1.png, "PROPER"
word_2.png, "FOOD"
``` ```
- `test_file_list` 测试数据的列表文件,格式同上 - `test_file_list` 测试数据的列表文件,格式同上
### 预测 ### 预测
预测部分由infer.py完成,本示例对于ctc的预测使用的是最优路径解码算法(CTC greedy decoder),即在每个时间步选择一个概率最大的字符。在使用过程中,需要在infer.py中指定具体的模型目录、图片固定尺寸、batch_size和图片文件的列表文件。例如: 预测部分由infer.py完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在infer.py中指定具体的模型目录、图片固定尺寸、batch_size和图片文件的列表文件。例如:
```python ```python
model_path = "model.ctc-pass-9-batch-150-test-10.0065517931.tar.gz" model_path = "model.ctc-pass-9-batch-150-test.tar.gz"
image_shape = "173,46" image_shape = "173,46"
batch_size = 50 batch_size = 50
infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt' infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt'
...@@ -73,7 +74,7 @@ infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt' ...@@ -73,7 +74,7 @@ infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt'
### 具体执行的过程: ### 具体执行的过程:
1.从官方下载数据\[[3](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: Challenge2_Training_Task3_Images_GT.zip、Challenge2_Test_Task3_Images.zip和 Challenge2_Test_Task3_GT.txt。 1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: Challenge2_Training_Task3_Images_GT.zip、Challenge2_Test_Task3_Images.zip和 Challenge2_Test_Task3_GT.txt。
分别对应训练集的图片和图片对应的单词,测试集的图片,测试数据对应的单词,然后执行以下命令,对数据解压并移动至目标文件夹: 分别对应训练集的图片和图片对应的单词,测试集的图片,测试数据对应的单词,然后执行以下命令,对数据解压并移动至目标文件夹:
``` ```
...@@ -104,11 +105,10 @@ python train.py --train_file_list data/train_data/gt.txt --test_file_list data/t ...@@ -104,11 +105,10 @@ python train.py --train_file_list data/train_data/gt.txt --test_file_list data/t
- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行 - 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行
- 本模型参数较多,占用显存比较大,实际执行时可以调节batch_size 控制显存占用 - 本模型参数较多,占用显存比较大,实际执行时可以调节batch_size 控制显存占用
- 本模型使用的数据集较小,可以选用其他更大的数据集\[[4](#参考文献)\]来训练需要的模型 - 本模型使用的数据集较小,可以选用其他更大的数据集\[[3](#参考文献)\]来训练需要的模型
## 参考文献 ## 参考文献
1. [Google Now Using ReCAPTCHA To Decode Street View Addresses](https://techcrunch.com/2012/03/29/google-now-using-recaptcha-to-decode-street-view-addresses/) 1. [Google Now Using ReCAPTCHA To Decode Street View Addresses](https://techcrunch.com/2012/03/29/google-now-using-recaptcha-to-decode-street-view-addresses/)
2. Shi B, Bai X, Yao C. [An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition](https://arxiv.org/pdf/1507.05717.pdf)[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2016. APA 2. [Focused Scene Text](http://rrc.cvc.uab.es/?ch=2&com=introduction)
3. [Focused Scene Text](http://rrc.cvc.uab.es/?ch=2&com=introduction) 3. [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)
4. [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function
import os import os
from paddle.v2.image import load_image
import cv2 import cv2
from paddle.v2.image import load_image
class AsciiDic(object): class AsciiDic(object):
UNK = 0 UNK = 0
...@@ -30,7 +32,6 @@ class AsciiDic(object): ...@@ -30,7 +32,6 @@ class AsciiDic(object):
def word2ids(self, sent): def word2ids(self, sent):
''' '''
transform a word to a list of ids. transform a word to a list of ids.
@sent: str
''' '''
return [self.lookup(c) for c in list(sent)] return [self.lookup(c) for c in list(sent)]
...@@ -46,11 +47,11 @@ class ImageDataset(object): ...@@ -46,11 +47,11 @@ class ImageDataset(object):
fixed_shape=None, fixed_shape=None,
is_infer=False): is_infer=False):
''' '''
@image_paths_generator: function :param train_image_paths_generator:
return a list of images' paths, called like: return list of train images' paths.
:type train_image_paths_generator: function
for path in image_paths_generator(): :param fixed_shape: fixed shape of images.
load_image(path) :type fixed_shape: tuple
''' '''
if is_infer == False: if is_infer == False:
self.train_filelist = [p for p in train_image_paths_generator] self.train_filelist = [p for p in train_image_paths_generator]
...@@ -93,7 +94,7 @@ def get_file_list(image_file_list): ...@@ -93,7 +94,7 @@ def get_file_list(image_file_list):
pwd = os.path.dirname(image_file_list) pwd = os.path.dirname(image_file_list)
with open(image_file_list) as f: with open(image_file_list) as f:
for line in f: for line in f:
fs = line.strip().split(',') fs = line.strip().split(',', 1)
file = fs[0].strip() file = fs[0].strip()
path = os.path.join(pwd, file) path = os.path.join(pwd, file)
yield path, fs[1][2:-1] yield path, fs[1][2:-1]
"""Contains various CTC decoders."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
<html>
<head>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
extensions: ["tex2jax.js", "TeX/AMSsymbols.js", "TeX/AMSmath.js"],
jax: ["input/TeX", "output/HTML-CSS"],
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true
},
"HTML-CSS": { availableFonts: ["TeX"] }
});
</script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js" async></script>
<script type="text/javascript" src="../.tools/theme/marked.js">
</script>
<link href="http://cdn.bootcss.com/highlight.js/9.9.0/styles/darcula.min.css" rel="stylesheet">
<script src="http://cdn.bootcss.com/highlight.js/9.9.0/highlight.min.js"></script>
<link href="http://cdn.bootcss.com/bootstrap/4.0.0-alpha.6/css/bootstrap.min.css" rel="stylesheet">
<link href="https://cdn.jsdelivr.net/perfect-scrollbar/0.6.14/css/perfect-scrollbar.min.css" rel="stylesheet">
<link href="../.tools/theme/github-markdown.css" rel='stylesheet'>
</head>
<style type="text/css" >
.markdown-body {
box-sizing: border-box;
min-width: 200px;
max-width: 980px;
margin: 0 auto;
padding: 45px;
}
</style>
<body>
<div id="context" class="container-fluid markdown-body">
</div>
<!-- This block will be replaced by each markdown file content. Please do not change lines below.-->
<div id="markdown" style='display:none'>
# 场景文字识别 (STR, Scene Text Recognition)
## STR任务简介
在现实生活中,包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息,\[[1](#参考文献)\]使用深度学习模型自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。以下图为例,给定一个场景图片,STR需要从图片中识别出对应的文字"keep":
<p align="center">
<img src="./images/503.jpg"/><br/>
图 1. 数据示例 "keep"
</p>
## 使用 PaddlePaddle 训练与预测
### 模型训练
训练脚本 [./train.py](./train.py) 中设置了如下命令行参数:
```
usage: train.py [-h] --image_shape IMAGE_SHAPE --train_file_list
TRAIN_FILE_LIST --test_file_list TEST_FILE_LIST
[--batch_size BATCH_SIZE]
[--model_output_prefix MODEL_OUTPUT_PREFIX]
[--trainer_count TRAINER_COUNT]
[--save_period_by_batch SAVE_PERIOD_BY_BATCH]
[--num_passes NUM_PASSES]
PaddlePaddle CTC example
optional arguments:
-h, --help show this help message and exit
--image_shape IMAGE_SHAPE
image's shape, format is like '173,46'
--train_file_list TRAIN_FILE_LIST
path of the file which contains path list of train
image files
--test_file_list TEST_FILE_LIST
path of the file which contains path list of test
image files
--batch_size BATCH_SIZE
size of a mini-batch
--model_output_prefix MODEL_OUTPUT_PREFIX
prefix of path for model to store (default:
./model.ctc)
--trainer_count TRAINER_COUNT
number of training threads
--save_period_by_batch SAVE_PERIOD_BY_BATCH
save model to disk every N batches
--num_passes NUM_PASSES
number of passes to train (default: 1)
```
重要的几个参数包括:
- `image_shape` 图片的尺寸
- `train_file_list` 训练数据的列表文件,每行一个路径加对应的text,具体格式为:
```
word_1.png, "PROPER"
word_2.png, "FOOD"
```
- `test_file_list` 测试数据的列表文件,格式同上
### 预测
预测部分由infer.py完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在infer.py中指定具体的模型目录、图片固定尺寸、batch_size和图片文件的列表文件。例如:
```python
model_path = "model.ctc-pass-9-batch-150-test.tar.gz"
image_shape = "173,46"
batch_size = 50
infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt'
```
然后运行```python infer.py```
### 具体执行的过程:
1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件: Challenge2_Training_Task3_Images_GT.zip、Challenge2_Test_Task3_Images.zip和 Challenge2_Test_Task3_GT.txt。
分别对应训练集的图片和图片对应的单词,测试集的图片,测试数据对应的单词,然后执行以下命令,对数据解压并移动至目标文件夹:
```
mkdir -p data/train_data
mkdir -p data/test_data
unzip Challenge2_Training_Task3_Images_GT.zip -d data/train_data
unzip Challenge2_Test_Task3_Images.zip -d data/test_data
mv Challenge2_Test_Task3_GT.txt data/test_data
```
2.获取训练数据文件夹中 `gt.txt` 的路径 (data/train_data)和测试数据文件夹中`Challenge2_Test_Task3_GT.txt`的路径(data/test_data)
3.执行命令
```
python train.py --train_file_list data/train_data/gt.txt --test_file_list data/test_data/Challenge2_Test_Task3_GT.txt --image_shape '173,46'
```
4.训练过程中,模型参数会自动备份到指定目录,默认为 ./model.ctc
5.设置infer.py中的相关参数(模型所在路径),运行```python infer.py``` 进行预测
### 其他数据集
- [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)(41G)
- [ICDAR 2003 Robust Reading Competitions](http://www.iapr-tc11.org/mediawiki/index.php?title=ICDAR_2003_Robust_Reading_Competitions)
### 注意事项
- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行
- 本模型参数较多,占用显存比较大,实际执行时可以调节batch_size 控制显存占用
- 本模型使用的数据集较小,可以选用其他更大的数据集\[[3](#参考文献)\]来训练需要的模型
## 参考文献
1. [Google Now Using ReCAPTCHA To Decode Street View Addresses](https://techcrunch.com/2012/03/29/google-now-using-recaptcha-to-decode-street-view-addresses/)
2. [Focused Scene Text](http://rrc.cvc.uab.es/?ch=2&com=introduction)
3. [SynthText in the Wild Dataset](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)
</div>
<!-- You can change the lines below now. -->
<script type="text/javascript">
marked.setOptions({
renderer: new marked.Renderer(),
gfm: true,
breaks: false,
smartypants: true,
highlight: function(code, lang) {
code = code.replace(/&amp;/g, "&")
code = code.replace(/&gt;/g, ">")
code = code.replace(/&lt;/g, "<")
code = code.replace(/&nbsp;/g, " ")
return hljs.highlightAuto(code, [lang]).value;
}
});
document.getElementById("context").innerHTML = marked(
document.getElementById("markdown").innerHTML)
</script>
</body>
import logging import logging
import argparse import argparse
import paddle.v2 as paddle
import gzip import gzip
import paddle.v2 as paddle
from model import Model from model import Model
from data_provider import get_file_list, AsciiDic, ImageDataset from data_provider import get_file_list, AsciiDic, ImageDataset
from decoder import ctc_greedy_decoder from decoder import ctc_greedy_decoder
def infer(inferer, test_batch, labels): def infer_batch(inferer, test_batch, labels):
infer_results = inferer.infer(input=test_batch) infer_results = inferer.infer(input=test_batch)
num_steps = len(infer_results) // len(test_batch) num_steps = len(infer_results) // len(test_batch)
probs_split = [ probs_split = [
...@@ -23,15 +24,11 @@ def infer(inferer, test_batch, labels): ...@@ -23,15 +24,11 @@ def infer(inferer, test_batch, labels):
results.append(output_transcription) results.append(output_transcription)
for result, label in zip(results, labels): for result, label in zip(results, labels):
print("\nOutput Transcription: %s\nTarget Transcription: %s" % (result, print("\nOutput Transcription: %s\nTarget Transcription: %s" %
label)) (result, label))
if __name__ == "__main__": def infer(model_path, image_shape, batch_size, infer_file_list):
model_path = "model.ctc-pass-1-batch-150-test-10.2607016472.tar.gz"
image_shape = "173,46"
batch_size = 50
infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt'
image_shape = tuple(map(int, image_shape.split(','))) image_shape = tuple(map(int, image_shape.split(',')))
infer_generator = get_file_list(infer_file_list) infer_generator = get_file_list(infer_file_list)
...@@ -49,8 +46,17 @@ if __name__ == "__main__": ...@@ -49,8 +46,17 @@ if __name__ == "__main__":
test_batch.append([image]) test_batch.append([image])
labels.append(label) labels.append(label)
if len(test_batch) == batch_size: if len(test_batch) == batch_size:
infer(inferer, test_batch, labels) infer_batch(inferer, test_batch, labels)
test_batch = [] test_batch = []
labels = [] labels = []
if test_batch: if test_batch:
infer(inferer, test_batch, labels) infer_batch(inferer, test_batch, labels)
if __name__ == "__main__":
model_path = "model.ctc-pass-9-batch-150-test.tar.gz"
image_shape = "173,46"
batch_size = 50
infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt'
infer(model_path, image_shape, batch_size, infer_file_list)
...@@ -5,66 +5,15 @@ from paddle.v2.activation import Relu, Linear ...@@ -5,66 +5,15 @@ from paddle.v2.activation import Relu, Linear
from paddle.v2.networks import img_conv_group, simple_gru from paddle.v2.networks import img_conv_group, simple_gru
def conv_groups(input_image, num, with_bn):
'''
a deep CNN.
@input_image: input image
@num: number of CONV filters
@with_bn: whether with batch normal
'''
assert num % 4 == 0
tmp = img_conv_group(
input=input_image,
num_channels=1,
conv_padding=1,
conv_num_filter=[16] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )
tmp = img_conv_group(
input=tmp,
conv_padding=1,
conv_num_filter=[32] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )
tmp = img_conv_group(
input=tmp,
conv_padding=1,
conv_num_filter=[64] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )
tmp = img_conv_group(
input=tmp,
conv_padding=1,
conv_num_filter=[128] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )
return tmp
class Model(object): class Model(object):
def __init__(self, num_classes, shape, is_infer=False): def __init__(self, num_classes, shape, is_infer=False):
''' '''
@num_classes: int :param num_classes: size of the character dict.
size of the character dict :type num_classes: int
@shape: tuple of 2 int :param shape: size of the input images.
size of the input images :type shape: tuple of 2 int
:param is_infer: infer mode or not
:type shape: bool
''' '''
self.num_classes = num_classes self.num_classes = num_classes
self.shape = shape self.shape = shape
...@@ -90,7 +39,7 @@ class Model(object): ...@@ -90,7 +39,7 @@ class Model(object):
def __build_nn__(self): def __build_nn__(self):
# CNN output image features, 128 float matrixes # CNN output image features, 128 float matrixes
conv_features = conv_groups(self.image, 8, True) conv_features = self.conv_groups(self.image, 8, True)
# cutting CNN output into a sequence of feature vectors, which are # cutting CNN output into a sequence of feature vectors, which are
# 1 pixel wide and 11 pixel high. # 1 pixel wide and 11 pixel high.
...@@ -125,3 +74,41 @@ class Model(object): ...@@ -125,3 +74,41 @@ class Model(object):
size=self.num_classes + 1, size=self.num_classes + 1,
norm_by_times=True, norm_by_times=True,
blank=self.num_classes) blank=self.num_classes)
self.eval = evaluator.ctc_error(input=self.output, label=self.label)
def conv_groups(self, input_image, num, with_bn):
'''
:param input_image: input image.
:type input_image: LayerOutput
:param num: number of CONV filters.
:type num: int
:param with_bn: whether with batch normal.
:type with_bn: bool
'''
assert num % 4 == 0
filter_num_list = [16, 32, 64, 128]
is_input_image = True
tmp = input_image
for num_filter in filter_num_list:
if is_input_image:
num_channels = 1
is_input_image = False
else:
num_channels = None
tmp = img_conv_group(
input=tmp,
num_channels=num_channels,
conv_padding=1,
conv_num_filter=[num_filter] * (num / 4),
conv_filter_size=3,
conv_act=Relu(),
conv_with_batchnorm=with_bn,
pool_size=2,
pool_stride=2, )
return tmp
import logging import logging
import argparse import argparse
import paddle.v2 as paddle
import gzip import gzip
import paddle.v2 as paddle
from model import Model from model import Model
from data_provider import get_file_list, AsciiDic, ImageDataset from data_provider import get_file_list, AsciiDic, ImageDataset
...@@ -33,66 +34,68 @@ parser.add_argument( ...@@ -33,66 +34,68 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
'--save_period_by_batch', '--save_period_by_batch',
type=int, type=int,
default=50, default=150,
help='save model to disk every N batches') help='save model to disk every N batches')
parser.add_argument( parser.add_argument(
'--num_passes', '--num_passes',
type=int, type=int,
default=1, default=10,
help='number of passes to train (default: 1)') help='number of passes to train (default: 1)')
args = parser.parse_args() args = parser.parse_args()
image_shape = tuple(map(int, args.image_shape.split(',')))
print 'image_shape', image_shape def main():
print 'batch_size', args.batch_size image_shape = tuple(map(int, args.image_shape.split(',')))
print 'train_file_list', args.train_file_list
print 'test_file_list', args.test_file_list print 'image_shape', image_shape
print 'batch_size', args.batch_size
print 'train_file_list', args.train_file_list
print 'test_file_list', args.test_file_list
train_generator = get_file_list(args.train_file_list) train_generator = get_file_list(args.train_file_list)
test_generator = get_file_list(args.test_file_list) test_generator = get_file_list(args.test_file_list)
infer_generator = None infer_generator = None
dataset = ImageDataset( dataset = ImageDataset(
train_generator, train_generator,
test_generator, test_generator,
infer_generator, infer_generator,
fixed_shape=image_shape, fixed_shape=image_shape,
is_infer=False) is_infer=False)
paddle.init(use_gpu=True, trainer_count=args.trainer_count) paddle.init(use_gpu=True, trainer_count=args.trainer_count)
model = Model(AsciiDic().size(), image_shape, is_infer=False) model = Model(AsciiDic().size(), image_shape, is_infer=False)
params = paddle.parameters.create(model.cost) params = paddle.parameters.create(model.cost)
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0)
trainer = paddle.trainer.SGD( trainer = paddle.trainer.SGD(
cost=model.cost, parameters=params, update_equation=optimizer) cost=model.cost,
parameters=params,
update_equation=optimizer,
extra_layers=model.eval)
def event_handler(event):
def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
print "Pass %d, batch %d, Samples %d, Cost %f" % ( print "Pass %d, batch %d, Samples %d, Cost %f, Eval %s" % (
event.pass_id, event.batch_id, event.batch_id * args.batch_size, event.pass_id, event.batch_id,
event.cost) event.batch_id * args.batch_size, event.cost, event.metrics)
if event.batch_id > 0 and event.batch_id % args.save_period_by_batch == 0: if event.batch_id > 0 and event.batch_id % args.save_period_by_batch == 0:
result = trainer.test( result = trainer.test(
reader=paddle.batch(dataset.test, batch_size=10), reader=paddle.batch(dataset.test, batch_size=10),
feeding={'image': 0, feeding={'image': 0,
'label': 1}) 'label': 1})
print "Test %d-%d, Cost %f " % (event.pass_id, event.batch_id, print "Test %d-%d, Cost %f, Eval %s" % (
result.cost) event.pass_id, event.batch_id, result.cost, result.metrics)
path = "{}-pass-{}-batch-{}-test-{}.tar.gz".format( path = "{}-pass-{}-batch-{}-test.tar.gz".format(
args.model_output_prefix, event.pass_id, event.batch_id, args.model_output_prefix, event.pass_id, event.batch_id)
result.cost)
with gzip.open(path, 'w') as f: with gzip.open(path, 'w') as f:
params.to_tar(f) params.to_tar(f)
trainer.train(
trainer.train(
reader=paddle.batch( reader=paddle.batch(
paddle.reader.shuffle(dataset.train, buf_size=500), paddle.reader.shuffle(dataset.train, buf_size=500),
batch_size=args.batch_size), batch_size=args.batch_size),
...@@ -100,3 +103,7 @@ trainer.train( ...@@ -100,3 +103,7 @@ trainer.train(
'label': 1}, 'label': 1},
event_handler=event_handler, event_handler=event_handler,
num_passes=args.num_passes) num_passes=args.num_passes)
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册