Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
672e8565
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
672e8565
编写于
10月 26, 2017
作者:
P
peterzhang2029
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update the dictionary module
上级
0fa990bb
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
146 addition
and
94 deletion
+146
-94
scene_text_recognition/README.md
scene_text_recognition/README.md
+22
-14
scene_text_recognition/config.py
scene_text_recognition/config.py
+1
-1
scene_text_recognition/index.html
scene_text_recognition/index.html
+22
-14
scene_text_recognition/infer.py
scene_text_recognition/infer.py
+17
-8
scene_text_recognition/model.py
scene_text_recognition/model.py
+4
-5
scene_text_recognition/reader.py
scene_text_recognition/reader.py
+10
-8
scene_text_recognition/train.py
scene_text_recognition/train.py
+20
-4
scene_text_recognition/utils.py
scene_text_recognition/utils.py
+50
-40
未找到文件。
scene_text_recognition/README.md
浏览文件 @
672e8565
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
## STR任务简介
## STR任务简介
在现实生活中,
包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息,
\[
[
1
](
#参考文献
)
\]
使用深度学习模型
自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
在现实生活中,
许多图片中的文字为图片所处场景的理解提供了丰富的语义信息(例如:路牌、菜单、街道标语等)。同时,场景图片文字识别技术的发展也促进了一些新型应用的产生,例如:
\[
[
1
](
#参考文献
)
\]
通过使用深度学习模型来
自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
本例将演示如何用 PaddlePaddle 完成
**场景文字识别 (STR, Scene Text Recognition)**
任务。
以下图为例,给定一个场景图片,STR需要从图片
中识别出对应的文字"keep"。
本例将演示如何用 PaddlePaddle 完成
**场景文字识别 (STR, Scene Text Recognition)**
任务。
如下图所示,给定一张场景图片,
`STR`
需要从
中识别出对应的文字"keep"。
<p
align=
"center"
>
<p
align=
"center"
>
<img
src=
"./images/503.jpg"
/><br/>
<img
src=
"./images/503.jpg"
/><br/>
...
@@ -21,7 +21,7 @@ pip install -r requirements.txt
...
@@ -21,7 +21,7 @@ pip install -r requirements.txt
### 指定训练配置参数
### 指定训练配置参数
通过
`config.py`
脚本修改训练和模型配置参数,脚本中有对可配置参数的详细解释,示例
如下:
`config.py`
脚本中包含了模型配置和训练相关的参数以及对应的详细解释,代码
如下:
```
python
```
python
class
TrainerConfig
(
object
):
class
TrainerConfig
(
object
):
...
@@ -43,7 +43,8 @@ class ModelConfig(object):
...
@@ -43,7 +43,8 @@ class ModelConfig(object):
...
...
```
```
修改
`config.py`
对参数进行调整。例如,通过修改
`use_gpu`
参数来指定是否使用 GPU 进行训练。
修改
`config.py`
脚本可以实现对参数的调整。例如,通过修改
`use_gpu`
参数来指定是否使用 GPU 进行训练。
### 模型训练
### 模型训练
训练脚本
[
./train.py
](
./train.py
)
中设置了如下命令行参数:
训练脚本
[
./train.py
](
./train.py
)
中设置了如下命令行参数:
...
@@ -54,24 +55,29 @@ Options:
...
@@ -54,24 +55,29 @@ Options:
of train image files. [required]
of train image files. [required]
--test_file_list_path TEXT The path of the file which contains path list
--test_file_list_path TEXT The path of the file which contains path list
of test image files. [required]
of test image files. [required]
--label_dict_path TEXT The path of label dictionary. If this parameter
is set, but the file does not exist, label
dictionay will be built from the training data
automatically. [required]
--model_save_dir TEXT The path to save the trained models (default:
--model_save_dir TEXT The path to save the trained models (default:
'models').
'models').
--help Show this message and exit.
--help Show this message and exit.
```
```
-
`train_file_list`
训练数据的列表文件,每行一个路径加对应的text
,具体格式为:
-
`train_file_list`
:训练数据的列表文件,每行由图片的存储路径和对应的标记文本组成
,具体格式为:
```
```
word_1.png, "PROPER"
word_1.png, "PROPER"
word_2.png, "FOOD"
word_2.png, "FOOD"
```
```
-
`test_file_list`
测试数据的列表文件,格式同上。
-
`test_file_list`
:测试数据的列表文件,格式同上。
-
`model_save_dir`
模型参数会的保存目录目录, 默认为当前目录下的
`models`
目录。
-
`label_dict_path`
:训练数据中标记字典的存储路径,如果指定路径中字典文件不存在,程序会使用训练数据中的标记数据自动生成标记字典。
-
`model_save_dir`
:模型参数的保存目录,默认为
`./models`
。
### 具体执行的过程:
### 具体执行的过程:
1.
从官方网站下载数据
\[
[
2
](
#参考文献
)
\]
(Task 2.3: Word Recognition (2013 edition)),会有三个文件:
Challenge2_Training_Task3_Images_GT.zip、Challenge2_Test_Task3_Images.zip和 Challenge2_Test_Task3_GT.txt
。
1.
从官方网站下载数据
\[
[
2
](
#参考文献
)
\]
(Task 2.3: Word Recognition (2013 edition)),会有三个文件:
`Challenge2_Training_Task3_Images_GT.zip`
、
`Challenge2_Test_Task3_Images.zip`
和
`Challenge2_Test_Task3_GT.txt`
。
分别对应训练集的图片和图片对应的单词
,测试集的图片,测试数据对应的单词,
然后执行以下命令,对数据解压并移动至目标文件夹:
分别对应训练集的图片和图片对应的单词
、测试集的图片、测试数据对应的单词。
然后执行以下命令,对数据解压并移动至目标文件夹:
```
bash
```
bash
mkdir
-p
data/train_data
mkdir
-p
data/train_data
...
@@ -87,17 +93,19 @@ mv Challenge2_Test_Task3_GT.txt data/test_data
...
@@ -87,17 +93,19 @@ mv Challenge2_Test_Task3_GT.txt data/test_data
```
bash
```
bash
python train.py
\
python train.py
\
--train_file_list_path
'data/train_data/gt.txt'
\
--train_file_list_path
'data/train_data/gt.txt'
\
--test_file_list_path
'data/test_data/Challenge2_Test_Task3_GT.txt'
--test_file_list_path
'data/test_data/Challenge2_Test_Task3_GT.txt'
\
--label_dict_path
'label_dict.txt'
```
```
4.
训练过程中,模型参数会自动备份到指定目录,默认会保存在
`./models`
目录下。
4.
训练过程中,模型参数会自动备份到指定目录,默认会保存在
`./models`
目录下。
### 预测
### 预测
预测部分由
`infer.py`
完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在
`infer.py`
中指定具体的模型
目录、图片固定尺寸、batch_size(默认设置为10)
和图片文件的列表文件。执行如下代码:
预测部分由
`infer.py`
完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在
`infer.py`
中指定具体的模型
保存路径、图片固定尺寸、batch_size(默认为10)、标记词典路径
和图片文件的列表文件。执行如下代码:
```
bash
```
bash
python infer.py
\
python infer.py
\
--model_path
'models/params_pass_00000.tar.gz'
\
--model_path
'models/params_pass_00000.tar.gz'
\
--image_shape
'173,46'
\
--image_shape
'173,46'
\
--label_dict_path
'label_dict.txt'
\
--infer_file_list_path
'data/test_data/Challenge2_Test_Task3_GT.txt'
--infer_file_list_path
'data/test_data/Challenge2_Test_Task3_GT.txt'
```
```
即可进行预测。
即可进行预测。
...
@@ -109,9 +117,9 @@ python infer.py \
...
@@ -109,9 +117,9 @@ python infer.py \
### 注意事项
### 注意事项
-
由于模型依赖的
`warp CTC`
只有CUDA的实现,本模型只支持 GPU 运行
-
由于模型依赖的
`warp CTC`
只有CUDA的实现,本模型只支持 GPU 运行
。
-
本模型参数较多,占用显存比较大,实际执行时可以
调节
`batch_size`
控制显存占用
-
本模型参数较多,占用显存比较大,实际执行时可以
通过调节
`batch_size`
来控制显存占用。
-
本
模型使用的数据集较小,可以选用其他更大的数据集
\[
[
3
](
#参考文献
)
\]
来训练需要的模型
-
本
例使用的数据集较小,如有需要,可以选用其他更大的数据集
\[
[
3
](
#参考文献
)
\]
来训练模型。
## 参考文献
## 参考文献
...
...
scene_text_recognition/config.py
浏览文件 @
672e8565
scene_text_recognition/index.html
浏览文件 @
672e8565
...
@@ -44,9 +44,9 @@
...
@@ -44,9 +44,9 @@
## STR任务简介
## STR任务简介
在现实生活中,
包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息,\[[1](#参考文献)\]使用深度学习模型
自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
在现实生活中,
许多图片中的文字为图片所处场景的理解提供了丰富的语义信息(例如:路牌、菜单、街道标语等)。同时,场景图片文字识别技术的发展也促进了一些新型应用的产生,例如:\[[1](#参考文献)\]通过使用深度学习模型来
自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。
以下图为例,给定一个场景图片,STR需要从图片
中识别出对应的文字"keep"。
本例将演示如何用 PaddlePaddle 完成 **场景文字识别 (STR, Scene Text Recognition)** 任务。
如下图所示,给定一张场景图片,`STR` 需要从
中识别出对应的文字"keep"。
<p
align=
"center"
>
<p
align=
"center"
>
<img
src=
"./images/503.jpg"
/><br/>
<img
src=
"./images/503.jpg"
/><br/>
...
@@ -63,7 +63,7 @@ pip install -r requirements.txt
...
@@ -63,7 +63,7 @@ pip install -r requirements.txt
### 指定训练配置参数
### 指定训练配置参数
通过 `config.py` 脚本修改训练和模型配置参数,脚本中有对可配置参数的详细解释,示例
如下:
`config.py` 脚本中包含了模型配置和训练相关的参数以及对应的详细解释,代码
如下:
```python
```python
class TrainerConfig(object):
class TrainerConfig(object):
...
@@ -85,7 +85,8 @@ class ModelConfig(object):
...
@@ -85,7 +85,8 @@ class ModelConfig(object):
...
...
```
```
修改 `config.py` 对参数进行调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。
修改 `config.py` 脚本可以实现对参数的调整。例如,通过修改 `use_gpu` 参数来指定是否使用 GPU 进行训练。
### 模型训练
### 模型训练
训练脚本 [./train.py](./train.py) 中设置了如下命令行参数:
训练脚本 [./train.py](./train.py) 中设置了如下命令行参数:
...
@@ -96,24 +97,29 @@ Options:
...
@@ -96,24 +97,29 @@ Options:
of train image files. [required]
of train image files. [required]
--test_file_list_path TEXT The path of the file which contains path list
--test_file_list_path TEXT The path of the file which contains path list
of test image files. [required]
of test image files. [required]
--label_dict_path TEXT The path of label dictionary. If this parameter
is set, but the file does not exist, label
dictionay will be built from the training data
automatically. [required]
--model_save_dir TEXT The path to save the trained models (default:
--model_save_dir TEXT The path to save the trained models (default:
'models').
'models').
--help Show this message and exit.
--help Show this message and exit.
```
```
- `train_file_list`
训练数据的列表文件,每行一个路径加对应的text
,具体格式为:
- `train_file_list`
:训练数据的列表文件,每行由图片的存储路径和对应的标记文本组成
,具体格式为:
```
```
word_1.png, "PROPER"
word_1.png, "PROPER"
word_2.png, "FOOD"
word_2.png, "FOOD"
```
```
- `test_file_list` 测试数据的列表文件,格式同上。
- `test_file_list` :测试数据的列表文件,格式同上。
- `model_save_dir` 模型参数会的保存目录目录, 默认为当前目录下的`models`目录。
- `label_dict_path` :训练数据中标记字典的存储路径,如果指定路径中字典文件不存在,程序会使用训练数据中的标记数据自动生成标记字典。
- `model_save_dir` :模型参数的保存目录,默认为`./models`。
### 具体执行的过程:
### 具体执行的过程:
1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件:
Challenge2_Training_Task3_Images_GT.zip、Challenge2_Test_Task3_Images.zip和 Challenge2_Test_Task3_GT.txt
。
1.从官方网站下载数据\[[2](#参考文献)\](Task 2.3: Word Recognition (2013 edition)),会有三个文件:
`Challenge2_Training_Task3_Images_GT.zip`、`Challenge2_Test_Task3_Images.zip` 和 `Challenge2_Test_Task3_GT.txt`
。
分别对应训练集的图片和图片对应的单词
,测试集的图片,测试数据对应的单词,
然后执行以下命令,对数据解压并移动至目标文件夹:
分别对应训练集的图片和图片对应的单词
、测试集的图片、测试数据对应的单词。
然后执行以下命令,对数据解压并移动至目标文件夹:
```bash
```bash
mkdir -p data/train_data
mkdir -p data/train_data
...
@@ -129,17 +135,19 @@ mv Challenge2_Test_Task3_GT.txt data/test_data
...
@@ -129,17 +135,19 @@ mv Challenge2_Test_Task3_GT.txt data/test_data
```bash
```bash
python train.py \
python train.py \
--train_file_list_path 'data/train_data/gt.txt' \
--train_file_list_path 'data/train_data/gt.txt' \
--test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt'
--test_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt' \
--label_dict_path 'label_dict.txt'
```
```
4.训练过程中,模型参数会自动备份到指定目录,默认会保存在 `./models` 目录下。
4.训练过程中,模型参数会自动备份到指定目录,默认会保存在 `./models` 目录下。
### 预测
### 预测
预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型
目录、图片固定尺寸、batch_size(默认设置为10)
和图片文件的列表文件。执行如下代码:
预测部分由 `infer.py` 完成,使用的是最优路径解码算法,即:在每个时间步选择一个概率最大的字符。在使用过程中,需要在 `infer.py` 中指定具体的模型
保存路径、图片固定尺寸、batch_size(默认为10)、标记词典路径
和图片文件的列表文件。执行如下代码:
```bash
```bash
python infer.py \
python infer.py \
--model_path 'models/params_pass_00000.tar.gz' \
--model_path 'models/params_pass_00000.tar.gz' \
--image_shape '173,46' \
--image_shape '173,46' \
--label_dict_path 'label_dict.txt' \
--infer_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt'
--infer_file_list_path 'data/test_data/Challenge2_Test_Task3_GT.txt'
```
```
即可进行预测。
即可进行预测。
...
@@ -151,9 +159,9 @@ python infer.py \
...
@@ -151,9 +159,9 @@ python infer.py \
### 注意事项
### 注意事项
- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行
- 由于模型依赖的 `warp CTC` 只有CUDA的实现,本模型只支持 GPU 运行
。
- 本模型参数较多,占用显存比较大,实际执行时可以
调节`batch_size`控制显存占用
- 本模型参数较多,占用显存比较大,实际执行时可以
通过调节 `batch_size` 来控制显存占用。
- 本
模型使用的数据集较小,可以选用其他更大的数据集\[[3](#参考文献)\]来训练需要的模型
- 本
例使用的数据集较小,如有需要,可以选用其他更大的数据集\[[3](#参考文献)\]来训练模型。
## 参考文献
## 参考文献
...
...
scene_text_recognition/infer.py
浏览文件 @
672e8565
...
@@ -5,10 +5,10 @@ import paddle.v2 as paddle
...
@@ -5,10 +5,10 @@ import paddle.v2 as paddle
from
model
import
Model
from
model
import
Model
from
reader
import
DataGenerator
from
reader
import
DataGenerator
from
decoder
import
ctc_greedy_decoder
from
decoder
import
ctc_greedy_decoder
from
utils
import
AsciiDic
,
get_file_lis
t
from
utils
import
get_file_list
,
load_dict
,
load_reverse_dic
t
def
infer_batch
(
inferer
,
test_batch
,
labels
):
def
infer_batch
(
inferer
,
test_batch
,
labels
,
reversed_char_dict
):
infer_results
=
inferer
.
infer
(
input
=
test_batch
)
infer_results
=
inferer
.
infer
(
input
=
test_batch
)
num_steps
=
len
(
infer_results
)
//
len
(
test_batch
)
num_steps
=
len
(
infer_results
)
//
len
(
test_batch
)
probs_split
=
[
probs_split
=
[
...
@@ -19,7 +19,7 @@ def infer_batch(inferer, test_batch, labels):
...
@@ -19,7 +19,7 @@ def infer_batch(inferer, test_batch, labels):
# Best path decode.
# Best path decode.
for
i
,
probs
in
enumerate
(
probs_split
):
for
i
,
probs
in
enumerate
(
probs_split
):
output_transcription
=
ctc_greedy_decoder
(
output_transcription
=
ctc_greedy_decoder
(
probs_seq
=
probs
,
vocabulary
=
AsciiDic
().
id2word
()
)
probs_seq
=
probs
,
vocabulary
=
reversed_char_dict
)
results
.
append
(
output_transcription
)
results
.
append
(
output_transcription
)
for
result
,
label
in
zip
(
results
,
labels
):
for
result
,
label
in
zip
(
results
,
labels
):
...
@@ -40,17 +40,26 @@ def infer_batch(inferer, test_batch, labels):
...
@@ -40,17 +40,26 @@ def infer_batch(inferer, test_batch, labels):
type
=
int
,
type
=
int
,
default
=
10
,
default
=
10
,
help
=
(
"The number of examples in one batch (default: 10)."
))
help
=
(
"The number of examples in one batch (default: 10)."
))
@
click
.
option
(
"--label_dict_path"
,
type
=
str
,
required
=
True
,
help
=
(
"The path of label dictionary. "
))
@
click
.
option
(
@
click
.
option
(
"--infer_file_list_path"
,
"--infer_file_list_path"
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
(
"The path of the file which contains "
help
=
(
"The path of the file which contains "
"path list of image files for inference."
))
"path list of image files for inference."
))
def
infer
(
model_path
,
image_shape
,
batch_size
,
infer_file_list_path
):
def
infer
(
model_path
,
image_shape
,
batch_size
,
label_dict_path
,
infer_file_list_path
):
image_shape
=
tuple
(
map
(
int
,
image_shape
.
split
(
','
)))
image_shape
=
tuple
(
map
(
int
,
image_shape
.
split
(
','
)))
infer_file_list
=
get_file_list
(
infer_file_list_path
)
infer_file_list
=
get_file_list
(
infer_file_list_path
)
char_dict
=
AsciiDic
()
dict_size
=
char_dict
.
size
()
char_dict
=
load_dict
(
label_dict_path
)
reversed_char_dict
=
load_reverse_dict
(
label_dict_path
)
dict_size
=
len
(
char_dict
)
data_generator
=
DataGenerator
(
char_dict
=
char_dict
,
image_shape
=
image_shape
)
data_generator
=
DataGenerator
(
char_dict
=
char_dict
,
image_shape
=
image_shape
)
paddle
.
init
(
use_gpu
=
True
,
trainer_count
=
1
)
paddle
.
init
(
use_gpu
=
True
,
trainer_count
=
1
)
...
@@ -66,11 +75,11 @@ def infer(model_path, image_shape, batch_size, infer_file_list_path):
...
@@ -66,11 +75,11 @@ def infer(model_path, image_shape, batch_size, infer_file_list_path):
test_batch
.
append
([
image
])
test_batch
.
append
([
image
])
labels
.
append
(
label
)
labels
.
append
(
label
)
if
len
(
test_batch
)
==
batch_size
:
if
len
(
test_batch
)
==
batch_size
:
infer_batch
(
inferer
,
test_batch
,
labels
)
infer_batch
(
inferer
,
test_batch
,
labels
,
reversed_char_dict
)
test_batch
=
[]
test_batch
=
[]
labels
=
[]
labels
=
[]
if
test_batch
:
if
test_batch
:
infer_batch
(
inferer
,
test_batch
,
labels
)
infer_batch
(
inferer
,
test_batch
,
labels
,
reversed_char_dict
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
scene_text_recognition/model.py
浏览文件 @
672e8565
...
@@ -45,12 +45,11 @@ class Model(object):
...
@@ -45,12 +45,11 @@ class Model(object):
'''
'''
Build the network topology.
Build the network topology.
'''
'''
#
CNN output image features
.
#
Get the image features with CNN
.
conv_features
=
self
.
conv_groups
(
self
.
image
,
conf
.
filter_num
,
conv_features
=
self
.
conv_groups
(
self
.
image
,
conf
.
filter_num
,
conf
.
with_bn
)
conf
.
with_bn
)
# Cut CNN output into a sequence of feature vectors, which are
# Expand the output of CNN into a sequence of feature vectors.
# 1 pixel wide and 11 pixel high.
sliced_feature
=
layer
.
block_expand
(
sliced_feature
=
layer
.
block_expand
(
input
=
conv_features
,
input
=
conv_features
,
num_channels
=
conf
.
num_channels
,
num_channels
=
conf
.
num_channels
,
...
@@ -59,7 +58,7 @@ class Model(object):
...
@@ -59,7 +58,7 @@ class Model(object):
block_x
=
conf
.
block_x
,
block_x
=
conf
.
block_x
,
block_y
=
conf
.
block_y
)
block_y
=
conf
.
block_y
)
#
RNNs
to capture sequence information forwards and backwards.
#
Use RNN
to capture sequence information forwards and backwards.
gru_forward
=
simple_gru
(
gru_forward
=
simple_gru
(
input
=
sliced_feature
,
size
=
conf
.
hidden_size
,
act
=
Relu
())
input
=
sliced_feature
,
size
=
conf
.
hidden_size
,
act
=
Relu
())
gru_backward
=
simple_gru
(
gru_backward
=
simple_gru
(
...
@@ -68,7 +67,7 @@ class Model(object):
...
@@ -68,7 +67,7 @@ class Model(object):
act
=
Relu
(),
act
=
Relu
(),
reverse
=
True
)
reverse
=
True
)
# Map
each step
of RNN to character distribution.
# Map
the output
of RNN to character distribution.
self
.
output
=
layer
.
fc
(
self
.
output
=
layer
.
fc
(
input
=
[
gru_forward
,
gru_backward
],
input
=
[
gru_forward
,
gru_backward
],
size
=
self
.
num_classes
+
1
,
size
=
self
.
num_classes
+
1
,
...
...
scene_text_recognition/reader.py
浏览文件 @
672e8565
...
@@ -24,8 +24,10 @@ class DataGenerator(object):
...
@@ -24,8 +24,10 @@ class DataGenerator(object):
'''
'''
def
reader
():
def
reader
():
for
i
,
(
image
,
label
)
in
enumerate
(
file_list
):
UNK_ID
=
self
.
char_dict
[
'<unk>'
]
yield
self
.
load_image
(
image
),
self
.
char_dict
.
word2ids
(
label
)
for
image_path
,
label
in
file_list
:
label
=
[
self
.
char_dict
.
get
(
c
,
UNK_ID
)
for
c
in
label
]
yield
self
.
load_image
(
image_path
),
label
return
reader
return
reader
...
@@ -38,14 +40,14 @@ class DataGenerator(object):
...
@@ -38,14 +40,14 @@ class DataGenerator(object):
'''
'''
def
reader
():
def
reader
():
for
i
,
(
image
,
label
)
in
enumerate
(
file_list
)
:
for
i
mage_path
,
label
in
file_list
:
yield
self
.
load_image
(
image
),
label
yield
self
.
load_image
(
image
_path
),
label
return
reader
return
reader
def
load_image
(
self
,
path
):
def
load_image
(
self
,
path
):
'''
'''
Load
image and transform
to 1-dimention vector.
Load
an image and transform it
to 1-dimention vector.
:param path: The path of the image data.
:param path: The path of the image data.
:type path: str
:type path: str
...
...
scene_text_recognition/train.py
浏览文件 @
672e8565
...
@@ -6,7 +6,7 @@ import paddle.v2 as paddle
...
@@ -6,7 +6,7 @@ import paddle.v2 as paddle
from
config
import
TrainerConfig
as
conf
from
config
import
TrainerConfig
as
conf
from
model
import
Model
from
model
import
Model
from
reader
import
DataGenerator
from
reader
import
DataGenerator
from
utils
import
get_file_list
,
AsciiDic
from
utils
import
get_file_list
,
build_label_dict
,
load_dict
@
click
.
command
(
'train'
)
@
click
.
command
(
'train'
)
...
@@ -22,19 +22,35 @@ from utils import get_file_list, AsciiDic
...
@@ -22,19 +22,35 @@ from utils import get_file_list, AsciiDic
required
=
True
,
required
=
True
,
help
=
(
"The path of the file which contains "
help
=
(
"The path of the file which contains "
"path list of test image files."
))
"path list of test image files."
))
@
click
.
option
(
"--label_dict_path"
,
type
=
str
,
required
=
True
,
help
=
(
"The path of label dictionary. "
"If this parameter is set, but the file does not exist, "
"label dictionay will be built from "
"the training data automatically."
))
@
click
.
option
(
@
click
.
option
(
"--model_save_dir"
,
"--model_save_dir"
,
type
=
str
,
type
=
str
,
default
=
"models"
,
default
=
"models"
,
help
=
"The path to save the trained models (default: 'models')."
)
help
=
"The path to save the trained models (default: 'models')."
)
def
train
(
train_file_list_path
,
test_file_list_path
,
model_save_dir
):
def
train
(
train_file_list_path
,
test_file_list_path
,
label_dict_path
,
model_save_dir
):
if
not
os
.
path
.
exists
(
model_save_dir
):
if
not
os
.
path
.
exists
(
model_save_dir
):
os
.
mkdir
(
model_save_dir
)
os
.
mkdir
(
model_save_dir
)
train_file_list
=
get_file_list
(
train_file_list_path
)
train_file_list
=
get_file_list
(
train_file_list_path
)
test_file_list
=
get_file_list
(
test_file_list_path
)
test_file_list
=
get_file_list
(
test_file_list_path
)
char_dict
=
AsciiDic
()
dict_size
=
char_dict
.
size
()
if
not
os
.
path
.
exists
(
label_dict_path
):
print
((
"Label dictionary is not given, the dictionary "
"is automatically built from the training data."
))
build_label_dict
(
train_file_list
,
label_dict_path
)
char_dict
=
load_dict
(
label_dict_path
)
dict_size
=
len
(
char_dict
)
data_generator
=
DataGenerator
(
data_generator
=
DataGenerator
(
char_dict
=
char_dict
,
image_shape
=
conf
.
image_shape
)
char_dict
=
char_dict
,
image_shape
=
conf
.
image_shape
)
...
...
scene_text_recognition/utils.py
浏览文件 @
672e8565
import
os
import
os
from
collections
import
defaultdict
class
AsciiDic
(
object
):
UNK_ID
=
0
def
__init__
(
self
):
self
.
dic
=
{
'<unk>'
:
self
.
UNK_ID
,
}
self
.
chars
=
[
chr
(
i
)
for
i
in
range
(
40
,
171
)]
for
id
,
c
in
enumerate
(
self
.
chars
):
self
.
dic
[
c
]
=
id
+
1
def
lookup
(
self
,
w
):
return
self
.
dic
.
get
(
w
,
self
.
UNK_ID
)
def
id2word
(
self
):
'''
Return a reversed char dict.
'''
self
.
id2word
=
{}
for
key
,
value
in
self
.
dic
.
items
():
self
.
id2word
[
value
]
=
key
return
self
.
id2word
def
word2ids
(
self
,
word
):
'''
Transform a word to a list of ids.
:param word: The word appears in image data.
:type word: str
'''
return
[
self
.
lookup
(
c
)
for
c
in
list
(
word
)]
def
size
(
self
):
return
len
(
self
.
dic
)
def
get_file_list
(
image_file_list
):
def
get_file_list
(
image_file_list
):
...
@@ -53,7 +17,53 @@ def get_file_list(image_file_list):
...
@@ -53,7 +17,53 @@ def get_file_list(image_file_list):
line_split
=
line
.
strip
().
split
(
','
,
1
)
line_split
=
line
.
strip
().
split
(
','
,
1
)
filename
=
line_split
[
0
].
strip
()
filename
=
line_split
[
0
].
strip
()
path
=
os
.
path
.
join
(
dirname
,
filename
)
path
=
os
.
path
.
join
(
dirname
,
filename
)
label
=
line_split
[
1
][
2
:
-
1
]
label
=
line_split
[
1
][
2
:
-
1
].
strip
()
if
label
:
path_list
.
append
((
path
,
label
))
path_list
.
append
((
path
,
label
))
return
path_list
return
path_list
def
build_label_dict
(
file_list
,
save_path
):
"""
Build label dictionary from training data.
:param file_list: The list which contains the labels
of training data.
:type file_list: list
:params save_path: The path where the label dictionary will be saved.
:type save_path: str
"""
values
=
defaultdict
(
int
)
for
path
,
label
in
file_list
:
for
c
in
label
:
if
c
:
values
[
c
]
+=
1
values
[
'<unk>'
]
=
0
with
open
(
save_path
,
"w"
)
as
f
:
for
v
,
count
in
sorted
(
values
.
iteritems
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
):
f
.
write
(
"%s
\t
%d
\n
"
%
(
v
,
count
))
def
load_dict
(
dict_path
):
"""
Load label dictionary from the dictionary path.
:param dict_path: The path of word dictionary.
:type dict_path: str
"""
return
dict
((
line
.
strip
().
split
(
"
\t
"
)[
0
],
idx
)
for
idx
,
line
in
enumerate
(
open
(
dict_path
,
"r"
).
readlines
()))
def
load_reverse_dict
(
dict_path
):
"""
Load the reversed label dictionary from dictionary path.
:param dict_path: The path of word dictionary.
:type dict_path: str
"""
return
dict
((
idx
,
line
.
strip
().
split
(
"
\t
"
)[
0
])
for
idx
,
line
in
enumerate
(
open
(
dict_path
,
"r"
).
readlines
()))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录