Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
3f5ff9e6
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3f5ff9e6
编写于
6月 14, 2022
作者:
Z
zhiboniu
提交者:
zhiboniu
6月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ppvehicle plate
上级
333370d9
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
2323 addition
and
0 deletion
+2323
-0
deploy/python/preprocess.py
deploy/python/preprocess.py
+79
-0
deploy/python/utils.py
deploy/python/utils.py
+8
-0
deploy/python/vechile_plate.py
deploy/python/vechile_plate.py
+621
-0
deploy/python/vechile_plateutils.py
deploy/python/vechile_plateutils.py
+667
-0
deploy/python/vecplatepostprocess.py
deploy/python/vecplatepostprocess.py
+948
-0
未找到文件。
deploy/python/preprocess.py
浏览文件 @
3f5ff9e6
...
...
@@ -40,6 +40,85 @@ def decode_image(im_file, im_info):
return
im
,
im_info
class
Resize_Mult32
(
object
):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def
__init__
(
self
,
limit_side_len
,
limit_type
,
interp
=
cv2
.
INTER_LINEAR
):
self
.
limit_side_len
=
limit_side_len
self
.
limit_type
=
limit_type
self
.
interp
=
interp
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im_channel
=
im
.
shape
[
2
]
im_scale_y
,
im_scale_x
=
self
.
generate_scale
(
im
)
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
im_info
[
'im_shape'
]
=
np
.
array
(
im
.
shape
[:
2
]).
astype
(
'float32'
)
im_info
[
'scale_factor'
]
=
np
.
array
(
[
im_scale_y
,
im_scale_x
]).
astype
(
'float32'
)
return
im
,
im_info
def
generate_scale
(
self
,
img
):
"""
Args:
img (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
limit_side_len
=
self
.
limit_side_len
h
,
w
,
c
=
img
.
shape
# limit the max side
if
self
.
limit_type
==
'max'
:
if
max
(
h
,
w
)
>
limit_side_len
:
if
h
>
w
:
ratio
=
float
(
limit_side_len
)
/
h
else
:
ratio
=
float
(
limit_side_len
)
/
w
else
:
ratio
=
1.
elif
self
.
limit_type
==
'min'
:
if
min
(
h
,
w
)
<
limit_side_len
:
if
h
<
w
:
ratio
=
float
(
limit_side_len
)
/
h
else
:
ratio
=
float
(
limit_side_len
)
/
w
else
:
ratio
=
1.
elif
self
.
limit_type
==
'resize_long'
:
ratio
=
float
(
limit_side_len
)
/
max
(
h
,
w
)
else
:
raise
Exception
(
'not support limit type, image '
)
resize_h
=
int
(
h
*
ratio
)
resize_w
=
int
(
w
*
ratio
)
resize_h
=
max
(
int
(
round
(
resize_h
/
32
)
*
32
),
32
)
resize_w
=
max
(
int
(
round
(
resize_w
/
32
)
*
32
),
32
)
im_scale_y
=
resize_h
/
float
(
h
)
im_scale_x
=
resize_w
/
float
(
w
)
return
im_scale_y
,
im_scale_x
class
Resize
(
object
):
"""resize image by target_size and max_size
Args:
...
...
deploy/python/utils.py
浏览文件 @
3f5ff9e6
...
...
@@ -27,6 +27,14 @@ def argsparser():
help
=
(
"Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."
),
required
=
True
)
parser
.
add_argument
(
"--det_algorithm"
,
type
=
str
,
default
=
'DB'
)
parser
.
add_argument
(
"--det_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--det_limit_side_len"
,
type
=
float
,
default
=
960
)
parser
.
add_argument
(
"--det_limit_type"
,
type
=
str
,
default
=
'max'
)
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'SVTR_LCNet'
)
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 48, 320"
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--image_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of image file."
)
parser
.
add_argument
(
...
...
deploy/python/vechile_plate.py
0 → 100644
浏览文件 @
3f5ff9e6
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
yaml
import
glob
from
functools
import
reduce
import
time
import
cv2
import
numpy
as
np
import
math
import
paddle
import
sys
# add deploy path of PadleDetection to sys.path
parent_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
*
([
'..'
])))
sys
.
path
.
insert
(
0
,
parent_path
)
from
utils
import
Timer
,
get_current_memory_mb
from
infer
import
Detector
,
get_test_images
,
print_arguments
,
create_inputs
from
vechile_plateutils
import
create_predictor
,
get_infer_gpuid
,
argsparser
,
get_rotate_crop_image
from
vecplatepostprocess
import
build_post_process
from
preprocess
import
preprocess
,
Resize
,
NormalizeImage
,
Permute
,
PadStride
,
LetterBoxResize
,
WarpAffine
,
Pad
,
decode_image
,
Resize_Mult32
class
PlateDetector
(
object
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
det_algorithm
=
args
.
det_algorithm
self
.
pre_process_list
=
{
'Resize_Mult32'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
,
},
'NormalizeImage'
:
{
'mean'
:
[
0.485
,
0.456
,
0.406
],
'std'
:
[
0.229
,
0.224
,
0.225
],
'is_scale'
:
True
,
},
'Permute'
:
{}
}
postprocess_params
=
{}
postprocess_params
[
'name'
]
=
'DBPostProcess'
postprocess_params
[
"thresh"
]
=
0.3
postprocess_params
[
"box_thresh"
]
=
0.6
postprocess_params
[
"max_candidates"
]
=
1000
postprocess_params
[
"unclip_ratio"
]
=
1.5
postprocess_params
[
"use_dilation"
]
=
False
postprocess_params
[
"score_mode"
]
=
"fast"
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
create_predictor
(
args
,
'det'
)
if
args
.
run_benchmark
:
import
auto_log
pid
=
os
.
getpid
()
gpu_id
=
get_infer_gpuid
()
self
.
autolog
=
auto_log
.
AutoLogger
(
model_name
=
"det"
,
model_precision
=
"fp32"
,
batch_size
=
1
,
data_shape
=
"dynamic"
,
save_path
=
None
,
inference_config
=
self
.
config
,
pids
=
pid
,
process_name
=
None
,
gpu_ids
=
gpu_id
if
args
.
device
==
"GPU"
else
None
,
time_keys
=
[
'preprocess_time'
,
'inference_time'
,
'postprocess_time'
],
warmup
=
2
,
)
def
preprocess
(
self
,
image_list
):
preprocess_ops
=
[]
for
op_type
,
new_op_info
in
self
.
pre_process_list
.
items
():
preprocess_ops
.
append
(
eval
(
op_type
)(
**
new_op_info
))
input_im_lst
=
[]
input_im_info_lst
=
[]
for
im_path
in
image_list
:
im
,
im_info
=
preprocess
(
im_path
,
preprocess_ops
)
input_im_lst
.
append
(
im
)
input_im_info_lst
.
append
(
im_info
[
'im_shape'
])
return
np
.
stack
(
input_im_lst
,
axis
=
0
),
input_im_info_lst
def
order_points_clockwise
(
self
,
pts
):
rect
=
np
.
zeros
((
4
,
2
),
dtype
=
"float32"
)
s
=
pts
.
sum
(
axis
=
1
)
rect
[
0
]
=
pts
[
np
.
argmin
(
s
)]
rect
[
2
]
=
pts
[
np
.
argmax
(
s
)]
diff
=
np
.
diff
(
pts
,
axis
=
1
)
rect
[
1
]
=
pts
[
np
.
argmin
(
diff
)]
rect
[
3
]
=
pts
[
np
.
argmax
(
diff
)]
return
rect
def
clip_det_res
(
self
,
points
,
img_height
,
img_width
):
for
pno
in
range
(
points
.
shape
[
0
]):
points
[
pno
,
0
]
=
int
(
min
(
max
(
points
[
pno
,
0
],
0
),
img_width
-
1
))
points
[
pno
,
1
]
=
int
(
min
(
max
(
points
[
pno
,
1
],
0
),
img_height
-
1
))
return
points
def
filter_tag_det_res
(
self
,
dt_boxes
,
image_shape
):
img_height
,
img_width
=
image_shape
[
0
:
2
]
dt_boxes_new
=
[]
for
box
in
dt_boxes
:
box
=
self
.
order_points_clockwise
(
box
)
box
=
self
.
clip_det_res
(
box
,
img_height
,
img_width
)
rect_width
=
int
(
np
.
linalg
.
norm
(
box
[
0
]
-
box
[
1
]))
rect_height
=
int
(
np
.
linalg
.
norm
(
box
[
0
]
-
box
[
3
]))
if
rect_width
<=
3
or
rect_height
<=
3
:
continue
dt_boxes_new
.
append
(
box
)
dt_boxes
=
np
.
array
(
dt_boxes_new
)
return
dt_boxes
def
filter_tag_det_res_only_clip
(
self
,
dt_boxes
,
image_shape
):
img_height
,
img_width
=
image_shape
[
0
:
2
]
dt_boxes_new
=
[]
for
box
in
dt_boxes
:
box
=
self
.
clip_det_res
(
box
,
img_height
,
img_width
)
dt_boxes_new
.
append
(
box
)
dt_boxes
=
np
.
array
(
dt_boxes_new
)
return
dt_boxes
def
predict_image
(
self
,
img
):
st
=
time
.
time
()
if
self
.
args
.
run_benchmark
:
self
.
autolog
.
times
.
start
()
img
,
shape_list
=
self
.
preprocess
(
img
)
if
img
is
None
:
return
None
,
0
# img = np.expand_dims(img, axis=0)
# shape_list = np.expand_dims(shape_list, axis=0)
# img = img.copy()
if
self
.
args
.
run_benchmark
:
self
.
autolog
.
times
.
stamp
()
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
args
.
run_benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
{}
preds
[
'maps'
]
=
outputs
[
0
]
#self.predictor.try_shrink_memory()
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
shape_list
[
0
])
if
self
.
args
.
run_benchmark
:
self
.
autolog
.
times
.
end
(
stamp
=
True
)
et
=
time
.
time
()
return
dt_boxes
,
et
-
st
class
TextRecognizer
(
object
):
def
__init__
(
self
,
FLAGS
,
input_shape
=
[
3
,
48
,
320
],
batch_size
=
8
,
rec_algorithm
=
"SVTR"
,
word_dict_path
=
"rec_word_dict.txt"
,
use_gpu
=
True
,
benchmark
=
False
):
self
.
rec_image_shape
=
input_shape
self
.
rec_batch_num
=
batch_size
self
.
rec_algorithm
=
rec_algorithm
isuse_space_char
=
True
postprocess_params
=
{
'name'
:
'CTCLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
isuse_space_char
}
if
self
.
rec_algorithm
==
"SRN"
:
postprocess_params
=
{
'name'
:
'SRNLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
isuse_space_char
}
elif
self
.
rec_algorithm
==
"RARE"
:
postprocess_params
=
{
'name'
:
'AttnLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
isuse_space_char
}
elif
self
.
rec_algorithm
==
'NRTR'
:
postprocess_params
=
{
'name'
:
'NRTRLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
isuse_space_char
}
elif
self
.
rec_algorithm
==
"SAR"
:
postprocess_params
=
{
'name'
:
'SARLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
isuse_space_char
}
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
create_predictor
(
FLAGS
,
'rec'
)
self
.
benchmark
=
benchmark
self
.
use_onnx
=
False
if
benchmark
:
import
auto_log
pid
=
os
.
getpid
()
gpu_id
=
get_infer_gpuid
()
self
.
autolog
=
auto_log
.
AutoLogger
(
model_name
=
"rec"
,
model_precision
=
'fp32'
,
batch_size
=
batch_size
,
data_shape
=
"dynamic"
,
save_path
=
None
,
#save_log_path,
inference_config
=
self
.
config
,
pids
=
pid
,
process_name
=
None
,
gpu_ids
=
gpu_id
if
use_gpu
else
None
,
time_keys
=
[
'preprocess_time'
,
'inference_time'
,
'postprocess_time'
],
warmup
=
0
)
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
if
self
.
rec_algorithm
==
'NRTR'
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
# return padding_im
image_pil
=
Image
.
fromarray
(
np
.
uint8
(
img
))
img
=
image_pil
.
resize
([
100
,
32
],
Image
.
ANTIALIAS
)
img
=
np
.
array
(
img
)
norm_img
=
np
.
expand_dims
(
img
,
-
1
)
norm_img
=
norm_img
.
transpose
((
2
,
0
,
1
))
return
norm_img
.
astype
(
np
.
float32
)
/
128.
-
1.
assert
imgC
==
img
.
shape
[
2
]
imgW
=
int
((
imgH
*
max_wh_ratio
))
if
self
.
use_onnx
:
w
=
self
.
input_tensor
.
shape
[
3
:][
0
]
if
w
is
not
None
and
w
>
0
:
imgW
=
w
h
,
w
=
img
.
shape
[:
2
]
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
if
self
.
rec_algorithm
==
'RARE'
:
if
resized_w
>
self
.
rec_image_shape
[
2
]:
resized_w
=
self
.
rec_image_shape
[
2
]
imgW
=
self
.
rec_image_shape
[
2
]
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
padding_im
=
np
.
zeros
((
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
def
resize_norm_img_svtr
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
return
resized_image
def
resize_norm_img_srn
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
img_black
=
np
.
zeros
((
imgH
,
imgW
))
im_hei
=
img
.
shape
[
0
]
im_wid
=
img
.
shape
[
1
]
if
im_wid
<=
im_hei
*
1
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
1
,
imgH
))
elif
im_wid
<=
im_hei
*
2
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
2
,
imgH
))
elif
im_wid
<=
im_hei
*
3
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
3
,
imgH
))
else
:
img_new
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
))
img_np
=
np
.
asarray
(
img_new
)
img_np
=
cv2
.
cvtColor
(
img_np
,
cv2
.
COLOR_BGR2GRAY
)
img_black
[:,
0
:
img_np
.
shape
[
1
]]
=
img_np
img_black
=
img_black
[:,
:,
np
.
newaxis
]
row
,
col
,
c
=
img_black
.
shape
c
=
1
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
srn_other_inputs
(
self
,
image_shape
,
num_heads
,
max_text_length
):
imgC
,
imgH
,
imgW
=
image_shape
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
(
(
feature_dim
,
1
)).
astype
(
'int64'
)
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
(
(
max_text_length
,
1
)).
astype
(
'int64'
)
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias1
=
np
.
tile
(
gsrm_slf_attn_bias1
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
gsrm_slf_attn_bias2
=
np
.
tril
(
gsrm_attn_bias_data
,
-
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias2
=
np
.
tile
(
gsrm_slf_attn_bias2
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
encoder_word_pos
=
encoder_word_pos
[
np
.
newaxis
,
:]
gsrm_word_pos
=
gsrm_word_pos
[
np
.
newaxis
,
:]
return
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
def
process_image_srn
(
self
,
img
,
image_shape
,
num_heads
,
max_text_length
):
norm_img
=
self
.
resize_norm_img_srn
(
img
,
image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
=
\
self
.
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
)
gsrm_slf_attn_bias1
=
gsrm_slf_attn_bias1
.
astype
(
np
.
float32
)
gsrm_slf_attn_bias2
=
gsrm_slf_attn_bias2
.
astype
(
np
.
float32
)
encoder_word_pos
=
encoder_word_pos
.
astype
(
np
.
int64
)
gsrm_word_pos
=
gsrm_word_pos
.
astype
(
np
.
int64
)
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
def
resize_norm_img_sar
(
self
,
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
valid_ratio
=
1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor
=
int
(
1
/
width_downsample_ratio
)
# resize
ratio
=
w
/
float
(
h
)
resize_w
=
math
.
ceil
(
imgH
*
ratio
)
if
resize_w
%
width_divisor
!=
0
:
resize_w
=
round
(
resize_w
/
width_divisor
)
*
width_divisor
if
imgW_min
is
not
None
:
resize_w
=
max
(
imgW_min
,
resize_w
)
if
imgW_max
is
not
None
:
valid_ratio
=
min
(
1.0
,
1.0
*
resize_w
/
imgW_max
)
resize_w
=
min
(
imgW_max
,
resize_w
)
resized_image
=
cv2
.
resize
(
img
,
(
resize_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
# norm
if
image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
resized_image
=
resized_image
[
np
.
newaxis
,
:]
else
:
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
resize_shape
=
resized_image
.
shape
padding_im
=
-
1.0
*
np
.
ones
((
imgC
,
imgH
,
imgW_max
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resize_w
]
=
resized_image
pad_shape
=
padding_im
.
shape
return
padding_im
,
resize_shape
,
pad_shape
,
valid_ratio
def
__call__
(
self
,
img_list
):
img_num
=
len
(
img_list
)
# Calculate the aspect ratio of all text bars
width_list
=
[]
for
img
in
img_list
:
width_list
.
append
(
img
.
shape
[
1
]
/
float
(
img
.
shape
[
0
]))
# Sorting can speed up the recognition process
indices
=
np
.
argsort
(
np
.
array
(
width_list
))
rec_res
=
[[
''
,
0.0
]]
*
img_num
batch_num
=
self
.
rec_batch_num
st
=
time
.
time
()
if
self
.
benchmark
:
self
.
autolog
.
times
.
start
()
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
norm_img_batch
=
[]
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
max_wh_ratio
=
imgW
/
imgH
# max_wh_ratio = 0
for
ino
in
range
(
beg_img_no
,
end_img_no
):
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
if
self
.
rec_algorithm
==
"SAR"
:
norm_img
,
_
,
_
,
valid_ratio
=
self
.
resize_norm_img_sar
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
valid_ratio
=
np
.
expand_dims
(
valid_ratio
,
axis
=
0
)
valid_ratios
=
[]
valid_ratios
.
append
(
valid_ratio
)
norm_img_batch
.
append
(
norm_img
)
elif
self
.
rec_algorithm
==
"SRN"
:
norm_img
=
self
.
process_image_srn
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
,
8
,
25
)
encoder_word_pos_list
=
[]
gsrm_word_pos_list
=
[]
gsrm_slf_attn_bias1_list
=
[]
gsrm_slf_attn_bias2_list
=
[]
encoder_word_pos_list
.
append
(
norm_img
[
1
])
gsrm_word_pos_list
.
append
(
norm_img
[
2
])
gsrm_slf_attn_bias1_list
.
append
(
norm_img
[
3
])
gsrm_slf_attn_bias2_list
.
append
(
norm_img
[
4
])
norm_img_batch
.
append
(
norm_img
[
0
])
elif
self
.
rec_algorithm
==
"SVTR"
:
norm_img
=
self
.
resize_norm_img_svtr
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
else
:
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
norm_img_batch
.
copy
()
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
if
self
.
rec_algorithm
==
"SRN"
:
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
)
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
)
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
gsrm_slf_attn_bias1_list
)
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
)
inputs
=
[
norm_img_batch
,
encoder_word_pos_list
,
gsrm_word_pos_list
,
gsrm_slf_attn_bias1_list
,
gsrm_slf_attn_bias2_list
,
]
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
preds
=
{
"predict"
:
outputs
[
2
]}
else
:
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
{
"predict"
:
outputs
[
2
]}
elif
self
.
rec_algorithm
==
"SAR"
:
valid_ratios
=
np
.
concatenate
(
valid_ratios
)
inputs
=
[
norm_img_batch
,
valid_ratios
,
]
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
preds
=
outputs
[
0
]
else
:
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
outputs
[
0
]
else
:
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
preds
=
outputs
[
0
]
else
:
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
if
len
(
outputs
)
!=
1
:
preds
=
outputs
else
:
preds
=
outputs
[
0
]
rec_result
=
self
.
postprocess_op
(
preds
)
for
rno
in
range
(
len
(
rec_result
)):
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
rec_result
[
rno
]
if
self
.
benchmark
:
self
.
autolog
.
times
.
end
(
stamp
=
True
)
return
rec_res
,
time
.
time
()
-
st
class
PlateRecognizer
(
object
):
def
__init__
(
self
):
self
.
batch_size
=
8
self
.
platedetector
=
PlateDetector
(
FLAGS
)
self
.
textrecognizer
=
TextRecognizer
(
FLAGS
,
input_shape
=
[
3
,
48
,
320
],
batch_size
=
8
,
rec_algorithm
=
"SVTR"
,
word_dict_path
=
"rec_word_dict.txt"
,
use_gpu
=
True
,
benchmark
=
False
)
def
get_platelicense
(
self
,
image_list
):
plate_text_list
=
[]
plateboxes
,
det_time
=
self
.
platedetector
.
predict_image
(
image_list
)
for
idx
,
boxes_pcar
in
enumerate
(
plateboxes
):
plate_images
=
get_rotate_crop_image
(
image_list
[
idx
],
boxes_pcar
)
print
(
plate_images
.
shape
)
plate_texts
=
self
.
textrecognizer
(
plate_images
)
plate_text_list
.
append
(
plate_texts
)
import
pdb
pdb
.
set_trace
()
return
results
def
main
():
detector
=
PlateRecognizer
()
# predict from image
if
FLAGS
.
image_dir
is
None
and
FLAGS
.
image_file
is
not
None
:
assert
FLAGS
.
batch_size
==
1
,
"batch_size should be 1, when image_file is not None"
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
for
img
in
img_list
:
image
=
cv2
.
imread
(
img
)
results
=
detector
.
get_platelicense
([
image
])
if
not
FLAGS
.
run_benchmark
:
detector
.
det_times
.
info
(
average
=
True
)
else
:
mems
=
{
'cpu_rss_mb'
:
detector
.
cpu_mem
/
len
(
img_list
),
'gpu_rss_mb'
:
detector
.
gpu_mem
/
len
(
img_list
),
'gpu_util'
:
detector
.
gpu_util
*
100
/
len
(
img_list
)
}
perf_info
=
detector
.
det_times
.
report
(
average
=
True
)
model_dir
=
FLAGS
.
model_dir
mode
=
FLAGS
.
run_mode
model_info
=
{
'model_name'
:
model_dir
.
strip
(
'/'
).
split
(
'/'
)[
-
1
],
'precision'
:
mode
.
split
(
'_'
)[
-
1
]
}
data_info
=
{
'batch_size'
:
FLAGS
.
batch_size
,
'shape'
:
"dynamic_shape"
,
'data_num'
:
perf_info
[
'img_num'
]
}
det_log
=
PaddleInferBenchmark
(
detector
.
config
,
model_info
,
data_info
,
perf_info
,
mems
)
det_log
(
'Attr'
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
FLAGS
.
device
=
FLAGS
.
device
.
upper
()
assert
FLAGS
.
device
in
[
'CPU'
,
'GPU'
,
'XPU'
],
"device should be CPU, GPU or XPU"
assert
not
FLAGS
.
use_gpu
,
"use_gpu has been deprecated, please use --device"
main
()
deploy/python/vechile_plateutils.py
0 → 100644
浏览文件 @
3f5ff9e6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
sys
import
platform
import
cv2
import
numpy
as
np
import
paddle
from
PIL
import
Image
,
ImageDraw
,
ImageFont
import
math
from
paddle
import
inference
import
time
import
ast
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--det_algorithm"
,
type
=
str
,
default
=
'DB'
)
parser
.
add_argument
(
"--det_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--det_limit_side_len"
,
type
=
float
,
default
=
960
)
parser
.
add_argument
(
"--det_limit_type"
,
type
=
str
,
default
=
'max'
)
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'SVTR_LCNet'
)
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 48, 320"
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--image_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of image file."
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
,
default
=
None
,
help
=
"Dir of image file, `image_file` has a higher priority."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"batch_size for inference."
)
parser
.
add_argument
(
"--video_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of video file, `video_file` or `camera_id` has a highest priority."
)
parser
.
add_argument
(
"--camera_id"
,
type
=
int
,
default
=-
1
,
help
=
"device id of camera to predict."
)
parser
.
add_argument
(
"--threshold"
,
type
=
float
,
default
=
0.5
,
help
=
"Threshold of score."
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"output"
,
help
=
"Directory of output visualization files."
)
parser
.
add_argument
(
"--run_mode"
,
type
=
str
,
default
=
'paddle'
,
help
=
"mode of running(paddle/trt_fp32/trt_fp16/trt_int8)"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'cpu'
,
help
=
"Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Deprecated, please use `--device`."
)
parser
.
add_argument
(
"--run_benchmark"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether to predict a image_file repeatedly for benchmark"
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use mkldnn with CPU."
)
parser
.
add_argument
(
"--enable_mkldnn_bfloat16"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use mkldnn bfloat16 inference with CPU."
)
parser
.
add_argument
(
"--cpu_threads"
,
type
=
int
,
default
=
1
,
help
=
"Num of threads with CPU."
)
parser
.
add_argument
(
"--trt_min_shape"
,
type
=
int
,
default
=
1
,
help
=
"min_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_max_shape"
,
type
=
int
,
default
=
1280
,
help
=
"max_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_opt_shape"
,
type
=
int
,
default
=
640
,
help
=
"opt_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_calib_mode"
,
type
=
bool
,
default
=
False
,
help
=
"If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True."
)
parser
.
add_argument
(
'--save_images'
,
action
=
'store_true'
,
help
=
'Save visualization image results.'
)
parser
.
add_argument
(
'--save_mot_txts'
,
action
=
'store_true'
,
help
=
'Save tracking results (txt).'
)
parser
.
add_argument
(
'--save_mot_txt_per_img'
,
action
=
'store_true'
,
help
=
'Save tracking results (txt) for each image.'
)
parser
.
add_argument
(
'--scaled'
,
type
=
bool
,
default
=
False
,
help
=
"Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector."
)
parser
.
add_argument
(
"--tracker_config"
,
type
=
str
,
default
=
None
,
help
=
(
"tracker donfig"
))
parser
.
add_argument
(
"--reid_model_dir"
,
type
=
str
,
default
=
None
,
help
=
(
"Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."
))
parser
.
add_argument
(
"--reid_batch_size"
,
type
=
int
,
default
=
50
,
help
=
"max batch_size for reid model inference."
)
parser
.
add_argument
(
'--use_dark'
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'whether to use darkpose to get better keypoint position predict '
)
parser
.
add_argument
(
"--action_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of input file for action recognition."
)
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
50
,
help
=
"Temporal size of skeleton feature for action recognition."
)
parser
.
add_argument
(
"--random_pad"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether do random padding for action recognition."
)
parser
.
add_argument
(
"--save_results"
,
type
=
bool
,
default
=
False
,
help
=
"Whether save detection result to file using coco format"
)
return
parser
def
create_predictor
(
args
,
mode
):
if
mode
==
"det"
:
model_dir
=
args
.
det_model_dir
elif
mode
==
'cls'
:
model_dir
=
args
.
cls_model_dir
elif
mode
==
'rec'
:
model_dir
=
args
.
rec_model_dir
elif
mode
==
'table'
:
model_dir
=
args
.
table_model_dir
else
:
model_dir
=
args
.
e2e_model_dir
if
model_dir
is
None
:
print
(
"not find {} model file path {}"
.
format
(
mode
,
model_dir
))
sys
.
exit
(
0
)
model_file_path
=
model_dir
+
"/inference.pdmodel"
params_file_path
=
model_dir
+
"/inference.pdiparams"
if
not
os
.
path
.
exists
(
model_file_path
):
raise
ValueError
(
"not find model file path {}"
.
format
(
model_file_path
))
if
not
os
.
path
.
exists
(
params_file_path
):
raise
ValueError
(
"not find params file path {}"
.
format
(
params_file_path
))
config
=
inference
.
Config
(
model_file_path
,
params_file_path
)
if
args
.
device
==
"GPU"
:
gpu_id
=
get_infer_gpuid
()
if
gpu_id
is
None
:
print
(
"GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
)
config
.
enable_use_gpu
(
500
,
0
)
precision_map
=
{
'trt_int8'
:
inference
.
PrecisionType
.
Int8
,
'trt_fp32'
:
inference
.
PrecisionType
.
Float32
,
'trt_fp16'
:
inference
.
PrecisionType
.
Half
}
if
args
.
run_mode
in
precision_map
.
keys
():
config
.
enable_tensorrt_engine
(
workspace_size
=
(
1
<<
25
)
*
batch_size
,
max_batch_size
=
batch_size
,
min_subgraph_size
=
min_subgraph_size
,
precision_mode
=
precision_map
[
args
.
run_mode
],
use_static
=
False
,
use_calib_mode
=
trt_calib_mode
)
# skip the minmum trt subgraph
use_dynamic_shape
=
True
if
mode
==
"det"
:
min_input_shape
=
{
"x"
:
[
1
,
3
,
50
,
50
],
"conv2d_92.tmp_0"
:
[
1
,
120
,
20
,
20
],
"conv2d_91.tmp_0"
:
[
1
,
24
,
10
,
10
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
20
,
20
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
256
,
10
,
10
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
256
,
20
,
20
],
"conv2d_124.tmp_0"
:
[
1
,
256
,
20
,
20
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
64
,
20
,
20
],
"elementwise_add_7"
:
[
1
,
56
,
2
,
2
],
"nearest_interp_v2_0.tmp_0"
:
[
1
,
256
,
2
,
2
]
}
max_input_shape
=
{
"x"
:
[
1
,
3
,
1536
,
1536
],
"conv2d_92.tmp_0"
:
[
1
,
120
,
400
,
400
],
"conv2d_91.tmp_0"
:
[
1
,
24
,
200
,
200
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
400
,
400
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
256
,
200
,
200
],
"conv2d_124.tmp_0"
:
[
1
,
256
,
400
,
400
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
256
,
400
,
400
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
64
,
400
,
400
],
"elementwise_add_7"
:
[
1
,
56
,
400
,
400
],
"nearest_interp_v2_0.tmp_0"
:
[
1
,
256
,
400
,
400
]
}
opt_input_shape
=
{
"x"
:
[
1
,
3
,
640
,
640
],
"conv2d_92.tmp_0"
:
[
1
,
120
,
160
,
160
],
"conv2d_91.tmp_0"
:
[
1
,
24
,
80
,
80
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
160
,
160
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
256
,
80
,
80
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
256
,
160
,
160
],
"conv2d_124.tmp_0"
:
[
1
,
256
,
160
,
160
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
64
,
160
,
160
],
"elementwise_add_7"
:
[
1
,
56
,
40
,
40
],
"nearest_interp_v2_0.tmp_0"
:
[
1
,
256
,
40
,
40
]
}
min_pact_shape
=
{
"nearest_interp_v2_26.tmp_0"
:
[
1
,
256
,
20
,
20
],
"nearest_interp_v2_27.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_28.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_29.tmp_0"
:
[
1
,
64
,
20
,
20
]
}
max_pact_shape
=
{
"nearest_interp_v2_26.tmp_0"
:
[
1
,
256
,
400
,
400
],
"nearest_interp_v2_27.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_28.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_29.tmp_0"
:
[
1
,
64
,
400
,
400
]
}
opt_pact_shape
=
{
"nearest_interp_v2_26.tmp_0"
:
[
1
,
256
,
160
,
160
],
"nearest_interp_v2_27.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_28.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_29.tmp_0"
:
[
1
,
64
,
160
,
160
]
}
min_input_shape
.
update
(
min_pact_shape
)
max_input_shape
.
update
(
max_pact_shape
)
opt_input_shape
.
update
(
opt_pact_shape
)
elif
mode
==
"rec"
:
imgH
=
int
(
args
.
rec_image_shape
.
split
(
','
)[
-
2
])
min_input_shape
=
{
"x"
:
[
1
,
3
,
imgH
,
10
]}
max_input_shape
=
{
"x"
:
[
args
.
batch_size
,
3
,
imgH
,
2304
]}
opt_input_shape
=
{
"x"
:
[
args
.
batch_size
,
3
,
imgH
,
320
]}
elif
mode
==
"cls"
:
min_input_shape
=
{
"x"
:
[
1
,
3
,
48
,
10
]}
max_input_shape
=
{
"x"
:
[
args
.
batch_size
,
3
,
48
,
1024
]}
opt_input_shape
=
{
"x"
:
[
args
.
batch_size
,
3
,
48
,
320
]}
else
:
use_dynamic_shape
=
False
if
use_dynamic_shape
:
config
.
set_trt_dynamic_shape_info
(
min_input_shape
,
max_input_shape
,
opt_input_shape
)
else
:
config
.
disable_gpu
()
if
hasattr
(
args
,
"cpu_threads"
):
config
.
set_cpu_math_library_num_threads
(
args
.
cpu_threads
)
else
:
# default cpu threads as 10
config
.
set_cpu_math_library_num_threads
(
10
)
if
args
.
enable_mkldnn
:
# cache 10 different shapes for mkldnn to avoid memory leak
config
.
set_mkldnn_cache_capacity
(
10
)
config
.
enable_mkldnn
()
if
args
.
run_mode
==
"fp16"
:
config
.
enable_mkldnn_bfloat16
()
# enable memory optim
config
.
enable_memory_optim
()
config
.
disable_glog_info
()
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
delete_pass
(
"matmul_transpose_reshape_fuse_pass"
)
if
mode
==
'table'
:
config
.
delete_pass
(
"fc_fuse_pass"
)
# not supported for table
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_ir_optim
(
True
)
# create predictor
predictor
=
inference
.
create_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
for
name
in
input_names
:
input_tensor
=
predictor
.
get_input_handle
(
name
)
output_tensors
=
get_output_tensors
(
args
,
mode
,
predictor
)
return
predictor
,
input_tensor
,
output_tensors
,
config
def
get_output_tensors
(
args
,
mode
,
predictor
):
output_names
=
predictor
.
get_output_names
()
output_tensors
=
[]
if
mode
==
"rec"
and
args
.
rec_algorithm
in
[
"CRNN"
,
"SVTR_LCNet"
]:
output_name
=
'softmax_0.tmp_0'
if
output_name
in
output_names
:
return
[
predictor
.
get_output_handle
(
output_name
)]
else
:
for
output_name
in
output_names
:
output_tensor
=
predictor
.
get_output_handle
(
output_name
)
output_tensors
.
append
(
output_tensor
)
else
:
for
output_name
in
output_names
:
output_tensor
=
predictor
.
get_output_handle
(
output_name
)
output_tensors
.
append
(
output_tensor
)
return
output_tensors
def
get_infer_gpuid
():
sysstr
=
platform
.
system
()
if
sysstr
==
"Windows"
:
return
0
if
not
paddle
.
fluid
.
core
.
is_compiled_with_rocm
():
cmd
=
"env | grep CUDA_VISIBLE_DEVICES"
else
:
cmd
=
"env | grep HIP_VISIBLE_DEVICES"
env_cuda
=
os
.
popen
(
cmd
).
readlines
()
if
len
(
env_cuda
)
==
0
:
return
0
else
:
gpu_id
=
env_cuda
[
0
].
strip
().
split
(
"="
)[
1
]
return
int
(
gpu_id
[
0
])
def
draw_e2e_res
(
dt_boxes
,
strs
,
img_path
):
src_im
=
cv2
.
imread
(
img_path
)
for
box
,
str
in
zip
(
dt_boxes
,
strs
):
box
=
box
.
astype
(
np
.
int32
).
reshape
((
-
1
,
1
,
2
))
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
cv2
.
putText
(
src_im
,
str
,
org
=
(
int
(
box
[
0
,
0
,
0
]),
int
(
box
[
0
,
0
,
1
])),
fontFace
=
cv2
.
FONT_HERSHEY_COMPLEX
,
fontScale
=
0.7
,
color
=
(
0
,
255
,
0
),
thickness
=
1
)
return
src_im
def
draw_text_det_res
(
dt_boxes
,
img_path
):
src_im
=
cv2
.
imread
(
img_path
)
for
box
in
dt_boxes
:
box
=
np
.
array
(
box
).
astype
(
np
.
int32
).
reshape
(
-
1
,
2
)
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
return
src_im
def
resize_img
(
img
,
input_size
=
600
):
"""
resize img and limit the longest side of the image to input_size
"""
img
=
np
.
array
(
img
)
im_shape
=
img
.
shape
im_size_max
=
np
.
max
(
im_shape
[
0
:
2
])
im_scale
=
float
(
input_size
)
/
float
(
im_size_max
)
img
=
cv2
.
resize
(
img
,
None
,
None
,
fx
=
im_scale
,
fy
=
im_scale
)
return
img
def
draw_ocr
(
image
,
boxes
,
txts
=
None
,
scores
=
None
,
drop_score
=
0.5
,
font_path
=
"./doc/fonts/simfang.ttf"
):
"""
Visualize the results of OCR detection and recognition
args:
image(Image|array): RGB image
boxes(list): boxes with shape(N, 4, 2)
txts(list): the texts
scores(list): txxs corresponding scores
drop_score(float): only scores greater than drop_threshold will be visualized
font_path: the path of font which is used to draw text
return(array):
the visualized img
"""
if
scores
is
None
:
scores
=
[
1
]
*
len
(
boxes
)
box_num
=
len
(
boxes
)
for
i
in
range
(
box_num
):
if
scores
is
not
None
and
(
scores
[
i
]
<
drop_score
or
math
.
isnan
(
scores
[
i
])):
continue
box
=
np
.
reshape
(
np
.
array
(
boxes
[
i
]),
[
-
1
,
1
,
2
]).
astype
(
np
.
int64
)
image
=
cv2
.
polylines
(
np
.
array
(
image
),
[
box
],
True
,
(
255
,
0
,
0
),
2
)
if
txts
is
not
None
:
img
=
np
.
array
(
resize_img
(
image
,
input_size
=
600
))
txt_img
=
text_visual
(
txts
,
scores
,
img_h
=
img
.
shape
[
0
],
img_w
=
600
,
threshold
=
drop_score
,
font_path
=
font_path
)
img
=
np
.
concatenate
([
np
.
array
(
img
),
np
.
array
(
txt_img
)],
axis
=
1
)
return
img
return
image
def
draw_ocr_box_txt
(
image
,
boxes
,
txts
,
scores
=
None
,
drop_score
=
0.5
,
font_path
=
"./doc/simfang.ttf"
):
h
,
w
=
image
.
height
,
image
.
width
img_left
=
image
.
copy
()
img_right
=
Image
.
new
(
'RGB'
,
(
w
,
h
),
(
255
,
255
,
255
))
import
random
random
.
seed
(
0
)
draw_left
=
ImageDraw
.
Draw
(
img_left
)
draw_right
=
ImageDraw
.
Draw
(
img_right
)
for
idx
,
(
box
,
txt
)
in
enumerate
(
zip
(
boxes
,
txts
)):
if
scores
is
not
None
and
scores
[
idx
]
<
drop_score
:
continue
color
=
(
random
.
randint
(
0
,
255
),
random
.
randint
(
0
,
255
),
random
.
randint
(
0
,
255
))
draw_left
.
polygon
(
box
,
fill
=
color
)
draw_right
.
polygon
(
[
box
[
0
][
0
],
box
[
0
][
1
],
box
[
1
][
0
],
box
[
1
][
1
],
box
[
2
][
0
],
box
[
2
][
1
],
box
[
3
][
0
],
box
[
3
][
1
]
],
outline
=
color
)
box_height
=
math
.
sqrt
((
box
[
0
][
0
]
-
box
[
3
][
0
])
**
2
+
(
box
[
0
][
1
]
-
box
[
3
][
1
])
**
2
)
box_width
=
math
.
sqrt
((
box
[
0
][
0
]
-
box
[
1
][
0
])
**
2
+
(
box
[
0
][
1
]
-
box
[
1
][
1
])
**
2
)
if
box_height
>
2
*
box_width
:
font_size
=
max
(
int
(
box_width
*
0.9
),
10
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
cur_y
=
box
[
0
][
1
]
for
c
in
txt
:
char_size
=
font
.
getsize
(
c
)
draw_right
.
text
(
(
box
[
0
][
0
]
+
3
,
cur_y
),
c
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
cur_y
+=
char_size
[
1
]
else
:
font_size
=
max
(
int
(
box_height
*
0.8
),
10
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
draw_right
.
text
(
[
box
[
0
][
0
],
box
[
0
][
1
]],
txt
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
img_left
=
Image
.
blend
(
image
,
img_left
,
0.5
)
img_show
=
Image
.
new
(
'RGB'
,
(
w
*
2
,
h
),
(
255
,
255
,
255
))
img_show
.
paste
(
img_left
,
(
0
,
0
,
w
,
h
))
img_show
.
paste
(
img_right
,
(
w
,
0
,
w
*
2
,
h
))
return
np
.
array
(
img_show
)
def
str_count
(
s
):
"""
Count the number of Chinese characters,
a single English character and a single number
equal to half the length of Chinese characters.
args:
s(string): the input of string
return(int):
the number of Chinese characters
"""
import
string
count_zh
=
count_pu
=
0
s_len
=
len
(
s
)
en_dg_count
=
0
for
c
in
s
:
if
c
in
string
.
ascii_letters
or
c
.
isdigit
()
or
c
.
isspace
():
en_dg_count
+=
1
elif
c
.
isalpha
():
count_zh
+=
1
else
:
count_pu
+=
1
return
s_len
-
math
.
ceil
(
en_dg_count
/
2
)
def
text_visual
(
texts
,
scores
,
img_h
=
400
,
img_w
=
600
,
threshold
=
0.
,
font_path
=
"./doc/simfang.ttf"
):
"""
create new blank img and draw txt on it
args:
texts(list): the text will be draw
scores(list|None): corresponding score of each txt
img_h(int): the height of blank img
img_w(int): the width of blank img
font_path: the path of font which is used to draw text
return(array):
"""
if
scores
is
not
None
:
assert
len
(
texts
)
==
len
(
scores
),
"The number of txts and corresponding scores must match"
def
create_blank_img
():
blank_img
=
np
.
ones
(
shape
=
[
img_h
,
img_w
],
dtype
=
np
.
int8
)
*
255
blank_img
[:,
img_w
-
1
:]
=
0
blank_img
=
Image
.
fromarray
(
blank_img
).
convert
(
"RGB"
)
draw_txt
=
ImageDraw
.
Draw
(
blank_img
)
return
blank_img
,
draw_txt
blank_img
,
draw_txt
=
create_blank_img
()
font_size
=
20
txt_color
=
(
0
,
0
,
0
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
gap
=
font_size
+
5
txt_img_list
=
[]
count
,
index
=
1
,
0
for
idx
,
txt
in
enumerate
(
texts
):
index
+=
1
if
scores
[
idx
]
<
threshold
or
math
.
isnan
(
scores
[
idx
]):
index
-=
1
continue
first_line
=
True
while
str_count
(
txt
)
>=
img_w
//
font_size
-
4
:
tmp
=
txt
txt
=
tmp
[:
img_w
//
font_size
-
4
]
if
first_line
:
new_txt
=
str
(
index
)
+
': '
+
txt
first_line
=
False
else
:
new_txt
=
' '
+
txt
draw_txt
.
text
((
0
,
gap
*
count
),
new_txt
,
txt_color
,
font
=
font
)
txt
=
tmp
[
img_w
//
font_size
-
4
:]
if
count
>=
img_h
//
gap
-
1
:
txt_img_list
.
append
(
np
.
array
(
blank_img
))
blank_img
,
draw_txt
=
create_blank_img
()
count
=
0
count
+=
1
if
first_line
:
new_txt
=
str
(
index
)
+
': '
+
txt
+
' '
+
'%.3f'
%
(
scores
[
idx
])
else
:
new_txt
=
" "
+
txt
+
" "
+
'%.3f'
%
(
scores
[
idx
])
draw_txt
.
text
((
0
,
gap
*
count
),
new_txt
,
txt_color
,
font
=
font
)
# whether add new blank img or not
if
count
>=
img_h
//
gap
-
1
and
idx
+
1
<
len
(
texts
):
txt_img_list
.
append
(
np
.
array
(
blank_img
))
blank_img
,
draw_txt
=
create_blank_img
()
count
=
0
count
+=
1
txt_img_list
.
append
(
np
.
array
(
blank_img
))
if
len
(
txt_img_list
)
==
1
:
blank_img
=
np
.
array
(
txt_img_list
[
0
])
else
:
blank_img
=
np
.
concatenate
(
txt_img_list
,
axis
=
1
)
return
np
.
array
(
blank_img
)
def
base64_to_cv2
(
b64str
):
import
base64
data
=
base64
.
b64decode
(
b64str
.
encode
(
'utf8'
))
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
data
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
return
data
def
draw_boxes
(
image
,
boxes
,
scores
=
None
,
drop_score
=
0.5
):
if
scores
is
None
:
scores
=
[
1
]
*
len
(
boxes
)
for
(
box
,
score
)
in
zip
(
boxes
,
scores
):
if
score
<
drop_score
:
continue
box
=
np
.
reshape
(
np
.
array
(
box
),
[
-
1
,
1
,
2
]).
astype
(
np
.
int64
)
image
=
cv2
.
polylines
(
np
.
array
(
image
),
[
box
],
True
,
(
255
,
0
,
0
),
2
)
return
image
def
get_rotate_crop_image
(
img
,
points
):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert
len
(
points
)
==
4
,
"shape of points must be 4*2"
img_crop_width
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
1
]),
np
.
linalg
.
norm
(
points
[
2
]
-
points
[
3
])))
img_crop_height
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
3
]),
np
.
linalg
.
norm
(
points
[
1
]
-
points
[
2
])))
pts_std
=
np
.
float32
([[
0
,
0
],
[
img_crop_width
,
0
],
[
img_crop_width
,
img_crop_height
],
[
0
,
img_crop_height
]])
M
=
cv2
.
getPerspectiveTransform
(
points
,
pts_std
)
dst_img
=
cv2
.
warpPerspective
(
img
,
M
,
(
img_crop_width
,
img_crop_height
),
borderMode
=
cv2
.
BORDER_REPLICATE
,
flags
=
cv2
.
INTER_CUBIC
)
dst_img_height
,
dst_img_width
=
dst_img
.
shape
[
0
:
2
]
if
dst_img_height
*
1.0
/
dst_img_width
>=
1.5
:
dst_img
=
np
.
rot90
(
dst_img
)
return
dst_img
def
check_gpu
(
use_gpu
):
if
use_gpu
and
not
paddle
.
is_compiled_with_cuda
():
use_gpu
=
False
return
use_gpu
if
__name__
==
'__main__'
:
pass
deploy/python/vecplatepostprocess.py
0 → 100644
浏览文件 @
3f5ff9e6
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
from
paddle.nn
import
functional
as
F
import
re
from
shapely.geometry
import
Polygon
import
pyclipper
import
cv2
import
copy
def
build_post_process
(
config
,
global_config
=
None
):
support_dict
=
[
'DBPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'SRNLabelDecode'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
,
'PRENLabelDecode'
,
'DistillationSARLabelDecode'
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
from
.pse_postprocess
import
PSEPostProcess
support_dict
.
append
(
'PSEPostProcess'
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
if
module_name
==
"None"
:
return
if
global_config
is
not
None
:
config
.
update
(
global_config
)
assert
module_name
in
support_dict
,
Exception
(
'post process only support {}'
.
format
(
support_dict
))
module_class
=
eval
(
module_name
)(
**
config
)
return
module_class
class
DBPostProcess
(
object
):
"""
The post process for Differentiable Binarization (DB).
"""
def
__init__
(
self
,
thresh
=
0.3
,
box_thresh
=
0.7
,
max_candidates
=
1000
,
unclip_ratio
=
2.0
,
use_dilation
=
False
,
score_mode
=
"fast"
,
**
kwargs
):
self
.
thresh
=
thresh
self
.
box_thresh
=
box_thresh
self
.
max_candidates
=
max_candidates
self
.
unclip_ratio
=
unclip_ratio
self
.
min_size
=
3
self
.
score_mode
=
score_mode
assert
score_mode
in
[
"slow"
,
"fast"
],
"Score mode must be in [slow, fast] but got: {}"
.
format
(
score_mode
)
self
.
dilation_kernel
=
None
if
not
use_dilation
else
np
.
array
(
[[
1
,
1
],
[
1
,
1
]])
def
boxes_from_bitmap
(
self
,
pred
,
_bitmap
,
dest_width
,
dest_height
):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap
=
_bitmap
height
,
width
=
bitmap
.
shape
outs
=
cv2
.
findContours
((
bitmap
*
255
).
astype
(
np
.
uint8
),
cv2
.
RETR_LIST
,
cv2
.
CHAIN_APPROX_SIMPLE
)
if
len
(
outs
)
==
3
:
img
,
contours
,
_
=
outs
[
0
],
outs
[
1
],
outs
[
2
]
elif
len
(
outs
)
==
2
:
contours
,
_
=
outs
[
0
],
outs
[
1
]
num_contours
=
min
(
len
(
contours
),
self
.
max_candidates
)
boxes
=
[]
scores
=
[]
for
index
in
range
(
num_contours
):
contour
=
contours
[
index
]
points
,
sside
=
self
.
get_mini_boxes
(
contour
)
if
sside
<
self
.
min_size
:
continue
points
=
np
.
array
(
points
)
if
self
.
score_mode
==
"fast"
:
score
=
self
.
box_score_fast
(
pred
,
points
.
reshape
(
-
1
,
2
))
else
:
score
=
self
.
box_score_slow
(
pred
,
contour
)
if
self
.
box_thresh
>
score
:
continue
box
=
self
.
unclip
(
points
).
reshape
(
-
1
,
1
,
2
)
box
,
sside
=
self
.
get_mini_boxes
(
box
)
if
sside
<
self
.
min_size
+
2
:
continue
box
=
np
.
array
(
box
)
box
[:,
0
]
=
np
.
clip
(
np
.
round
(
box
[:,
0
]
/
width
*
dest_width
),
0
,
dest_width
)
box
[:,
1
]
=
np
.
clip
(
np
.
round
(
box
[:,
1
]
/
height
*
dest_height
),
0
,
dest_height
)
boxes
.
append
(
box
.
astype
(
np
.
int16
))
scores
.
append
(
score
)
return
np
.
array
(
boxes
,
dtype
=
np
.
int16
),
scores
def
unclip
(
self
,
box
):
unclip_ratio
=
self
.
unclip_ratio
poly
=
Polygon
(
box
)
distance
=
poly
.
area
*
unclip_ratio
/
poly
.
length
offset
=
pyclipper
.
PyclipperOffset
()
offset
.
AddPath
(
box
,
pyclipper
.
JT_ROUND
,
pyclipper
.
ET_CLOSEDPOLYGON
)
expanded
=
np
.
array
(
offset
.
Execute
(
distance
))
return
expanded
def
get_mini_boxes
(
self
,
contour
):
bounding_box
=
cv2
.
minAreaRect
(
contour
)
points
=
sorted
(
list
(
cv2
.
boxPoints
(
bounding_box
)),
key
=
lambda
x
:
x
[
0
])
index_1
,
index_2
,
index_3
,
index_4
=
0
,
1
,
2
,
3
if
points
[
1
][
1
]
>
points
[
0
][
1
]:
index_1
=
0
index_4
=
1
else
:
index_1
=
1
index_4
=
0
if
points
[
3
][
1
]
>
points
[
2
][
1
]:
index_2
=
2
index_3
=
3
else
:
index_2
=
3
index_3
=
2
box
=
[
points
[
index_1
],
points
[
index_2
],
points
[
index_3
],
points
[
index_4
]
]
return
box
,
min
(
bounding_box
[
1
])
def
box_score_fast
(
self
,
bitmap
,
_box
):
'''
box_score_fast: use bbox mean score as the mean score
'''
h
,
w
=
bitmap
.
shape
[:
2
]
box
=
_box
.
copy
()
xmin
=
np
.
clip
(
np
.
floor
(
box
[:,
0
].
min
()).
astype
(
np
.
int
),
0
,
w
-
1
)
xmax
=
np
.
clip
(
np
.
ceil
(
box
[:,
0
].
max
()).
astype
(
np
.
int
),
0
,
w
-
1
)
ymin
=
np
.
clip
(
np
.
floor
(
box
[:,
1
].
min
()).
astype
(
np
.
int
),
0
,
h
-
1
)
ymax
=
np
.
clip
(
np
.
ceil
(
box
[:,
1
].
max
()).
astype
(
np
.
int
),
0
,
h
-
1
)
mask
=
np
.
zeros
((
ymax
-
ymin
+
1
,
xmax
-
xmin
+
1
),
dtype
=
np
.
uint8
)
box
[:,
0
]
=
box
[:,
0
]
-
xmin
box
[:,
1
]
=
box
[:,
1
]
-
ymin
cv2
.
fillPoly
(
mask
,
box
.
reshape
(
1
,
-
1
,
2
).
astype
(
np
.
int32
),
1
)
return
cv2
.
mean
(
bitmap
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
],
mask
)[
0
]
def
box_score_slow
(
self
,
bitmap
,
contour
):
'''
box_score_slow: use polyon mean score as the mean score
'''
h
,
w
=
bitmap
.
shape
[:
2
]
contour
=
contour
.
copy
()
contour
=
np
.
reshape
(
contour
,
(
-
1
,
2
))
xmin
=
np
.
clip
(
np
.
min
(
contour
[:,
0
]),
0
,
w
-
1
)
xmax
=
np
.
clip
(
np
.
max
(
contour
[:,
0
]),
0
,
w
-
1
)
ymin
=
np
.
clip
(
np
.
min
(
contour
[:,
1
]),
0
,
h
-
1
)
ymax
=
np
.
clip
(
np
.
max
(
contour
[:,
1
]),
0
,
h
-
1
)
mask
=
np
.
zeros
((
ymax
-
ymin
+
1
,
xmax
-
xmin
+
1
),
dtype
=
np
.
uint8
)
contour
[:,
0
]
=
contour
[:,
0
]
-
xmin
contour
[:,
1
]
=
contour
[:,
1
]
-
ymin
cv2
.
fillPoly
(
mask
,
contour
.
reshape
(
1
,
-
1
,
2
).
astype
(
np
.
int32
),
1
)
return
cv2
.
mean
(
bitmap
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
],
mask
)[
0
]
def
__call__
(
self
,
outs_dict
,
shape_list
):
pred
=
outs_dict
[
'maps'
]
if
isinstance
(
pred
,
paddle
.
Tensor
):
pred
=
pred
.
numpy
()
pred
=
pred
[:,
0
,
:,
:]
segmentation
=
pred
>
self
.
thresh
boxes_batch
=
[]
for
batch_index
in
range
(
pred
.
shape
[
0
]):
src_h
,
src_w
=
shape_list
[
batch_index
]
if
self
.
dilation_kernel
is
not
None
:
mask
=
cv2
.
dilate
(
np
.
array
(
segmentation
[
batch_index
]).
astype
(
np
.
uint8
),
self
.
dilation_kernel
)
else
:
mask
=
segmentation
[
batch_index
]
boxes
,
scores
=
self
.
boxes_from_bitmap
(
pred
[
batch_index
],
mask
,
src_w
,
src_h
)
boxes_batch
.
append
({
'points'
:
boxes
})
return
boxes_batch
class
BaseRecLabelDecode
(
object
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
character_str
=
[]
if
character_dict_path
is
None
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
else
:
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
line
=
line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
self
.
character_str
.
append
(
line
)
if
use_space_char
:
self
.
character_str
.
append
(
" "
)
dict_character
=
list
(
self
.
character_str
)
dict_character
=
self
.
add_special_char
(
dict_character
)
self
.
dict
=
{}
for
i
,
char
in
enumerate
(
dict_character
):
self
.
dict
[
char
]
=
i
self
.
character
=
dict_character
def
add_special_char
(
self
,
dict_character
):
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
selection
=
np
.
ones
(
len
(
text_index
[
batch_idx
]),
dtype
=
bool
)
if
is_remove_duplicate
:
selection
[
1
:]
=
text_index
[
batch_idx
][
1
:]
!=
text_index
[
batch_idx
][:
-
1
]
for
ignored_token
in
ignored_tokens
:
selection
&=
text_index
[
batch_idx
]
!=
ignored_token
char_list
=
[
self
.
character
[
text_id
]
for
text_id
in
text_index
[
batch_idx
][
selection
]
]
if
text_prob
is
not
None
:
conf_list
=
text_prob
[
batch_idx
][
selection
]
else
:
conf_list
=
[
1
]
*
len
(
selection
)
if
len
(
conf_list
)
==
0
:
conf_list
=
[
0
]
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
get_ignored_tokens
(
self
):
return
[
0
]
# for ctc blank
class
CTCLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
CTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
tuple
)
or
isinstance
(
preds
,
list
):
preds
=
preds
[
-
1
]
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
)
return
text
,
label
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
]
+
dict_character
return
dict_character
class
DistillationCTCLabelDecode
(
CTCLabelDecode
):
"""
Convert
Convert between text-label and text-index
"""
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
model_name
=
[
"student"
],
key
=
None
,
multi_head
=
False
,
**
kwargs
):
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
self
.
multi_head
=
multi_head
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
output
=
dict
()
for
name
in
self
.
model_name
:
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
if
self
.
multi_head
and
isinstance
(
pred
,
dict
):
pred
=
pred
[
'ctc'
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
class
NRTRLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
True
,
**
kwargs
):
super
(
NRTRLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
len
(
preds
)
==
2
:
preds_id
=
preds
[
0
]
preds_prob
=
preds
[
1
]
if
isinstance
(
preds_id
,
paddle
.
Tensor
):
preds_id
=
preds_id
.
numpy
()
if
isinstance
(
preds_prob
,
paddle
.
Tensor
):
preds_prob
=
preds_prob
.
numpy
()
if
preds_id
[
0
][
0
]
==
2
:
preds_idx
=
preds_id
[:,
1
:]
preds_prob
=
preds_prob
[:,
1
:]
else
:
preds_idx
=
preds_id
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
[:,
1
:])
else
:
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
[:,
1
:])
return
text
,
label
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
,
'<unk>'
,
'<s>'
,
'</s>'
]
+
dict_character
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
==
3
:
# end
break
try
:
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
except
:
continue
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
.
lower
(),
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
class
AttnLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
AttnLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
[
beg_idx
,
end_idx
]
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
end_idx
):
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
class
SEEDLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SEEDLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
padding_str
=
"padding"
self
.
end_str
=
"eos"
self
.
unknown
=
"unknown"
dict_character
=
dict_character
+
[
self
.
end_str
,
self
.
padding_str
,
self
.
unknown
]
return
dict_character
def
get_ignored_tokens
(
self
):
end_idx
=
self
.
get_beg_end_flag_idx
(
"eos"
)
return
[
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"sos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"eos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
%
beg_or_end
return
idx
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
[
end_idx
]
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
end_idx
):
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds_idx
=
preds
[
"rec_pred"
]
if
isinstance
(
preds_idx
,
paddle
.
Tensor
):
preds_idx
=
preds_idx
.
numpy
()
if
"rec_pred_scores"
in
preds
:
preds_idx
=
preds
[
"rec_pred"
]
preds_prob
=
preds
[
"rec_pred_scores"
]
else
:
preds_idx
=
preds
[
"rec_pred"
].
argmax
(
axis
=
2
)
preds_prob
=
preds
[
"rec_pred"
].
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
class
SRNLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SRNLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
self
.
max_text_length
=
kwargs
.
get
(
'max_text_length'
,
25
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
pred
=
preds
[
'predict'
]
char_num
=
len
(
self
.
character_str
)
+
2
if
isinstance
(
pred
,
paddle
.
Tensor
):
pred
=
pred
.
numpy
()
pred
=
np
.
reshape
(
pred
,
[
-
1
,
char_num
])
preds_idx
=
np
.
argmax
(
pred
,
axis
=
1
)
preds_prob
=
np
.
max
(
pred
,
axis
=
1
)
preds_idx
=
np
.
reshape
(
preds_idx
,
[
-
1
,
self
.
max_text_length
])
preds_prob
=
np
.
reshape
(
preds_prob
,
[
-
1
,
self
.
max_text_length
])
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
if
label
is
None
:
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
return
text
label
=
self
.
decode
(
label
)
return
text
,
label
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
add_special_char
(
self
,
dict_character
):
dict_character
=
dict_character
+
[
self
.
beg_str
,
self
.
end_str
]
return
dict_character
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
class
TableLabelDecode
(
object
):
""" """
def
__init__
(
self
,
character_dict_path
,
**
kwargs
):
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
list_character
=
self
.
add_special_char
(
list_character
)
list_elem
=
self
.
add_special_char
(
list_elem
)
self
.
dict_character
=
{}
self
.
dict_idx_character
=
{}
for
i
,
char
in
enumerate
(
list_character
):
self
.
dict_idx_character
[
i
]
=
char
self
.
dict_character
[
char
]
=
i
self
.
dict_elem
=
{}
self
.
dict_idx_elem
=
{}
for
i
,
elem
in
enumerate
(
list_elem
):
self
.
dict_idx_elem
[
i
]
=
elem
self
.
dict_elem
[
elem
]
=
i
def
load_char_elem_dict
(
self
,
character_dict_path
):
list_character
=
[]
list_elem
=
[]
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
).
split
(
"
\t
"
)
character_num
=
int
(
substr
[
0
])
elem_num
=
int
(
substr
[
1
])
for
cno
in
range
(
1
,
1
+
character_num
):
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
list_character
.
append
(
character
)
for
eno
in
range
(
1
+
character_num
,
1
+
character_num
+
elem_num
):
elem
=
lines
[
eno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
list_elem
.
append
(
elem
)
return
list_character
,
list_elem
def
add_special_char
(
self
,
list_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
def
__call__
(
self
,
preds
):
structure_probs
=
preds
[
'structure_probs'
]
loc_preds
=
preds
[
'loc_preds'
]
if
isinstance
(
structure_probs
,
paddle
.
Tensor
):
structure_probs
=
structure_probs
.
numpy
()
if
isinstance
(
loc_preds
,
paddle
.
Tensor
):
loc_preds
=
loc_preds
.
numpy
()
structure_idx
=
structure_probs
.
argmax
(
axis
=
2
)
structure_probs
=
structure_probs
.
max
(
axis
=
2
)
structure_str
,
structure_pos
,
result_score_list
,
result_elem_idx_list
=
self
.
decode
(
structure_idx
,
structure_probs
,
'elem'
)
res_html_code_list
=
[]
res_loc_list
=
[]
batch_num
=
len
(
structure_str
)
for
bno
in
range
(
batch_num
):
res_loc
=
[]
for
sno
in
range
(
len
(
structure_str
[
bno
])):
text
=
structure_str
[
bno
][
sno
]
if
text
in
[
'<td>'
,
'<td'
]:
pos
=
structure_pos
[
bno
][
sno
]
res_loc
.
append
(
loc_preds
[
bno
,
pos
])
res_html_code
=
''
.
join
(
structure_str
[
bno
])
res_loc
=
np
.
array
(
res_loc
)
res_html_code_list
.
append
(
res_html_code
)
res_loc_list
.
append
(
res_loc
)
return
{
'res_html_code'
:
res_html_code_list
,
'res_loc'
:
res_loc_list
,
'res_score_list'
:
result_score_list
,
'res_elem_idx_list'
:
result_elem_idx_list
,
'structure_str_list'
:
structure_str
}
def
decode
(
self
,
text_index
,
structure_probs
,
char_or_elem
):
"""convert text-label into text-index.
"""
if
char_or_elem
==
"char"
:
current_dict
=
self
.
dict_idx_character
else
:
current_dict
=
self
.
dict_idx_elem
ignored_tokens
=
self
.
get_ignored_tokens
(
'elem'
)
beg_idx
,
end_idx
=
ignored_tokens
result_list
=
[]
result_pos_list
=
[]
result_score_list
=
[]
result_elem_idx_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
elem_pos_list
=
[]
elem_idx_list
=
[]
score_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
tmp_elem_idx
=
int
(
text_index
[
batch_idx
][
idx
])
if
idx
>
0
and
tmp_elem_idx
==
end_idx
:
break
if
tmp_elem_idx
in
ignored_tokens
:
continue
char_list
.
append
(
current_dict
[
tmp_elem_idx
])
elem_pos_list
.
append
(
idx
)
score_list
.
append
(
structure_probs
[
batch_idx
,
idx
])
elem_idx_list
.
append
(
tmp_elem_idx
)
result_list
.
append
(
char_list
)
result_pos_list
.
append
(
elem_pos_list
)
result_score_list
.
append
(
score_list
)
result_elem_idx_list
.
append
(
elem_idx_list
)
return
result_list
,
result_pos_list
,
result_score_list
,
result_elem_idx_list
def
get_ignored_tokens
(
self
,
char_or_elem
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
,
char_or_elem
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
,
char_or_elem
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
,
char_or_elem
):
if
char_or_elem
==
"char"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_character
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_character
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of char"
\
%
beg_or_end
elif
char_or_elem
==
"elem"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_elem
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_elem
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of elem"
\
%
beg_or_end
else
:
assert
False
,
"Unsupport type %s in char_or_elem"
\
%
char_or_elem
return
idx
class
SARLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SARLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
self
.
rm_symbol
=
kwargs
.
get
(
'rm_symbol'
,
False
)
def
add_special_char
(
self
,
dict_character
):
beg_end_str
=
"<BOS/EOS>"
unknown_str
=
"<UKN>"
padding_str
=
"<PAD>"
dict_character
=
dict_character
+
[
unknown_str
]
self
.
unknown_idx
=
len
(
dict_character
)
-
1
dict_character
=
dict_character
+
[
beg_end_str
]
self
.
start_idx
=
len
(
dict_character
)
-
1
self
.
end_idx
=
len
(
dict_character
)
-
1
dict_character
=
dict_character
+
[
padding_str
]
self
.
padding_idx
=
len
(
dict_character
)
-
1
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
self
.
end_idx
):
if
text_prob
is
None
and
idx
==
0
:
continue
else
:
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
if
self
.
rm_symbol
:
comp
=
re
.
compile
(
'[^A-Z^a-z^0-9^
\u4e00
-
\u9fa5
]'
)
text
=
text
.
lower
()
text
=
comp
.
sub
(
''
,
text
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
def
get_ignored_tokens
(
self
):
return
[
self
.
padding_idx
]
class
DistillationSARLabelDecode
(
SARLabelDecode
):
"""
Convert
Convert between text-label and text-index
"""
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
model_name
=
[
"student"
],
key
=
None
,
multi_head
=
False
,
**
kwargs
):
super
(
DistillationSARLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
self
.
multi_head
=
multi_head
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
output
=
dict
()
for
name
in
self
.
model_name
:
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
if
self
.
multi_head
and
isinstance
(
pred
,
dict
):
pred
=
pred
[
'sar'
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
class
PRENLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
PRENLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
padding_str
=
'<PAD>'
# 0
end_str
=
'<EOS>'
# 1
unknown_str
=
'<UNK>'
# 2
dict_character
=
[
padding_str
,
end_str
,
unknown_str
]
+
dict_character
self
.
padding_idx
=
0
self
.
end_idx
=
1
self
.
unknown_idx
=
2
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
):
""" convert text-index into text-label. """
result_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
==
self
.
end_idx
:
break
if
text_index
[
batch_idx
][
idx
]
in
\
[
self
.
padding_idx
,
self
.
unknown_idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
if
len
(
text
)
>
0
:
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
else
:
# here confidence of empty recog result is 1
result_list
.
append
((
''
,
1
))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
)
return
text
,
label
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录