提交 04b3f389 编写于 作者: Z zhiboniu 提交者: zhiboniu

separate vehicle and human

上级 5f9e8f01
...@@ -9,10 +9,11 @@ DET: ...@@ -9,10 +9,11 @@ DET:
batch_size: 1 batch_size: 1
MOT: MOT:
model_dir: output_inference/mot_ppyoloe_l_36e_pipeline/ model_dir: pipeline_model//mot_ppyoloe_s_36e_ppvehicle/
tracker_config: deploy/pphuman/config/tracker_config.yml tracker_config: deploy/pphuman/config/tracker_config.yml
batch_size: 1 batch_size: 1
basemode: "idbased" basemode: "idbased"
enable: False
KPT: KPT:
model_dir: output_inference/dark_hrnet_w32_256x192/ model_dir: output_inference/dark_hrnet_w32_256x192/
......
crop_thresh: 0.5
attr_thresh: 0.5
kpt_thresh: 0.2
visual: False
warmup_frame: 50
DET:
model_dir: output_inference/mot_ppyoloe_l_36e_ppvehicle/
batch_size: 1
MOT:
model_dir: pipeline_model//mot_ppyoloe_s_36e_ppvehicle/
tracker_config: deploy/pphuman/config/tracker_config.yml
batch_size: 1
basemode: "idbased"
enable: False
VEHICLE_PLATE:
det_algorithm: "DB"
det_model_dir: "output/ch_PP-OCRv3_det_infer/"
det_limit_side_len: 480
det_limit_type: "max"
rec_algorithm: "SVTR_LCNet"
rec_model_dir: "output/ch_PP-OCRv3_rec_infer/"
rec_image_shape: [3, 48, 320]
rec_batch_num: 6
word_dict_path: "deploy/pphuman/ppvehicle/rec_word_dict.txt"
basemode: "idbased"
enable: True
ATTR:
model_dir: output_inference/strongbaseline_r50_30e/
batch_size: 8
basemode: "idbased"
enable: False
VIDEO_ACTION:
model_dir: output_inference/ppTSM
batch_size: 1
frame_len: 8
sample_freq: 7
short_size: 340
target_size: 320
basemode: "videobased"
enable: False
SKELETON_ACTION:
model_dir: output_inference/STGCN
batch_size: 1
max_frames: 50
display_frames: 80
coord_size: [384, 512]
basemode: "skeletonbased"
enable: False
ID_BASED_DETACTION:
model_dir: output_inference/detector
batch_size: 1
basemode: "idbased"
enable: False
ID_BASED_CLSACTION:
model_dir: output_inference/classification
batch_size: 1
basemode: "idbased"
enable: False
REID:
model_dir: output_inference/reid_model/
batch_size: 16
basemode: "idbased"
enable: False
...@@ -28,6 +28,7 @@ class Result(object): ...@@ -28,6 +28,7 @@ class Result(object):
'reid': dict(), 'reid': dict(),
'det_action': dict(), 'det_action': dict(),
'cls_action': dict(), 'cls_action': dict(),
'vehicleplate': dict()
} }
def update(self, res, name): def update(self, res, name):
...@@ -70,7 +71,8 @@ class DataCollector(object): ...@@ -70,7 +71,8 @@ class DataCollector(object):
"kpts": [], "kpts": [],
"features": [], "features": [],
"qualities": [], "qualities": [],
"skeleton_action": [] "skeleton_action": [],
"vehicleplate": []
} }
self.collector = {} self.collector = {}
...@@ -80,6 +82,7 @@ class DataCollector(object): ...@@ -80,6 +82,7 @@ class DataCollector(object):
kpt_res = Result.get('kpt') kpt_res = Result.get('kpt')
skeleton_action_res = Result.get('skeleton_action') skeleton_action_res = Result.get('skeleton_action')
reid_res = Result.get('reid') reid_res = Result.get('reid')
vehicplate_res = Result.get('vehicleplate')
rects = [] rects = []
if reid_res is not None: if reid_res is not None:
...@@ -109,6 +112,9 @@ class DataCollector(object): ...@@ -109,6 +112,9 @@ class DataCollector(object):
idx]) idx])
self.collector[ids]["qualities"].append(reid_res['qualities'][ self.collector[ids]["qualities"].append(reid_res['qualities'][
idx]) idx])
if vehicplate_res:
self.collector[ids]["vehicleplate"].append(vehicplate_res[
'plate'][idx])
def get_res(self): def get_res(self):
return self.collector return self.collector
...@@ -32,18 +32,6 @@ def argsparser(): ...@@ -32,18 +32,6 @@ def argsparser():
default=None, default=None,
help=("Path of configure"), help=("Path of configure"),
required=True) required=True)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument(
"--word_dict_path",
type=str,
default="deploy/pphuman/rec_word_dict.txt")
parser.add_argument( parser.add_argument(
"--image_file", type=str, default=None, help="Path of image file.") "--image_file", type=str, default=None, help="Path of image file.")
parser.add_argument( parser.add_argument(
......
...@@ -27,6 +27,7 @@ from collections import Sequence ...@@ -27,6 +27,7 @@ from collections import Sequence
from reid import ReID from reid import ReID
from datacollector import DataCollector, Result from datacollector import DataCollector, Result
from mtmct import mtmct_process from mtmct import mtmct_process
from ppvehicle.vehicle_plate import PlateRecognizer
# add deploy path of PadleDetection to sys.path # add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
...@@ -81,74 +82,33 @@ class Pipeline(object): ...@@ -81,74 +82,33 @@ class Pipeline(object):
counting in MOT. counting in MOT.
""" """
def __init__(self, def __init__(self, args, cfg):
cfg,
image_file=None,
image_dir=None,
video_file=None,
video_dir=None,
camera_id=-1,
device='CPU',
run_mode='paddle',
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
output_dir='output',
draw_center_traj=False,
secs_interval=10,
do_entrance_counting=False):
self.multi_camera = False self.multi_camera = False
reid_cfg = cfg.get('REID', False) reid_cfg = cfg.get('REID', False)
self.enable_mtmct = reid_cfg['enable'] if reid_cfg else False self.enable_mtmct = reid_cfg['enable'] if reid_cfg else False
self.is_video = False self.is_video = False
self.output_dir = output_dir self.output_dir = args.output_dir
self.vis_result = cfg['visual'] self.vis_result = cfg['visual']
self.input = self._parse_input(image_file, image_dir, video_file, self.input = self._parse_input(args.image_file, args.image_dir,
video_dir, camera_id) args.video_file, args.video_dir,
args.camera_id)
if self.multi_camera: if self.multi_camera:
self.predictor = [] self.predictor = []
for name in self.input: for name in self.input:
predictor_item = PipePredictor( predictor_item = PipePredictor(
cfg, args, cfg, is_video=True, multi_camera=True)
is_video=True,
multi_camera=True,
device=device,
run_mode=run_mode,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
output_dir=output_dir)
predictor_item.set_file_name(name) predictor_item.set_file_name(name)
self.predictor.append(predictor_item) self.predictor.append(predictor_item)
else: else:
self.predictor = PipePredictor( self.predictor = PipePredictor(args, cfg, self.is_video)
cfg,
self.is_video,
device=device,
run_mode=run_mode,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
draw_center_traj=draw_center_traj,
secs_interval=secs_interval,
do_entrance_counting=do_entrance_counting)
if self.is_video: if self.is_video:
self.predictor.set_file_name(video_file) self.predictor.set_file_name(args.video_file)
self.output_dir = output_dir self.output_dir = args.output_dir
self.draw_center_traj = draw_center_traj self.draw_center_traj = args.draw_center_traj
self.secs_interval = secs_interval self.secs_interval = args.secs_interval
self.do_entrance_counting = do_entrance_counting self.do_entrance_counting = args.do_entrance_counting
def _parse_input(self, image_file, image_dir, video_file, video_dir, def _parse_input(self, image_file, image_dir, video_file, video_dir,
camera_id): camera_id):
...@@ -247,25 +207,31 @@ class PipePredictor(object): ...@@ -247,25 +207,31 @@ class PipePredictor(object):
counting in MOT. counting in MOT.
""" """
def __init__(self, def __init__(self, args, cfg, is_video=True, multi_camera=False):
cfg, device = args.device
is_video=True, run_mode = args.run_mode
multi_camera=False, trt_min_shape = args.trt_min_shape
device='CPU', trt_max_shape = args.trt_max_shape
run_mode='paddle', trt_opt_shape = args.trt_opt_shape
trt_min_shape=1, trt_calib_mode = args.trt_calib_mode
trt_max_shape=1280, cpu_threads = args.cpu_threads
trt_opt_shape=640, enable_mkldnn = args.enable_mkldnn
trt_calib_mode=False, output_dir = args.output_dir
cpu_threads=1, draw_center_traj = args.draw_center_traj
enable_mkldnn=False, secs_interval = args.secs_interval
output_dir='output', do_entrance_counting = args.do_entrance_counting
draw_center_traj=False,
secs_interval=10, # general module for pphuman and ppvehicle
do_entrance_counting=False): self.with_mot = cfg.get('MOT', False)['enable'] if cfg.get(
'MOT', False) else False
self.with_attr = cfg.get('ATTR', False)['enable'] if cfg.get( self.with_attr = cfg.get('ATTR', False)['enable'] if cfg.get(
'ATTR', False) else False 'ATTR', False) else False
if self.with_mot:
print('Multi-Object Tracking enabled')
if self.with_attr:
print('Attribute Recognition enabled')
# only for pphuman
self.with_skeleton_action = cfg.get( self.with_skeleton_action = cfg.get(
'SKELETON_ACTION', False)['enable'] if cfg.get('SKELETON_ACTION', 'SKELETON_ACTION', False)['enable'] if cfg.get('SKELETON_ACTION',
False) else False False) else False
...@@ -281,8 +247,6 @@ class PipePredictor(object): ...@@ -281,8 +247,6 @@ class PipePredictor(object):
self.with_mtmct = cfg.get('REID', False)['enable'] if cfg.get( self.with_mtmct = cfg.get('REID', False)['enable'] if cfg.get(
'REID', False) else False 'REID', False) else False
if self.with_attr:
print('Attribute Recognition enabled')
if self.with_skeleton_action: if self.with_skeleton_action:
print('SkeletonAction Recognition enabled') print('SkeletonAction Recognition enabled')
if self.with_video_action: if self.with_video_action:
...@@ -294,6 +258,13 @@ class PipePredictor(object): ...@@ -294,6 +258,13 @@ class PipePredictor(object):
if self.with_mtmct: if self.with_mtmct:
print("MTMCT enabled") print("MTMCT enabled")
# only for ppvehicle
self.with_vehicleplate = cfg.get(
'VEHICLE_PLATE', False)['enable'] if cfg.get('VEHICLE_PLATE',
False) else False
if self.with_vehicleplate:
print('Vehicle Plate Recognition enabled')
self.modebase = { self.modebase = {
"framebased": False, "framebased": False,
"videobased": False, "videobased": False,
...@@ -335,27 +306,6 @@ class PipePredictor(object): ...@@ -335,27 +306,6 @@ class PipePredictor(object):
enable_mkldnn) enable_mkldnn)
else: else:
mot_cfg = self.cfg['MOT']
model_dir = mot_cfg['model_dir']
tracker_config = mot_cfg['tracker_config']
batch_size = mot_cfg['batch_size']
basemode = mot_cfg['basemode']
self.modebase[basemode] = True
self.mot_predictor = SDE_Detector(
model_dir,
tracker_config,
device,
run_mode,
batch_size,
trt_min_shape,
trt_max_shape,
trt_opt_shape,
trt_calib_mode,
cpu_threads,
enable_mkldnn,
draw_center_traj=draw_center_traj,
secs_interval=secs_interval,
do_entrance_counting=do_entrance_counting)
if self.with_attr: if self.with_attr:
attr_cfg = self.cfg['ATTR'] attr_cfg = self.cfg['ATTR']
model_dir = attr_cfg['model_dir'] model_dir = attr_cfg['model_dir']
...@@ -455,6 +405,37 @@ class PipePredictor(object): ...@@ -455,6 +405,37 @@ class PipePredictor(object):
use_dark=False) use_dark=False)
self.kpt_buff = KeyPointBuff(skeleton_action_frames) self.kpt_buff = KeyPointBuff(skeleton_action_frames)
if self.with_vehicleplate:
vehicleplate_cfg = self.cfg['VEHICLE_PLATE']
self.vehicleplate_detector = PlateRecognizer(args,
vehicleplate_cfg)
basemode = vehicleplate_cfg['basemode']
self.modebase[basemode] = True
if self.with_mot or self.modebase["idbased"] or self.modebase[
"skeletonbased"]:
mot_cfg = self.cfg['MOT']
model_dir = mot_cfg['model_dir']
tracker_config = mot_cfg['tracker_config']
batch_size = mot_cfg['batch_size']
basemode = mot_cfg['basemode']
self.modebase[basemode] = True
self.mot_predictor = SDE_Detector(
model_dir,
tracker_config,
device,
run_mode,
batch_size,
trt_min_shape,
trt_max_shape,
trt_opt_shape,
trt_calib_mode,
cpu_threads,
enable_mkldnn,
draw_center_traj=draw_center_traj,
secs_interval=secs_interval,
do_entrance_counting=do_entrance_counting)
if self.with_video_action: if self.with_video_action:
video_action_cfg = self.cfg['VIDEO_ACTION'] video_action_cfg = self.cfg['VIDEO_ACTION']
...@@ -484,10 +465,10 @@ class PipePredictor(object): ...@@ -484,10 +465,10 @@ class PipePredictor(object):
reid_cfg = self.cfg['REID'] reid_cfg = self.cfg['REID']
model_dir = reid_cfg['model_dir'] model_dir = reid_cfg['model_dir']
batch_size = reid_cfg['batch_size'] batch_size = reid_cfg['batch_size']
self.reid_predictor = ReID(model_dir, device, run_mode, batch_size, self.reid_predictor = ReID(
trt_min_shape, trt_max_shape, model_dir, device, run_mode, batch_size, trt_min_shape,
trt_opt_shape, trt_calib_mode, trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
cpu_threads, enable_mkldnn) enable_mkldnn)
def set_file_name(self, path): def set_file_name(self, path):
if path is not None: if path is not None:
...@@ -640,15 +621,18 @@ class PipePredictor(object): ...@@ -640,15 +621,18 @@ class PipePredictor(object):
cv2.imshow('PPHuman', im) cv2.imshow('PPHuman', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
continue continue
self.pipeline_res.update(mot_res, 'mot') self.pipeline_res.update(mot_res, 'mot')
#todo: move this code to each class's predeal function
crop_input, new_bboxes, ori_bboxes = crop_image_with_mot( crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
frame, mot_res) frame, mot_res)
if self.with_vehicleplate:
platelicense = self.vehicleplate_detector.get_platelicense(
crop_input)
print("find plate license:{}".format(platelicense))
self.pipeline_res.update(platelicense, 'vehicleplate')
if self.with_attr: if self.with_attr:
if frame_id > self.warmup_frame: if frame_id > self.warmup_frame:
self.pipe_timer.module_time['attr'].start() self.pipe_timer.module_time['attr'].start()
...@@ -924,14 +908,7 @@ def main(): ...@@ -924,14 +908,7 @@ def main():
cfg = merge_cfg(FLAGS) cfg = merge_cfg(FLAGS)
print_arguments(cfg) print_arguments(cfg)
pipeline = Pipeline( pipeline = Pipeline(FLAGS, cfg)
cfg, FLAGS.image_file, FLAGS.image_dir, FLAGS.video_file,
FLAGS.video_dir, FLAGS.camera_id, FLAGS.device, FLAGS.run_mode,
FLAGS.trt_min_shape, FLAGS.trt_max_shape, FLAGS.trt_opt_shape,
FLAGS.trt_calib_mode, FLAGS.cpu_threads, FLAGS.enable_mkldnn,
FLAGS.output_dir, FLAGS.draw_center_traj, FLAGS.secs_interval,
FLAGS.do_entrance_counting)
pipeline.run() pipeline.run()
......
...@@ -30,20 +30,19 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 3))) ...@@ -30,20 +30,19 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 3)))
sys.path.insert(0, parent_path) sys.path.insert(0, parent_path)
from python.infer import get_test_images, print_arguments from python.infer import get_test_images, print_arguments
from vechile_plateutils import create_predictor, get_infer_gpuid, get_rotate_crop_image, draw_boxes from pphuman.ppvehicle.vehicle_plateutils import create_predictor, get_infer_gpuid, get_rotate_crop_image, draw_boxes, argsparser
from vecplatepostprocess import build_post_process from pphuman.ppvehicle.vecplatepostprocess import build_post_process
from python.preprocess import preprocess, NormalizeImage, Permute, Resize_Mult32 from python.preprocess import preprocess, NormalizeImage, Permute, Resize_Mult32
from vechile_plateutils import argsparser
class PlateDetector(object): class PlateDetector(object):
def __init__(self, args): def __init__(self, args, cfg):
self.args = args self.args = args
self.det_algorithm = args.det_algorithm self.det_algorithm = cfg['det_algorithm']
self.pre_process_list = { self.pre_process_list = {
'Resize_Mult32': { 'Resize_Mult32': {
'limit_side_len': args.det_limit_side_len, 'limit_side_len': cfg['det_limit_side_len'],
'limit_type': args.det_limit_type, 'limit_type': cfg['det_limit_type'],
}, },
'NormalizeImage': { 'NormalizeImage': {
'mean': [0.485, 0.456, 0.406], 'mean': [0.485, 0.456, 0.406],
...@@ -63,7 +62,7 @@ class PlateDetector(object): ...@@ -63,7 +62,7 @@ class PlateDetector(object):
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = create_predictor( self.predictor, self.input_tensor, self.output_tensors, self.config = create_predictor(
args, 'det') args, cfg, 'det')
def preprocess(self, image_list): def preprocess(self, image_list):
preprocess_ops = [] preprocess_ops = []
...@@ -151,13 +150,11 @@ class PlateDetector(object): ...@@ -151,13 +150,11 @@ class PlateDetector(object):
class TextRecognizer(object): class TextRecognizer(object):
def __init__(self, FLAGS, use_gpu=True): def __init__(self, args, cfg, use_gpu=True):
self.rec_image_shape = [ self.rec_image_shape = cfg['rec_image_shape']
int(v) for v in FLAGS.rec_image_shape.split(",") self.rec_batch_num = cfg['rec_batch_num']
] self.rec_algorithm = cfg['rec_algorithm']
self.rec_batch_num = FLAGS.rec_batch_num word_dict_path = cfg['word_dict_path']
self.rec_algorithm = FLAGS.rec_algorithm
word_dict_path = FLAGS.word_dict_path
isuse_space_char = True isuse_space_char = True
postprocess_params = { postprocess_params = {
...@@ -191,7 +188,7 @@ class TextRecognizer(object): ...@@ -191,7 +188,7 @@ class TextRecognizer(object):
} }
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
create_predictor(FLAGS, 'rec') create_predictor(args, cfg, 'rec')
self.use_onnx = False self.use_onnx = False
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
...@@ -488,22 +485,24 @@ class TextRecognizer(object): ...@@ -488,22 +485,24 @@ class TextRecognizer(object):
class PlateRecognizer(object): class PlateRecognizer(object):
def __init__(self): def __init__(self, args, cfg):
use_gpu = FLAGS.device.lower() == "gpu" use_gpu = args.device.lower() == "gpu"
self.platedetector = PlateDetector(FLAGS) self.platedetector = PlateDetector(args, cfg)
self.textrecognizer = TextRecognizer(FLAGS, use_gpu=use_gpu) self.textrecognizer = TextRecognizer(args, cfg, use_gpu=use_gpu)
def get_platelicense(self, image_list): def get_platelicense(self, image_list):
plate_text_list = [] plate_text_list = []
plateboxes, det_time = self.platedetector.predict_image(image_list) plateboxes, det_time = self.platedetector.predict_image(image_list)
for idx, boxes_pcar in enumerate(plateboxes): for idx, boxes_pcar in enumerate(plateboxes):
plate_pcar_list = []
for box in boxes_pcar: for box in boxes_pcar:
plate_images = get_rotate_crop_image(image_list[idx], box) plate_images = get_rotate_crop_image(image_list[idx], box)
plate_texts = self.textrecognizer.predict_text([plate_images]) plate_texts = self.textrecognizer.predict_text([plate_images])
plate_text_list.append(plate_texts) plate_pcar_list.append(plate_texts)
print("plate text:{}".format(plate_texts)) # print("plate text:{}".format(plate_texts))
newimg = draw_boxes(image_list[idx], boxes_pcar) plate_text_list.append(plate_pcar_list)
cv2.imwrite("vechile_plate.jpg", newimg) # newimg = draw_boxes(image_list[idx], boxes_pcar)
# cv2.imwrite("vehicle_plate.jpg", newimg)
return self.check_plate(plate_text_list) return self.check_plate(plate_text_list)
def check_plate(self, text_list): def check_plate(self, text_list):
...@@ -512,16 +511,20 @@ class PlateRecognizer(object): ...@@ -512,16 +511,20 @@ class PlateRecognizer(object):
'赣', '鲁', '豫', '鄂', '湘', '桂', '琼', '渝', '川', '贵', '云', '藏', '陕', '赣', '鲁', '豫', '鄂', '湘', '桂', '琼', '渝', '川', '贵', '云', '藏', '陕',
'甘', '青', '宁' '甘', '青', '宁'
] ]
for text_info in text_list: plate_all = {"plate": []}
# import pdb;pdb.set_trace() for text_pcar in text_list:
platelicense = None
for text_info in text_pcar:
text = text_info[0][0][0] text = text_info[0][0][0]
if len(text) > 2 and text[0] in simcode and len(text) < 10: if len(text) > 2 and text[0] in simcode and len(text) < 10:
print("text:{} length:{}".format(text, len(text))) # print("text:{} length:{}".format(text, len(text)))
return text platelicense = text
plate_all["plate"].append(platelicense)
return plate_all
def main(): def main():
detector = PlateRecognizer() detector = PlateRecognizer(FLAGS)
# predict from image # predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
for img in img_list: for img in img_list:
......
...@@ -41,7 +41,7 @@ def argsparser(): ...@@ -41,7 +41,7 @@ def argsparser():
parser.add_argument( parser.add_argument(
"--word_dict_path", "--word_dict_path",
type=str, type=str,
default="deploy/pphuman/rec_word_dict.txt") default="deploy/pphuman/ppvehicle/rec_word_dict.txt")
parser.add_argument( parser.add_argument(
"--image_file", type=str, default=None, help="Path of image file.") "--image_file", type=str, default=None, help="Path of image file.")
parser.add_argument( parser.add_argument(
...@@ -126,17 +126,11 @@ def argsparser(): ...@@ -126,17 +126,11 @@ def argsparser():
return parser return parser
def create_predictor(args, mode): def create_predictor(args, cfg, mode):
if mode == "det": if mode == "det":
model_dir = args.det_model_dir model_dir = cfg['det_model_dir']
elif mode == 'cls':
model_dir = args.cls_model_dir
elif mode == 'rec':
model_dir = args.rec_model_dir
elif mode == 'table':
model_dir = args.table_model_dir
else: else:
model_dir = args.e2e_model_dir model_dir = cfg['rec_model_dir']
if model_dir is None: if model_dir is None:
print("not find {} model file path {}".format(mode, model_dir)) print("not find {} model file path {}".format(mode, model_dir))
...@@ -243,7 +237,7 @@ def create_predictor(args, mode): ...@@ -243,7 +237,7 @@ def create_predictor(args, mode):
max_input_shape.update(max_pact_shape) max_input_shape.update(max_pact_shape)
opt_input_shape.update(opt_pact_shape) opt_input_shape.update(opt_pact_shape)
elif mode == "rec": elif mode == "rec":
imgH = int(args.rec_image_shape.split(',')[-2]) imgH = int(cfg['rec_image_shape'][-2])
min_input_shape = {"x": [1, 3, imgH, 10]} min_input_shape = {"x": [1, 3, imgH, 10]}
max_input_shape = {"x": [batch_size, 3, imgH, 2304]} max_input_shape = {"x": [batch_size, 3, imgH, 2304]}
opt_input_shape = {"x": [batch_size, 3, imgH, 320]} opt_input_shape = {"x": [batch_size, 3, imgH, 320]}
...@@ -285,14 +279,14 @@ def create_predictor(args, mode): ...@@ -285,14 +279,14 @@ def create_predictor(args, mode):
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
for name in input_names: for name in input_names:
input_tensor = predictor.get_input_handle(name) input_tensor = predictor.get_input_handle(name)
output_tensors = get_output_tensors(args, mode, predictor) output_tensors = get_output_tensors(cfg, mode, predictor)
return predictor, input_tensor, output_tensors, config return predictor, input_tensor, output_tensors, config
def get_output_tensors(args, mode, predictor): def get_output_tensors(cfg, mode, predictor):
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
output_tensors = [] output_tensors = []
if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]: if mode == "rec" and cfg['rec_algorithm'] in ["CRNN", "SVTR_LCNet"]:
output_name = 'softmax_0.tmp_0' output_name = 'softmax_0.tmp_0'
if output_name in output_names: if output_name in output_names:
return [predictor.get_output_handle(output_name)] return [predictor.get_output_handle(output_name)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册