未验证 提交 deadba2c 编写于 作者: C chenjian 提交者: GitHub

Fix MOT module bug (#1665)

上级 756d321a
......@@ -31,7 +31,7 @@
- ### 1、环境依赖
- paddledet >= 2.1.0
- paddledet >= 2.2.0
- opencv-python
......@@ -42,6 +42,7 @@
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
- 在windows下安装,由于paddledet package会依赖cython-bbox以及pycocotools, 这两个包需要windows用户提前装好,可参考[cython-bbox安装](https://blog.csdn.net/qq_24739717/article/details/105588729)[pycocotools安装](https://github.com/PaddlePaddle/PaddleX/blob/release/1.3/docs/install.md#pycocotools安装问题)
## 三、模型API预测
- ### 1、命令行预测
......
......@@ -31,12 +31,13 @@ from .tracker import StreamTracker
logger = setup_logger('Predict')
@moduleinfo(name="fairmot_dla34",
type="CV/multiple_object_tracking",
author="paddlepaddle",
author_email="",
summary="Fairmot is a model for multiple object tracking.",
version="1.0.0")
@moduleinfo(
name="fairmot_dla34",
type="CV/multiple_object_tracking",
author="paddlepaddle",
author_email="",
summary="Fairmot is a model for multiple object tracking.",
version="1.0.0")
class FairmotTracker_1088x608:
def __init__(self):
self.pretrained_model = os.path.join(self.directory, "fairmot_dla34_30e_1088x608")
......@@ -70,12 +71,13 @@ class FairmotTracker_1088x608:
tracker.load_weights_jde(self.pretrained_model)
signal.signal(signal.SIGINT, self.signalhandler)
# inference
tracker.videostream_predict(video_stream=video_stream,
output_dir=output_dir,
data_type='mot',
model_type='FairMOT',
visualization=visualization,
draw_threshold=draw_threshold)
tracker.videostream_predict(
video_stream=video_stream,
output_dir=output_dir,
data_type='mot',
model_type='FairMOT',
visualization=visualization,
draw_threshold=draw_threshold)
def stream_mode(self, output_dir='mot_result', visualization=True, draw_threshold=0.5, use_gpu=False):
'''
......@@ -106,11 +108,12 @@ class FairmotTracker_1088x608:
return self
def __enter__(self):
self.tracker_generator = self.tracker.imagestream_predict(self.output_dir,
data_type='mot',
model_type='FairMOT',
visualization=self.visualization,
draw_threshold=self.draw_threshold)
self.tracker_generator = self.tracker.imagestream_predict(
self.output_dir,
data_type='mot',
model_type='FairMOT',
visualization=self.visualization,
draw_threshold=self.draw_threshold)
next(self.tracker_generator)
def __exit__(self, exc_type, exc_value, traceback):
......@@ -129,10 +132,12 @@ class FairmotTracker_1088x608:
logger.info('No output images to save for video')
return
img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=[img.shape[1], img.shape[0]])
video_writer = cv2.VideoWriter(
output_video_path,
apiPreference=0,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath)
......@@ -169,10 +174,11 @@ class FairmotTracker_1088x608:
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
......@@ -204,10 +210,12 @@ class FairmotTracker_1088x608:
logger.info('No output images to save for video')
return
img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=[img.shape[1], img.shape[0]])
video_writer = cv2.VideoWriter(
output_video_path,
apiPreference=0,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath)
......@@ -223,22 +231,16 @@ class FairmotTracker_1088x608:
"""
self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
self.arg_config_group.add_argument('--output_dir',
type=str,
default='mot_result',
help='Directory name for output tracking results.')
self.arg_config_group.add_argument('--visualization',
action='store_true',
help="whether to save output as images.")
self.arg_config_group.add_argument("--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
self.arg_config_group.add_argument(
'--output_dir', type=str, default='mot_result', help='Directory name for output tracking results.')
self.arg_config_group.add_argument(
'--visualization', action='store_true', help="whether to save output as images.")
self.arg_config_group.add_argument(
"--draw_threshold", type=float, default=0.5, help="Threshold to reserve the result for visualization.")
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument('--video_stream',
type=str,
help="path to video stream, can be a video file or stream device number.")
self.arg_input_group.add_argument(
'--video_stream', type=str, help="path to video stream, can be a video file or stream device number.")
paddledet >= 2.1.0
cython
paddledet >= 2.2.0
opencv-python
imageio
......@@ -159,13 +159,12 @@ class StreamTracker(object):
yield
results = []
while True:
with paddle.no_grad():
try:
results, nf = next(generator)
yield results
except StopIteration as e:
self.write_mot_results(result_filename, results, data_type)
return
try:
results, nf = next(generator)
yield results
except StopIteration as e:
self.write_mot_results(result_filename, results, data_type)
return
def videostream_predict(self,
video_stream,
......@@ -175,7 +174,7 @@ class StreamTracker(object):
visualization=True,
draw_threshold=0.5):
assert video_stream is not None, \
"--video_file or --image_dir should be set."
"--video_stream should be set."
if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results')
......@@ -215,9 +214,10 @@ class StreamTracker(object):
img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(
output_video_path,
apiPreference=0,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=[img.shape[1], img.shape[0]])
frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath)
......
......@@ -31,7 +31,7 @@
- ### 1、环境依赖
- paddledet >= 2.1.0
- paddledet >= 2.2.0
- opencv-python
......@@ -42,6 +42,7 @@
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
- 在windows下安装,由于paddledet package会依赖cython-bbox以及pycocotools, 这两个包需要windows用户提前装好,可参考[cython-bbox安装](https://blog.csdn.net/qq_24739717/article/details/105588729)[pycocotools安装](https://github.com/PaddlePaddle/PaddleX/blob/release/1.3/docs/install.md#pycocotools安装问题)
## 三、模型API预测
......
......@@ -31,12 +31,13 @@ from .tracker import StreamTracker
logger = setup_logger('Predict')
@moduleinfo(name="jde_darknet53",
type="CV/multiple_object_tracking",
author="paddlepaddle",
author_email="",
summary="JDE is a joint detection and appearance embedding model for multiple object tracking.",
version="1.0.0")
@moduleinfo(
name="jde_darknet53",
type="CV/multiple_object_tracking",
author="paddlepaddle",
author_email="",
summary="JDE is a joint detection and appearance embedding model for multiple object tracking.",
version="1.0.0")
class JDETracker_1088x608:
def __init__(self):
self.pretrained_model = os.path.join(self.directory, "jde_darknet53_30e_1088x608")
......@@ -70,12 +71,13 @@ class JDETracker_1088x608:
tracker.load_weights_jde(self.pretrained_model)
signal.signal(signal.SIGINT, self.signalhandler)
# inference
tracker.videostream_predict(video_stream=video_stream,
output_dir=output_dir,
data_type='mot',
model_type='JDE',
visualization=visualization,
draw_threshold=draw_threshold)
tracker.videostream_predict(
video_stream=video_stream,
output_dir=output_dir,
data_type='mot',
model_type='JDE',
visualization=visualization,
draw_threshold=draw_threshold)
def stream_mode(self, output_dir='mot_result', visualization=True, draw_threshold=0.5, use_gpu=False):
'''
......@@ -106,11 +108,12 @@ class JDETracker_1088x608:
return self
def __enter__(self):
self.tracker_generator = self.tracker.imagestream_predict(self.output_dir,
data_type='mot',
model_type='JDE',
visualization=self.visualization,
draw_threshold=self.draw_threshold)
self.tracker_generator = self.tracker.imagestream_predict(
self.output_dir,
data_type='mot',
model_type='JDE',
visualization=self.visualization,
draw_threshold=self.draw_threshold)
next(self.tracker_generator)
def __exit__(self, exc_type, exc_value, traceback):
......@@ -129,10 +132,12 @@ class JDETracker_1088x608:
logger.info('No output images to save for video')
return
img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=[img.shape[1], img.shape[0]])
video_writer = cv2.VideoWriter(
output_video_path,
apiPreference=0,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath)
......@@ -169,10 +174,11 @@ class JDETracker_1088x608:
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
......@@ -204,10 +210,12 @@ class JDETracker_1088x608:
logger.info('No output images to save for video')
return
img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=[img.shape[1], img.shape[0]])
video_writer = cv2.VideoWriter(
output_video_path,
apiPreference=0,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath)
......@@ -223,22 +231,16 @@ class JDETracker_1088x608:
"""
self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
self.arg_config_group.add_argument('--output_dir',
type=str,
default='mot_result',
help='Directory name for output tracking results.')
self.arg_config_group.add_argument('--visualization',
action='store_true',
help="whether to save output as images.")
self.arg_config_group.add_argument("--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
self.arg_config_group.add_argument(
'--output_dir', type=str, default='mot_result', help='Directory name for output tracking results.')
self.arg_config_group.add_argument(
'--visualization', action='store_true', help="whether to save output as images.")
self.arg_config_group.add_argument(
"--draw_threshold", type=float, default=0.5, help="Threshold to reserve the result for visualization.")
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument('--video_stream',
type=str,
help="path to video stream, can be a video file or stream device number.")
self.arg_input_group.add_argument(
'--video_stream', type=str, help="path to video stream, can be a video file or stream device number.")
paddledet >= 2.1.0
cython
paddledet >= 2.2.0
opencv-python
imageio
......@@ -160,13 +160,12 @@ class StreamTracker(object):
yield
results = []
while True:
with paddle.no_grad():
try:
results, nf = next(generator)
yield results
except StopIteration as e:
self.write_mot_results(result_filename, results, data_type)
return
try:
results, nf = next(generator)
yield results
except StopIteration as e:
self.write_mot_results(result_filename, results, data_type)
return
def videostream_predict(self,
video_stream,
......@@ -176,7 +175,7 @@ class StreamTracker(object):
visualization=True,
draw_threshold=0.5):
assert video_stream is not None, \
"--video_file or --image_dir should be set."
"--video_stream should be set."
if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results')
......@@ -214,7 +213,12 @@ class StreamTracker(object):
logger.info('No output images to save for video')
return
img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path, fourcc=cv2.VideoWriter_fourcc('M','J','P','G'), fps=30, frameSize=[img.shape[1],img.shape[0]])
video_writer = cv2.VideoWriter(
output_video_path,
apiPreference=0,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册