提交 14b7a24b 编写于 作者: Z zhiboniu 提交者: zhiboniu

del enable-xxx command

上级 81236740
...@@ -8,11 +8,6 @@ DET: ...@@ -8,11 +8,6 @@ DET:
model_dir: output_inference/mot_ppyoloe_l_36e_pipeline/ model_dir: output_inference/mot_ppyoloe_l_36e_pipeline/
batch_size: 1 batch_size: 1
ATTR:
model_dir: output_inference/strongbaseline_r50_30e_pa100k/
batch_size: 8
basemode: "idbased"
MOT: MOT:
model_dir: output_inference/mot_ppyoloe_l_36e_pipeline/ model_dir: output_inference/mot_ppyoloe_l_36e_pipeline/
tracker_config: deploy/pphuman/config/tracker_config.yml tracker_config: deploy/pphuman/config/tracker_config.yml
...@@ -23,15 +18,23 @@ KPT: ...@@ -23,15 +18,23 @@ KPT:
model_dir: output_inference/dark_hrnet_w32_256x192/ model_dir: output_inference/dark_hrnet_w32_256x192/
batch_size: 8 batch_size: 8
FALLING: ATTR:
model_dir: output_inference/strongbaseline_r50_30e_pa100k/
batch_size: 8
basemode: "idbased"
enable: False
SKELETON_ACTION:
model_dir: output_inference/STGCN model_dir: output_inference/STGCN
batch_size: 1 batch_size: 1
max_frames: 50 max_frames: 50
display_frames: 80 display_frames: 80
coord_size: [384, 512] coord_size: [384, 512]
basemode: "skeletonbased" basemode: "skeletonbased"
enable: False
REID: REID:
model_dir: output_inference/reid_model/ model_dir: output_inference/reid_model/
batch_size: 16 batch_size: 16
basemode: "idbased" basemode: "idbased"
enable: False
...@@ -23,7 +23,7 @@ class Result(object): ...@@ -23,7 +23,7 @@ class Result(object):
'mot': dict(), 'mot': dict(),
'attr': dict(), 'attr': dict(),
'kpt': dict(), 'kpt': dict(),
'falling': dict(), 'skeleton_action': dict(),
'reid': dict() 'reid': dict()
} }
...@@ -53,13 +53,13 @@ class DataCollector(object): ...@@ -53,13 +53,13 @@ class DataCollector(object):
- qualities(list of float): Nx[float] - qualities(list of float): Nx[float]
- attrs(list of attr): refer to attrs for details - attrs(list of attr): refer to attrs for details
- kpts(list of kpts): refer to kpts for details - kpts(list of kpts): refer to kpts for details
- falling(list of falling): refer to falling for details - skeleton_action(list of skeleton_action): refer to skeleton_action for details
... ...
- [idN] - [idN]
""" """
def __init__(self): def __init__(self):
#id, frame, rect, score, label, attrs, kpts, falling #id, frame, rect, score, label, attrs, kpts, skeleton_action
self.mots = { self.mots = {
"frames": [], "frames": [],
"rects": [], "rects": [],
...@@ -67,7 +67,7 @@ class DataCollector(object): ...@@ -67,7 +67,7 @@ class DataCollector(object):
"kpts": [], "kpts": [],
"features": [], "features": [],
"qualities": [], "qualities": [],
"falling": [] "skeleton_action": []
} }
self.collector = {} self.collector = {}
...@@ -75,7 +75,7 @@ class DataCollector(object): ...@@ -75,7 +75,7 @@ class DataCollector(object):
mot_res = Result.get('mot') mot_res = Result.get('mot')
attr_res = Result.get('attr') attr_res = Result.get('attr')
kpt_res = Result.get('kpt') kpt_res = Result.get('kpt')
falling_res = Result.get('falling') skeleton_action_res = Result.get('skeleton_action')
reid_res = Result.get('reid') reid_res = Result.get('reid')
rects = [] rects = []
...@@ -95,11 +95,12 @@ class DataCollector(object): ...@@ -95,11 +95,12 @@ class DataCollector(object):
if kpt_res: if kpt_res:
self.collector[ids]["kpts"].append( self.collector[ids]["kpts"].append(
[kpt_res['keypoint'][0][idx], kpt_res['keypoint'][1][idx]]) [kpt_res['keypoint'][0][idx], kpt_res['keypoint'][1][idx]])
if falling_res and (idx + 1) in falling_res: if skeleton_action_res and (idx + 1) in skeleton_action_res:
self.collector[ids]["falling"].append(falling_res[idx + 1]) self.collector[ids]["skeleton_action"].append(
skeleton_action_res[idx + 1])
else: else:
# action model generate result per X frames, Not available every frames # action model generate result per X frames, Not available every frames
self.collector[ids]["falling"].append(None) self.collector[ids]["skeleton_action"].append(None)
if reid_res: if reid_res:
self.collector[ids]["features"].append(reid_res['features'][ self.collector[ids]["features"].append(reid_res['features'][
idx]) idx])
......
...@@ -57,21 +57,6 @@ def argsparser(): ...@@ -57,21 +57,6 @@ def argsparser():
type=int, type=int,
default=-1, default=-1,
help="device id of camera to predict.") help="device id of camera to predict.")
parser.add_argument(
"--enable_attr",
type=ast.literal_eval,
default=False,
help="Whether use attribute recognition.")
parser.add_argument(
"--enable_falling",
type=ast.literal_eval,
default=False,
help="Whether use action recognition.")
parser.add_argument(
"--enable_mtmct",
type=ast.literal_eval,
default=False,
help="Whether to enable multi-camera reid track.")
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
type=str, type=str,
...@@ -167,7 +152,7 @@ class PipeTimer(Times): ...@@ -167,7 +152,7 @@ class PipeTimer(Times):
'mot': Times(), 'mot': Times(),
'attr': Times(), 'attr': Times(),
'kpt': Times(), 'kpt': Times(),
'falling': Times(), 'skeleton_action': Times(),
'reid': Times() 'reid': Times()
} }
self.img_num = 0 self.img_num = 0
...@@ -212,9 +197,9 @@ class PipeTimer(Times): ...@@ -212,9 +197,9 @@ class PipeTimer(Times):
dic['kpt'] = round(self.module_time['kpt'].value() / dic['kpt'] = round(self.module_time['kpt'].value() /
max(1, self.img_num), max(1, self.img_num),
4) if average else self.module_time['kpt'].value() 4) if average else self.module_time['kpt'].value()
dic['falling'] = round( dic['skeleton_action'] = round(
self.module_time['falling'].value() / max(1, self.img_num), self.module_time['skeleton_action'].value() / max(1, self.img_num),
4) if average else self.module_time['falling'].value() 4) if average else self.module_time['skeleton_action'].value()
dic['img_num'] = self.img_num dic['img_num'] = self.img_num
return dic return dic
...@@ -222,7 +207,7 @@ class PipeTimer(Times): ...@@ -222,7 +207,7 @@ class PipeTimer(Times):
def merge_model_dir(args, model_dir): def merge_model_dir(args, model_dir):
# set --model_dir DET=ppyoloe/ to overwrite the model_dir in config file # set --model_dir DET=ppyoloe/ to overwrite the model_dir in config file
task_set = ['DET', 'ATTR', 'MOT', 'KPT', 'FALLING', 'REID'] task_set = ['DET', 'ATTR', 'MOT', 'KPT', 'SKELETON_ACTION', 'REID']
if not model_dir: if not model_dir:
return args return args
for md in model_dir: for md in model_dir:
......
...@@ -36,8 +36,8 @@ from python.infer import Detector, DetectorPicoDet ...@@ -36,8 +36,8 @@ from python.infer import Detector, DetectorPicoDet
from python.attr_infer import AttrDetector from python.attr_infer import AttrDetector
from python.keypoint_infer import KeyPointDetector from python.keypoint_infer import KeyPointDetector
from python.keypoint_postprocess import translate_to_ori_images from python.keypoint_postprocess import translate_to_ori_images
from python.action_infer import FallingRecognizer from python.action_infer import SkeletonActionRecognizer
from python.action_utils import KeyPointBuff, FallingVisualHelper from python.action_utils import KeyPointBuff, SkeletonActionVisualHelper
from pipe_utils import argsparser, print_arguments, merge_cfg, PipeTimer from pipe_utils import argsparser, print_arguments, merge_cfg, PipeTimer
from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint
...@@ -60,8 +60,6 @@ class Pipeline(object): ...@@ -60,8 +60,6 @@ class Pipeline(object):
then all the images in directory will be predicted, default as None then all the images in directory will be predicted, default as None
video_file (string|None): the path of video file, default as None video_file (string|None): the path of video file, default as None
camera_id (int): the device id of camera to predict, default as -1 camera_id (int): the device id of camera to predict, default as -1
enable_attr (bool): whether use attribute recognition, default as false
enable_falling (bool): whether use action recognition, default as false
device (string): the device to predict, options are: CPU/GPU/XPU, device (string): the device to predict, options are: CPU/GPU/XPU,
default as CPU default as CPU
run_mode (string): the mode of prediction, options are: run_mode (string): the mode of prediction, options are:
...@@ -88,9 +86,6 @@ class Pipeline(object): ...@@ -88,9 +86,6 @@ class Pipeline(object):
video_file=None, video_file=None,
video_dir=None, video_dir=None,
camera_id=-1, camera_id=-1,
enable_attr=False,
enable_falling=False,
enable_mtmct=False,
device='CPU', device='CPU',
run_mode='paddle', run_mode='paddle',
trt_min_shape=1, trt_min_shape=1,
...@@ -104,7 +99,8 @@ class Pipeline(object): ...@@ -104,7 +99,8 @@ class Pipeline(object):
secs_interval=10, secs_interval=10,
do_entrance_counting=False): do_entrance_counting=False):
self.multi_camera = False self.multi_camera = False
self.enable_mtmct = enable_mtmct reid_cfg = cfg.get('REID', False)
self.enable_mtmct = reid_cfg['enable'] if reid_cfg else False
self.is_video = False self.is_video = False
self.output_dir = output_dir self.output_dir = output_dir
self.vis_result = cfg['visual'] self.vis_result = cfg['visual']
...@@ -117,9 +113,6 @@ class Pipeline(object): ...@@ -117,9 +113,6 @@ class Pipeline(object):
cfg, cfg,
is_video=True, is_video=True,
multi_camera=True, multi_camera=True,
enable_attr=enable_attr,
enable_falling=enable_falling,
enable_mtmct=enable_mtmct,
device=device, device=device,
run_mode=run_mode, run_mode=run_mode,
trt_min_shape=trt_min_shape, trt_min_shape=trt_min_shape,
...@@ -135,9 +128,6 @@ class Pipeline(object): ...@@ -135,9 +128,6 @@ class Pipeline(object):
self.predictor = PipePredictor( self.predictor = PipePredictor(
cfg, cfg,
self.is_video, self.is_video,
enable_attr=enable_attr,
enable_falling=enable_falling,
enable_mtmct=enable_mtmct,
device=device, device=device,
run_mode=run_mode, run_mode=run_mode,
trt_min_shape=trt_min_shape, trt_min_shape=trt_min_shape,
...@@ -227,7 +217,7 @@ class PipePredictor(object): ...@@ -227,7 +217,7 @@ class PipePredictor(object):
1. Tracking 1. Tracking
2. Tracking -> Attribute 2. Tracking -> Attribute
3. Tracking -> KeyPoint -> Falling Recognition 3. Tracking -> KeyPoint -> SkeletonAction Recognition
Args: Args:
cfg (dict): config of models in pipeline cfg (dict): config of models in pipeline
...@@ -235,8 +225,6 @@ class PipePredictor(object): ...@@ -235,8 +225,6 @@ class PipePredictor(object):
multi_camera (bool): whether to use multi camera in pipeline, multi_camera (bool): whether to use multi camera in pipeline,
default as False default as False
camera_id (int): the device id of camera to predict, default as -1 camera_id (int): the device id of camera to predict, default as -1
enable_attr (bool): whether use attribute recognition, default as false
enable_falling (bool): whether use action recognition, default as false
device (string): the device to predict, options are: CPU/GPU/XPU, device (string): the device to predict, options are: CPU/GPU/XPU,
default as CPU default as CPU
run_mode (string): the mode of prediction, options are: run_mode (string): the mode of prediction, options are:
...@@ -260,9 +248,6 @@ class PipePredictor(object): ...@@ -260,9 +248,6 @@ class PipePredictor(object):
cfg, cfg,
is_video=True, is_video=True,
multi_camera=False, multi_camera=False,
enable_attr=False,
enable_falling=False,
enable_mtmct=False,
device='CPU', device='CPU',
run_mode='paddle', run_mode='paddle',
trt_min_shape=1, trt_min_shape=1,
...@@ -276,29 +261,19 @@ class PipePredictor(object): ...@@ -276,29 +261,19 @@ class PipePredictor(object):
secs_interval=10, secs_interval=10,
do_entrance_counting=False): do_entrance_counting=False):
if enable_attr and not cfg.get('ATTR', False): self.with_attr = cfg.get('ATTR', False)['enable'] if cfg.get(
ValueError( 'ATTR', False) else False
'enable_attr is set to True, please set ATTR in config file') self.with_skeleton_action = cfg.get(
if enable_falling and (not cfg.get('FALLING', False) or 'SKELETON_ACTION', False)['enable'] if cfg.get('SKELETON_ACTION',
not cfg.get('KPT', False)): False) else False
ValueError( self.with_mtmct = cfg.get('REID', False)['enable'] if cfg.get(
'enable_falling is set to True, please set KPT and FALLING in config file' 'REID', False) else False
)
self.with_attr = cfg.get('ATTR', False) and enable_attr
self.with_falling = cfg.get('FALLING', False) and enable_falling
self.with_mtmct = cfg.get('REID', False) and enable_mtmct
if self.with_attr: if self.with_attr:
print('Attribute Recognition enabled') print('Attribute Recognition enabled')
if self.with_falling: if self.with_skeleton_action:
print('Falling Recognition enabled') print('SkeletonAction Recognition enabled')
if enable_mtmct: if self.with_mtmct:
if not self.with_mtmct: print("MTMCT enabled")
print(
'Warning!!! MTMCT enabled, but cannot find REID config in [infer_cfg.yml], please check!'
)
else:
print("MTMCT enabled")
self.modebase = { self.modebase = {
"framebased": False, "framebased": False,
...@@ -371,29 +346,30 @@ class PipePredictor(object): ...@@ -371,29 +346,30 @@ class PipePredictor(object):
model_dir, device, run_mode, batch_size, trt_min_shape, model_dir, device, run_mode, batch_size, trt_min_shape,
trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads, trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
enable_mkldnn) enable_mkldnn)
if self.with_falling: if self.with_skeleton_action:
falling_cfg = self.cfg['FALLING'] skeleton_action_cfg = self.cfg['SKELETON_ACTION']
falling_model_dir = falling_cfg['model_dir'] skeleton_action_model_dir = skeleton_action_cfg['model_dir']
falling_batch_size = falling_cfg['batch_size'] skeleton_action_batch_size = skeleton_action_cfg['batch_size']
falling_frames = falling_cfg['max_frames'] skeleton_action_frames = skeleton_action_cfg['max_frames']
display_frames = falling_cfg['display_frames'] display_frames = skeleton_action_cfg['display_frames']
self.coord_size = falling_cfg['coord_size'] self.coord_size = skeleton_action_cfg['coord_size']
basemode = falling_cfg['basemode'] basemode = skeleton_action_cfg['basemode']
self.modebase[basemode] = True self.modebase[basemode] = True
self.falling_predictor = FallingRecognizer( self.skeleton_action_predictor = SkeletonActionRecognizer(
falling_model_dir, skeleton_action_model_dir,
device, device,
run_mode, run_mode,
falling_batch_size, skeleton_action_batch_size,
trt_min_shape, trt_min_shape,
trt_max_shape, trt_max_shape,
trt_opt_shape, trt_opt_shape,
trt_calib_mode, trt_calib_mode,
cpu_threads, cpu_threads,
enable_mkldnn, enable_mkldnn,
window_size=falling_frames) window_size=skeleton_action_frames)
self.falling_visual_helper = FallingVisualHelper(display_frames) self.skeleton_action_visual_helper = SkeletonActionVisualHelper(
display_frames)
if self.modebase["skeletonbased"]: if self.modebase["skeletonbased"]:
kpt_cfg = self.cfg['KPT'] kpt_cfg = self.cfg['KPT']
...@@ -411,7 +387,7 @@ class PipePredictor(object): ...@@ -411,7 +387,7 @@ class PipePredictor(object):
cpu_threads, cpu_threads,
enable_mkldnn, enable_mkldnn,
use_dark=False) use_dark=False)
self.kpt_buff = KeyPointBuff(falling_frames) self.kpt_buff = KeyPointBuff(skeleton_action_frames)
if self.with_mtmct: if self.with_mtmct:
reid_cfg = self.cfg['REID'] reid_cfg = self.cfg['REID']
...@@ -570,7 +546,7 @@ class PipePredictor(object): ...@@ -570,7 +546,7 @@ class PipePredictor(object):
continue continue
self.pipeline_res.update(mot_res, 'mot') self.pipeline_res.update(mot_res, 'mot')
if self.with_attr or self.with_falling: if self.with_attr or self.with_skeleton_action:
crop_input, new_bboxes, ori_bboxes = crop_image_with_mot( crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
frame, mot_res) frame, mot_res)
...@@ -583,7 +559,7 @@ class PipePredictor(object): ...@@ -583,7 +559,7 @@ class PipePredictor(object):
self.pipe_timer.module_time['attr'].end() self.pipe_timer.module_time['attr'].end()
self.pipeline_res.update(attr_res, 'attr') self.pipeline_res.update(attr_res, 'attr')
if self.with_falling: if self.with_skeleton_action:
if self.modebase["skeletonbased"]: if self.modebase["skeletonbased"]:
if frame_id > self.warmup_frame: if frame_id > self.warmup_frame:
self.pipe_timer.module_time['kpt'].start() self.pipe_timer.module_time['kpt'].start()
...@@ -606,22 +582,25 @@ class PipePredictor(object): ...@@ -606,22 +582,25 @@ class PipePredictor(object):
state = self.kpt_buff.get_state( state = self.kpt_buff.get_state(
) # whether frame num is enough or lost tracker ) # whether frame num is enough or lost tracker
falling_res = {} skeleton_action_res = {}
if state: if state:
if frame_id > self.warmup_frame: if frame_id > self.warmup_frame:
self.pipe_timer.module_time['falling'].start() self.pipe_timer.module_time[
'skeleton_action'].start()
collected_keypoint = self.kpt_buff.get_collected_keypoint( collected_keypoint = self.kpt_buff.get_collected_keypoint(
) # reoragnize kpt output with ID ) # reoragnize kpt output with ID
falling_input = parse_mot_keypoint(collected_keypoint, skeleton_action_input = parse_mot_keypoint(
self.coord_size) collected_keypoint, self.coord_size)
falling_res = self.falling_predictor.predict_skeleton_with_mot( skeleton_action_res = self.skeleton_action_predictor.predict_skeleton_with_mot(
falling_input) skeleton_action_input)
if frame_id > self.warmup_frame: if frame_id > self.warmup_frame:
self.pipe_timer.module_time['falling'].end() self.pipe_timer.module_time['skeleton_action'].end()
self.pipeline_res.update(falling_res, 'falling') self.pipeline_res.update(skeleton_action_res,
'skeleton_action')
if self.cfg['visual']: if self.cfg['visual']:
self.falling_visual_helper.update(falling_res) self.skeleton_action_visual_helper.update(
skeleton_action_res)
if self.with_mtmct and frame_id % 10 == 0: if self.with_mtmct and frame_id % 10 == 0:
crop_input, img_qualities, rects = self.reid_predictor.crop_image_with_mot( crop_input, img_qualities, rects = self.reid_predictor.crop_image_with_mot(
...@@ -726,10 +705,11 @@ class PipePredictor(object): ...@@ -726,10 +705,11 @@ class PipePredictor(object):
visual_thresh=self.cfg['kpt_thresh'], visual_thresh=self.cfg['kpt_thresh'],
returnimg=True) returnimg=True)
falling_res = result.get('falling') skeleton_action_res = result.get('skeleton_action')
if falling_res is not None: if skeleton_action_res is not None:
image = visualize_action(image, mot_res['boxes'], image = visualize_action(image, mot_res['boxes'],
self.falling_visual_helper, "Falling") self.skeleton_action_visual_helper,
"SkeletonAction")
return image return image
...@@ -768,8 +748,7 @@ def main(): ...@@ -768,8 +748,7 @@ def main():
print_arguments(cfg) print_arguments(cfg)
pipeline = Pipeline( pipeline = Pipeline(
cfg, FLAGS.image_file, FLAGS.image_dir, FLAGS.video_file, cfg, FLAGS.image_file, FLAGS.image_dir, FLAGS.video_file,
FLAGS.video_dir, FLAGS.camera_id, FLAGS.enable_attr, FLAGS.video_dir, FLAGS.camera_id, FLAGS.device, FLAGS.run_mode,
FLAGS.enable_falling, FLAGS.enable_mtmct, FLAGS.device, FLAGS.run_mode,
FLAGS.trt_min_shape, FLAGS.trt_max_shape, FLAGS.trt_opt_shape, FLAGS.trt_min_shape, FLAGS.trt_max_shape, FLAGS.trt_opt_shape,
FLAGS.trt_calib_mode, FLAGS.cpu_threads, FLAGS.enable_mkldnn, FLAGS.trt_calib_mode, FLAGS.cpu_threads, FLAGS.enable_mkldnn,
FLAGS.output_dir, FLAGS.draw_center_traj, FLAGS.secs_interval, FLAGS.output_dir, FLAGS.draw_center_traj, FLAGS.secs_interval,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册