未验证 提交 ff8a7b1d 编写于 作者: J JYChen 提交者: GitHub

move initialize part into class (#6621)

上级 6409d062
此差异已折叠。
...@@ -84,6 +84,20 @@ class SkeletonActionRecognizer(Detector): ...@@ -84,6 +84,20 @@ class SkeletonActionRecognizer(Detector):
threshold=threshold, threshold=threshold,
delete_shuffle_pass=True) delete_shuffle_pass=True)
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
window_size=cfg['max_frames'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def predict(self, repeats=1): def predict(self, repeats=1):
''' '''
Args: Args:
...@@ -322,6 +336,22 @@ class DetActionRecognizer(object): ...@@ -322,6 +336,22 @@ class DetActionRecognizer(object):
self.skip_frame_cnt = 0 self.skip_frame_cnt = 0
self.id_in_last_frame = [] self.id_in_last_frame = []
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
threshold=cfg['threshold'],
display_frames=cfg['display_frames'],
skip_frame_num=cfg['skip_frame_num'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def predict(self, images, mot_result): def predict(self, images, mot_result):
if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)): if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)):
det_result = self.detector.predict_image(images, visual=False) det_result = self.detector.predict_image(images, visual=False)
...@@ -473,6 +503,22 @@ class ClsActionRecognizer(AttrDetector): ...@@ -473,6 +503,22 @@ class ClsActionRecognizer(AttrDetector):
self.skip_frame_cnt = 0 self.skip_frame_cnt = 0
self.id_in_last_frame = [] self.id_in_last_frame = []
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
threshold=cfg['threshold'],
display_frames=cfg['display_frames'],
skip_frame_num=cfg['skip_frame_num'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def predict_with_mot(self, images, mot_result): def predict_with_mot(self, images, mot_result):
if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)): if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)):
images = self.crop_half_body(images) images = self.crop_half_body(images)
......
...@@ -84,6 +84,19 @@ class AttrDetector(Detector): ...@@ -84,6 +84,19 @@ class AttrDetector(Detector):
output_dir=output_dir, output_dir=output_dir,
threshold=threshold, ) threshold=threshold, )
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def get_label(self): def get_label(self):
return self.pred_config.labels return self.pred_config.labels
......
...@@ -75,6 +75,19 @@ class ReID(object): ...@@ -75,6 +75,19 @@ class ReID(object):
self.batch_size = batch_size self.batch_size = batch_size
self.input_wh = (128, 256) self.input_wh = (128, 256)
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def set_config(self, model_dir): def set_config(self, model_dir):
return PredictConfig(model_dir) return PredictConfig(model_dir)
......
...@@ -126,6 +126,21 @@ class VideoActionRecognizer(object): ...@@ -126,6 +126,21 @@ class VideoActionRecognizer(object):
self.predictor = create_predictor(self.config) self.predictor = create_predictor(self.config)
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
short_size=cfg['short_size'],
target_size=cfg['target_size'],
batch_size=cfg['batch_size'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def preprocess_batch(self, file_list): def preprocess_batch(self, file_list):
batched_inputs = [] batched_inputs = []
for file in file_list: for file in file_list:
......
...@@ -90,6 +90,21 @@ class VehicleAttr(AttrDetector): ...@@ -90,6 +90,21 @@ class VehicleAttr(AttrDetector):
"estate" "estate"
] ]
@classmethod
def init_with_cfg(cls, args, cfg):
return cls(model_dir=cfg['model_dir'],
batch_size=cfg['batch_size'],
color_threshold=cfg['color_threshold'],
type_threshold=cfg['type_threshold'],
device=args.device,
run_mode=args.run_mode,
trt_min_shape=args.trt_min_shape,
trt_max_shape=args.trt_max_shape,
trt_opt_shape=args.trt_opt_shape,
trt_calib_mode=args.trt_calib_mode,
cpu_threads=args.cpu_threads,
enable_mkldnn=args.enable_mkldnn)
def postprocess(self, inputs, result): def postprocess(self, inputs, result):
# postprocess output of predictor # postprocess output of predictor
im_results = result['output'] im_results = result['output']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册