未验证 提交 a82faaa2 编写于 作者: G George Ni 提交者: GitHub

[MOT] fix deepsort deploy trt infer (#3652)

上级 8a66e2d7
......@@ -193,6 +193,7 @@ class SDE_ReID(object):
model_dir,
device='CPU',
run_mode='fluid',
batch_size=50,
trt_min_shape=1,
trt_max_shape=1088,
trt_opt_shape=608,
......@@ -203,6 +204,7 @@ class SDE_ReID(object):
self.predictor, self.config = load_predictor(
model_dir,
run_mode=run_mode,
batch_size=batch_size,
min_subgraph_size=self.pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
......@@ -214,10 +216,12 @@ class SDE_ReID(object):
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
self.batch_size = batch_size
assert pred_config.tracker, "Tracking model should have tracker"
self.tracker = DeepSORTTracker()
def preprocess(self, crops):
crops = crops[:self.batch_size]
inputs = {}
inputs['crops'] = np.array(crops).astype('float32')
return inputs
......@@ -423,6 +427,7 @@ def main():
FLAGS.reid_model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.reid_batch_size,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
......
......@@ -114,6 +114,11 @@ def argsparser():
default=None,
help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."))
parser.add_argument(
"--reid_batch_size",
type=int,
default=50,
help="max batch_size for reid model inference.")
parser.add_argument(
'--use_dark',
type=bool,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册