Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
9b0d8621
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9b0d8621
编写于
4月 01, 2018
作者:
W
whs
提交者:
GitHub
4月 01, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #789 from wanghaoshuang/refine_ctc
Refine OCR CTC model.
上级
958812f7
97cfb9de
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
47 addition
and
28 deletion
+47
-28
fluid/ocr_recognition/crnn_ctc_model.py
fluid/ocr_recognition/crnn_ctc_model.py
+7
-15
fluid/ocr_recognition/ctc_reader.py
fluid/ocr_recognition/ctc_reader.py
+35
-7
fluid/ocr_recognition/ctc_train.py
fluid/ocr_recognition/ctc_train.py
+5
-6
未找到文件。
fluid/ocr_recognition/crnn_ctc_model.py
浏览文件 @
9b0d8621
...
...
@@ -187,25 +187,17 @@ def ctc_train_net(images, label, args, num_classes):
error_evaluator
=
fluid
.
evaluator
.
EditDistance
(
input
=
decoded_out
,
label
=
casted_label
)
inference_program
=
fluid
.
default_main_program
().
clone
()
with
fluid
.
program_guard
(
inference_program
):
inference_program
=
fluid
.
io
.
get_inference_program
(
error_evaluator
)
inference_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
args
.
learning_rate
,
momentum
=
args
.
momentum
)
_
,
params_grads
=
optimizer
.
minimize
(
sum_cost
)
model_average
=
None
if
args
.
model_average
:
model_average
=
fluid
.
optimizer
.
ModelAverage
(
params_grads
,
args
.
average_window
,
min_average_window
=
args
.
min_average_window
,
max_average_window
=
args
.
max_average_window
)
decoded_out
=
fluid
.
layers
.
ctc_greedy_decoder
(
input
=
fc_out
,
blank
=
num_classes
)
casted_label
=
fluid
.
layers
.
cast
(
x
=
label
,
dtype
=
'int64'
)
error_evaluator
=
fluid
.
evaluator
.
EditDistance
(
input
=
decoded_out
,
label
=
casted_label
)
model_average
=
fluid
.
optimizer
.
ModelAverage
(
params_grads
,
args
.
average_window
,
min_average_window
=
args
.
min_average_window
,
max_average_window
=
args
.
max_average_window
)
return
sum_cost
,
error_evaluator
,
inference_program
,
model_average
...
...
fluid/ocr_recognition/ctc_reader.py
浏览文件 @
9b0d8621
import
os
import
cv2
import
tarfile
import
numpy
as
np
from
PIL
import
Image
from
os
import
path
from
paddle.v2.image
import
load_image
import
paddle.v2
as
paddle
NUM_CLASSES
=
10784
DATA_SHAPE
=
[
1
,
48
,
512
]
DATA_MD5
=
"1de60d54d19632022144e4e58c2637b5"
DATA_URL
=
"http://cloud.dlnel.org/filepub/?uuid=df937251-3c0b-480d-9a7b-0080dfeee65c"
CACHE_DIR_NAME
=
"ctc_data"
SAVED_FILE_NAME
=
"data.tar.gz"
DATA_DIR_NAME
=
"data"
TRAIN_DATA_DIR_NAME
=
"train_images"
TEST_DATA_DIR_NAME
=
"test_images"
TRAIN_LIST_FILE_NAME
=
"train.list"
TEST_LIST_FILE_NAME
=
"test.list"
class
DataGenerator
(
object
):
def
__init__
(
self
):
...
...
@@ -102,25 +113,42 @@ class DataGenerator(object):
def
num_classes
():
'''Get classes number of this dataset.
'''
return
NUM_CLASSES
def
data_shape
():
'''Get image shape of this dataset. It is a dummy shape for this dataset.
'''
return
DATA_SHAPE
def
train
(
batch_size
):
generator
=
DataGenerator
()
data_dir
=
download_data
()
return
generator
.
train_reader
(
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/train_images/"
,
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/train.list"
,
batch_size
)
path
.
join
(
data_dir
,
TRAIN_DATA_DIR_NAME
),
path
.
join
(
data_dir
,
TRAIN_LIST_FILE_NAME
),
batch_size
)
def
test
(
batch_size
=
1
):
generator
=
DataGenerator
()
data_dir
=
download_data
()
return
paddle
.
batch
(
generator
.
test_reader
(
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/test_images/"
,
"/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/test.list"
),
batch_size
)
path
.
join
(
data_dir
,
TRAIN_DATA_DIR_NAME
),
path
.
join
(
data_dir
,
TRAIN_LIST_FILE_NAME
)),
batch_size
)
def
download_data
():
'''Download train and test data.
'''
tar_file
=
paddle
.
dataset
.
common
.
download
(
DATA_URL
,
CACHE_DIR_NAME
,
DATA_MD5
,
save_name
=
SAVED_FILE_NAME
)
data_dir
=
path
.
join
(
path
.
dirname
(
tar_file
),
DATA_DIR_NAME
)
if
not
path
.
isdir
(
data_dir
):
t
=
tarfile
.
open
(
tar_file
,
"r:gz"
)
t
.
extractall
(
path
=
path
.
dirname
(
tar_file
))
t
.
close
()
return
data_dir
fluid/ocr_recognition/ctc_train.py
浏览文件 @
9b0d8621
...
...
@@ -8,6 +8,7 @@ import functools
import
sys
from
utility
import
add_arguments
,
print_arguments
,
to_lodtensor
,
get_feeder_data
from
crnn_ctc_model
import
ctc_train_net
import
time
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
...
...
@@ -23,11 +24,10 @@ add_arg('momentum', float, 0.9, "Momentum.")
add_arg
(
'rnn_hidden_size'
,
int
,
200
,
"Hidden size of rnn layers."
)
add_arg
(
'device'
,
int
,
0
,
"Device id.'-1' means running on CPU"
"while '0' means GPU-0."
)
add_arg
(
'model_average'
,
bool
,
True
,
"Whether to aevrage model for evaluation."
)
add_arg
(
'min_average_window'
,
int
,
10000
,
"Min average window."
)
add_arg
(
'max_average_window'
,
int
,
15625
,
"Max average window."
)
add_arg
(
'average_window'
,
float
,
0.15
,
"Average window."
)
add_arg
(
'parallel'
,
bool
,
Tru
e
,
"Whether use parallel training."
)
add_arg
(
'parallel'
,
bool
,
Fals
e
,
"Whether use parallel training."
)
# yapf: disable
def
load_parameter
(
place
):
...
...
@@ -70,11 +70,12 @@ def train(args, data_reader=dummy_reader):
fetch_list
=
[
sum_cost
]
+
error_evaluator
.
metrics
)
total_loss
+=
batch_loss
[
0
]
total_seq_error
+=
batch_seq_error
[
0
]
if
batch_id
%
10
==
1
:
if
batch_id
%
10
0
==
1
:
print
'.'
,
sys
.
stdout
.
flush
()
if
batch_id
%
args
.
log_period
==
1
:
print
"
\n
Pass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq error: %s."
%
(
print
"
\n
Time: %s; Pass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq error: %s."
%
(
time
.
time
(),
pass_id
,
batch_id
,
total_loss
/
(
batch_id
*
args
.
batch_size
),
total_seq_error
/
(
batch_id
*
args
.
batch_size
))
sys
.
stdout
.
flush
()
batch_id
+=
1
...
...
@@ -84,8 +85,6 @@ def train(args, data_reader=dummy_reader):
for
data
in
test_reader
():
exe
.
run
(
inference_program
,
feed
=
get_feeder_data
(
data
,
place
))
_
,
test_seq_error
=
error_evaluator
.
eval
(
exe
)
if
model_average
!=
None
:
model_average
.
restore
(
exe
)
print
"
\n
End pass[%d]; Test seq error: %s.
\n
"
%
(
pass_id
,
str
(
test_seq_error
[
0
]))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录