Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
cb370419
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cb370419
编写于
7月 11, 2022
作者:
xuyang2233
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modified pr
上级
4a3b874a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
88 addition
and
68 deletion
+88
-68
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+1
-1
ppocr/modeling/heads/rec_spin_att_head.py
ppocr/modeling/heads/rec_spin_att_head.py
+5
-0
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+81
-66
tools/export_model.py
tools/export_model.py
+1
-1
未找到文件。
ppocr/data/imaug/label_ops.py
浏览文件 @
cb370419
...
@@ -1217,7 +1217,7 @@ class ABINetLabelEncode(BaseRecLabelEncode):
...
@@ -1217,7 +1217,7 @@ class ABINetLabelEncode(BaseRecLabelEncode):
dict_character
=
[
'</s>'
]
+
dict_character
dict_character
=
[
'</s>'
]
+
dict_character
return
dict_character
return
dict_character
class
SPINAttnLabelEncode
(
BaseRec
LabelEncode
):
class
SPINAttnLabelEncode
(
Attn
LabelEncode
):
""" Convert between text-label and text-index """
""" Convert between text-label and text-index """
def
__init__
(
self
,
def
__init__
(
self
,
...
...
ppocr/modeling/heads/rec_spin_att_head.py
浏览文件 @
cb370419
...
@@ -12,6 +12,11 @@
...
@@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/sequence_heads/att_head.py
"""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
cb370419
...
@@ -669,7 +669,86 @@ class ABINetLabelDecode(NRTRLabelDecode):
...
@@ -669,7 +669,86 @@ class ABINetLabelDecode(NRTRLabelDecode):
return
dict_character
return
dict_character
class
SPINAttnLabelDecode
(
BaseRecLabelDecode
):
# class SPINAttnLabelDecode(BaseRecLabelDecode):
# """ Convert between text-label and text-index """
# def __init__(self, character_dict_path=None, use_space_char=False,
# **kwargs):
# super(SPINAttnLabelDecode, self).__init__(character_dict_path,
# use_space_char)
# def add_special_char(self, dict_character):
# self.beg_str = "sos"
# self.end_str = "eos"
# dict_character = dict_character
# dict_character = [self.beg_str] + [self.end_str] + dict_character
# return dict_character
# def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
# """ convert text-index into text-label. """
# result_list = []
# ignored_tokens = self.get_ignored_tokens()
# [beg_idx, end_idx] = self.get_ignored_tokens()
# batch_size = len(text_index)
# for batch_idx in range(batch_size):
# char_list = []
# conf_list = []
# for idx in range(len(text_index[batch_idx])):
# if text_index[batch_idx][idx] == int(beg_idx):
# continue
# if int(text_index[batch_idx][idx]) == int(end_idx):
# break
# if is_remove_duplicate:
# # only for predict
# if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
# batch_idx][idx]:
# continue
# char_list.append(self.character[int(text_index[batch_idx][
# idx])])
# if text_prob is not None:
# conf_list.append(text_prob[batch_idx][idx])
# else:
# conf_list.append(1)
# text = ''.join(char_list)
# result_list.append((text.lower(), np.mean(conf_list).tolist()))
# return result_list
# def __call__(self, preds, label=None, *args, **kwargs):
# """
# text = self.decode(text)
# if label is None:
# return text
# else:
# label = self.decode(label, is_remove_duplicate=False)
# return text, label
# """
# if isinstance(preds, paddle.Tensor):
# preds = preds.numpy()
# preds_idx = preds.argmax(axis=2)
# preds_prob = preds.max(axis=2)
# text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
# if label is None:
# return text
# label = self.decode(label, is_remove_duplicate=False)
# return text, label
# def get_ignored_tokens(self):
# beg_idx = self.get_beg_end_flag_idx("beg")
# end_idx = self.get_beg_end_flag_idx("end")
# return [beg_idx, end_idx]
# def get_beg_end_flag_idx(self, beg_or_end):
# if beg_or_end == "beg":
# idx = np.array(self.dict[self.beg_str])
# elif beg_or_end == "end":
# idx = np.array(self.dict[self.end_str])
# else:
# assert False, "unsupport type %s in get_beg_end_flag_idx" \
# % beg_or_end
# return idx
class
SPINAttnLabelDecode
(
AttnLabelDecode
):
""" Convert between text-label and text-index """
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
...
@@ -682,68 +761,4 @@ class SPINAttnLabelDecode(BaseRecLabelDecode):
...
@@ -682,68 +761,4 @@ class SPINAttnLabelDecode(BaseRecLabelDecode):
self
.
end_str
=
"eos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
dict_character
=
dict_character
dict_character
=
[
self
.
beg_str
]
+
[
self
.
end_str
]
+
dict_character
dict_character
=
[
self
.
beg_str
]
+
[
self
.
end_str
]
+
dict_character
return
dict_character
return
dict_character
\ No newline at end of file
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
[
beg_idx
,
end_idx
]
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
==
int
(
beg_idx
):
continue
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
end_idx
):
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
.
lower
(),
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
\ No newline at end of file
tools/export_model.py
浏览文件 @
cb370419
...
@@ -91,7 +91,7 @@ def export_single_model(model,
...
@@ -91,7 +91,7 @@ def export_single_model(model,
]
]
# print([None, 3, 32, 128])
# print([None, 3, 32, 128])
model
=
to_static
(
model
,
input_spec
=
other_shape
)
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"NRTR"
or
arch_config
[
"algorithm"
]
==
"SPIN"
:
elif
arch_config
[
"algorithm"
]
in
[
"NRTR"
,
"SPIN"
]
:
other_shape
=
[
other_shape
=
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
1
,
32
,
100
],
dtype
=
"float32"
),
shape
=
[
None
,
1
,
32
,
100
],
dtype
=
"float32"
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录