提交 bb8a2b02 编写于 作者: Z zhiboniu 提交者: zhiboniu

update plate code

上级 f2a883ed
...@@ -24,15 +24,14 @@ import math ...@@ -24,15 +24,14 @@ import math
import paddle import paddle
import sys import sys
# add deploy path of PadleDetection to sys.path
# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 3))) parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 3)))
sys.path.insert(0, parent_path) sys.path.insert(0, parent_path)
from python.infer import get_test_images, print_arguments from python.infer import get_test_images
from pphuman.ppvehicle.vehicle_plateutils import create_predictor, get_infer_gpuid, get_rotate_crop_image, draw_boxes, argsparser
from pphuman.ppvehicle.vecplatepostprocess import build_post_process
from python.preprocess import preprocess, NormalizeImage, Permute, Resize_Mult32 from python.preprocess import preprocess, NormalizeImage, Permute, Resize_Mult32
from pphuman.ppvehicle.vehicle_plateutils import create_predictor, get_infer_gpuid, get_rotate_crop_image, draw_boxes
from pphuman.ppvehicle.vehicleplate_postprocess import build_post_process
from pphuman.pipe_utils import merge_cfg, print_arguments, argsparser
class PlateDetector(object): class PlateDetector(object):
...@@ -155,36 +154,36 @@ class TextRecognizer(object): ...@@ -155,36 +154,36 @@ class TextRecognizer(object):
self.rec_batch_num = cfg['rec_batch_num'] self.rec_batch_num = cfg['rec_batch_num']
self.rec_algorithm = cfg['rec_algorithm'] self.rec_algorithm = cfg['rec_algorithm']
word_dict_path = cfg['word_dict_path'] word_dict_path = cfg['word_dict_path']
isuse_space_char = True use_space_char = True
postprocess_params = { postprocess_params = {
'name': 'CTCLabelDecode', 'name': 'CTCLabelDecode',
"character_dict_path": word_dict_path, "character_dict_path": word_dict_path,
"use_space_char": isuse_space_char "use_space_char": use_space_char
} }
if self.rec_algorithm == "SRN": if self.rec_algorithm == "SRN":
postprocess_params = { postprocess_params = {
'name': 'SRNLabelDecode', 'name': 'SRNLabelDecode',
"character_dict_path": word_dict_path, "character_dict_path": word_dict_path,
"use_space_char": isuse_space_char "use_space_char": use_space_char
} }
elif self.rec_algorithm == "RARE": elif self.rec_algorithm == "RARE":
postprocess_params = { postprocess_params = {
'name': 'AttnLabelDecode', 'name': 'AttnLabelDecode',
"character_dict_path": word_dict_path, "character_dict_path": word_dict_path,
"use_space_char": isuse_space_char "use_space_char": use_space_char
} }
elif self.rec_algorithm == 'NRTR': elif self.rec_algorithm == 'NRTR':
postprocess_params = { postprocess_params = {
'name': 'NRTRLabelDecode', 'name': 'NRTRLabelDecode',
"character_dict_path": word_dict_path, "character_dict_path": word_dict_path,
"use_space_char": isuse_space_char "use_space_char": use_space_char
} }
elif self.rec_algorithm == "SAR": elif self.rec_algorithm == "SAR":
postprocess_params = { postprocess_params = {
'name': 'SARLabelDecode', 'name': 'SARLabelDecode',
"character_dict_path": word_dict_path, "character_dict_path": word_dict_path,
"use_space_char": isuse_space_char "use_space_char": use_space_char
} }
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
...@@ -520,22 +519,24 @@ class PlateRecognizer(object): ...@@ -520,22 +519,24 @@ class PlateRecognizer(object):
def main(): def main():
detector = PlateRecognizer(FLAGS) cfg = merge_cfg(FLAGS)
print_arguments(cfg)
vehicleplate_cfg = cfg['VEHICLE_PLATE']
detector = PlateRecognizer(FLAGS, vehicleplate_cfg)
# predict from image # predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
for img in img_list: for img in img_list:
image = cv2.imread(img) image = cv2.imread(img)
results = detector.get_platelicense([image]) results = detector.get_platelicense([image])
print(results)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
parser = argsparser() parser = argsparser()
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
print_arguments(FLAGS)
FLAGS.device = FLAGS.device.upper() FLAGS.device = FLAGS.device.upper()
assert FLAGS.device in ['CPU', 'GPU', 'XPU' assert FLAGS.device in ['CPU', 'GPU', 'XPU'
], "device should be CPU, GPU or XPU" ], "device should be CPU, GPU or XPU"
# assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
main() main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -26,106 +26,6 @@ import time ...@@ -26,106 +26,6 @@ import time
import ast import ast
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--config", type=str, default=None, help=("Path of configure"))
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument(
"--word_dict_path",
type=str,
default="deploy/pphuman/ppvehicle/rec_word_dict.txt")
parser.add_argument(
"--image_file", type=str, default=None, help="Path of image file.")
parser.add_argument(
"--image_dir",
type=str,
default=None,
help="Dir of image file, `image_file` has a higher priority.")
parser.add_argument(
"--video_file",
type=str,
default=None,
help="Path of video file, `video_file` or `camera_id` has a highest priority."
)
parser.add_argument(
"--video_dir",
type=str,
default=None,
help="Dir of video file, `video_file` has a higher priority.")
parser.add_argument(
"--model_dir", nargs='*', help="set model dir in pipeline")
parser.add_argument(
"--camera_id",
type=int,
default=-1,
help="device id of camera to predict.")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory of output visualization files.")
parser.add_argument(
"--run_mode",
type=str,
default='paddle',
help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."
)
parser.add_argument(
"--enable_mkldnn",
type=ast.literal_eval,
default=False,
help="Whether use mkldnn with CPU.")
parser.add_argument(
"--cpu_threads", type=int, default=1, help="Num of threads with CPU.")
parser.add_argument(
"--trt_min_shape", type=int, default=1, help="min_shape for TensorRT.")
parser.add_argument(
"--trt_max_shape",
type=int,
default=1280,
help="max_shape for TensorRT.")
parser.add_argument(
"--trt_opt_shape",
type=int,
default=640,
help="opt_shape for TensorRT.")
parser.add_argument(
"--trt_calib_mode",
type=bool,
default=False,
help="If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True.")
parser.add_argument(
"--do_entrance_counting",
action='store_true',
help="Whether counting the numbers of identifiers entering "
"or getting out from the entrance. Note that only support one-class"
"counting, multi-class counting is coming soon.")
parser.add_argument(
"--secs_interval",
type=int,
default=2,
help="The seconds interval to count after tracking")
parser.add_argument(
"--draw_center_traj",
action='store_true',
help="Whether drawing the trajectory of center")
return parser
def create_predictor(args, cfg, mode): def create_predictor(args, cfg, mode):
if mode == "det": if mode == "det":
model_dir = cfg['det_model_dir'] model_dir = cfg['det_model_dir']
...@@ -169,9 +69,8 @@ def create_predictor(args, cfg, mode): ...@@ -169,9 +69,8 @@ def create_predictor(args, cfg, mode):
precision_mode=precision_map[args.run_mode], precision_mode=precision_map[args.run_mode],
use_static=False, use_static=False,
use_calib_mode=trt_calib_mode) use_calib_mode=trt_calib_mode)
# skip the minmum trt subgraph
use_dynamic_shape = True use_dynamic_shape = True
if mode == "det": if mode == "det":
min_input_shape = { min_input_shape = {
"x": [1, 3, 50, 50], "x": [1, 3, 50, 50],
...@@ -248,8 +147,8 @@ def create_predictor(args, cfg, mode): ...@@ -248,8 +147,8 @@ def create_predictor(args, cfg, mode):
else: else:
use_dynamic_shape = False use_dynamic_shape = False
if use_dynamic_shape: if use_dynamic_shape:
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, config.set_trt_dynamic_shape_info(
opt_input_shape) min_input_shape, max_input_shape, opt_input_shape)
else: else:
config.disable_gpu() config.disable_gpu()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册