Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
6dd494b6
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
1 年多 前同步成功
通知
1532
Star
32963
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6dd494b6
编写于
5月 20, 2020
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add anno
上级
fc2f9c2e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
65 addition
and
7 deletion
+65
-7
ppocr/modeling/architectures/rec_model.py
ppocr/modeling/architectures/rec_model.py
+21
-0
ppocr/modeling/losses/rec_attention_loss.py
ppocr/modeling/losses/rec_attention_loss.py
+1
-0
ppocr/modeling/losses/rec_ctc_loss.py
ppocr/modeling/losses/rec_ctc_loss.py
+1
-0
ppocr/utils/character.py
ppocr/utils/character.py
+42
-7
未找到文件。
ppocr/modeling/architectures/rec_model.py
浏览文件 @
6dd494b6
...
...
@@ -25,6 +25,14 @@ from copy import deepcopy
class
RecModel
(
object
):
"""
Rec model architecture
Args:
params(object): Params from yaml file and settings from command line
"""
def
__init__
(
self
,
params
):
super
(
RecModel
,
self
).
__init__
()
global_params
=
params
[
'Global'
]
...
...
@@ -58,6 +66,13 @@ class RecModel(object):
self
.
max_text_length
=
global_params
[
'max_text_length'
]
def
create_feed
(
self
,
mode
):
"""
Create feed dict and DataLoader object
Args:
mode(str): runtime mode, can be "train", "eval" or "test"
Return: image, labels, loader
"""
image_shape
=
deepcopy
(
self
.
image_shape
)
image_shape
.
insert
(
0
,
-
1
)
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
image_shape
,
dtype
=
'float32'
)
...
...
@@ -96,9 +111,13 @@ class RecModel(object):
inputs
=
image
else
:
inputs
=
self
.
tps
(
image
)
# backbone
conv_feas
=
self
.
backbone
(
inputs
)
# predict
predicts
=
self
.
head
(
conv_feas
,
labels
,
mode
)
decoded_out
=
predicts
[
'decoded_out'
]
#loss
if
mode
==
"train"
:
loss
=
self
.
loss
(
predicts
,
labels
)
if
self
.
loss_type
==
"attention"
:
...
...
@@ -108,9 +127,11 @@ class RecModel(object):
outputs
=
{
'total_loss'
:
loss
,
'decoded_out'
:
\
decoded_out
,
'label'
:
label
}
return
loader
,
outputs
# export_model
elif
mode
==
"export"
:
predict
=
predicts
[
'predict'
]
predict
=
fluid
.
layers
.
softmax
(
predict
)
return
[
image
,
{
'decoded_out'
:
decoded_out
,
'predicts'
:
predict
}]
# eval or test
else
:
return
loader
,
{
'decoded_out'
:
decoded_out
}
ppocr/modeling/losses/rec_attention_loss.py
浏览文件 @
6dd494b6
...
...
@@ -33,6 +33,7 @@ class AttentionLoss(object):
predict
=
predicts
[
'predict'
]
label_out
=
labels
[
'label_out'
]
label_out
=
fluid
.
layers
.
cast
(
x
=
label_out
,
dtype
=
'int64'
)
# calculate attention loss
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label_out
)
sum_cost
=
fluid
.
layers
.
reduce_sum
(
cost
)
return
sum_cost
ppocr/modeling/losses/rec_ctc_loss.py
浏览文件 @
6dd494b6
...
...
@@ -30,6 +30,7 @@ class CTCLoss(object):
def
__call__
(
self
,
predicts
,
labels
):
predict
=
predicts
[
'predict'
]
label
=
labels
[
'label'
]
# calculate ctc loss
cost
=
fluid
.
layers
.
warpctc
(
input
=
predict
,
label
=
label
,
blank
=
self
.
char_num
,
norm_by_times
=
True
)
sum_cost
=
fluid
.
layers
.
reduce_sum
(
cost
)
...
...
ppocr/utils/character.py
浏览文件 @
6dd494b6
...
...
@@ -20,14 +20,22 @@ import sys
class
CharacterOps
(
object
):
""" Convert between text-label and text-index """
"""
Convert between text-label and text-index
Args:
config: config from yaml file
"""
def
__init__
(
self
,
config
):
self
.
character_type
=
config
[
'character_type'
]
self
.
loss_type
=
config
[
'loss_type'
]
# use the default dictionary(36 char)
if
self
.
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
# use the custom dictionary
elif
self
.
character_type
==
"ch"
:
character_dict_path
=
config
[
'character_dict_path'
]
self
.
character_str
=
""
...
...
@@ -47,26 +55,29 @@ class CharacterOps(object):
"Nonsupport type of the character: {}"
.
format
(
self
.
character_str
)
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
# add start and end str for attention
if
self
.
loss_type
==
"attention"
:
dict_character
=
[
self
.
beg_str
,
self
.
end_str
]
+
dict_character
# create char dict
self
.
dict
=
{}
for
i
,
char
in
enumerate
(
dict_character
):
self
.
dict
[
char
]
=
i
self
.
character
=
dict_character
def
encode
(
self
,
text
):
"""convert text-label into text-index.
input:
"""
convert text-label into text-index.
Args:
text: text labels of each image. [batch_size]
output
:
Reture
:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
# Ignore capital
if
self
.
character_type
==
"en"
:
text
=
text
.
lower
()
text_list
=
[]
for
char
in
text
:
if
char
not
in
self
.
dict
:
...
...
@@ -76,7 +87,15 @@ class CharacterOps(object):
return
text
def
decode
(
self
,
text_index
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
"""
convert text-index into text-label.
Args:
text_index: text index for each image
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
text: text label
"""
char_list
=
[]
char_num
=
self
.
get_char_num
()
...
...
@@ -98,6 +117,9 @@ class CharacterOps(object):
return
text
def
get_char_num
(
self
):
"""
Get character num
"""
return
len
(
self
.
character
)
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
...
...
@@ -122,6 +144,19 @@ def cal_predicts_accuracy(char_ops,
labels
,
labels_lod
,
is_remove_duplicate
=
False
):
"""
Calculate predicts accrarcy
Args:
char_ops: CharacterOps
preds: preds result,text index
preds_lod:
labels:
labels_lod:
is_remove_duplicate:
Return:
"""
acc_num
=
0
img_num
=
0
for
ino
in
range
(
len
(
labels_lod
)
-
1
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录