未验证 提交 b29aa30f 编写于 作者: XYZ_916's avatar XYZ_916 提交者: GitHub

given inference model path, download model automatically (#6403)

* given inference model path, download model automatically

* encapsulate auto download model as a function

* given note that default model dir is download link

* change attr model download path to https://bj.bcebos.com/v1/paddledet/models/pipeline/PPLCNet_x1_0_person_attribute_945_infer.zip
上级 e6d4d2bc
...@@ -5,28 +5,28 @@ visual: True ...@@ -5,28 +5,28 @@ visual: True
warmup_frame: 50 warmup_frame: 50
DET: DET:
model_dir: output_inference/mot_ppyoloe_l_36e_pipeline/ model_dir: https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip
batch_size: 1 batch_size: 1
MOT: MOT:
model_dir: output_inference/mot_ppyoloe_l_36e_pipeline/ model_dir: https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip
tracker_config: deploy/pipeline/config/tracker_config.yml tracker_config: deploy/pipeline/config/tracker_config.yml
batch_size: 1 batch_size: 1
basemode: "idbased" basemode: "idbased"
enable: False enable: False
KPT: KPT:
model_dir: output_inference/dark_hrnet_w32_256x192/ model_dir: https://bj.bcebos.com/v1/paddledet/models/pipeline/dark_hrnet_w32_256x192.zip
batch_size: 8 batch_size: 8
ATTR: ATTR:
model_dir: output_inference/PPLCNet_x1_0_person_attribute_945_infer/ model_dir: https://bj.bcebos.com/v1/paddledet/models/pipeline/PPLCNet_x1_0_person_attribute_945_infer.zip
batch_size: 8 batch_size: 8
basemode: "idbased" basemode: "idbased"
enable: False enable: False
VIDEO_ACTION: VIDEO_ACTION:
model_dir: output_inference/ppTSM model_dir: https://videotag.bj.bcebos.com/PaddleVideo-release2.3/ppTSM_fight.zip
batch_size: 1 batch_size: 1
frame_len: 8 frame_len: 8
sample_freq: 7 sample_freq: 7
...@@ -36,7 +36,7 @@ VIDEO_ACTION: ...@@ -36,7 +36,7 @@ VIDEO_ACTION:
enable: False enable: False
SKELETON_ACTION: SKELETON_ACTION:
model_dir: output_inference/STGCN model_dir: https://bj.bcebos.com/v1/paddledet/models/pipeline/STGCN.zip
batch_size: 1 batch_size: 1
max_frames: 50 max_frames: 50
display_frames: 80 display_frames: 80
...@@ -45,7 +45,7 @@ SKELETON_ACTION: ...@@ -45,7 +45,7 @@ SKELETON_ACTION:
enable: False enable: False
ID_BASED_DETACTION: ID_BASED_DETACTION:
model_dir: output_inference/ppyoloe_crn_s_80e_smoking_visdrone model_dir: https://bj.bcebos.com/v1/paddledet/models/pipeline/ppyoloe_crn_s_80e_smoking_visdrone.zip
batch_size: 8 batch_size: 8
basemode: "idbased" basemode: "idbased"
threshold: 0.6 threshold: 0.6
...@@ -54,7 +54,7 @@ ID_BASED_DETACTION: ...@@ -54,7 +54,7 @@ ID_BASED_DETACTION:
enable: False enable: False
ID_BASED_CLSACTION: ID_BASED_CLSACTION:
model_dir: output_inference/PPHGNet_tiny_calling_halfbody model_dir: https://bj.bcebos.com/v1/paddledet/models/pipeline/PPHGNet_tiny_calling_halfbody.zip
batch_size: 8 batch_size: 8
basemode: "idbased" basemode: "idbased"
threshold: 0.8 threshold: 0.8
...@@ -63,7 +63,7 @@ ID_BASED_CLSACTION: ...@@ -63,7 +63,7 @@ ID_BASED_CLSACTION:
enable: False enable: False
REID: REID:
model_dir: output_inference/reid_model/ model_dir: https://bj.bcebos.com/v1/paddledet/models/pipeline/reid_model.zip
batch_size: 16 batch_size: 16
basemode: "idbased" basemode: "idbased"
enable: False enable: False
...@@ -40,7 +40,9 @@ PP-Human提供了目标检测、属性识别、行为识别、ReID预训练模 ...@@ -40,7 +40,9 @@ PP-Human提供了目标检测、属性识别、行为识别、ReID预训练模
| 行为识别 | 视频输入 行为识别 | 准确率: 96.43 | 单人2.7ms | - |[下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/STGCN.zip) | | 行为识别 | 视频输入 行为识别 | 准确率: 96.43 | 单人2.7ms | - |[下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/STGCN.zip) |
| ReID | 视频输入 跨镜跟踪 | mAP: 98.8 | 单人1.5ms | - |[下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/reid_model.zip) | | ReID | 视频输入 跨镜跟踪 | mAP: 98.8 | 单人1.5ms | - |[下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/reid_model.zip) |
下载模型后,解压至`./output_inference`文件夹 下载模型后,解压至`./output_inference`文件夹。
在配置文件中,模型路径默认为模型的下载路径,如果用户不修改,则在推理时会自动下载对应的模型。
**注意:** **注意:**
......
...@@ -47,7 +47,7 @@ SKELETON_ACTION: # 基于骨骼点的行为识别模型配置 ...@@ -47,7 +47,7 @@ SKELETON_ACTION: # 基于骨骼点的行为识别模型配置
``` ```
### 使用方法 ### 使用方法
1. 从上表链接中下载模型并解压到```./output_inference```路径下。 1. 从上表链接中下载模型并解压到```./output_inference```路径下。默认自动下载模型,如果手动下载,需要修改模型文件夹为模型存放路径。
2. 目前行为识别模块仅支持视频输入,根据期望开启的行为识别方案类型,设置infer_cfg_pphuman.yml中`SKELETON_ACTION`的enable: True, 然后启动命令如下: 2. 目前行为识别模块仅支持视频输入,根据期望开启的行为识别方案类型,设置infer_cfg_pphuman.yml中`SKELETON_ACTION`的enable: True, 然后启动命令如下:
```python ```python
python deploy/pipeline/pipeline.py --config deploy/pipeline/config/infer_cfg_pphuman.yml \ python deploy/pipeline/pipeline.py --config deploy/pipeline/config/infer_cfg_pphuman.yml \
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
## 使用方法 ## 使用方法
1. 从上表链接中下载模型并解压到```PaddleDetection/output_inference```路径下,并设置```deploy/pipeline/config/infer_cfg_pphuman.yml````ATTR`的enable: True 1. 从上表链接中下载模型并解压到```PaddleDetection/output_inference```路径下,并修改配置文件中模型路径,也可默认自动下载模型。设置```deploy/pipeline/config/infer_cfg_pphuman.yml````ATTR`的enable: True
`infer_cfg_pphuman.yml`中配置项说明: `infer_cfg_pphuman.yml`中配置项说明:
``` ```
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
## 使用方法 ## 使用方法
1. 从上表链接中下载模型并解压到```./output_inference```路径下 1. 从上表链接中下载模型并解压到```./output_inference```路径下,并修改配置文件中模型路径。默认为自动下载模型,无需做改动。
2. 图片输入时,是纯检测任务,启动命令如下 2. 图片输入时,是纯检测任务,启动命令如下
```python ```python
python deploy/pipeline/pipeline.py --config deploy/pipeline/config/infer_cfg_pphuman.yml \ python deploy/pipeline/pipeline.py --config deploy/pipeline/config/infer_cfg_pphuman.yml \
......
...@@ -7,7 +7,7 @@ PP-Human跨镜头跟踪模块主要目的在于提供一套简洁、高效的跨 ...@@ -7,7 +7,7 @@ PP-Human跨镜头跟踪模块主要目的在于提供一套简洁、高效的跨
## 使用方法 ## 使用方法
1. 下载模型 [REID模型](https://bj.bcebos.com/v1/paddledet/models/pipeline/reid_model.zip) 并解压到```./output_inference```路径下, MOT模型请参考[mot说明](./mot.md)文件下载。 1. 下载模型 [REID模型](https://bj.bcebos.com/v1/paddledet/models/pipeline/reid_model.zip) 并解压到```./output_inference```路径下,修改配置文件中模型路径。也可简单起见直接用默认配置,自动下载模型。 MOT模型请参考[mot说明](./mot.md)文件下载。
2. 跨镜头跟踪模式下,要求输入的多个视频放在同一目录下,同时开启infer_cfg_pphuman.yml 中的REID选择中的enable=True, 命令如下: 2. 跨镜头跟踪模式下,要求输入的多个视频放在同一目录下,同时开启infer_cfg_pphuman.yml 中的REID选择中的enable=True, 命令如下:
```python ```python
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import os.path as osp
import hashlib
import requests
import shutil
import tqdm
import time
import tarfile
import zipfile
from paddle.utils.download import _get_unique_endpoints
PPDET_WEIGHTS_DOWNLOAD_URL_PREFIX = 'https://paddledet.bj.bcebos.com/'
DOWNLOAD_RETRY_LIMIT = 3
WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/infer_weights")
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') \
or path.startswith('https://') \
or path.startswith('ppdet://')
def parse_url(url):
url = url.replace("ppdet://", PPDET_WEIGHTS_DOWNLOAD_URL_PREFIX)
return url
def map_path(url, root_dir, path_depth=1):
# parse path after download to decompress under root_dir
assert path_depth > 0, "path_depth should be a positive integer"
dirname = url
for _ in range(path_depth):
dirname = osp.dirname(dirname)
fpath = osp.relpath(url, dirname)
zip_formats = ['.zip', '.tar', '.gz']
for zip_format in zip_formats:
fpath = fpath.replace(zip_format, '')
return osp.join(root_dir, fpath)
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
return False
return True
def _check_exist_file_md5(filename, md5sum, url):
return _md5check(filename, md5sum)
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
while not (osp.exists(fullname) and _check_exist_file_md5(fullname, md5sum,
url)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
# NOTE: windows path join may incur \, which is invalid in url
if sys.platform == "win32":
url = url.replace('\\', '/')
req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
for chunk in tqdm.tqdm(
req.iter_content(chunk_size=1024),
total=(int(total_size) + 1023) // 1024,
unit='KB'):
f.write(chunk)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _download_dist(url, path, md5sum=None):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
return _download(url, path, md5sum)
else:
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
lock_path = fullname + '.download.lock'
if not osp.isdir(path):
os.makedirs(path)
if not osp.exists(fullname):
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if ParallelEnv().current_endpoint in unique_endpoints:
_download(url, path, md5sum)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(0.5)
return fullname
else:
return _download(url, path, md5sum)
def _move_and_merge_tree(src, dst):
"""
Move src directory to dst, if dst is already exists,
merge src to dst
"""
if not osp.exists(dst):
shutil.move(src, dst)
elif osp.isfile(src):
shutil.move(src, dst)
else:
for fp in os.listdir(src):
src_fp = osp.join(src, fp)
dst_fp = osp.join(dst, fp)
if osp.isdir(src_fp):
if osp.isdir(dst_fp):
_move_and_merge_tree(src_fp, dst_fp)
else:
shutil.move(src_fp, dst_fp)
elif osp.isfile(src_fp) and \
not osp.isfile(dst_fp):
shutil.move(src_fp, dst_fp)
def _decompress(fname):
"""
Decompress for zip and tar file
"""
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
fpath = osp.split(fname)[0]
fpath_tmp = osp.join(fpath, 'tmp')
if osp.isdir(fpath_tmp):
shutil.rmtree(fpath_tmp)
os.makedirs(fpath_tmp)
if fname.find('tar') >= 0:
with tarfile.open(fname) as tf:
tf.extractall(path=fpath_tmp)
elif fname.find('zip') >= 0:
with zipfile.ZipFile(fname) as zf:
zf.extractall(path=fpath_tmp)
elif fname.find('.txt') >= 0:
return
else:
raise TypeError("Unsupport compress file type {}".format(fname))
for f in os.listdir(fpath_tmp):
src_dir = osp.join(fpath_tmp, f)
dst_dir = osp.join(fpath, f)
_move_and_merge_tree(src_dir, dst_dir)
shutil.rmtree(fpath_tmp)
os.remove(fname)
def _decompress_dist(fname):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
_decompress(fname)
else:
lock_path = fname + '.decompress.lock'
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
# NOTE(dkp): _decompress_dist always performed after
# _download_dist, in _download_dist sub-trainers is waiting
# for download lock file release with sleeping, if decompress
# prograss is very fast and finished with in the sleeping gap
# time, e.g in tiny dataset such as coco_ce, spine_coco, main
# trainer may finish decompress and release lock file, so we
# only craete lock file in main trainer and all sub-trainer
# wait 1s for main trainer to create lock file, for 1s is
# twice as sleeping gap, this waiting time can keep all
# trainer pipeline in order
# **change this if you have more elegent methods**
if ParallelEnv().current_endpoint in unique_endpoints:
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
_decompress(fname)
os.remove(lock_path)
else:
time.sleep(1)
while os.path.exists(lock_path):
time.sleep(0.5)
else:
_decompress(fname)
def get_path(url, root_dir=WEIGHTS_HOME, md5sum=None, check_exist=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
url (str): download url
root_dir (str): root dir for downloading
md5sum (str): md5 sum of download package
"""
# parse path after download to decompress under root_dir
fullpath = map_path(url, root_dir)
# For same zip file, decompressed directory name different
# from zip file name, rename by following map
decompress_name_map = {"ppTSM_fight": "ppTSM", }
for k, v in decompress_name_map.items():
if fullpath.find(k) >= 0:
fullpath = osp.join(osp.split(fullpath)[0], v)
if osp.exists(fullpath) and check_exist:
if not osp.isfile(fullpath) or \
_check_exist_file_md5(fullpath, md5sum, url):
return fullpath, True
else:
os.remove(fullpath)
fullname = _download_dist(url, root_dir, md5sum)
# new weights format which postfix is 'pdparams' not
# need to decompress
if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']:
_decompress_dist(fullname)
return fullpath, False
def get_weights_path(url):
"""Get weights path from WEIGHTS_HOME, if not exists,
download it from url.
"""
url = parse_url(url)
path, _ = get_path(url, WEIGHTS_HOME)
return path
def auto_download_model(model_path):
# auto download
if is_url(model_path):
weight = get_weights_path(model_path)
return weight
return None
if __name__ == "__main__":
model_path = "https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip"
auto_download_model(model_path)
...@@ -51,6 +51,8 @@ from pphuman.mtmct import mtmct_process ...@@ -51,6 +51,8 @@ from pphuman.mtmct import mtmct_process
from ppvehicle.vehicle_plate import PlateRecognizer from ppvehicle.vehicle_plate import PlateRecognizer
from ppvehicle.vehicle_attr import VehicleAttr from ppvehicle.vehicle_attr import VehicleAttr
from download import auto_download_model
class Pipeline(object): class Pipeline(object):
""" """
...@@ -166,6 +168,45 @@ class Pipeline(object): ...@@ -166,6 +168,45 @@ class Pipeline(object):
self.predictor.run(self.input) self.predictor.run(self.input)
def get_model_dir(cfg):
# auto download inference model
model_dir_dict = {}
for key in cfg.keys():
if type(cfg[key]) == dict and \
("enable" in cfg[key].keys() and cfg[key]['enable']
or "enable" not in cfg[key].keys()):
if "model_dir" in cfg[key].keys():
model_dir = cfg[key]["model_dir"]
downloaded_model_dir = auto_download_model(model_dir)
if downloaded_model_dir:
model_dir = downloaded_model_dir
model_dir_dict[key] = model_dir
print(key, " model dir:", model_dir)
elif key == "VEHICLE_PLATE":
det_model_dir = cfg[key]["det_model_dir"]
downloaded_det_model_dir = auto_download_model(det_model_dir)
if downloaded_det_model_dir:
det_model_dir = downloaded_det_model_dir
model_dir_dict["det_model_dir"] = det_model_dir
print("det_model_dir model dir:", det_model_dir)
rec_model_dir = cfg[key]["rec_model_dir"]
downloaded_rec_model_dir = auto_download_model(rec_model_dir)
if downloaded_rec_model_dir:
rec_model_dir = downloaded_rec_model_dir
model_dir_dict["rec_model_dir"] = rec_model_dir
print("rec_model_dir model dir:", rec_model_dir)
elif key == "MOT": # for idbased and skeletonbased actions
model_dir = cfg[key]["model_dir"]
downloaded_model_dir = auto_download_model(model_dir)
if downloaded_model_dir:
model_dir = downloaded_model_dir
model_dir_dict[key] = model_dir
return model_dir_dict
class PipePredictor(object): class PipePredictor(object):
""" """
Predictor in single camera Predictor in single camera
...@@ -292,9 +333,12 @@ class PipePredictor(object): ...@@ -292,9 +333,12 @@ class PipePredictor(object):
self.file_name = None self.file_name = None
self.collector = DataCollector() self.collector = DataCollector()
# auto download inference model
model_dir_dict = get_model_dir(self.cfg)
if not is_video: if not is_video:
det_cfg = self.cfg['DET'] det_cfg = self.cfg['DET']
model_dir = det_cfg['model_dir'] model_dir = model_dir_dict['DET']
batch_size = det_cfg['batch_size'] batch_size = det_cfg['batch_size']
self.det_predictor = Detector( self.det_predictor = Detector(
model_dir, device, run_mode, batch_size, trt_min_shape, model_dir, device, run_mode, batch_size, trt_min_shape,
...@@ -302,7 +346,7 @@ class PipePredictor(object): ...@@ -302,7 +346,7 @@ class PipePredictor(object):
enable_mkldnn) enable_mkldnn)
if self.with_human_attr: if self.with_human_attr:
attr_cfg = self.cfg['ATTR'] attr_cfg = self.cfg['ATTR']
model_dir = attr_cfg['model_dir'] model_dir = model_dir_dict['ATTR']
batch_size = attr_cfg['batch_size'] batch_size = attr_cfg['batch_size']
basemode = attr_cfg['basemode'] basemode = attr_cfg['basemode']
self.modebase[basemode] = True self.modebase[basemode] = True
...@@ -313,7 +357,7 @@ class PipePredictor(object): ...@@ -313,7 +357,7 @@ class PipePredictor(object):
if self.with_vehicle_attr: if self.with_vehicle_attr:
vehicleattr_cfg = self.cfg['VEHICLE_ATTR'] vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
model_dir = vehicleattr_cfg['model_dir'] model_dir = model_dir_dict['VEHICLE_ATTR']
batch_size = vehicleattr_cfg['batch_size'] batch_size = vehicleattr_cfg['batch_size']
color_threshold = vehicleattr_cfg['color_threshold'] color_threshold = vehicleattr_cfg['color_threshold']
type_threshold = vehicleattr_cfg['type_threshold'] type_threshold = vehicleattr_cfg['type_threshold']
...@@ -327,7 +371,7 @@ class PipePredictor(object): ...@@ -327,7 +371,7 @@ class PipePredictor(object):
else: else:
if self.with_human_attr: if self.with_human_attr:
attr_cfg = self.cfg['ATTR'] attr_cfg = self.cfg['ATTR']
model_dir = attr_cfg['model_dir'] model_dir = model_dir_dict['ATTR']
batch_size = attr_cfg['batch_size'] batch_size = attr_cfg['batch_size']
basemode = attr_cfg['basemode'] basemode = attr_cfg['basemode']
self.modebase[basemode] = True self.modebase[basemode] = True
...@@ -337,7 +381,7 @@ class PipePredictor(object): ...@@ -337,7 +381,7 @@ class PipePredictor(object):
enable_mkldnn) enable_mkldnn)
if self.with_idbased_detaction: if self.with_idbased_detaction:
idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION'] idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION']
model_dir = idbased_detaction_cfg['model_dir'] model_dir = model_dir_dict['ID_BASED_DETACTION']
batch_size = idbased_detaction_cfg['batch_size'] batch_size = idbased_detaction_cfg['batch_size']
basemode = idbased_detaction_cfg['basemode'] basemode = idbased_detaction_cfg['basemode']
threshold = idbased_detaction_cfg['threshold'] threshold = idbased_detaction_cfg['threshold']
...@@ -363,7 +407,7 @@ class PipePredictor(object): ...@@ -363,7 +407,7 @@ class PipePredictor(object):
if self.with_idbased_clsaction: if self.with_idbased_clsaction:
idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION'] idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION']
model_dir = idbased_clsaction_cfg['model_dir'] model_dir = model_dir_dict['ID_BASED_CLSACTION']
batch_size = idbased_clsaction_cfg['batch_size'] batch_size = idbased_clsaction_cfg['batch_size']
basemode = idbased_clsaction_cfg['basemode'] basemode = idbased_clsaction_cfg['basemode']
threshold = idbased_clsaction_cfg['threshold'] threshold = idbased_clsaction_cfg['threshold']
...@@ -389,7 +433,7 @@ class PipePredictor(object): ...@@ -389,7 +433,7 @@ class PipePredictor(object):
if self.with_skeleton_action: if self.with_skeleton_action:
skeleton_action_cfg = self.cfg['SKELETON_ACTION'] skeleton_action_cfg = self.cfg['SKELETON_ACTION']
skeleton_action_model_dir = skeleton_action_cfg['model_dir'] skeleton_action_model_dir = model_dir_dict['SKELETON_ACTION']
skeleton_action_batch_size = skeleton_action_cfg['batch_size'] skeleton_action_batch_size = skeleton_action_cfg['batch_size']
skeleton_action_frames = skeleton_action_cfg['max_frames'] skeleton_action_frames = skeleton_action_cfg['max_frames']
display_frames = skeleton_action_cfg['display_frames'] display_frames = skeleton_action_cfg['display_frames']
...@@ -414,7 +458,7 @@ class PipePredictor(object): ...@@ -414,7 +458,7 @@ class PipePredictor(object):
if self.modebase["skeletonbased"]: if self.modebase["skeletonbased"]:
kpt_cfg = self.cfg['KPT'] kpt_cfg = self.cfg['KPT']
kpt_model_dir = kpt_cfg['model_dir'] kpt_model_dir = model_dir_dict['KPT']
kpt_batch_size = kpt_cfg['batch_size'] kpt_batch_size = kpt_cfg['batch_size']
self.kpt_predictor = KeyPointDetector( self.kpt_predictor = KeyPointDetector(
kpt_model_dir, kpt_model_dir,
...@@ -439,7 +483,7 @@ class PipePredictor(object): ...@@ -439,7 +483,7 @@ class PipePredictor(object):
if self.with_vehicle_attr: if self.with_vehicle_attr:
vehicleattr_cfg = self.cfg['VEHICLE_ATTR'] vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
model_dir = vehicleattr_cfg['model_dir'] model_dir = model_dir_dict['VEHICLE_ATTR']
batch_size = vehicleattr_cfg['batch_size'] batch_size = vehicleattr_cfg['batch_size']
color_threshold = vehicleattr_cfg['color_threshold'] color_threshold = vehicleattr_cfg['color_threshold']
type_threshold = vehicleattr_cfg['type_threshold'] type_threshold = vehicleattr_cfg['type_threshold']
...@@ -452,7 +496,7 @@ class PipePredictor(object): ...@@ -452,7 +496,7 @@ class PipePredictor(object):
if self.with_mtmct: if self.with_mtmct:
reid_cfg = self.cfg['REID'] reid_cfg = self.cfg['REID']
model_dir = reid_cfg['model_dir'] model_dir = model_dir_dict['REID']
batch_size = reid_cfg['batch_size'] batch_size = reid_cfg['batch_size']
basemode = reid_cfg['basemode'] basemode = reid_cfg['basemode']
self.modebase[basemode] = True self.modebase[basemode] = True
...@@ -464,7 +508,7 @@ class PipePredictor(object): ...@@ -464,7 +508,7 @@ class PipePredictor(object):
if self.with_mot or self.modebase["idbased"] or self.modebase[ if self.with_mot or self.modebase["idbased"] or self.modebase[
"skeletonbased"]: "skeletonbased"]:
mot_cfg = self.cfg['MOT'] mot_cfg = self.cfg['MOT']
model_dir = mot_cfg['model_dir'] model_dir = model_dir_dict['MOT']
tracker_config = mot_cfg['tracker_config'] tracker_config = mot_cfg['tracker_config']
batch_size = mot_cfg['batch_size'] batch_size = mot_cfg['batch_size']
basemode = mot_cfg['basemode'] basemode = mot_cfg['basemode']
...@@ -491,7 +535,7 @@ class PipePredictor(object): ...@@ -491,7 +535,7 @@ class PipePredictor(object):
basemode = video_action_cfg['basemode'] basemode = video_action_cfg['basemode']
self.modebase[basemode] = True self.modebase[basemode] = True
video_action_model_dir = video_action_cfg['model_dir'] video_action_model_dir = model_dir_dict['VIDEO_ACTION']
video_action_batch_size = video_action_cfg['batch_size'] video_action_batch_size = video_action_cfg['batch_size']
short_size = video_action_cfg["short_size"] short_size = video_action_cfg["short_size"]
target_size = video_action_cfg["target_size"] target_size = video_action_cfg["target_size"]
......
...@@ -118,8 +118,7 @@ class VideoActionRecognizer(object): ...@@ -118,8 +118,7 @@ class VideoActionRecognizer(object):
} }
if run_mode in precision_map.keys(): if run_mode in precision_map.keys():
self.config.enable_tensorrt_engine( self.config.enable_tensorrt_engine(
max_batch_size=self.batch_size, max_batch_size=8, precision_mode=precision_map[run_mode])
precision_mode=precision_map[run_mode])
self.config.enable_memory_optim() self.config.enable_memory_optim()
# use zero copy # use zero copy
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册