未验证 提交 ff8a7b1d 编写于 作者: J JYChen 提交者: GitHub

move initialize part into class (#6621)

上级 6409d062
此差异已折叠。
......@@ -84,6 +84,20 @@ class SkeletonActionRecognizer(Detector):
threshold=threshold,
delete_shuffle_pass=True)
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
window_size=cfg['max_frames'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def predict(self, repeats=1):
'''
Args:
......@@ -322,6 +336,22 @@ class DetActionRecognizer(object):
self.skip_frame_cnt = 0
self.id_in_last_frame = []
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
threshold=cfg['threshold'],
display_frames=cfg['display_frames'],
skip_frame_num=cfg['skip_frame_num'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def predict(self, images, mot_result):
if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)):
det_result = self.detector.predict_image(images, visual=False)
......@@ -473,6 +503,22 @@ class ClsActionRecognizer(AttrDetector):
self.skip_frame_cnt = 0
self.id_in_last_frame = []
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
threshold=cfg['threshold'],
display_frames=cfg['display_frames'],
skip_frame_num=cfg['skip_frame_num'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def predict_with_mot(self, images, mot_result):
if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)):
images = self.crop_half_body(images)
......
......@@ -84,6 +84,19 @@ class AttrDetector(Detector):
output_dir=output_dir,
threshold=threshold, )
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def get_label(self):
return self.pred_config.labels
......
......@@ -75,6 +75,19 @@ class ReID(object):
self.batch_size = batch_size
self.input_wh = (128, 256)
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def set_config(self, model_dir):
return PredictConfig(model_dir)
......
......@@ -126,6 +126,21 @@ class VideoActionRecognizer(object):
self.predictor = create_predictor(self.config)
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
short_size=cfg['short_size'],
target_size=cfg['target_size'],
batch_size=cfg['batch_size'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def preprocess_batch(self, file_list):
batched_inputs = []
for file in file_list:
......
......@@ -90,6 +90,21 @@ class VehicleAttr(AttrDetector):
"estate"
]
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
color_threshold=cfg['color_threshold'],
type_threshold=cfg['type_threshold'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def postprocess(self, inputs, result):
# postprocess output of predictor
im_results = result['output']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册