Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
3f5ff9e6
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3f5ff9e6
编写于
6月 14, 2022
作者:
Z
zhiboniu
提交者:
zhiboniu
6月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ppvehicle plate
上级
333370d9
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
2323 addition
and
0 deletion
+2323
-0
deploy/python/preprocess.py
deploy/python/preprocess.py
+79
-0
deploy/python/utils.py
deploy/python/utils.py
+8
-0
deploy/python/vechile_plate.py
deploy/python/vechile_plate.py
+621
-0
deploy/python/vechile_plateutils.py
deploy/python/vechile_plateutils.py
+667
-0
deploy/python/vecplatepostprocess.py
deploy/python/vecplatepostprocess.py
+948
-0
未找到文件。
deploy/python/preprocess.py
浏览文件 @
3f5ff9e6
...
@@ -40,6 +40,85 @@ def decode_image(im_file, im_info):
...
@@ -40,6 +40,85 @@ def decode_image(im_file, im_info):
return
im
,
im_info
return
im
,
im_info
class
Resize_Mult32
(
object
):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def
__init__
(
self
,
limit_side_len
,
limit_type
,
interp
=
cv2
.
INTER_LINEAR
):
self
.
limit_side_len
=
limit_side_len
self
.
limit_type
=
limit_type
self
.
interp
=
interp
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im_channel
=
im
.
shape
[
2
]
im_scale_y
,
im_scale_x
=
self
.
generate_scale
(
im
)
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
im_info
[
'im_shape'
]
=
np
.
array
(
im
.
shape
[:
2
]).
astype
(
'float32'
)
im_info
[
'scale_factor'
]
=
np
.
array
(
[
im_scale_y
,
im_scale_x
]).
astype
(
'float32'
)
return
im
,
im_info
def
generate_scale
(
self
,
img
):
"""
Args:
img (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
limit_side_len
=
self
.
limit_side_len
h
,
w
,
c
=
img
.
shape
# limit the max side
if
self
.
limit_type
==
'max'
:
if
max
(
h
,
w
)
>
limit_side_len
:
if
h
>
w
:
ratio
=
float
(
limit_side_len
)
/
h
else
:
ratio
=
float
(
limit_side_len
)
/
w
else
:
ratio
=
1.
elif
self
.
limit_type
==
'min'
:
if
min
(
h
,
w
)
<
limit_side_len
:
if
h
<
w
:
ratio
=
float
(
limit_side_len
)
/
h
else
:
ratio
=
float
(
limit_side_len
)
/
w
else
:
ratio
=
1.
elif
self
.
limit_type
==
'resize_long'
:
ratio
=
float
(
limit_side_len
)
/
max
(
h
,
w
)
else
:
raise
Exception
(
'not support limit type, image '
)
resize_h
=
int
(
h
*
ratio
)
resize_w
=
int
(
w
*
ratio
)
resize_h
=
max
(
int
(
round
(
resize_h
/
32
)
*
32
),
32
)
resize_w
=
max
(
int
(
round
(
resize_w
/
32
)
*
32
),
32
)
im_scale_y
=
resize_h
/
float
(
h
)
im_scale_x
=
resize_w
/
float
(
w
)
return
im_scale_y
,
im_scale_x
class
Resize
(
object
):
class
Resize
(
object
):
"""resize image by target_size and max_size
"""resize image by target_size and max_size
Args:
Args:
...
...
deploy/python/utils.py
浏览文件 @
3f5ff9e6
...
@@ -27,6 +27,14 @@ def argsparser():
...
@@ -27,6 +27,14 @@ def argsparser():
help
=
(
"Directory include:'model.pdiparams', 'model.pdmodel', "
help
=
(
"Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."
),
"'infer_cfg.yml', created by tools/export_model.py."
),
required
=
True
)
required
=
True
)
parser
.
add_argument
(
"--det_algorithm"
,
type
=
str
,
default
=
'DB'
)
parser
.
add_argument
(
"--det_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--det_limit_side_len"
,
type
=
float
,
default
=
960
)
parser
.
add_argument
(
"--det_limit_type"
,
type
=
str
,
default
=
'max'
)
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'SVTR_LCNet'
)
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 48, 320"
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
parser
.
add_argument
(
"--image_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of image file."
)
"--image_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of image file."
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
deploy/python/vechile_plate.py
0 → 100644
浏览文件 @
3f5ff9e6
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
yaml
import
glob
from
functools
import
reduce
import
time
import
cv2
import
numpy
as
np
import
math
import
paddle
import
sys
# add deploy path of PadleDetection to sys.path
parent_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
*
([
'..'
])))
sys
.
path
.
insert
(
0
,
parent_path
)
from
utils
import
Timer
,
get_current_memory_mb
from
infer
import
Detector
,
get_test_images
,
print_arguments
,
create_inputs
from
vechile_plateutils
import
create_predictor
,
get_infer_gpuid
,
argsparser
,
get_rotate_crop_image
from
vecplatepostprocess
import
build_post_process
from
preprocess
import
preprocess
,
Resize
,
NormalizeImage
,
Permute
,
PadStride
,
LetterBoxResize
,
WarpAffine
,
Pad
,
decode_image
,
Resize_Mult32
class
PlateDetector
(
object
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
det_algorithm
=
args
.
det_algorithm
self
.
pre_process_list
=
{
'Resize_Mult32'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
,
},
'NormalizeImage'
:
{
'mean'
:
[
0.485
,
0.456
,
0.406
],
'std'
:
[
0.229
,
0.224
,
0.225
],
'is_scale'
:
True
,
},
'Permute'
:
{}
}
postprocess_params
=
{}
postprocess_params
[
'name'
]
=
'DBPostProcess'
postprocess_params
[
"thresh"
]
=
0.3
postprocess_params
[
"box_thresh"
]
=
0.6
postprocess_params
[
"max_candidates"
]
=
1000
postprocess_params
[
"unclip_ratio"
]
=
1.5
postprocess_params
[
"use_dilation"
]
=
False
postprocess_params
[
"score_mode"
]
=
"fast"
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
create_predictor
(
args
,
'det'
)
if
args
.
run_benchmark
:
import
auto_log
pid
=
os
.
getpid
()
gpu_id
=
get_infer_gpuid
()
self
.
autolog
=
auto_log
.
AutoLogger
(
model_name
=
"det"
,
model_precision
=
"fp32"
,
batch_size
=
1
,
data_shape
=
"dynamic"
,
save_path
=
None
,
inference_config
=
self
.
config
,
pids
=
pid
,
process_name
=
None
,
gpu_ids
=
gpu_id
if
args
.
device
==
"GPU"
else
None
,
time_keys
=
[
'preprocess_time'
,
'inference_time'
,
'postprocess_time'
],
warmup
=
2
,
)
def
preprocess
(
self
,
image_list
):
preprocess_ops
=
[]
for
op_type
,
new_op_info
in
self
.
pre_process_list
.
items
():
preprocess_ops
.
append
(
eval
(
op_type
)(
**
new_op_info
))
input_im_lst
=
[]
input_im_info_lst
=
[]
for
im_path
in
image_list
:
im
,
im_info
=
preprocess
(
im_path
,
preprocess_ops
)
input_im_lst
.
append
(
im
)
input_im_info_lst
.
append
(
im_info
[
'im_shape'
])
return
np
.
stack
(
input_im_lst
,
axis
=
0
),
input_im_info_lst
def
order_points_clockwise
(
self
,
pts
):
rect
=
np
.
zeros
((
4
,
2
),
dtype
=
"float32"
)
s
=
pts
.
sum
(
axis
=
1
)
rect
[
0
]
=
pts
[
np
.
argmin
(
s
)]
rect
[
2
]
=
pts
[
np
.
argmax
(
s
)]
diff
=
np
.
diff
(
pts
,
axis
=
1
)
rect
[
1
]
=
pts
[
np
.
argmin
(
diff
)]
rect
[
3
]
=
pts
[
np
.
argmax
(
diff
)]
return
rect
def
clip_det_res
(
self
,
points
,
img_height
,
img_width
):
for
pno
in
range
(
points
.
shape
[
0
]):
points
[
pno
,
0
]
=
int
(
min
(
max
(
points
[
pno
,
0
],
0
),
img_width
-
1
))
points
[
pno
,
1
]
=
int
(
min
(
max
(
points
[
pno
,
1
],
0
),
img_height
-
1
))
return
points
def
filter_tag_det_res
(
self
,
dt_boxes
,
image_shape
):
img_height
,
img_width
=
image_shape
[
0
:
2
]
dt_boxes_new
=
[]
for
box
in
dt_boxes
:
box
=
self
.
order_points_clockwise
(
box
)
box
=
self
.
clip_det_res
(
box
,
img_height
,
img_width
)
rect_width
=
int
(
np
.
linalg
.
norm
(
box
[
0
]
-
box
[
1
]))
rect_height
=
int
(
np
.
linalg
.
norm
(
box
[
0
]
-
box
[
3
]))
if
rect_width
<=
3
or
rect_height
<=
3
:
continue
dt_boxes_new
.
append
(
box
)
dt_boxes
=
np
.
array
(
dt_boxes_new
)
return
dt_boxes
def
filter_tag_det_res_only_clip
(
self
,
dt_boxes
,
image_shape
):
img_height
,
img_width
=
image_shape
[
0
:
2
]
dt_boxes_new
=
[]
for
box
in
dt_boxes
:
box
=
self
.
clip_det_res
(
box
,
img_height
,
img_width
)
dt_boxes_new
.
append
(
box
)
dt_boxes
=
np
.
array
(
dt_boxes_new
)
return
dt_boxes
def
predict_image
(
self
,
img
):
st
=
time
.
time
()
if
self
.
args
.
run_benchmark
:
self
.
autolog
.
times
.
start
()
img
,
shape_list
=
self
.
preprocess
(
img
)
if
img
is
None
:
return
None
,
0
# img = np.expand_dims(img, axis=0)
# shape_list = np.expand_dims(shape_list, axis=0)
# img = img.copy()
if
self
.
args
.
run_benchmark
:
self
.
autolog
.
times
.
stamp
()
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
args
.
run_benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
{}
preds
[
'maps'
]
=
outputs
[
0
]
#self.predictor.try_shrink_memory()
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
shape_list
[
0
])
if
self
.
args
.
run_benchmark
:
self
.
autolog
.
times
.
end
(
stamp
=
True
)
et
=
time
.
time
()
return
dt_boxes
,
et
-
st
class
TextRecognizer
(
object
):
def
__init__
(
self
,
FLAGS
,
input_shape
=
[
3
,
48
,
320
],
batch_size
=
8
,
rec_algorithm
=
"SVTR"
,
word_dict_path
=
"rec_word_dict.txt"
,
use_gpu
=
True
,
benchmark
=
False
):
self
.
rec_image_shape
=
input_shape
self
.
rec_batch_num
=
batch_size
self
.
rec_algorithm
=
rec_algorithm
isuse_space_char
=
True
postprocess_params
=
{
'name'
:
'CTCLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
isuse_space_char
}
if
self
.
rec_algorithm
==
"SRN"
:
postprocess_params
=
{
'name'
:
'SRNLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
isuse_space_char
}
elif
self
.
rec_algorithm
==
"RARE"
:
postprocess_params
=
{
'name'
:
'AttnLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
isuse_space_char
}
elif
self
.
rec_algorithm
==
'NRTR'
:
postprocess_params
=
{
'name'
:
'NRTRLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
isuse_space_char
}
elif
self
.
rec_algorithm
==
"SAR"
:
postprocess_params
=
{
'name'
:
'SARLabelDecode'
,
"character_dict_path"
:
word_dict_path
,
"use_space_char"
:
isuse_space_char
}
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
create_predictor
(
FLAGS
,
'rec'
)
self
.
benchmark
=
benchmark
self
.
use_onnx
=
False
if
benchmark
:
import
auto_log
pid
=
os
.
getpid
()
gpu_id
=
get_infer_gpuid
()
self
.
autolog
=
auto_log
.
AutoLogger
(
model_name
=
"rec"
,
model_precision
=
'fp32'
,
batch_size
=
batch_size
,
data_shape
=
"dynamic"
,
save_path
=
None
,
#save_log_path,
inference_config
=
self
.
config
,
pids
=
pid
,
process_name
=
None
,
gpu_ids
=
gpu_id
if
use_gpu
else
None
,
time_keys
=
[
'preprocess_time'
,
'inference_time'
,
'postprocess_time'
],
warmup
=
0
)
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
if
self
.
rec_algorithm
==
'NRTR'
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
# return padding_im
image_pil
=
Image
.
fromarray
(
np
.
uint8
(
img
))
img
=
image_pil
.
resize
([
100
,
32
],
Image
.
ANTIALIAS
)
img
=
np
.
array
(
img
)
norm_img
=
np
.
expand_dims
(
img
,
-
1
)
norm_img
=
norm_img
.
transpose
((
2
,
0
,
1
))
return
norm_img
.
astype
(
np
.
float32
)
/
128.
-
1.
assert
imgC
==
img
.
shape
[
2
]
imgW
=
int
((
imgH
*
max_wh_ratio
))
if
self
.
use_onnx
:
w
=
self
.
input_tensor
.
shape
[
3
:][
0
]
if
w
is
not
None
and
w
>
0
:
imgW
=
w
h
,
w
=
img
.
shape
[:
2
]
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
if
self
.
rec_algorithm
==
'RARE'
:
if
resized_w
>
self
.
rec_image_shape
[
2
]:
resized_w
=
self
.
rec_image_shape
[
2
]
imgW
=
self
.
rec_image_shape
[
2
]
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
padding_im
=
np
.
zeros
((
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
def
resize_norm_img_svtr
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
return
resized_image
def
resize_norm_img_srn
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
img_black
=
np
.
zeros
((
imgH
,
imgW
))
im_hei
=
img
.
shape
[
0
]
im_wid
=
img
.
shape
[
1
]
if
im_wid
<=
im_hei
*
1
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
1
,
imgH
))
elif
im_wid
<=
im_hei
*
2
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
2
,
imgH
))
elif
im_wid
<=
im_hei
*
3
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
3
,
imgH
))
else
:
img_new
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
))
img_np
=
np
.
asarray
(
img_new
)
img_np
=
cv2
.
cvtColor
(
img_np
,
cv2
.
COLOR_BGR2GRAY
)
img_black
[:,
0
:
img_np
.
shape
[
1
]]
=
img_np
img_black
=
img_black
[:,
:,
np
.
newaxis
]
row
,
col
,
c
=
img_black
.
shape
c
=
1
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
srn_other_inputs
(
self
,
image_shape
,
num_heads
,
max_text_length
):
imgC
,
imgH
,
imgW
=
image_shape
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
(
(
feature_dim
,
1
)).
astype
(
'int64'
)
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
(
(
max_text_length
,
1
)).
astype
(
'int64'
)
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias1
=
np
.
tile
(
gsrm_slf_attn_bias1
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
gsrm_slf_attn_bias2
=
np
.
tril
(
gsrm_attn_bias_data
,
-
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias2
=
np
.
tile
(
gsrm_slf_attn_bias2
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
encoder_word_pos
=
encoder_word_pos
[
np
.
newaxis
,
:]
gsrm_word_pos
=
gsrm_word_pos
[
np
.
newaxis
,
:]
return
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
def
process_image_srn
(
self
,
img
,
image_shape
,
num_heads
,
max_text_length
):
norm_img
=
self
.
resize_norm_img_srn
(
img
,
image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
=
\
self
.
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
)
gsrm_slf_attn_bias1
=
gsrm_slf_attn_bias1
.
astype
(
np
.
float32
)
gsrm_slf_attn_bias2
=
gsrm_slf_attn_bias2
.
astype
(
np
.
float32
)
encoder_word_pos
=
encoder_word_pos
.
astype
(
np
.
int64
)
gsrm_word_pos
=
gsrm_word_pos
.
astype
(
np
.
int64
)
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
def
resize_norm_img_sar
(
self
,
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
valid_ratio
=
1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor
=
int
(
1
/
width_downsample_ratio
)
# resize
ratio
=
w
/
float
(
h
)
resize_w
=
math
.
ceil
(
imgH
*
ratio
)
if
resize_w
%
width_divisor
!=
0
:
resize_w
=
round
(
resize_w
/
width_divisor
)
*
width_divisor
if
imgW_min
is
not
None
:
resize_w
=
max
(
imgW_min
,
resize_w
)
if
imgW_max
is
not
None
:
valid_ratio
=
min
(
1.0
,
1.0
*
resize_w
/
imgW_max
)
resize_w
=
min
(
imgW_max
,
resize_w
)
resized_image
=
cv2
.
resize
(
img
,
(
resize_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
# norm
if
image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
resized_image
=
resized_image
[
np
.
newaxis
,
:]
else
:
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
resize_shape
=
resized_image
.
shape
padding_im
=
-
1.0
*
np
.
ones
((
imgC
,
imgH
,
imgW_max
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resize_w
]
=
resized_image
pad_shape
=
padding_im
.
shape
return
padding_im
,
resize_shape
,
pad_shape
,
valid_ratio
def
__call__
(
self
,
img_list
):
img_num
=
len
(
img_list
)
# Calculate the aspect ratio of all text bars
width_list
=
[]
for
img
in
img_list
:
width_list
.
append
(
img
.
shape
[
1
]
/
float
(
img
.
shape
[
0
]))
# Sorting can speed up the recognition process
indices
=
np
.
argsort
(
np
.
array
(
width_list
))
rec_res
=
[[
''
,
0.0
]]
*
img_num
batch_num
=
self
.
rec_batch_num
st
=
time
.
time
()
if
self
.
benchmark
:
self
.
autolog
.
times
.
start
()
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
norm_img_batch
=
[]
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
max_wh_ratio
=
imgW
/
imgH
# max_wh_ratio = 0
for
ino
in
range
(
beg_img_no
,
end_img_no
):
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
if
self
.
rec_algorithm
==
"SAR"
:
norm_img
,
_
,
_
,
valid_ratio
=
self
.
resize_norm_img_sar
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
valid_ratio
=
np
.
expand_dims
(
valid_ratio
,
axis
=
0
)
valid_ratios
=
[]
valid_ratios
.
append
(
valid_ratio
)
norm_img_batch
.
append
(
norm_img
)
elif
self
.
rec_algorithm
==
"SRN"
:
norm_img
=
self
.
process_image_srn
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
,
8
,
25
)
encoder_word_pos_list
=
[]
gsrm_word_pos_list
=
[]
gsrm_slf_attn_bias1_list
=
[]
gsrm_slf_attn_bias2_list
=
[]
encoder_word_pos_list
.
append
(
norm_img
[
1
])
gsrm_word_pos_list
.
append
(
norm_img
[
2
])
gsrm_slf_attn_bias1_list
.
append
(
norm_img
[
3
])
gsrm_slf_attn_bias2_list
.
append
(
norm_img
[
4
])
norm_img_batch
.
append
(
norm_img
[
0
])
elif
self
.
rec_algorithm
==
"SVTR"
:
norm_img
=
self
.
resize_norm_img_svtr
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
else
:
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
norm_img_batch
.
copy
()
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
if
self
.
rec_algorithm
==
"SRN"
:
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
)
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
)
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
gsrm_slf_attn_bias1_list
)
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
)
inputs
=
[
norm_img_batch
,
encoder_word_pos_list
,
gsrm_word_pos_list
,
gsrm_slf_attn_bias1_list
,
gsrm_slf_attn_bias2_list
,
]
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
preds
=
{
"predict"
:
outputs
[
2
]}
else
:
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
{
"predict"
:
outputs
[
2
]}
elif
self
.
rec_algorithm
==
"SAR"
:
valid_ratios
=
np
.
concatenate
(
valid_ratios
)
inputs
=
[
norm_img_batch
,
valid_ratios
,
]
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
preds
=
outputs
[
0
]
else
:
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
outputs
[
0
]
else
:
if
self
.
use_onnx
:
input_dict
=
{}
input_dict
[
self
.
input_tensor
.
name
]
=
norm_img_batch
outputs
=
self
.
predictor
.
run
(
self
.
output_tensors
,
input_dict
)
preds
=
outputs
[
0
]
else
:
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
if
len
(
outputs
)
!=
1
:
preds
=
outputs
else
:
preds
=
outputs
[
0
]
rec_result
=
self
.
postprocess_op
(
preds
)
for
rno
in
range
(
len
(
rec_result
)):
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
rec_result
[
rno
]
if
self
.
benchmark
:
self
.
autolog
.
times
.
end
(
stamp
=
True
)
return
rec_res
,
time
.
time
()
-
st
class
PlateRecognizer
(
object
):
def
__init__
(
self
):
self
.
batch_size
=
8
self
.
platedetector
=
PlateDetector
(
FLAGS
)
self
.
textrecognizer
=
TextRecognizer
(
FLAGS
,
input_shape
=
[
3
,
48
,
320
],
batch_size
=
8
,
rec_algorithm
=
"SVTR"
,
word_dict_path
=
"rec_word_dict.txt"
,
use_gpu
=
True
,
benchmark
=
False
)
def
get_platelicense
(
self
,
image_list
):
plate_text_list
=
[]
plateboxes
,
det_time
=
self
.
platedetector
.
predict_image
(
image_list
)
for
idx
,
boxes_pcar
in
enumerate
(
plateboxes
):
plate_images
=
get_rotate_crop_image
(
image_list
[
idx
],
boxes_pcar
)
print
(
plate_images
.
shape
)
plate_texts
=
self
.
textrecognizer
(
plate_images
)
plate_text_list
.
append
(
plate_texts
)
import
pdb
pdb
.
set_trace
()
return
results
def
main
():
detector
=
PlateRecognizer
()
# predict from image
if
FLAGS
.
image_dir
is
None
and
FLAGS
.
image_file
is
not
None
:
assert
FLAGS
.
batch_size
==
1
,
"batch_size should be 1, when image_file is not None"
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
for
img
in
img_list
:
image
=
cv2
.
imread
(
img
)
results
=
detector
.
get_platelicense
([
image
])
if
not
FLAGS
.
run_benchmark
:
detector
.
det_times
.
info
(
average
=
True
)
else
:
mems
=
{
'cpu_rss_mb'
:
detector
.
cpu_mem
/
len
(
img_list
),
'gpu_rss_mb'
:
detector
.
gpu_mem
/
len
(
img_list
),
'gpu_util'
:
detector
.
gpu_util
*
100
/
len
(
img_list
)
}
perf_info
=
detector
.
det_times
.
report
(
average
=
True
)
model_dir
=
FLAGS
.
model_dir
mode
=
FLAGS
.
run_mode
model_info
=
{
'model_name'
:
model_dir
.
strip
(
'/'
).
split
(
'/'
)[
-
1
],
'precision'
:
mode
.
split
(
'_'
)[
-
1
]
}
data_info
=
{
'batch_size'
:
FLAGS
.
batch_size
,
'shape'
:
"dynamic_shape"
,
'data_num'
:
perf_info
[
'img_num'
]
}
det_log
=
PaddleInferBenchmark
(
detector
.
config
,
model_info
,
data_info
,
perf_info
,
mems
)
det_log
(
'Attr'
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
FLAGS
.
device
=
FLAGS
.
device
.
upper
()
assert
FLAGS
.
device
in
[
'CPU'
,
'GPU'
,
'XPU'
],
"device should be CPU, GPU or XPU"
assert
not
FLAGS
.
use_gpu
,
"use_gpu has been deprecated, please use --device"
main
()
deploy/python/vechile_plateutils.py
0 → 100644
浏览文件 @
3f5ff9e6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
sys
import
platform
import
cv2
import
numpy
as
np
import
paddle
from
PIL
import
Image
,
ImageDraw
,
ImageFont
import
math
from
paddle
import
inference
import
time
import
ast
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--det_algorithm"
,
type
=
str
,
default
=
'DB'
)
parser
.
add_argument
(
"--det_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--det_limit_side_len"
,
type
=
float
,
default
=
960
)
parser
.
add_argument
(
"--det_limit_type"
,
type
=
str
,
default
=
'max'
)
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'SVTR_LCNet'
)
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 48, 320"
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--image_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of image file."
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
,
default
=
None
,
help
=
"Dir of image file, `image_file` has a higher priority."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"batch_size for inference."
)
parser
.
add_argument
(
"--video_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of video file, `video_file` or `camera_id` has a highest priority."
)
parser
.
add_argument
(
"--camera_id"
,
type
=
int
,
default
=-
1
,
help
=
"device id of camera to predict."
)
parser
.
add_argument
(
"--threshold"
,
type
=
float
,
default
=
0.5
,
help
=
"Threshold of score."
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"output"
,
help
=
"Directory of output visualization files."
)
parser
.
add_argument
(
"--run_mode"
,
type
=
str
,
default
=
'paddle'
,
help
=
"mode of running(paddle/trt_fp32/trt_fp16/trt_int8)"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'cpu'
,
help
=
"Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Deprecated, please use `--device`."
)
parser
.
add_argument
(
"--run_benchmark"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether to predict a image_file repeatedly for benchmark"
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use mkldnn with CPU."
)
parser
.
add_argument
(
"--enable_mkldnn_bfloat16"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use mkldnn bfloat16 inference with CPU."
)
parser
.
add_argument
(
"--cpu_threads"
,
type
=
int
,
default
=
1
,
help
=
"Num of threads with CPU."
)
parser
.
add_argument
(
"--trt_min_shape"
,
type
=
int
,
default
=
1
,
help
=
"min_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_max_shape"
,
type
=
int
,
default
=
1280
,
help
=
"max_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_opt_shape"
,
type
=
int
,
default
=
640
,
help
=
"opt_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_calib_mode"
,
type
=
bool
,
default
=
False
,
help
=
"If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True."
)
parser
.
add_argument
(
'--save_images'
,
action
=
'store_true'
,
help
=
'Save visualization image results.'
)
parser
.
add_argument
(
'--save_mot_txts'
,
action
=
'store_true'
,
help
=
'Save tracking results (txt).'
)
parser
.
add_argument
(
'--save_mot_txt_per_img'
,
action
=
'store_true'
,
help
=
'Save tracking results (txt) for each image.'
)
parser
.
add_argument
(
'--scaled'
,
type
=
bool
,
default
=
False
,
help
=
"Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector."
)
parser
.
add_argument
(
"--tracker_config"
,
type
=
str
,
default
=
None
,
help
=
(
"tracker donfig"
))
parser
.
add_argument
(
"--reid_model_dir"
,
type
=
str
,
default
=
None
,
help
=
(
"Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."
))
parser
.
add_argument
(
"--reid_batch_size"
,
type
=
int
,
default
=
50
,
help
=
"max batch_size for reid model inference."
)
parser
.
add_argument
(
'--use_dark'
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'whether to use darkpose to get better keypoint position predict '
)
parser
.
add_argument
(
"--action_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of input file for action recognition."
)
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
50
,
help
=
"Temporal size of skeleton feature for action recognition."
)
parser
.
add_argument
(
"--random_pad"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether do random padding for action recognition."
)
parser
.
add_argument
(
"--save_results"
,
type
=
bool
,
default
=
False
,
help
=
"Whether save detection result to file using coco format"
)
return
parser
def
create_predictor
(
args
,
mode
):
if
mode
==
"det"
:
model_dir
=
args
.
det_model_dir
elif
mode
==
'cls'
:
model_dir
=
args
.
cls_model_dir
elif
mode
==
'rec'
:
model_dir
=
args
.
rec_model_dir
elif
mode
==
'table'
:
model_dir
=
args
.
table_model_dir
else
:
model_dir
=
args
.
e2e_model_dir
if
model_dir
is
None
:
print
(
"not find {} model file path {}"
.
format
(
mode
,
model_dir
))
sys
.
exit
(
0
)
model_file_path
=
model_dir
+
"/inference.pdmodel"
params_file_path
=
model_dir
+
"/inference.pdiparams"
if
not
os
.
path
.
exists
(
model_file_path
):
raise
ValueError
(
"not find model file path {}"
.
format
(
model_file_path
))
if
not
os
.
path
.
exists
(
params_file_path
):
raise
ValueError
(
"not find params file path {}"
.
format
(
params_file_path
))
config
=
inference
.
Config
(
model_file_path
,
params_file_path
)
if
args
.
device
==
"GPU"
:
gpu_id
=
get_infer_gpuid
()
if
gpu_id
is
None
:
print
(
"GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
)
config
.
enable_use_gpu
(
500
,
0
)
precision_map
=
{
'trt_int8'
:
inference
.
PrecisionType
.
Int8
,
'trt_fp32'
:
inference
.
PrecisionType
.
Float32
,
'trt_fp16'
:
inference
.
PrecisionType
.
Half
}
if
args
.
run_mode
in
precision_map
.
keys
():
config
.
enable_tensorrt_engine
(
workspace_size
=
(
1
<<
25
)
*
batch_size
,
max_batch_size
=
batch_size
,
min_subgraph_size
=
min_subgraph_size
,
precision_mode
=
precision_map
[
args
.
run_mode
],
use_static
=
False
,
use_calib_mode
=
trt_calib_mode
)
# skip the minmum trt subgraph
use_dynamic_shape
=
True
if
mode
==
"det"
:
min_input_shape
=
{
"x"
:
[
1
,
3
,
50
,
50
],
"conv2d_92.tmp_0"
:
[
1
,
120
,
20
,
20
],
"conv2d_91.tmp_0"
:
[
1
,
24
,
10
,
10
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
20
,
20
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
256
,
10
,
10
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
256
,
20
,
20
],
"conv2d_124.tmp_0"
:
[
1
,
256
,
20
,
20
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
64
,
20
,
20
],
"elementwise_add_7"
:
[
1
,
56
,
2
,
2
],
"nearest_interp_v2_0.tmp_0"
:
[
1
,
256
,
2
,
2
]
}
max_input_shape
=
{
"x"
:
[
1
,
3
,
1536
,
1536
],
"conv2d_92.tmp_0"
:
[
1
,
120
,
400
,
400
],
"conv2d_91.tmp_0"
:
[
1
,
24
,
200
,
200
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
400
,
400
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
256
,
200
,
200
],
"conv2d_124.tmp_0"
:
[
1
,
256
,
400
,
400
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
256
,
400
,
400
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
64
,
400
,
400
],
"elementwise_add_7"
:
[
1
,
56
,
400
,
400
],
"nearest_interp_v2_0.tmp_0"
:
[
1
,
256
,
400
,
400
]
}
opt_input_shape
=
{
"x"
:
[
1
,
3
,
640
,
640
],
"conv2d_92.tmp_0"
:
[
1
,
120
,
160
,
160
],
"conv2d_91.tmp_0"
:
[
1
,
24
,
80
,
80
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
160
,
160
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
256
,
80
,
80
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
256
,
160
,
160
],
"conv2d_124.tmp_0"
:
[
1
,
256
,
160
,
160
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
64
,
160
,
160
],
"elementwise_add_7"
:
[
1
,
56
,
40
,
40
],
"nearest_interp_v2_0.tmp_0"
:
[
1
,
256
,
40
,
40
]
}
min_pact_shape
=
{
"nearest_interp_v2_26.tmp_0"
:
[
1
,
256
,
20
,
20
],
"nearest_interp_v2_27.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_28.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_29.tmp_0"
:
[
1
,
64
,
20
,
20
]
}
max_pact_shape
=
{
"nearest_interp_v2_26.tmp_0"
:
[
1
,
256
,
400
,
400
],
"nearest_interp_v2_27.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_28.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_29.tmp_0"
:
[
1
,
64
,
400
,
400
]
}
opt_pact_shape
=
{
"nearest_interp_v2_26.tmp_0"
:
[
1
,
256
,
160
,
160
],
"nearest_interp_v2_27.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_28.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_29.tmp_0"
:
[
1
,
64
,
160
,
160
]
}
min_input_shape
.
update
(
min_pact_shape
)
max_input_shape
.
update
(
max_pact_shape
)
opt_input_shape
.
update
(
opt_pact_shape
)
elif
mode
==
"rec"
:
imgH
=
int
(
args
.
rec_image_shape
.
split
(
','
)[
-
2
])
min_input_shape
=
{
"x"
:
[
1
,
3
,
imgH
,
10
]}
max_input_shape
=
{
"x"
:
[
args
.
batch_size
,
3
,
imgH
,
2304
]}
opt_input_shape
=
{
"x"
:
[
args
.
batch_size
,
3
,
imgH
,
320
]}
elif
mode
==
"cls"
:
min_input_shape
=
{
"x"
:
[
1
,
3
,
48
,
10
]}
max_input_shape
=
{
"x"
:
[
args
.
batch_size
,
3
,
48
,
1024
]}
opt_input_shape
=
{
"x"
:
[
args
.
batch_size
,
3
,
48
,
320
]}
else
:
use_dynamic_shape
=
False
if
use_dynamic_shape
:
config
.
set_trt_dynamic_shape_info
(
min_input_shape
,
max_input_shape
,
opt_input_shape
)
else
:
config
.
disable_gpu
()
if
hasattr
(
args
,
"cpu_threads"
):
config
.
set_cpu_math_library_num_threads
(
args
.
cpu_threads
)
else
:
# default cpu threads as 10
config
.
set_cpu_math_library_num_threads
(
10
)
if
args
.
enable_mkldnn
:
# cache 10 different shapes for mkldnn to avoid memory leak
config
.
set_mkldnn_cache_capacity
(
10
)
config
.
enable_mkldnn
()
if
args
.
run_mode
==
"fp16"
:
config
.
enable_mkldnn_bfloat16
()
# enable memory optim
config
.
enable_memory_optim
()
config
.
disable_glog_info
()
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
delete_pass
(
"matmul_transpose_reshape_fuse_pass"
)
if
mode
==
'table'
:
config
.
delete_pass
(
"fc_fuse_pass"
)
# not supported for table
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_ir_optim
(
True
)
# create predictor
predictor
=
inference
.
create_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
for
name
in
input_names
:
input_tensor
=
predictor
.
get_input_handle
(
name
)
output_tensors
=
get_output_tensors
(
args
,
mode
,
predictor
)
return
predictor
,
input_tensor
,
output_tensors
,
config
def
get_output_tensors
(
args
,
mode
,
predictor
):
output_names
=
predictor
.
get_output_names
()
output_tensors
=
[]
if
mode
==
"rec"
and
args
.
rec_algorithm
in
[
"CRNN"
,
"SVTR_LCNet"
]:
output_name
=
'softmax_0.tmp_0'
if
output_name
in
output_names
:
return
[
predictor
.
get_output_handle
(
output_name
)]
else
:
for
output_name
in
output_names
:
output_tensor
=
predictor
.
get_output_handle
(
output_name
)
output_tensors
.
append
(
output_tensor
)
else
:
for
output_name
in
output_names
:
output_tensor
=
predictor
.
get_output_handle
(
output_name
)
output_tensors
.
append
(
output_tensor
)
return
output_tensors
def
get_infer_gpuid
():
sysstr
=
platform
.
system
()
if
sysstr
==
"Windows"
:
return
0
if
not
paddle
.
fluid
.
core
.
is_compiled_with_rocm
():
cmd
=
"env | grep CUDA_VISIBLE_DEVICES"
else
:
cmd
=
"env | grep HIP_VISIBLE_DEVICES"
env_cuda
=
os
.
popen
(
cmd
).
readlines
()
if
len
(
env_cuda
)
==
0
:
return
0
else
:
gpu_id
=
env_cuda
[
0
].
strip
().
split
(
"="
)[
1
]
return
int
(
gpu_id
[
0
])
def
draw_e2e_res
(
dt_boxes
,
strs
,
img_path
):
src_im
=
cv2
.
imread
(
img_path
)
for
box
,
str
in
zip
(
dt_boxes
,
strs
):
box
=
box
.
astype
(
np
.
int32
).
reshape
((
-
1
,
1
,
2
))
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
cv2
.
putText
(
src_im
,
str
,
org
=
(
int
(
box
[
0
,
0
,
0
]),
int
(
box
[
0
,
0
,
1
])),
fontFace
=
cv2
.
FONT_HERSHEY_COMPLEX
,
fontScale
=
0.7
,
color
=
(
0
,
255
,
0
),
thickness
=
1
)
return
src_im
def
draw_text_det_res
(
dt_boxes
,
img_path
):
src_im
=
cv2
.
imread
(
img_path
)
for
box
in
dt_boxes
:
box
=
np
.
array
(
box
).
astype
(
np
.
int32
).
reshape
(
-
1
,
2
)
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
return
src_im
def
resize_img
(
img
,
input_size
=
600
):
"""
resize img and limit the longest side of the image to input_size
"""
img
=
np
.
array
(
img
)
im_shape
=
img
.
shape
im_size_max
=
np
.
max
(
im_shape
[
0
:
2
])
im_scale
=
float
(
input_size
)
/
float
(
im_size_max
)
img
=
cv2
.
resize
(
img
,
None
,
None
,
fx
=
im_scale
,
fy
=
im_scale
)
return
img
def
draw_ocr
(
image
,
boxes
,
txts
=
None
,
scores
=
None
,
drop_score
=
0.5
,
font_path
=
"./doc/fonts/simfang.ttf"
):
"""
Visualize the results of OCR detection and recognition
args:
image(Image|array): RGB image
boxes(list): boxes with shape(N, 4, 2)
txts(list): the texts
scores(list): txxs corresponding scores
drop_score(float): only scores greater than drop_threshold will be visualized
font_path: the path of font which is used to draw text
return(array):
the visualized img
"""
if
scores
is
None
:
scores
=
[
1
]
*
len
(
boxes
)
box_num
=
len
(
boxes
)
for
i
in
range
(
box_num
):
if
scores
is
not
None
and
(
scores
[
i
]
<
drop_score
or
math
.
isnan
(
scores
[
i
])):
continue
box
=
np
.
reshape
(
np
.
array
(
boxes
[
i
]),
[
-
1
,
1
,
2
]).
astype
(
np
.
int64
)
image
=
cv2
.
polylines
(
np
.
array
(
image
),
[
box
],
True
,
(
255
,
0
,
0
),
2
)
if
txts
is
not
None
:
img
=
np
.
array
(
resize_img
(
image
,
input_size
=
600
))
txt_img
=
text_visual
(
txts
,
scores
,
img_h
=
img
.
shape
[
0
],
img_w
=
600
,
threshold
=
drop_score
,
font_path
=
font_path
)
img
=
np
.
concatenate
([
np
.
array
(
img
),
np
.
array
(
txt_img
)],
axis
=
1
)
return
img
return
image
def
draw_ocr_box_txt
(
image
,
boxes
,
txts
,
scores
=
None
,
drop_score
=
0.5
,
font_path
=
"./doc/simfang.ttf"
):
h
,
w
=
image
.
height
,
image
.
width
img_left
=
image
.
copy
()
img_right
=
Image
.
new
(
'RGB'
,
(
w
,
h
),
(
255
,
255
,
255
))
import
random
random
.
seed
(
0
)
draw_left
=
ImageDraw
.
Draw
(
img_left
)
draw_right
=
ImageDraw
.
Draw
(
img_right
)
for
idx
,
(
box
,
txt
)
in
enumerate
(
zip
(
boxes
,
txts
)):
if
scores
is
not
None
and
scores
[
idx
]
<
drop_score
:
continue
color
=
(
random
.
randint
(
0
,
255
),
random
.
randint
(
0
,
255
),
random
.
randint
(
0
,
255
))
draw_left
.
polygon
(
box
,
fill
=
color
)
draw_right
.
polygon
(
[
box
[
0
][
0
],
box
[
0
][
1
],
box
[
1
][
0
],
box
[
1
][
1
],
box
[
2
][
0
],
box
[
2
][
1
],
box
[
3
][
0
],
box
[
3
][
1
]
],
outline
=
color
)
box_height
=
math
.
sqrt
((
box
[
0
][
0
]
-
box
[
3
][
0
])
**
2
+
(
box
[
0
][
1
]
-
box
[
3
][
1
])
**
2
)
box_width
=
math
.
sqrt
((
box
[
0
][
0
]
-
box
[
1
][
0
])
**
2
+
(
box
[
0
][
1
]
-
box
[
1
][
1
])
**
2
)
if
box_height
>
2
*
box_width
:
font_size
=
max
(
int
(
box_width
*
0.9
),
10
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
cur_y
=
box
[
0
][
1
]
for
c
in
txt
:
char_size
=
font
.
getsize
(
c
)
draw_right
.
text
(
(
box
[
0
][
0
]
+
3
,
cur_y
),
c
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
cur_y
+=
char_size
[
1
]
else
:
font_size
=
max
(
int
(
box_height
*
0.8
),
10
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
draw_right
.
text
(
[
box
[
0
][
0
],
box
[
0
][
1
]],
txt
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
img_left
=
Image
.
blend
(
image
,
img_left
,
0.5
)
img_show
=
Image
.
new
(
'RGB'
,
(
w
*
2
,
h
),
(
255
,
255
,
255
))
img_show
.
paste
(
img_left
,
(
0
,
0
,
w
,
h
))
img_show
.
paste
(
img_right
,
(
w
,
0
,
w
*
2
,
h
))
return
np
.
array
(
img_show
)
def
str_count
(
s
):
"""
Count the number of Chinese characters,
a single English character and a single number
equal to half the length of Chinese characters.
args:
s(string): the input of string
return(int):
the number of Chinese characters
"""
import
string
count_zh
=
count_pu
=
0
s_len
=
len
(
s
)
en_dg_count
=
0
for
c
in
s
:
if
c
in
string
.
ascii_letters
or
c
.
isdigit
()
or
c
.
isspace
():
en_dg_count
+=
1
elif
c
.
isalpha
():
count_zh
+=
1
else
:
count_pu
+=
1
return
s_len
-
math
.
ceil
(
en_dg_count
/
2
)
def
text_visual
(
texts
,
scores
,
img_h
=
400
,
img_w
=
600
,
threshold
=
0.
,
font_path
=
"./doc/simfang.ttf"
):
"""
create new blank img and draw txt on it
args:
texts(list): the text will be draw
scores(list|None): corresponding score of each txt
img_h(int): the height of blank img
img_w(int): the width of blank img
font_path: the path of font which is used to draw text
return(array):
"""
if
scores
is
not
None
:
assert
len
(
texts
)
==
len
(
scores
),
"The number of txts and corresponding scores must match"
def
create_blank_img
():
blank_img
=
np
.
ones
(
shape
=
[
img_h
,
img_w
],
dtype
=
np
.
int8
)
*
255
blank_img
[:,
img_w
-
1
:]
=
0
blank_img
=
Image
.
fromarray
(
blank_img
).
convert
(
"RGB"
)
draw_txt
=
ImageDraw
.
Draw
(
blank_img
)
return
blank_img
,
draw_txt
blank_img
,
draw_txt
=
create_blank_img
()
font_size
=
20
txt_color
=
(
0
,
0
,
0
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
gap
=
font_size
+
5
txt_img_list
=
[]
count
,
index
=
1
,
0
for
idx
,
txt
in
enumerate
(
texts
):
index
+=
1
if
scores
[
idx
]
<
threshold
or
math
.
isnan
(
scores
[
idx
]):
index
-=
1
continue
first_line
=
True
while
str_count
(
txt
)
>=
img_w
//
font_size
-
4
:
tmp
=
txt
txt
=
tmp
[:
img_w
//
font_size
-
4
]
if
first_line
:
new_txt
=
str
(
index
)
+
': '
+
txt
first_line
=
False
else
:
new_txt
=
' '
+
txt
draw_txt
.
text
((
0
,
gap
*
count
),
new_txt
,
txt_color
,
font
=
font
)
txt
=
tmp
[
img_w
//
font_size
-
4
:]
if
count
>=
img_h
//
gap
-
1
:
txt_img_list
.
append
(
np
.
array
(
blank_img
))
blank_img
,
draw_txt
=
create_blank_img
()
count
=
0
count
+=
1
if
first_line
:
new_txt
=
str
(
index
)
+
': '
+
txt
+
' '
+
'%.3f'
%
(
scores
[
idx
])
else
:
new_txt
=
" "
+
txt
+
" "
+
'%.3f'
%
(
scores
[
idx
])
draw_txt
.
text
((
0
,
gap
*
count
),
new_txt
,
txt_color
,
font
=
font
)
# whether add new blank img or not
if
count
>=
img_h
//
gap
-
1
and
idx
+
1
<
len
(
texts
):
txt_img_list
.
append
(
np
.
array
(
blank_img
))
blank_img
,
draw_txt
=
create_blank_img
()
count
=
0
count
+=
1
txt_img_list
.
append
(
np
.
array
(
blank_img
))
if
len
(
txt_img_list
)
==
1
:
blank_img
=
np
.
array
(
txt_img_list
[
0
])
else
:
blank_img
=
np
.
concatenate
(
txt_img_list
,
axis
=
1
)
return
np
.
array
(
blank_img
)
def
base64_to_cv2
(
b64str
):
import
base64
data
=
base64
.
b64decode
(
b64str
.
encode
(
'utf8'
))
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
data
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
return
data
def
draw_boxes
(
image
,
boxes
,
scores
=
None
,
drop_score
=
0.5
):
if
scores
is
None
:
scores
=
[
1
]
*
len
(
boxes
)
for
(
box
,
score
)
in
zip
(
boxes
,
scores
):
if
score
<
drop_score
:
continue
box
=
np
.
reshape
(
np
.
array
(
box
),
[
-
1
,
1
,
2
]).
astype
(
np
.
int64
)
image
=
cv2
.
polylines
(
np
.
array
(
image
),
[
box
],
True
,
(
255
,
0
,
0
),
2
)
return
image
def
get_rotate_crop_image
(
img
,
points
):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert
len
(
points
)
==
4
,
"shape of points must be 4*2"
img_crop_width
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
1
]),
np
.
linalg
.
norm
(
points
[
2
]
-
points
[
3
])))
img_crop_height
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
3
]),
np
.
linalg
.
norm
(
points
[
1
]
-
points
[
2
])))
pts_std
=
np
.
float32
([[
0
,
0
],
[
img_crop_width
,
0
],
[
img_crop_width
,
img_crop_height
],
[
0
,
img_crop_height
]])
M
=
cv2
.
getPerspectiveTransform
(
points
,
pts_std
)
dst_img
=
cv2
.
warpPerspective
(
img
,
M
,
(
img_crop_width
,
img_crop_height
),
borderMode
=
cv2
.
BORDER_REPLICATE
,
flags
=
cv2
.
INTER_CUBIC
)
dst_img_height
,
dst_img_width
=
dst_img
.
shape
[
0
:
2
]
if
dst_img_height
*
1.0
/
dst_img_width
>=
1.5
:
dst_img
=
np
.
rot90
(
dst_img
)
return
dst_img
def
check_gpu
(
use_gpu
):
if
use_gpu
and
not
paddle
.
is_compiled_with_cuda
():
use_gpu
=
False
return
use_gpu
if
__name__
==
'__main__'
:
pass
deploy/python/vecplatepostprocess.py
0 → 100644
浏览文件 @
3f5ff9e6
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
from
paddle.nn
import
functional
as
F
import
re
from
shapely.geometry
import
Polygon
import
pyclipper
import
cv2
import
copy
def
build_post_process
(
config
,
global_config
=
None
):
support_dict
=
[
'DBPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'SRNLabelDecode'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
,
'PRENLabelDecode'
,
'DistillationSARLabelDecode'
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
from
.pse_postprocess
import
PSEPostProcess
support_dict
.
append
(
'PSEPostProcess'
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
if
module_name
==
"None"
:
return
if
global_config
is
not
None
:
config
.
update
(
global_config
)
assert
module_name
in
support_dict
,
Exception
(
'post process only support {}'
.
format
(
support_dict
))
module_class
=
eval
(
module_name
)(
**
config
)
return
module_class
class
DBPostProcess
(
object
):
"""
The post process for Differentiable Binarization (DB).
"""
def
__init__
(
self
,
thresh
=
0.3
,
box_thresh
=
0.7
,
max_candidates
=
1000
,
unclip_ratio
=
2.0
,
use_dilation
=
False
,
score_mode
=
"fast"
,
**
kwargs
):
self
.
thresh
=
thresh
self
.
box_thresh
=
box_thresh
self
.
max_candidates
=
max_candidates
self
.
unclip_ratio
=
unclip_ratio
self
.
min_size
=
3
self
.
score_mode
=
score_mode
assert
score_mode
in
[
"slow"
,
"fast"
],
"Score mode must be in [slow, fast] but got: {}"
.
format
(
score_mode
)
self
.
dilation_kernel
=
None
if
not
use_dilation
else
np
.
array
(
[[
1
,
1
],
[
1
,
1
]])
def
boxes_from_bitmap
(
self
,
pred
,
_bitmap
,
dest_width
,
dest_height
):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap
=
_bitmap
height
,
width
=
bitmap
.
shape
outs
=
cv2
.
findContours
((
bitmap
*
255
).
astype
(
np
.
uint8
),
cv2
.
RETR_LIST
,
cv2
.
CHAIN_APPROX_SIMPLE
)
if
len
(
outs
)
==
3
:
img
,
contours
,
_
=
outs
[
0
],
outs
[
1
],
outs
[
2
]
elif
len
(
outs
)
==
2
:
contours
,
_
=
outs
[
0
],
outs
[
1
]
num_contours
=
min
(
len
(
contours
),
self
.
max_candidates
)
boxes
=
[]
scores
=
[]
for
index
in
range
(
num_contours
):
contour
=
contours
[
index
]
points
,
sside
=
self
.
get_mini_boxes
(
contour
)
if
sside
<
self
.
min_size
:
continue
points
=
np
.
array
(
points
)
if
self
.
score_mode
==
"fast"
:
score
=
self
.
box_score_fast
(
pred
,
points
.
reshape
(
-
1
,
2
))
else
:
score
=
self
.
box_score_slow
(
pred
,
contour
)
if
self
.
box_thresh
>
score
:
continue
box
=
self
.
unclip
(
points
).
reshape
(
-
1
,
1
,
2
)
box
,
sside
=
self
.
get_mini_boxes
(
box
)
if
sside
<
self
.
min_size
+
2
:
continue
box
=
np
.
array
(
box
)
box
[:,
0
]
=
np
.
clip
(
np
.
round
(
box
[:,
0
]
/
width
*
dest_width
),
0
,
dest_width
)
box
[:,
1
]
=
np
.
clip
(
np
.
round
(
box
[:,
1
]
/
height
*
dest_height
),
0
,
dest_height
)
boxes
.
append
(
box
.
astype
(
np
.
int16
))
scores
.
append
(
score
)
return
np
.
array
(
boxes
,
dtype
=
np
.
int16
),
scores
def
unclip
(
self
,
box
):
unclip_ratio
=
self
.
unclip_ratio
poly
=
Polygon
(
box
)
distance
=
poly
.
area
*
unclip_ratio
/
poly
.
length
offset
=
pyclipper
.
PyclipperOffset
()
offset
.
AddPath
(
box
,
pyclipper
.
JT_ROUND
,
pyclipper
.
ET_CLOSEDPOLYGON
)
expanded
=
np
.
array
(
offset
.
Execute
(
distance
))
return
expanded
def
get_mini_boxes
(
self
,
contour
):
bounding_box
=
cv2
.
minAreaRect
(
contour
)
points
=
sorted
(
list
(
cv2
.
boxPoints
(
bounding_box
)),
key
=
lambda
x
:
x
[
0
])
index_1
,
index_2
,
index_3
,
index_4
=
0
,
1
,
2
,
3
if
points
[
1
][
1
]
>
points
[
0
][
1
]:
index_1
=
0
index_4
=
1
else
:
index_1
=
1
index_4
=
0
if
points
[
3
][
1
]
>
points
[
2
][
1
]:
index_2
=
2
index_3
=
3
else
:
index_2
=
3
index_3
=
2
box
=
[
points
[
index_1
],
points
[
index_2
],
points
[
index_3
],
points
[
index_4
]
]
return
box
,
min
(
bounding_box
[
1
])
def
box_score_fast
(
self
,
bitmap
,
_box
):
'''
box_score_fast: use bbox mean score as the mean score
'''
h
,
w
=
bitmap
.
shape
[:
2
]
box
=
_box
.
copy
()
xmin
=
np
.
clip
(
np
.
floor
(
box
[:,
0
].
min
()).
astype
(
np
.
int
),
0
,
w
-
1
)
xmax
=
np
.
clip
(
np
.
ceil
(
box
[:,
0
].
max
()).
astype
(
np
.
int
),
0
,
w
-
1
)
ymin
=
np
.
clip
(
np
.
floor
(
box
[:,
1
].
min
()).
astype
(
np
.
int
),
0
,
h
-
1
)
ymax
=
np
.
clip
(
np
.
ceil
(
box
[:,
1
].
max
()).
astype
(
np
.
int
),
0
,
h
-
1
)
mask
=
np
.
zeros
((
ymax
-
ymin
+
1
,
xmax
-
xmin
+
1
),
dtype
=
np
.
uint8
)
box
[:,
0
]
=
box
[:,
0
]
-
xmin
box
[:,
1
]
=
box
[:,
1
]
-
ymin
cv2
.
fillPoly
(
mask
,
box
.
reshape
(
1
,
-
1
,
2
).
astype
(
np
.
int32
),
1
)
return
cv2
.
mean
(
bitmap
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
],
mask
)[
0
]
def
box_score_slow
(
self
,
bitmap
,
contour
):
'''
box_score_slow: use polyon mean score as the mean score
'''
h
,
w
=
bitmap
.
shape
[:
2
]
contour
=
contour
.
copy
()
contour
=
np
.
reshape
(
contour
,
(
-
1
,
2
))
xmin
=
np
.
clip
(
np
.
min
(
contour
[:,
0
]),
0
,
w
-
1
)
xmax
=
np
.
clip
(
np
.
max
(
contour
[:,
0
]),
0
,
w
-
1
)
ymin
=
np
.
clip
(
np
.
min
(
contour
[:,
1
]),
0
,
h
-
1
)
ymax
=
np
.
clip
(
np
.
max
(
contour
[:,
1
]),
0
,
h
-
1
)
mask
=
np
.
zeros
((
ymax
-
ymin
+
1
,
xmax
-
xmin
+
1
),
dtype
=
np
.
uint8
)
contour
[:,
0
]
=
contour
[:,
0
]
-
xmin
contour
[:,
1
]
=
contour
[:,
1
]
-
ymin
cv2
.
fillPoly
(
mask
,
contour
.
reshape
(
1
,
-
1
,
2
).
astype
(
np
.
int32
),
1
)
return
cv2
.
mean
(
bitmap
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
],
mask
)[
0
]
def
__call__
(
self
,
outs_dict
,
shape_list
):
pred
=
outs_dict
[
'maps'
]
if
isinstance
(
pred
,
paddle
.
Tensor
):
pred
=
pred
.
numpy
()
pred
=
pred
[:,
0
,
:,
:]
segmentation
=
pred
>
self
.
thresh
boxes_batch
=
[]
for
batch_index
in
range
(
pred
.
shape
[
0
]):
src_h
,
src_w
=
shape_list
[
batch_index
]
if
self
.
dilation_kernel
is
not
None
:
mask
=
cv2
.
dilate
(
np
.
array
(
segmentation
[
batch_index
]).
astype
(
np
.
uint8
),
self
.
dilation_kernel
)
else
:
mask
=
segmentation
[
batch_index
]
boxes
,
scores
=
self
.
boxes_from_bitmap
(
pred
[
batch_index
],
mask
,
src_w
,
src_h
)
boxes_batch
.
append
({
'points'
:
boxes
})
return
boxes_batch
class
BaseRecLabelDecode
(
object
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
character_str
=
[]
if
character_dict_path
is
None
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
else
:
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
line
=
line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
self
.
character_str
.
append
(
line
)
if
use_space_char
:
self
.
character_str
.
append
(
" "
)
dict_character
=
list
(
self
.
character_str
)
dict_character
=
self
.
add_special_char
(
dict_character
)
self
.
dict
=
{}
for
i
,
char
in
enumerate
(
dict_character
):
self
.
dict
[
char
]
=
i
self
.
character
=
dict_character
def
add_special_char
(
self
,
dict_character
):
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
selection
=
np
.
ones
(
len
(
text_index
[
batch_idx
]),
dtype
=
bool
)
if
is_remove_duplicate
:
selection
[
1
:]
=
text_index
[
batch_idx
][
1
:]
!=
text_index
[
batch_idx
][:
-
1
]
for
ignored_token
in
ignored_tokens
:
selection
&=
text_index
[
batch_idx
]
!=
ignored_token
char_list
=
[
self
.
character
[
text_id
]
for
text_id
in
text_index
[
batch_idx
][
selection
]
]
if
text_prob
is
not
None
:
conf_list
=
text_prob
[
batch_idx
][
selection
]
else
:
conf_list
=
[
1
]
*
len
(
selection
)
if
len
(
conf_list
)
==
0
:
conf_list
=
[
0
]
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
get_ignored_tokens
(
self
):
return
[
0
]
# for ctc blank
class
CTCLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
CTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
tuple
)
or
isinstance
(
preds
,
list
):
preds
=
preds
[
-
1
]
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
)
return
text
,
label
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
]
+
dict_character
return
dict_character
class
DistillationCTCLabelDecode
(
CTCLabelDecode
):
"""
Convert
Convert between text-label and text-index
"""
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
model_name
=
[
"student"
],
key
=
None
,
multi_head
=
False
,
**
kwargs
):
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
self
.
multi_head
=
multi_head
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
output
=
dict
()
for
name
in
self
.
model_name
:
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
if
self
.
multi_head
and
isinstance
(
pred
,
dict
):
pred
=
pred
[
'ctc'
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
class
NRTRLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
True
,
**
kwargs
):
super
(
NRTRLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
len
(
preds
)
==
2
:
preds_id
=
preds
[
0
]
preds_prob
=
preds
[
1
]
if
isinstance
(
preds_id
,
paddle
.
Tensor
):
preds_id
=
preds_id
.
numpy
()
if
isinstance
(
preds_prob
,
paddle
.
Tensor
):
preds_prob
=
preds_prob
.
numpy
()
if
preds_id
[
0
][
0
]
==
2
:
preds_idx
=
preds_id
[:,
1
:]
preds_prob
=
preds_prob
[:,
1
:]
else
:
preds_idx
=
preds_id
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
[:,
1
:])
else
:
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
[:,
1
:])
return
text
,
label
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
,
'<unk>'
,
'<s>'
,
'</s>'
]
+
dict_character
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
==
3
:
# end
break
try
:
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
except
:
continue
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
.
lower
(),
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
class
AttnLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
AttnLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
[
beg_idx
,
end_idx
]
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
end_idx
):
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
class
SEEDLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SEEDLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
padding_str
=
"padding"
self
.
end_str
=
"eos"
self
.
unknown
=
"unknown"
dict_character
=
dict_character
+
[
self
.
end_str
,
self
.
padding_str
,
self
.
unknown
]
return
dict_character
def
get_ignored_tokens
(
self
):
end_idx
=
self
.
get_beg_end_flag_idx
(
"eos"
)
return
[
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"sos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"eos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
%
beg_or_end
return
idx
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
[
end_idx
]
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
end_idx
):
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds_idx
=
preds
[
"rec_pred"
]
if
isinstance
(
preds_idx
,
paddle
.
Tensor
):
preds_idx
=
preds_idx
.
numpy
()
if
"rec_pred_scores"
in
preds
:
preds_idx
=
preds
[
"rec_pred"
]
preds_prob
=
preds
[
"rec_pred_scores"
]
else
:
preds_idx
=
preds
[
"rec_pred"
].
argmax
(
axis
=
2
)
preds_prob
=
preds
[
"rec_pred"
].
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
class
SRNLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SRNLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
self
.
max_text_length
=
kwargs
.
get
(
'max_text_length'
,
25
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
pred
=
preds
[
'predict'
]
char_num
=
len
(
self
.
character_str
)
+
2
if
isinstance
(
pred
,
paddle
.
Tensor
):
pred
=
pred
.
numpy
()
pred
=
np
.
reshape
(
pred
,
[
-
1
,
char_num
])
preds_idx
=
np
.
argmax
(
pred
,
axis
=
1
)
preds_prob
=
np
.
max
(
pred
,
axis
=
1
)
preds_idx
=
np
.
reshape
(
preds_idx
,
[
-
1
,
self
.
max_text_length
])
preds_prob
=
np
.
reshape
(
preds_prob
,
[
-
1
,
self
.
max_text_length
])
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
if
label
is
None
:
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
return
text
label
=
self
.
decode
(
label
)
return
text
,
label
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
add_special_char
(
self
,
dict_character
):
dict_character
=
dict_character
+
[
self
.
beg_str
,
self
.
end_str
]
return
dict_character
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
class
TableLabelDecode
(
object
):
""" """
def
__init__
(
self
,
character_dict_path
,
**
kwargs
):
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
list_character
=
self
.
add_special_char
(
list_character
)
list_elem
=
self
.
add_special_char
(
list_elem
)
self
.
dict_character
=
{}
self
.
dict_idx_character
=
{}
for
i
,
char
in
enumerate
(
list_character
):
self
.
dict_idx_character
[
i
]
=
char
self
.
dict_character
[
char
]
=
i
self
.
dict_elem
=
{}
self
.
dict_idx_elem
=
{}
for
i
,
elem
in
enumerate
(
list_elem
):
self
.
dict_idx_elem
[
i
]
=
elem
self
.
dict_elem
[
elem
]
=
i
def
load_char_elem_dict
(
self
,
character_dict_path
):
list_character
=
[]
list_elem
=
[]
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
).
split
(
"
\t
"
)
character_num
=
int
(
substr
[
0
])
elem_num
=
int
(
substr
[
1
])
for
cno
in
range
(
1
,
1
+
character_num
):
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
list_character
.
append
(
character
)
for
eno
in
range
(
1
+
character_num
,
1
+
character_num
+
elem_num
):
elem
=
lines
[
eno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
list_elem
.
append
(
elem
)
return
list_character
,
list_elem
def
add_special_char
(
self
,
list_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
def
__call__
(
self
,
preds
):
structure_probs
=
preds
[
'structure_probs'
]
loc_preds
=
preds
[
'loc_preds'
]
if
isinstance
(
structure_probs
,
paddle
.
Tensor
):
structure_probs
=
structure_probs
.
numpy
()
if
isinstance
(
loc_preds
,
paddle
.
Tensor
):
loc_preds
=
loc_preds
.
numpy
()
structure_idx
=
structure_probs
.
argmax
(
axis
=
2
)
structure_probs
=
structure_probs
.
max
(
axis
=
2
)
structure_str
,
structure_pos
,
result_score_list
,
result_elem_idx_list
=
self
.
decode
(
structure_idx
,
structure_probs
,
'elem'
)
res_html_code_list
=
[]
res_loc_list
=
[]
batch_num
=
len
(
structure_str
)
for
bno
in
range
(
batch_num
):
res_loc
=
[]
for
sno
in
range
(
len
(
structure_str
[
bno
])):
text
=
structure_str
[
bno
][
sno
]
if
text
in
[
'<td>'
,
'<td'
]:
pos
=
structure_pos
[
bno
][
sno
]
res_loc
.
append
(
loc_preds
[
bno
,
pos
])
res_html_code
=
''
.
join
(
structure_str
[
bno
])
res_loc
=
np
.
array
(
res_loc
)
res_html_code_list
.
append
(
res_html_code
)
res_loc_list
.
append
(
res_loc
)
return
{
'res_html_code'
:
res_html_code_list
,
'res_loc'
:
res_loc_list
,
'res_score_list'
:
result_score_list
,
'res_elem_idx_list'
:
result_elem_idx_list
,
'structure_str_list'
:
structure_str
}
def
decode
(
self
,
text_index
,
structure_probs
,
char_or_elem
):
"""convert text-label into text-index.
"""
if
char_or_elem
==
"char"
:
current_dict
=
self
.
dict_idx_character
else
:
current_dict
=
self
.
dict_idx_elem
ignored_tokens
=
self
.
get_ignored_tokens
(
'elem'
)
beg_idx
,
end_idx
=
ignored_tokens
result_list
=
[]
result_pos_list
=
[]
result_score_list
=
[]
result_elem_idx_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
elem_pos_list
=
[]
elem_idx_list
=
[]
score_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
tmp_elem_idx
=
int
(
text_index
[
batch_idx
][
idx
])
if
idx
>
0
and
tmp_elem_idx
==
end_idx
:
break
if
tmp_elem_idx
in
ignored_tokens
:
continue
char_list
.
append
(
current_dict
[
tmp_elem_idx
])
elem_pos_list
.
append
(
idx
)
score_list
.
append
(
structure_probs
[
batch_idx
,
idx
])
elem_idx_list
.
append
(
tmp_elem_idx
)
result_list
.
append
(
char_list
)
result_pos_list
.
append
(
elem_pos_list
)
result_score_list
.
append
(
score_list
)
result_elem_idx_list
.
append
(
elem_idx_list
)
return
result_list
,
result_pos_list
,
result_score_list
,
result_elem_idx_list
def
get_ignored_tokens
(
self
,
char_or_elem
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
,
char_or_elem
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
,
char_or_elem
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
,
char_or_elem
):
if
char_or_elem
==
"char"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_character
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_character
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of char"
\
%
beg_or_end
elif
char_or_elem
==
"elem"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_elem
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_elem
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of elem"
\
%
beg_or_end
else
:
assert
False
,
"Unsupport type %s in char_or_elem"
\
%
char_or_elem
return
idx
class
SARLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SARLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
self
.
rm_symbol
=
kwargs
.
get
(
'rm_symbol'
,
False
)
def
add_special_char
(
self
,
dict_character
):
beg_end_str
=
"<BOS/EOS>"
unknown_str
=
"<UKN>"
padding_str
=
"<PAD>"
dict_character
=
dict_character
+
[
unknown_str
]
self
.
unknown_idx
=
len
(
dict_character
)
-
1
dict_character
=
dict_character
+
[
beg_end_str
]
self
.
start_idx
=
len
(
dict_character
)
-
1
self
.
end_idx
=
len
(
dict_character
)
-
1
dict_character
=
dict_character
+
[
padding_str
]
self
.
padding_idx
=
len
(
dict_character
)
-
1
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
self
.
end_idx
):
if
text_prob
is
None
and
idx
==
0
:
continue
else
:
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
if
self
.
rm_symbol
:
comp
=
re
.
compile
(
'[^A-Z^a-z^0-9^
\u4e00
-
\u9fa5
]'
)
text
=
text
.
lower
()
text
=
comp
.
sub
(
''
,
text
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
def
get_ignored_tokens
(
self
):
return
[
self
.
padding_idx
]
class
DistillationSARLabelDecode
(
SARLabelDecode
):
"""
Convert
Convert between text-label and text-index
"""
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
model_name
=
[
"student"
],
key
=
None
,
multi_head
=
False
,
**
kwargs
):
super
(
DistillationSARLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
self
.
multi_head
=
multi_head
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
output
=
dict
()
for
name
in
self
.
model_name
:
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
if
self
.
multi_head
and
isinstance
(
pred
,
dict
):
pred
=
pred
[
'sar'
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
class
PRENLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
PRENLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
padding_str
=
'<PAD>'
# 0
end_str
=
'<EOS>'
# 1
unknown_str
=
'<UNK>'
# 2
dict_character
=
[
padding_str
,
end_str
,
unknown_str
]
+
dict_character
self
.
padding_idx
=
0
self
.
end_idx
=
1
self
.
unknown_idx
=
2
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
):
""" convert text-index into text-label. """
result_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
==
self
.
end_idx
:
break
if
text_index
[
batch_idx
][
idx
]
in
\
[
self
.
padding_idx
,
self
.
unknown_idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
if
len
(
text
)
>
0
:
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
else
:
# here confidence of empty recog result is 1
result_list
.
append
((
''
,
1
))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
)
return
text
,
label
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录