未验证 提交 a82faaa2 编写于 作者: G George Ni 提交者: GitHub

[MOT] fix deepsort deploy trt infer (#3652)

上级 8a66e2d7
...@@ -193,6 +193,7 @@ class SDE_ReID(object): ...@@ -193,6 +193,7 @@ class SDE_ReID(object):
model_dir, model_dir,
device='CPU', device='CPU',
run_mode='fluid', run_mode='fluid',
batch_size=50,
trt_min_shape=1, trt_min_shape=1,
trt_max_shape=1088, trt_max_shape=1088,
trt_opt_shape=608, trt_opt_shape=608,
...@@ -203,6 +204,7 @@ class SDE_ReID(object): ...@@ -203,6 +204,7 @@ class SDE_ReID(object):
self.predictor, self.config = load_predictor( self.predictor, self.config = load_predictor(
model_dir, model_dir,
run_mode=run_mode, run_mode=run_mode,
batch_size=batch_size,
min_subgraph_size=self.pred_config.min_subgraph_size, min_subgraph_size=self.pred_config.min_subgraph_size,
device=device, device=device,
use_dynamic_shape=self.pred_config.use_dynamic_shape, use_dynamic_shape=self.pred_config.use_dynamic_shape,
...@@ -214,10 +216,12 @@ class SDE_ReID(object): ...@@ -214,10 +216,12 @@ class SDE_ReID(object):
enable_mkldnn=enable_mkldnn) enable_mkldnn=enable_mkldnn)
self.det_times = Timer() self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
self.batch_size = batch_size
assert pred_config.tracker, "Tracking model should have tracker" assert pred_config.tracker, "Tracking model should have tracker"
self.tracker = DeepSORTTracker() self.tracker = DeepSORTTracker()
def preprocess(self, crops): def preprocess(self, crops):
crops = crops[:self.batch_size]
inputs = {} inputs = {}
inputs['crops'] = np.array(crops).astype('float32') inputs['crops'] = np.array(crops).astype('float32')
return inputs return inputs
...@@ -423,6 +427,7 @@ def main(): ...@@ -423,6 +427,7 @@ def main():
FLAGS.reid_model_dir, FLAGS.reid_model_dir,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
batch_size=FLAGS.reid_batch_size,
trt_min_shape=FLAGS.trt_min_shape, trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape, trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape, trt_opt_shape=FLAGS.trt_opt_shape,
......
...@@ -114,6 +114,11 @@ def argsparser(): ...@@ -114,6 +114,11 @@ def argsparser():
default=None, default=None,
help=("Directory include:'model.pdiparams', 'model.pdmodel', " help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py.")) "'infer_cfg.yml', created by tools/export_model.py."))
parser.add_argument(
"--reid_batch_size",
type=int,
default=50,
help="max batch_size for reid model inference.")
parser.add_argument( parser.add_argument(
'--use_dark', '--use_dark',
type=bool, type=bool,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册