Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
5161d825
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5161d825
编写于
6月 21, 2022
作者:
Z
zhiboniu
提交者:
zhiboniu
6月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete other rec algorithm
上级
653009ab
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
22 addition
and
928 deletion
+22
-928
deploy/pphuman/config/infer_cfg_ppvehicle.yml
deploy/pphuman/config/infer_cfg_ppvehicle.yml
+0
-1
deploy/pphuman/ppvehicle/vehicle_plate.py
deploy/pphuman/ppvehicle/vehicle_plate.py
+18
-261
deploy/pphuman/ppvehicle/vehicle_plateutils.py
deploy/pphuman/ppvehicle/vehicle_plateutils.py
+3
-8
deploy/pphuman/ppvehicle/vehicleplate_postprocess.py
deploy/pphuman/ppvehicle/vehicleplate_postprocess.py
+1
-658
未找到文件。
deploy/pphuman/config/infer_cfg_ppvehicle.yml
浏览文件 @
5161d825
...
@@ -19,7 +19,6 @@ VEHICLE_PLATE:
...
@@ -19,7 +19,6 @@ VEHICLE_PLATE:
det_model_dir
:
output_inference/ch_PP-OCRv3_det_infer/
det_model_dir
:
output_inference/ch_PP-OCRv3_det_infer/
det_limit_side_len
:
480
det_limit_side_len
:
480
det_limit_type
:
"
max"
det_limit_type
:
"
max"
rec_algorithm
:
"
SVTR_LCNet"
rec_model_dir
:
output_inference/ch_PP-OCRv3_rec_infer/
rec_model_dir
:
output_inference/ch_PP-OCRv3_rec_infer/
rec_image_shape
:
[
3
,
48
,
320
]
rec_image_shape
:
[
3
,
48
,
320
]
rec_batch_num
:
6
rec_batch_num
:
6
...
...
deploy/pphuman/ppvehicle/vehicle_plate.py
浏览文件 @
5161d825
...
@@ -151,7 +151,6 @@ class TextRecognizer(object):
...
@@ -151,7 +151,6 @@ class TextRecognizer(object):
def
__init__
(
self
,
args
,
cfg
,
use_gpu
=
True
):
def
__init__
(
self
,
args
,
cfg
,
use_gpu
=
True
):
self
.
rec_image_shape
=
cfg
[
'rec_image_shape'
]
self
.
rec_image_shape
=
cfg
[
'rec_image_shape'
]
self
.
rec_batch_num
=
cfg
[
'rec_batch_num'
]
self
.
rec_batch_num
=
cfg
[
'rec_batch_num'
]
self
.
rec_algorithm
=
cfg
[
'rec_algorithm'
]
word_dict_path
=
cfg
[
'word_dict_path'
]
word_dict_path
=
cfg
[
'word_dict_path'
]
use_space_char
=
True
use_space_char
=
True
...
@@ -160,30 +159,6 @@ class TextRecognizer(object):
...
@@ -160,30 +159,6 @@ class TextRecognizer(object):
"character_dict_path"
:
word_dict_path
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
use_space_char
"use_space_char"
:
use_space_char
}
}
if
self
.
rec_algorithm
==
"SRN"
:
postprocess_params
=
{
'name'
:
'SRNLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
use_space_char
}
elif
self
.
rec_algorithm
==
"RARE"
:
postprocess_params
=
{
'name'
:
'AttnLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
use_space_char
}
elif
self
.
rec_algorithm
==
'NRTR'
:
postprocess_params
=
{
'name'
:
'NRTRLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
use_space_char
}
elif
self
.
rec_algorithm
==
"SAR"
:
postprocess_params
=
{
'name'
:
'SARLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
use_space_char
}
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
create_predictor
(
args
,
cfg
,
'rec'
)
create_predictor
(
args
,
cfg
,
'rec'
)
...
@@ -191,15 +166,6 @@ class TextRecognizer(object):
...
@@ -191,15 +166,6 @@ class TextRecognizer(object):
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
if
self
.
rec_algorithm
==
'NRTR'
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
# return padding_im
image_pil
=
Image
.
fromarray
(
np
.
uint8
(
img
))
img
=
image_pil
.
resize
([
100
,
32
],
Image
.
ANTIALIAS
)
img
=
np
.
array
(
img
)
norm_img
=
np
.
expand_dims
(
img
,
-
1
)
norm_img
=
norm_img
.
transpose
((
2
,
0
,
1
))
return
norm_img
.
astype
(
np
.
float32
)
/
128.
-
1.
assert
imgC
==
img
.
shape
[
2
]
assert
imgC
==
img
.
shape
[
2
]
imgW
=
int
((
imgH
*
max_wh_ratio
))
imgW
=
int
((
imgH
*
max_wh_ratio
))
...
@@ -214,10 +180,6 @@ class TextRecognizer(object):
...
@@ -214,10 +180,6 @@ class TextRecognizer(object):
resized_w
=
imgW
resized_w
=
imgW
else
:
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
if
self
.
rec_algorithm
==
'RARE'
:
if
resized_w
>
self
.
rec_image_shape
[
2
]:
resized_w
=
self
.
rec_image_shape
[
2
]
imgW
=
self
.
rec_image_shape
[
2
]
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
...
@@ -227,124 +189,6 @@ class TextRecognizer(object):
...
@@ -227,124 +189,6 @@ class TextRecognizer(object):
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
return
padding_im
def
resize_norm_img_svtr
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
return
resized_image
def
resize_norm_img_srn
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
img_black
=
np
.
zeros
((
imgH
,
imgW
))
im_hei
=
img
.
shape
[
0
]
im_wid
=
img
.
shape
[
1
]
if
im_wid
<=
im_hei
*
1
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
1
,
imgH
))
elif
im_wid
<=
im_hei
*
2
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
2
,
imgH
))
elif
im_wid
<=
im_hei
*
3
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
3
,
imgH
))
else
:
img_new
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
))
img_np
=
np
.
asarray
(
img_new
)
img_np
=
cv2
.
cvtColor
(
img_np
,
cv2
.
COLOR_BGR2GRAY
)
img_black
[:,
0
:
img_np
.
shape
[
1
]]
=
img_np
img_black
=
img_black
[:,
:,
np
.
newaxis
]
row
,
col
,
c
=
img_black
.
shape
c
=
1
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
srn_other_inputs
(
self
,
image_shape
,
num_heads
,
max_text_length
):
imgC
,
imgH
,
imgW
=
image_shape
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
(
(
feature_dim
,
1
)).
astype
(
'int64'
)
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
(
(
max_text_length
,
1
)).
astype
(
'int64'
)
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias1
=
np
.
tile
(
gsrm_slf_attn_bias1
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
gsrm_slf_attn_bias2
=
np
.
tril
(
gsrm_attn_bias_data
,
-
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias2
=
np
.
tile
(
gsrm_slf_attn_bias2
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
encoder_word_pos
=
encoder_word_pos
[
np
.
newaxis
,
:]
gsrm_word_pos
=
gsrm_word_pos
[
np
.
newaxis
,
:]
return
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
def
process_image_srn
(
self
,
img
,
image_shape
,
num_heads
,
max_text_length
):
norm_img
=
self
.
resize_norm_img_srn
(
img
,
image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
=
\
self
.
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
)
gsrm_slf_attn_bias1
=
gsrm_slf_attn_bias1
.
astype
(
np
.
float32
)
gsrm_slf_attn_bias2
=
gsrm_slf_attn_bias2
.
astype
(
np
.
float32
)
encoder_word_pos
=
encoder_word_pos
.
astype
(
np
.
int64
)
gsrm_word_pos
=
gsrm_word_pos
.
astype
(
np
.
int64
)
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
def
resize_norm_img_sar
(
self
,
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
valid_ratio
=
1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor
=
int
(
1
/
width_downsample_ratio
)
# resize
ratio
=
w
/
float
(
h
)
resize_w
=
math
.
ceil
(
imgH
*
ratio
)
if
resize_w
%
width_divisor
!=
0
:
resize_w
=
round
(
resize_w
/
width_divisor
)
*
width_divisor
if
imgW_min
is
not
None
:
resize_w
=
max
(
imgW_min
,
resize_w
)
if
imgW_max
is
not
None
:
valid_ratio
=
min
(
1.0
,
1.0
*
resize_w
/
imgW_max
)
resize_w
=
min
(
imgW_max
,
resize_w
)
resized_image
=
cv2
.
resize
(
img
,
(
resize_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
# norm
if
image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
resized_image
=
resized_image
[
np
.
newaxis
,
:]
else
:
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
resize_shape
=
resized_image
.
shape
padding_im
=
-
1.0
*
np
.
ones
((
imgC
,
imgH
,
imgW_max
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resize_w
]
=
resized_image
pad_shape
=
padding_im
.
shape
return
padding_im
,
resize_shape
,
pad_shape
,
valid_ratio
def
predict_text
(
self
,
img_list
):
def
predict_text
(
self
,
img_list
):
img_num
=
len
(
img_list
)
img_num
=
len
(
img_list
)
# Calculate the aspect ratio of all text bars
# Calculate the aspect ratio of all text bars
...
@@ -367,103 +211,16 @@ class TextRecognizer(object):
...
@@ -367,103 +211,16 @@ class TextRecognizer(object):
wh_ratio
=
w
*
1.0
/
h
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
for
ino
in
range
(
beg_img_no
,
end_img_no
):
if
self
.
rec_algorithm
==
"SAR"
:
norm_img
,
_
,
_
,
valid_ratio
=
self
.
resize_norm_img_sar
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
valid_ratio
=
np
.
expand_dims
(
valid_ratio
,
axis
=
0
)
valid_ratios
=
[]
valid_ratios
.
append
(
valid_ratio
)
norm_img_batch
.
append
(
norm_img
)
elif
self
.
rec_algorithm
==
"SRN"
:
norm_img
=
self
.
process_image_srn
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
,
8
,
25
)
encoder_word_pos_list
=
[]
gsrm_word_pos_list
=
[]
gsrm_slf_attn_bias1_list
=
[]
gsrm_slf_attn_bias2_list
=
[]
encoder_word_pos_list
.
append
(
norm_img
[
1
])
gsrm_word_pos_list
.
append
(
norm_img
[
2
])
gsrm_slf_attn_bias1_list
.
append
(
norm_img
[
3
])
gsrm_slf_attn_bias2_list
.
append
(
norm_img
[
4
])
norm_img_batch
.
append
(
norm_img
[
0
])
elif
self
.
rec_algorithm
==
"SVTR"
:
norm_img
=
self
.
resize_norm_img_svtr
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
else
:
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
norm_img_batch
.
copy
()
norm_img_batch
=
norm_img_batch
.
copy
()
if
self
.
rec_algorithm
==
"SRN"
:
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
)
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
)
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
gsrm_slf_attn_bias1_list
)
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
)
inputs
=
[
norm_img_batch
,
encoder_word_pos_list
,
gsrm_word_pos_list
,
gsrm_slf_attn_bias1_list
,
gsrm_slf_attn_bias2_list
,
]
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
preds
=
{
"predict"
:
outputs
[
2
]}
else
:
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
{
"predict"
:
outputs
[
2
]}
elif
self
.
rec_algorithm
==
"SAR"
:
valid_ratios
=
np
.
concatenate
(
valid_ratios
)
inputs
=
[
norm_img_batch
,
valid_ratios
,
]
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
preds
=
outputs
[
0
]
else
:
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
outputs
[
0
]
else
:
if
self
.
use_onnx
:
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
input_dict
)
preds
=
outputs
[
0
]
preds
=
outputs
[
0
]
else
:
else
:
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
...
...
deploy/pphuman/ppvehicle/vehicle_plateutils.py
浏览文件 @
5161d825
...
@@ -185,7 +185,6 @@ def create_predictor(args, cfg, mode):
...
@@ -185,7 +185,6 @@ def create_predictor(args, cfg, mode):
def
get_output_tensors
(
cfg
,
mode
,
predictor
):
def
get_output_tensors
(
cfg
,
mode
,
predictor
):
output_names
=
predictor
.
get_output_names
()
output_names
=
predictor
.
get_output_names
()
output_tensors
=
[]
output_tensors
=
[]
if
mode
==
"rec"
and
cfg
[
'rec_algorithm'
]
in
[
"CRNN"
,
"SVTR_LCNet"
]:
output_name
=
'softmax_0.tmp_0'
output_name
=
'softmax_0.tmp_0'
if
output_name
in
output_names
:
if
output_name
in
output_names
:
return
[
predictor
.
get_output_handle
(
output_name
)]
return
[
predictor
.
get_output_handle
(
output_name
)]
...
@@ -193,10 +192,6 @@ def get_output_tensors(cfg, mode, predictor):
...
@@ -193,10 +192,6 @@ def get_output_tensors(cfg, mode, predictor):
for
output_name
in
output_names
:
for
output_name
in
output_names
:
output_tensor
=
predictor
.
get_output_handle
(
output_name
)
output_tensor
=
predictor
.
get_output_handle
(
output_name
)
output_tensors
.
append
(
output_tensor
)
output_tensors
.
append
(
output_tensor
)
else
:
for
output_name
in
output_names
:
output_tensor
=
predictor
.
get_output_handle
(
output_name
)
output_tensors
.
append
(
output_tensor
)
return
output_tensors
return
output_tensors
...
...
deploy/pphuman/ppvehicle/vehicleplate_postprocess.py
浏览文件 @
5161d825
...
@@ -23,16 +23,7 @@ import copy
...
@@ -23,16 +23,7 @@ import copy
def
build_post_process
(
config
,
global_config
=
None
):
def
build_post_process
(
config
,
global_config
=
None
):
support_dict
=
[
support_dict
=
[
'DBPostProcess'
,
'CTCLabelDecode'
]
'DBPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'SRNLabelDecode'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
,
'PRENLabelDecode'
,
'DistillationSARLabelDecode'
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
from
.pse_postprocess
import
PSEPostProcess
support_dict
.
append
(
'PSEPostProcess'
)
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
...
@@ -298,651 +289,3 @@ class CTCLabelDecode(BaseRecLabelDecode):
...
@@ -298,651 +289,3 @@ class CTCLabelDecode(BaseRecLabelDecode):
def
add_special_char
(
self
,
dict_character
):
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
]
+
dict_character
dict_character
=
[
'blank'
]
+
dict_character
return
dict_character
return
dict_character
class
DistillationCTCLabelDecode
(
CTCLabelDecode
):
"""
Convert
Convert between text-label and text-index
"""
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
model_name
=
[
"student"
],
key
=
None
,
multi_head
=
False
,
**
kwargs
):
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
self
.
multi_head
=
multi_head
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
output
=
dict
()
for
name
in
self
.
model_name
:
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
if
self
.
multi_head
and
isinstance
(
pred
,
dict
):
pred
=
pred
[
'ctc'
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
class
NRTRLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
True
,
**
kwargs
):
super
(
NRTRLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
len
(
preds
)
==
2
:
preds_id
=
preds
[
0
]
preds_prob
=
preds
[
1
]
if
isinstance
(
preds_id
,
paddle
.
Tensor
):
preds_id
=
preds_id
.
numpy
()
if
isinstance
(
preds_prob
,
paddle
.
Tensor
):
preds_prob
=
preds_prob
.
numpy
()
if
preds_id
[
0
][
0
]
==
2
:
preds_idx
=
preds_id
[:,
1
:]
preds_prob
=
preds_prob
[:,
1
:]
else
:
preds_idx
=
preds_id
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
[:,
1
:])
else
:
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
[:,
1
:])
return
text
,
label
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
,
'<unk>'
,
'<s>'
,
'</s>'
]
+
dict_character
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
==
3
:
# end
break
try
:
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
except
:
continue
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
.
lower
(),
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
class
AttnLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
AttnLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
[
beg_idx
,
end_idx
]
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
end_idx
):
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
class
SEEDLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SEEDLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
padding_str
=
"padding"
self
.
end_str
=
"eos"
self
.
unknown
=
"unknown"
dict_character
=
dict_character
+
[
self
.
end_str
,
self
.
padding_str
,
self
.
unknown
]
return
dict_character
def
get_ignored_tokens
(
self
):
end_idx
=
self
.
get_beg_end_flag_idx
(
"eos"
)
return
[
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"sos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"eos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
%
beg_or_end
return
idx
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
[
end_idx
]
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
end_idx
):
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds_idx
=
preds
[
"rec_pred"
]
if
isinstance
(
preds_idx
,
paddle
.
Tensor
):
preds_idx
=
preds_idx
.
numpy
()
if
"rec_pred_scores"
in
preds
:
preds_idx
=
preds
[
"rec_pred"
]
preds_prob
=
preds
[
"rec_pred_scores"
]
else
:
preds_idx
=
preds
[
"rec_pred"
].
argmax
(
axis
=
2
)
preds_prob
=
preds
[
"rec_pred"
].
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
class
SRNLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SRNLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
self
.
max_text_length
=
kwargs
.
get
(
'max_text_length'
,
25
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
pred
=
preds
[
'predict'
]
char_num
=
len
(
self
.
character_str
)
+
2
if
isinstance
(
pred
,
paddle
.
Tensor
):
pred
=
pred
.
numpy
()
pred
=
np
.
reshape
(
pred
,
[
-
1
,
char_num
])
preds_idx
=
np
.
argmax
(
pred
,
axis
=
1
)
preds_prob
=
np
.
max
(
pred
,
axis
=
1
)
preds_idx
=
np
.
reshape
(
preds_idx
,
[
-
1
,
self
.
max_text_length
])
preds_prob
=
np
.
reshape
(
preds_prob
,
[
-
1
,
self
.
max_text_length
])
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
if
label
is
None
:
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
return
text
label
=
self
.
decode
(
label
)
return
text
,
label
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
add_special_char
(
self
,
dict_character
):
dict_character
=
dict_character
+
[
self
.
beg_str
,
self
.
end_str
]
return
dict_character
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
class
TableLabelDecode
(
object
):
""" """
def
__init__
(
self
,
character_dict_path
,
**
kwargs
):
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
list_character
=
self
.
add_special_char
(
list_character
)
list_elem
=
self
.
add_special_char
(
list_elem
)
self
.
dict_character
=
{}
self
.
dict_idx_character
=
{}
for
i
,
char
in
enumerate
(
list_character
):
self
.
dict_idx_character
[
i
]
=
char
self
.
dict_character
[
char
]
=
i
self
.
dict_elem
=
{}
self
.
dict_idx_elem
=
{}
for
i
,
elem
in
enumerate
(
list_elem
):
self
.
dict_idx_elem
[
i
]
=
elem
self
.
dict_elem
[
elem
]
=
i
def
load_char_elem_dict
(
self
,
character_dict_path
):
list_character
=
[]
list_elem
=
[]
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
).
split
(
"
\t
"
)
character_num
=
int
(
substr
[
0
])
elem_num
=
int
(
substr
[
1
])
for
cno
in
range
(
1
,
1
+
character_num
):
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
list_character
.
append
(
character
)
for
eno
in
range
(
1
+
character_num
,
1
+
character_num
+
elem_num
):
elem
=
lines
[
eno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
list_elem
.
append
(
elem
)
return
list_character
,
list_elem
def
add_special_char
(
self
,
list_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
def
__call__
(
self
,
preds
):
structure_probs
=
preds
[
'structure_probs'
]
loc_preds
=
preds
[
'loc_preds'
]
if
isinstance
(
structure_probs
,
paddle
.
Tensor
):
structure_probs
=
structure_probs
.
numpy
()
if
isinstance
(
loc_preds
,
paddle
.
Tensor
):
loc_preds
=
loc_preds
.
numpy
()
structure_idx
=
structure_probs
.
argmax
(
axis
=
2
)
structure_probs
=
structure_probs
.
max
(
axis
=
2
)
structure_str
,
structure_pos
,
result_score_list
,
result_elem_idx_list
=
self
.
decode
(
structure_idx
,
structure_probs
,
'elem'
)
res_html_code_list
=
[]
res_loc_list
=
[]
batch_num
=
len
(
structure_str
)
for
bno
in
range
(
batch_num
):
res_loc
=
[]
for
sno
in
range
(
len
(
structure_str
[
bno
])):
text
=
structure_str
[
bno
][
sno
]
if
text
in
[
'<td>'
,
'<td'
]:
pos
=
structure_pos
[
bno
][
sno
]
res_loc
.
append
(
loc_preds
[
bno
,
pos
])
res_html_code
=
''
.
join
(
structure_str
[
bno
])
res_loc
=
np
.
array
(
res_loc
)
res_html_code_list
.
append
(
res_html_code
)
res_loc_list
.
append
(
res_loc
)
return
{
'res_html_code'
:
res_html_code_list
,
'res_loc'
:
res_loc_list
,
'res_score_list'
:
result_score_list
,
'res_elem_idx_list'
:
result_elem_idx_list
,
'structure_str_list'
:
structure_str
}
def
decode
(
self
,
text_index
,
structure_probs
,
char_or_elem
):
"""convert text-label into text-index.
"""
if
char_or_elem
==
"char"
:
current_dict
=
self
.
dict_idx_character
else
:
current_dict
=
self
.
dict_idx_elem
ignored_tokens
=
self
.
get_ignored_tokens
(
'elem'
)
beg_idx
,
end_idx
=
ignored_tokens
result_list
=
[]
result_pos_list
=
[]
result_score_list
=
[]
result_elem_idx_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
elem_pos_list
=
[]
elem_idx_list
=
[]
score_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
tmp_elem_idx
=
int
(
text_index
[
batch_idx
][
idx
])
if
idx
>
0
and
tmp_elem_idx
==
end_idx
:
break
if
tmp_elem_idx
in
ignored_tokens
:
continue
char_list
.
append
(
current_dict
[
tmp_elem_idx
])
elem_pos_list
.
append
(
idx
)
score_list
.
append
(
structure_probs
[
batch_idx
,
idx
])
elem_idx_list
.
append
(
tmp_elem_idx
)
result_list
.
append
(
char_list
)
result_pos_list
.
append
(
elem_pos_list
)
result_score_list
.
append
(
score_list
)
result_elem_idx_list
.
append
(
elem_idx_list
)
return
result_list
,
result_pos_list
,
result_score_list
,
result_elem_idx_list
def
get_ignored_tokens
(
self
,
char_or_elem
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
,
char_or_elem
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
,
char_or_elem
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
,
char_or_elem
):
if
char_or_elem
==
"char"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_character
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_character
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of char"
\
%
beg_or_end
elif
char_or_elem
==
"elem"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_elem
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_elem
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of elem"
\
%
beg_or_end
else
:
assert
False
,
"Unsupport type %s in char_or_elem"
\
%
char_or_elem
return
idx
class
SARLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SARLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
self
.
rm_symbol
=
kwargs
.
get
(
'rm_symbol'
,
False
)
def
add_special_char
(
self
,
dict_character
):
beg_end_str
=
"<BOS/EOS>"
unknown_str
=
"<UKN>"
padding_str
=
"<PAD>"
dict_character
=
dict_character
+
[
unknown_str
]
self
.
unknown_idx
=
len
(
dict_character
)
-
1
dict_character
=
dict_character
+
[
beg_end_str
]
self
.
start_idx
=
len
(
dict_character
)
-
1
self
.
end_idx
=
len
(
dict_character
)
-
1
dict_character
=
dict_character
+
[
padding_str
]
self
.
padding_idx
=
len
(
dict_character
)
-
1
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
self
.
end_idx
):
if
text_prob
is
None
and
idx
==
0
:
continue
else
:
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
if
self
.
rm_symbol
:
comp
=
re
.
compile
(
'[^A-Z^a-z^0-9^
\u4e00
-
\u9fa5
]'
)
text
=
text
.
lower
()
text
=
comp
.
sub
(
''
,
text
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
def
get_ignored_tokens
(
self
):
return
[
self
.
padding_idx
]
class
DistillationSARLabelDecode
(
SARLabelDecode
):
"""
Convert
Convert between text-label and text-index
"""
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
model_name
=
[
"student"
],
key
=
None
,
multi_head
=
False
,
**
kwargs
):
super
(
DistillationSARLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
self
.
multi_head
=
multi_head
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
output
=
dict
()
for
name
in
self
.
model_name
:
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
if
self
.
multi_head
and
isinstance
(
pred
,
dict
):
pred
=
pred
[
'sar'
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
class
PRENLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
PRENLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
padding_str
=
'<PAD>'
# 0
end_str
=
'<EOS>'
# 1
unknown_str
=
'<UNK>'
# 2
dict_character
=
[
padding_str
,
end_str
,
unknown_str
]
+
dict_character
self
.
padding_idx
=
0
self
.
end_idx
=
1
self
.
unknown_idx
=
2
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
):
""" convert text-index into text-label. """
result_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
==
self
.
end_idx
:
break
if
text_index
[
batch_idx
][
idx
]
in
\
[
self
.
padding_idx
,
self
.
unknown_idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
if
len
(
text
)
>
0
:
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
else
:
# here confidence of empty recog result is 1
result_list
.
append
((
''
,
1
))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
)
return
text
,
label
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录