未验证 提交 ff8a7b1d 编写于 作者: J JYChen 提交者: GitHub

move initialize part into class (#6621)

上级 6409d062
...@@ -60,29 +60,8 @@ class Pipeline(object): ...@@ -60,29 +60,8 @@ class Pipeline(object):
Pipeline Pipeline
Args: Args:
args (argparse.Namespace): arguments in pipeline, which contains environment and runtime settings
cfg (dict): config of models in pipeline cfg (dict): config of models in pipeline
image_file (string|None): the path of image file, default as None
image_dir (string|None): the path of image directory, if not None,
then all the images in directory will be predicted, default as None
video_file (string|None): the path of video file, default as None
camera_id (int): the device id of camera to predict, default as -1
device (string): the device to predict, options are: CPU/GPU/XPU,
default as CPU
run_mode (string): the mode of prediction, options are:
paddle/trt_fp32/trt_fp16, default as paddle
trt_min_shape (int): min shape for dynamic shape in trt, default as 1
trt_max_shape (int): max shape for dynamic shape in trt, default as 1280
trt_opt_shape (int): opt shape for dynamic shape in trt, default as 640
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True. default as False
cpu_threads (int): cpu threads, default as 1
enable_mkldnn (bool): whether to open MKLDNN, default as False
output_dir (string): The path of output, default as 'output'
draw_center_traj (bool): Whether drawing the trajectory of center, default as False
secs_interval (int): The seconds interval to count after tracking, default as 10
do_entrance_counting(bool): Whether counting the numbers of identifiers entering
or getting out from the entrance, default as False, only support single class
counting in MOT.
""" """
def __init__(self, args, cfg): def __init__(self, args, cfg):
...@@ -108,18 +87,6 @@ class Pipeline(object): ...@@ -108,18 +87,6 @@ class Pipeline(object):
if self.is_video: if self.is_video:
self.predictor.set_file_name(args.video_file) self.predictor.set_file_name(args.video_file)
self.output_dir = args.output_dir
self.draw_center_traj = args.draw_center_traj
self.secs_interval = args.secs_interval
self.do_entrance_counting = args.do_entrance_counting
self.do_break_in_counting = args.do_break_in_counting
self.region_type = args.region_type
self.region_polygon = args.region_polygon
if self.region_type == 'custom':
assert len(
self.region_polygon
) > 6, 'region_type is custom, region_polygon should be at least 3 pairs of point coords.'
def _parse_input(self, image_file, image_dir, video_file, video_dir, def _parse_input(self, image_file, image_dir, video_file, video_dir,
camera_id): camera_id):
...@@ -179,8 +146,10 @@ class Pipeline(object): ...@@ -179,8 +146,10 @@ class Pipeline(object):
def get_model_dir(cfg): def get_model_dir(cfg):
# auto download inference model """
model_dir_dict = {} Auto download inference model if the model_path is a url link.
Otherwise it will use the model_path directly.
"""
for key in cfg.keys(): for key in cfg.keys():
if type(cfg[key]) == dict and \ if type(cfg[key]) == dict and \
("enable" in cfg[key].keys() and cfg[key]['enable'] ("enable" in cfg[key].keys() and cfg[key]['enable']
...@@ -191,30 +160,30 @@ def get_model_dir(cfg): ...@@ -191,30 +160,30 @@ def get_model_dir(cfg):
downloaded_model_dir = auto_download_model(model_dir) downloaded_model_dir = auto_download_model(model_dir)
if downloaded_model_dir: if downloaded_model_dir:
model_dir = downloaded_model_dir model_dir = downloaded_model_dir
model_dir_dict[key] = model_dir cfg[key]["model_dir"] = model_dir
print(key, " model dir:", model_dir) print(key, " model dir: ", model_dir)
elif key == "VEHICLE_PLATE": elif key == "VEHICLE_PLATE":
det_model_dir = cfg[key]["det_model_dir"] det_model_dir = cfg[key]["det_model_dir"]
downloaded_det_model_dir = auto_download_model(det_model_dir) downloaded_det_model_dir = auto_download_model(det_model_dir)
if downloaded_det_model_dir: if downloaded_det_model_dir:
det_model_dir = downloaded_det_model_dir det_model_dir = downloaded_det_model_dir
model_dir_dict["det_model_dir"] = det_model_dir cfg[key]["det_model_dir"] = det_model_dir
print("det_model_dir model dir:", det_model_dir) print("det_model_dir model dir: ", det_model_dir)
rec_model_dir = cfg[key]["rec_model_dir"] rec_model_dir = cfg[key]["rec_model_dir"]
downloaded_rec_model_dir = auto_download_model(rec_model_dir) downloaded_rec_model_dir = auto_download_model(rec_model_dir)
if downloaded_rec_model_dir: if downloaded_rec_model_dir:
rec_model_dir = downloaded_rec_model_dir rec_model_dir = downloaded_rec_model_dir
model_dir_dict["rec_model_dir"] = rec_model_dir cfg[key]["rec_model_dir"] = rec_model_dir
print("rec_model_dir model dir:", rec_model_dir) print("rec_model_dir model dir: ", rec_model_dir)
elif key == "MOT": # for idbased and skeletonbased actions elif key == "MOT": # for idbased and skeletonbased actions
model_dir = cfg[key]["model_dir"] model_dir = cfg[key]["model_dir"]
downloaded_model_dir = auto_download_model(model_dir) downloaded_model_dir = auto_download_model(model_dir)
if downloaded_model_dir: if downloaded_model_dir:
model_dir = downloaded_model_dir model_dir = downloaded_model_dir
model_dir_dict[key] = model_dir cfg[key]["model_dir"] = model_dir
print("mot_model_dir model_dir: ", model_dir)
return model_dir_dict
class PipePredictor(object): class PipePredictor(object):
...@@ -234,47 +203,14 @@ class PipePredictor(object): ...@@ -234,47 +203,14 @@ class PipePredictor(object):
4. VideoAction Recognition 4. VideoAction Recognition
Args: Args:
args (argparse.Namespace): arguments in pipeline, which contains environment and runtime settings
cfg (dict): config of models in pipeline cfg (dict): config of models in pipeline
is_video (bool): whether the input is video, default as False is_video (bool): whether the input is video, default as False
multi_camera (bool): whether to use multi camera in pipeline, multi_camera (bool): whether to use multi camera in pipeline,
default as False default as False
camera_id (int): the device id of camera to predict, default as -1
device (string): the device to predict, options are: CPU/GPU/XPU,
default as CPU
run_mode (string): the mode of prediction, options are:
paddle/trt_fp32/trt_fp16, default as paddle
trt_min_shape (int): min shape for dynamic shape in trt, default as 1
trt_max_shape (int): max shape for dynamic shape in trt, default as 1280
trt_opt_shape (int): opt shape for dynamic shape in trt, default as 640
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True. default as False
cpu_threads (int): cpu threads, default as 1
enable_mkldnn (bool): whether to open MKLDNN, default as False
output_dir (string): The path of output, default as 'output'
draw_center_traj (bool): Whether drawing the trajectory of center, default as False
secs_interval (int): The seconds interval to count after tracking, default as 10
do_entrance_counting(bool): Whether counting the numbers of identifiers entering
or getting out from the entrance, default as False, only support single class
counting in MOT.
""" """
def __init__(self, args, cfg, is_video=True, multi_camera=False): def __init__(self, args, cfg, is_video=True, multi_camera=False):
device = args.device
run_mode = args.run_mode
trt_min_shape = args.trt_min_shape
trt_max_shape = args.trt_max_shape
trt_opt_shape = args.trt_opt_shape
trt_calib_mode = args.trt_calib_mode
cpu_threads = args.cpu_threads
enable_mkldnn = args.enable_mkldnn
output_dir = args.output_dir
draw_center_traj = args.draw_center_traj
secs_interval = args.secs_interval
do_entrance_counting = args.do_entrance_counting
do_break_in_counting = args.do_break_in_counting
region_type = args.region_type
region_polygon = args.region_polygon
# general module for pphuman and ppvehicle # general module for pphuman and ppvehicle
self.with_mot = cfg.get('MOT', False)['enable'] if cfg.get( self.with_mot = cfg.get('MOT', False)['enable'] if cfg.get(
'MOT', False) else False 'MOT', False) else False
...@@ -347,13 +283,13 @@ class PipePredictor(object): ...@@ -347,13 +283,13 @@ class PipePredictor(object):
self.is_video = is_video self.is_video = is_video
self.multi_camera = multi_camera self.multi_camera = multi_camera
self.cfg = cfg self.cfg = cfg
self.output_dir = output_dir self.output_dir = args.output_dir
self.draw_center_traj = draw_center_traj self.draw_center_traj = args.draw_center_traj
self.secs_interval = secs_interval self.secs_interval = args.secs_interval
self.do_entrance_counting = do_entrance_counting self.do_entrance_counting = args.do_entrance_counting
self.do_break_in_counting = do_break_in_counting self.do_break_in_counting = args.do_break_in_counting
self.region_type = region_type self.region_type = args.region_type
self.region_polygon = region_polygon self.region_polygon = args.region_polygon
self.warmup_frame = self.cfg['warmup_frame'] self.warmup_frame = self.cfg['warmup_frame']
self.pipeline_res = Result() self.pipeline_res = Result()
...@@ -362,7 +298,7 @@ class PipePredictor(object): ...@@ -362,7 +298,7 @@ class PipePredictor(object):
self.collector = DataCollector() self.collector = DataCollector()
# auto download inference model # auto download inference model
model_dir_dict = get_model_dir(self.cfg) get_model_dir(self.cfg)
if self.with_vehicleplate: if self.with_vehicleplate:
vehicleplate_cfg = self.cfg['VEHICLE_PLATE'] vehicleplate_cfg = self.cfg['VEHICLE_PLATE']
...@@ -372,148 +308,84 @@ class PipePredictor(object): ...@@ -372,148 +308,84 @@ class PipePredictor(object):
if self.with_human_attr: if self.with_human_attr:
attr_cfg = self.cfg['ATTR'] attr_cfg = self.cfg['ATTR']
model_dir = model_dir_dict['ATTR']
batch_size = attr_cfg['batch_size']
basemode = self.basemode['ATTR'] basemode = self.basemode['ATTR']
self.modebase[basemode] = True self.modebase[basemode] = True
self.attr_predictor = AttrDetector( self.attr_predictor = AttrDetector.init_with_cfg(args, attr_cfg)
model_dir, device, run_mode, batch_size, trt_min_shape,
trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
enable_mkldnn)
if self.with_vehicle_attr: if self.with_vehicle_attr:
vehicleattr_cfg = self.cfg['VEHICLE_ATTR'] vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
model_dir = model_dir_dict['VEHICLE_ATTR']
batch_size = vehicleattr_cfg['batch_size']
color_threshold = vehicleattr_cfg['color_threshold']
type_threshold = vehicleattr_cfg['type_threshold']
basemode = self.basemode['VEHICLE_ATTR'] basemode = self.basemode['VEHICLE_ATTR']
self.modebase[basemode] = True self.modebase[basemode] = True
self.vehicle_attr_predictor = VehicleAttr( self.vehicle_attr_predictor = VehicleAttr.init_with_cfg(
model_dir, device, run_mode, batch_size, trt_min_shape, args, vehicleattr_cfg)
trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
enable_mkldnn, color_threshold, type_threshold)
if not is_video: if not is_video:
det_cfg = self.cfg['DET'] det_cfg = self.cfg['DET']
model_dir = model_dir_dict['DET'] model_dir = det_cfg['model_dir']
batch_size = det_cfg['batch_size'] batch_size = det_cfg['batch_size']
self.det_predictor = Detector( self.det_predictor = Detector(
model_dir, device, run_mode, batch_size, trt_min_shape, model_dir, args.device, args.run_mode, batch_size,
trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads, args.trt_min_shape, args.trt_max_shape, args.trt_opt_shape,
enable_mkldnn) args.trt_calib_mode, args.cpu_threads, args.enable_mkldnn)
else: else:
if self.with_idbased_detaction: if self.with_idbased_detaction:
idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION'] idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION']
model_dir = model_dir_dict['ID_BASED_DETACTION']
batch_size = idbased_detaction_cfg['batch_size']
basemode = self.basemode['ID_BASED_DETACTION'] basemode = self.basemode['ID_BASED_DETACTION']
threshold = idbased_detaction_cfg['threshold']
display_frames = idbased_detaction_cfg['display_frames']
skip_frame_num = idbased_detaction_cfg['skip_frame_num']
self.modebase[basemode] = True self.modebase[basemode] = True
self.det_action_predictor = DetActionRecognizer( self.det_action_predictor = DetActionRecognizer.init_with_cfg(
model_dir, args, idbased_detaction_cfg)
device,
run_mode,
batch_size,
trt_min_shape,
trt_max_shape,
trt_opt_shape,
trt_calib_mode,
cpu_threads,
enable_mkldnn,
threshold=threshold,
display_frames=display_frames,
skip_frame_num=skip_frame_num)
self.det_action_visual_helper = ActionVisualHelper(1) self.det_action_visual_helper = ActionVisualHelper(1)
if self.with_idbased_clsaction: if self.with_idbased_clsaction:
idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION'] idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION']
model_dir = model_dir_dict['ID_BASED_CLSACTION']
batch_size = idbased_clsaction_cfg['batch_size']
basemode = self.basemode['ID_BASED_CLSACTION'] basemode = self.basemode['ID_BASED_CLSACTION']
threshold = idbased_clsaction_cfg['threshold']
self.modebase[basemode] = True self.modebase[basemode] = True
display_frames = idbased_clsaction_cfg['display_frames']
skip_frame_num = idbased_clsaction_cfg['skip_frame_num']
self.cls_action_predictor = ClsActionRecognizer( self.cls_action_predictor = ClsActionRecognizer.init_with_cfg(
model_dir, args, idbased_clsaction_cfg)
device,
run_mode,
batch_size,
trt_min_shape,
trt_max_shape,
trt_opt_shape,
trt_calib_mode,
cpu_threads,
enable_mkldnn,
threshold=threshold,
display_frames=display_frames,
skip_frame_num=skip_frame_num)
self.cls_action_visual_helper = ActionVisualHelper(1) self.cls_action_visual_helper = ActionVisualHelper(1)
if self.with_skeleton_action: if self.with_skeleton_action:
skeleton_action_cfg = self.cfg['SKELETON_ACTION'] skeleton_action_cfg = self.cfg['SKELETON_ACTION']
skeleton_action_model_dir = model_dir_dict['SKELETON_ACTION']
skeleton_action_batch_size = skeleton_action_cfg['batch_size']
skeleton_action_frames = skeleton_action_cfg['max_frames']
display_frames = skeleton_action_cfg['display_frames'] display_frames = skeleton_action_cfg['display_frames']
self.coord_size = skeleton_action_cfg['coord_size'] self.coord_size = skeleton_action_cfg['coord_size']
basemode = self.basemode['SKELETON_ACTION'] basemode = self.basemode['SKELETON_ACTION']
self.modebase[basemode] = True self.modebase[basemode] = True
skeleton_action_frames = skeleton_action_cfg['max_frames']
self.skeleton_action_predictor = SkeletonActionRecognizer( self.skeleton_action_predictor = SkeletonActionRecognizer.init_with_cfg(
skeleton_action_model_dir, args, skeleton_action_cfg)
device,
run_mode,
skeleton_action_batch_size,
trt_min_shape,
trt_max_shape,
trt_opt_shape,
trt_calib_mode,
cpu_threads,
enable_mkldnn,
window_size=skeleton_action_frames)
self.skeleton_action_visual_helper = ActionVisualHelper( self.skeleton_action_visual_helper = ActionVisualHelper(
display_frames) display_frames)
if self.modebase["skeletonbased"]:
kpt_cfg = self.cfg['KPT'] kpt_cfg = self.cfg['KPT']
kpt_model_dir = model_dir_dict['KPT'] kpt_model_dir = kpt_cfg['model_dir']
kpt_batch_size = kpt_cfg['batch_size'] kpt_batch_size = kpt_cfg['batch_size']
self.kpt_predictor = KeyPointDetector( self.kpt_predictor = KeyPointDetector(
kpt_model_dir, kpt_model_dir,
device, args.device,
run_mode, args.run_mode,
kpt_batch_size, kpt_batch_size,
trt_min_shape, args.trt_min_shape,
trt_max_shape, args.trt_max_shape,
trt_opt_shape, args.trt_opt_shape,
trt_calib_mode, args.trt_calib_mode,
cpu_threads, args.cpu_threads,
enable_mkldnn, args.enable_mkldnn,
use_dark=False) use_dark=False)
self.kpt_buff = KeyPointBuff(skeleton_action_frames) self.kpt_buff = KeyPointBuff(skeleton_action_frames)
if self.with_mtmct: if self.with_mtmct:
reid_cfg = self.cfg['REID'] reid_cfg = self.cfg['REID']
model_dir = model_dir_dict['REID']
batch_size = reid_cfg['batch_size']
basemode = self.basemode['REID'] basemode = self.basemode['REID']
self.modebase[basemode] = True self.modebase[basemode] = True
self.reid_predictor = ReID( self.reid_predictor = ReID.init_with_cfg(args, reid_cfg)
model_dir, device, run_mode, batch_size, trt_min_shape,
trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
enable_mkldnn)
if self.with_mot or self.modebase["idbased"] or self.modebase[ if self.with_mot or self.modebase["idbased"] or self.modebase[
"skeletonbased"]: "skeletonbased"]:
mot_cfg = self.cfg['MOT'] mot_cfg = self.cfg['MOT']
model_dir = model_dir_dict['MOT'] model_dir = mot_cfg['model_dir']
tracker_config = mot_cfg['tracker_config'] tracker_config = mot_cfg['tracker_config']
batch_size = mot_cfg['batch_size'] batch_size = mot_cfg['batch_size']
basemode = self.basemode['MOT'] basemode = self.basemode['MOT']
...@@ -521,46 +393,28 @@ class PipePredictor(object): ...@@ -521,46 +393,28 @@ class PipePredictor(object):
self.mot_predictor = SDE_Detector( self.mot_predictor = SDE_Detector(
model_dir, model_dir,
tracker_config, tracker_config,
device, args.device,
run_mode, args.run_mode,
batch_size, batch_size,
trt_min_shape, args.trt_min_shape,
trt_max_shape, args.trt_max_shape,
trt_opt_shape, args.trt_opt_shape,
trt_calib_mode, args.trt_calib_mode,
cpu_threads, args.cpu_threads,
enable_mkldnn, args.enable_mkldnn,
draw_center_traj=draw_center_traj, draw_center_traj=self.draw_center_traj,
secs_interval=secs_interval, secs_interval=self.secs_interval,
do_entrance_counting=do_entrance_counting, do_entrance_counting=self.do_entrance_counting,
do_break_in_counting=do_break_in_counting, do_break_in_counting=self.do_break_in_counting,
region_type=region_type, region_type=self.region_type,
region_polygon=region_polygon) region_polygon=self.region_polygon)
if self.with_video_action: if self.with_video_action:
video_action_cfg = self.cfg['VIDEO_ACTION'] video_action_cfg = self.cfg['VIDEO_ACTION']
basemode = self.basemode['VIDEO_ACTION'] basemode = self.basemode['VIDEO_ACTION']
self.modebase[basemode] = True self.modebase[basemode] = True
self.video_action_predictor = VideoActionRecognizer.init_with_cfg(
video_action_model_dir = model_dir_dict['VIDEO_ACTION'] args, video_action_cfg)
video_action_batch_size = video_action_cfg['batch_size']
short_size = video_action_cfg["short_size"]
target_size = video_action_cfg["target_size"]
self.video_action_predictor = VideoActionRecognizer(
model_dir=video_action_model_dir,
short_size=short_size,
target_size=target_size,
device=device,
run_mode=run_mode,
batch_size=video_action_batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
def set_file_name(self, path): def set_file_name(self, path):
if path is not None: if path is not None:
...@@ -701,6 +555,10 @@ class PipePredictor(object): ...@@ -701,6 +555,10 @@ class PipePredictor(object):
assert len( assert len(
self.region_polygon self.region_polygon
) % 2 == 0, "region_polygon should be pairs of coords points when do break_in counting." ) % 2 == 0, "region_polygon should be pairs of coords points when do break_in counting."
assert len(
self.region_polygon
) > 6, 'region_type is custom, region_polygon should be at least 3 pairs of point coords.'
for i in range(0, len(self.region_polygon), 2): for i in range(0, len(self.region_polygon), 2):
entrance.append( entrance.append(
[self.region_polygon[i], self.region_polygon[i + 1]]) [self.region_polygon[i], self.region_polygon[i + 1]])
......
...@@ -84,6 +84,20 @@ class SkeletonActionRecognizer(Detector): ...@@ -84,6 +84,20 @@ class SkeletonActionRecognizer(Detector):
threshold=threshold, threshold=threshold,
delete_shuffle_pass=True) delete_shuffle_pass=True)
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
window_size=cfg['max_frames'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def predict(self, repeats=1): def predict(self, repeats=1):
''' '''
Args: Args:
...@@ -322,6 +336,22 @@ class DetActionRecognizer(object): ...@@ -322,6 +336,22 @@ class DetActionRecognizer(object):
self.skip_frame_cnt = 0 self.skip_frame_cnt = 0
self.id_in_last_frame = [] self.id_in_last_frame = []
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
threshold=cfg['threshold'],
display_frames=cfg['display_frames'],
skip_frame_num=cfg['skip_frame_num'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def predict(self, images, mot_result): def predict(self, images, mot_result):
if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)): if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)):
det_result = self.detector.predict_image(images, visual=False) det_result = self.detector.predict_image(images, visual=False)
...@@ -473,6 +503,22 @@ class ClsActionRecognizer(AttrDetector): ...@@ -473,6 +503,22 @@ class ClsActionRecognizer(AttrDetector):
self.skip_frame_cnt = 0 self.skip_frame_cnt = 0
self.id_in_last_frame = [] self.id_in_last_frame = []
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
threshold=cfg['threshold'],
display_frames=cfg['display_frames'],
skip_frame_num=cfg['skip_frame_num'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def predict_with_mot(self, images, mot_result): def predict_with_mot(self, images, mot_result):
if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)): if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)):
images = self.crop_half_body(images) images = self.crop_half_body(images)
......
...@@ -84,6 +84,19 @@ class AttrDetector(Detector): ...@@ -84,6 +84,19 @@ class AttrDetector(Detector):
output_dir=output_dir, output_dir=output_dir,
threshold=threshold, ) threshold=threshold, )
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def get_label(self): def get_label(self):
return self.pred_config.labels return self.pred_config.labels
......
...@@ -75,6 +75,19 @@ class ReID(object): ...@@ -75,6 +75,19 @@ class ReID(object):
self.batch_size = batch_size self.batch_size = batch_size
self.input_wh = (128, 256) self.input_wh = (128, 256)
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def set_config(self, model_dir): def set_config(self, model_dir):
return PredictConfig(model_dir) return PredictConfig(model_dir)
......
...@@ -126,6 +126,21 @@ class VideoActionRecognizer(object): ...@@ -126,6 +126,21 @@ class VideoActionRecognizer(object):
self.predictor = create_predictor(self.config) self.predictor = create_predictor(self.config)
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
short_size=cfg['short_size'],
target_size=cfg['target_size'],
batch_size=cfg['batch_size'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def preprocess_batch(self, file_list): def preprocess_batch(self, file_list):
batched_inputs = [] batched_inputs = []
for file in file_list: for file in file_list:
......
...@@ -90,6 +90,21 @@ class VehicleAttr(AttrDetector): ...@@ -90,6 +90,21 @@ class VehicleAttr(AttrDetector):
"estate" "estate"
] ]
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
color_threshold=cfg['color_threshold'],
type_threshold=cfg['type_threshold'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def postprocess(self, inputs, result): def postprocess(self, inputs, result):
# postprocess output of predictor # postprocess output of predictor
im_results = result['output'] im_results = result['output']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册