未验证 提交 deadba2c 编写于 作者: C chenjian 提交者: GitHub

Fix MOT module bug (#1665)

上级 756d321a
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
- ### 1、环境依赖 - ### 1、环境依赖
- paddledet >= 2.1.0 - paddledet >= 2.2.0
- opencv-python - opencv-python
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
``` ```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
- 在windows下安装,由于paddledet package会依赖cython-bbox以及pycocotools, 这两个包需要windows用户提前装好,可参考[cython-bbox安装](https://blog.csdn.net/qq_24739717/article/details/105588729)[pycocotools安装](https://github.com/PaddlePaddle/PaddleX/blob/release/1.3/docs/install.md#pycocotools安装问题)
## 三、模型API预测 ## 三、模型API预测
- ### 1、命令行预测 - ### 1、命令行预测
......
...@@ -31,12 +31,13 @@ from .tracker import StreamTracker ...@@ -31,12 +31,13 @@ from .tracker import StreamTracker
logger = setup_logger('Predict') logger = setup_logger('Predict')
@moduleinfo(name="fairmot_dla34", @moduleinfo(
type="CV/multiple_object_tracking", name="fairmot_dla34",
author="paddlepaddle", type="CV/multiple_object_tracking",
author_email="", author="paddlepaddle",
summary="Fairmot is a model for multiple object tracking.", author_email="",
version="1.0.0") summary="Fairmot is a model for multiple object tracking.",
version="1.0.0")
class FairmotTracker_1088x608: class FairmotTracker_1088x608:
def __init__(self): def __init__(self):
self.pretrained_model = os.path.join(self.directory, "fairmot_dla34_30e_1088x608") self.pretrained_model = os.path.join(self.directory, "fairmot_dla34_30e_1088x608")
...@@ -70,12 +71,13 @@ class FairmotTracker_1088x608: ...@@ -70,12 +71,13 @@ class FairmotTracker_1088x608:
tracker.load_weights_jde(self.pretrained_model) tracker.load_weights_jde(self.pretrained_model)
signal.signal(signal.SIGINT, self.signalhandler) signal.signal(signal.SIGINT, self.signalhandler)
# inference # inference
tracker.videostream_predict(video_stream=video_stream, tracker.videostream_predict(
output_dir=output_dir, video_stream=video_stream,
data_type='mot', output_dir=output_dir,
model_type='FairMOT', data_type='mot',
visualization=visualization, model_type='FairMOT',
draw_threshold=draw_threshold) visualization=visualization,
draw_threshold=draw_threshold)
def stream_mode(self, output_dir='mot_result', visualization=True, draw_threshold=0.5, use_gpu=False): def stream_mode(self, output_dir='mot_result', visualization=True, draw_threshold=0.5, use_gpu=False):
''' '''
...@@ -106,11 +108,12 @@ class FairmotTracker_1088x608: ...@@ -106,11 +108,12 @@ class FairmotTracker_1088x608:
return self return self
def __enter__(self): def __enter__(self):
self.tracker_generator = self.tracker.imagestream_predict(self.output_dir, self.tracker_generator = self.tracker.imagestream_predict(
data_type='mot', self.output_dir,
model_type='FairMOT', data_type='mot',
visualization=self.visualization, model_type='FairMOT',
draw_threshold=self.draw_threshold) visualization=self.visualization,
draw_threshold=self.draw_threshold)
next(self.tracker_generator) next(self.tracker_generator)
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
...@@ -129,10 +132,12 @@ class FairmotTracker_1088x608: ...@@ -129,10 +132,12 @@ class FairmotTracker_1088x608:
logger.info('No output images to save for video') logger.info('No output images to save for video')
return return
img = cv2.imread(os.path.join(save_dir, '00000.jpg')) img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path, video_writer = cv2.VideoWriter(
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), output_video_path,
fps=30, apiPreference=0,
frameSize=[img.shape[1], img.shape[0]]) fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)): for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i)) imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath) img = cv2.imread(imgpath)
...@@ -169,10 +174,11 @@ class FairmotTracker_1088x608: ...@@ -169,10 +174,11 @@ class FairmotTracker_1088x608:
""" """
Run as a command. Run as a command.
""" """
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name), self.parser = argparse.ArgumentParser(
prog='hub run {}'.format(self.name), description="Run the {} module.".format(self.name),
usage='%(prog)s', prog='hub run {}'.format(self.name),
add_help=True) usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
...@@ -204,10 +210,12 @@ class FairmotTracker_1088x608: ...@@ -204,10 +210,12 @@ class FairmotTracker_1088x608:
logger.info('No output images to save for video') logger.info('No output images to save for video')
return return
img = cv2.imread(os.path.join(save_dir, '00000.jpg')) img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path, video_writer = cv2.VideoWriter(
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), output_video_path,
fps=30, apiPreference=0,
frameSize=[img.shape[1], img.shape[0]]) fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)): for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i)) imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath) img = cv2.imread(imgpath)
...@@ -223,22 +231,16 @@ class FairmotTracker_1088x608: ...@@ -223,22 +231,16 @@ class FairmotTracker_1088x608:
""" """
self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not") self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
self.arg_config_group.add_argument('--output_dir', self.arg_config_group.add_argument(
type=str, '--output_dir', type=str, default='mot_result', help='Directory name for output tracking results.')
default='mot_result', self.arg_config_group.add_argument(
help='Directory name for output tracking results.') '--visualization', action='store_true', help="whether to save output as images.")
self.arg_config_group.add_argument('--visualization', self.arg_config_group.add_argument(
action='store_true', "--draw_threshold", type=float, default=0.5, help="Threshold to reserve the result for visualization.")
help="whether to save output as images.")
self.arg_config_group.add_argument("--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
Add the command input options. Add the command input options.
""" """
self.arg_input_group.add_argument('--video_stream', self.arg_input_group.add_argument(
type=str, '--video_stream', type=str, help="path to video stream, can be a video file or stream device number.")
help="path to video stream, can be a video file or stream device number.")
paddledet >= 2.1.0 cython
paddledet >= 2.2.0
opencv-python opencv-python
imageio
...@@ -159,13 +159,12 @@ class StreamTracker(object): ...@@ -159,13 +159,12 @@ class StreamTracker(object):
yield yield
results = [] results = []
while True: while True:
with paddle.no_grad(): try:
try: results, nf = next(generator)
results, nf = next(generator) yield results
yield results except StopIteration as e:
except StopIteration as e: self.write_mot_results(result_filename, results, data_type)
self.write_mot_results(result_filename, results, data_type) return
return
def videostream_predict(self, def videostream_predict(self,
video_stream, video_stream,
...@@ -175,7 +174,7 @@ class StreamTracker(object): ...@@ -175,7 +174,7 @@ class StreamTracker(object):
visualization=True, visualization=True,
draw_threshold=0.5): draw_threshold=0.5):
assert video_stream is not None, \ assert video_stream is not None, \
"--video_file or --image_dir should be set." "--video_stream should be set."
if not os.path.exists(output_dir): os.makedirs(output_dir) if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results') result_root = os.path.join(output_dir, 'mot_results')
...@@ -215,9 +214,10 @@ class StreamTracker(object): ...@@ -215,9 +214,10 @@ class StreamTracker(object):
img = cv2.imread(os.path.join(save_dir, '00000.jpg')) img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter( video_writer = cv2.VideoWriter(
output_video_path, output_video_path,
apiPreference=0,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30, fps=30,
frameSize=[img.shape[1], img.shape[0]]) frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)): for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i)) imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath) img = cv2.imread(imgpath)
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
- ### 1、环境依赖 - ### 1、环境依赖
- paddledet >= 2.1.0 - paddledet >= 2.2.0
- opencv-python - opencv-python
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
``` ```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
- 在windows下安装,由于paddledet package会依赖cython-bbox以及pycocotools, 这两个包需要windows用户提前装好,可参考[cython-bbox安装](https://blog.csdn.net/qq_24739717/article/details/105588729)[pycocotools安装](https://github.com/PaddlePaddle/PaddleX/blob/release/1.3/docs/install.md#pycocotools安装问题)
## 三、模型API预测 ## 三、模型API预测
......
...@@ -31,12 +31,13 @@ from .tracker import StreamTracker ...@@ -31,12 +31,13 @@ from .tracker import StreamTracker
logger = setup_logger('Predict') logger = setup_logger('Predict')
@moduleinfo(name="jde_darknet53", @moduleinfo(
type="CV/multiple_object_tracking", name="jde_darknet53",
author="paddlepaddle", type="CV/multiple_object_tracking",
author_email="", author="paddlepaddle",
summary="JDE is a joint detection and appearance embedding model for multiple object tracking.", author_email="",
version="1.0.0") summary="JDE is a joint detection and appearance embedding model for multiple object tracking.",
version="1.0.0")
class JDETracker_1088x608: class JDETracker_1088x608:
def __init__(self): def __init__(self):
self.pretrained_model = os.path.join(self.directory, "jde_darknet53_30e_1088x608") self.pretrained_model = os.path.join(self.directory, "jde_darknet53_30e_1088x608")
...@@ -70,12 +71,13 @@ class JDETracker_1088x608: ...@@ -70,12 +71,13 @@ class JDETracker_1088x608:
tracker.load_weights_jde(self.pretrained_model) tracker.load_weights_jde(self.pretrained_model)
signal.signal(signal.SIGINT, self.signalhandler) signal.signal(signal.SIGINT, self.signalhandler)
# inference # inference
tracker.videostream_predict(video_stream=video_stream, tracker.videostream_predict(
output_dir=output_dir, video_stream=video_stream,
data_type='mot', output_dir=output_dir,
model_type='JDE', data_type='mot',
visualization=visualization, model_type='JDE',
draw_threshold=draw_threshold) visualization=visualization,
draw_threshold=draw_threshold)
def stream_mode(self, output_dir='mot_result', visualization=True, draw_threshold=0.5, use_gpu=False): def stream_mode(self, output_dir='mot_result', visualization=True, draw_threshold=0.5, use_gpu=False):
''' '''
...@@ -106,11 +108,12 @@ class JDETracker_1088x608: ...@@ -106,11 +108,12 @@ class JDETracker_1088x608:
return self return self
def __enter__(self): def __enter__(self):
self.tracker_generator = self.tracker.imagestream_predict(self.output_dir, self.tracker_generator = self.tracker.imagestream_predict(
data_type='mot', self.output_dir,
model_type='JDE', data_type='mot',
visualization=self.visualization, model_type='JDE',
draw_threshold=self.draw_threshold) visualization=self.visualization,
draw_threshold=self.draw_threshold)
next(self.tracker_generator) next(self.tracker_generator)
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
...@@ -129,10 +132,12 @@ class JDETracker_1088x608: ...@@ -129,10 +132,12 @@ class JDETracker_1088x608:
logger.info('No output images to save for video') logger.info('No output images to save for video')
return return
img = cv2.imread(os.path.join(save_dir, '00000.jpg')) img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path, video_writer = cv2.VideoWriter(
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), output_video_path,
fps=30, apiPreference=0,
frameSize=[img.shape[1], img.shape[0]]) fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)): for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i)) imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath) img = cv2.imread(imgpath)
...@@ -169,10 +174,11 @@ class JDETracker_1088x608: ...@@ -169,10 +174,11 @@ class JDETracker_1088x608:
""" """
Run as a command. Run as a command.
""" """
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name), self.parser = argparse.ArgumentParser(
prog='hub run {}'.format(self.name), description="Run the {} module.".format(self.name),
usage='%(prog)s', prog='hub run {}'.format(self.name),
add_help=True) usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
...@@ -204,10 +210,12 @@ class JDETracker_1088x608: ...@@ -204,10 +210,12 @@ class JDETracker_1088x608:
logger.info('No output images to save for video') logger.info('No output images to save for video')
return return
img = cv2.imread(os.path.join(save_dir, '00000.jpg')) img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path, video_writer = cv2.VideoWriter(
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), output_video_path,
fps=30, apiPreference=0,
frameSize=[img.shape[1], img.shape[0]]) fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)): for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i)) imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath) img = cv2.imread(imgpath)
...@@ -223,22 +231,16 @@ class JDETracker_1088x608: ...@@ -223,22 +231,16 @@ class JDETracker_1088x608:
""" """
self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not") self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
self.arg_config_group.add_argument('--output_dir', self.arg_config_group.add_argument(
type=str, '--output_dir', type=str, default='mot_result', help='Directory name for output tracking results.')
default='mot_result', self.arg_config_group.add_argument(
help='Directory name for output tracking results.') '--visualization', action='store_true', help="whether to save output as images.")
self.arg_config_group.add_argument('--visualization', self.arg_config_group.add_argument(
action='store_true', "--draw_threshold", type=float, default=0.5, help="Threshold to reserve the result for visualization.")
help="whether to save output as images.")
self.arg_config_group.add_argument("--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
Add the command input options. Add the command input options.
""" """
self.arg_input_group.add_argument('--video_stream', self.arg_input_group.add_argument(
type=str, '--video_stream', type=str, help="path to video stream, can be a video file or stream device number.")
help="path to video stream, can be a video file or stream device number.")
paddledet >= 2.1.0 cython
paddledet >= 2.2.0
opencv-python opencv-python
imageio
...@@ -160,13 +160,12 @@ class StreamTracker(object): ...@@ -160,13 +160,12 @@ class StreamTracker(object):
yield yield
results = [] results = []
while True: while True:
with paddle.no_grad(): try:
try: results, nf = next(generator)
results, nf = next(generator) yield results
yield results except StopIteration as e:
except StopIteration as e: self.write_mot_results(result_filename, results, data_type)
self.write_mot_results(result_filename, results, data_type) return
return
def videostream_predict(self, def videostream_predict(self,
video_stream, video_stream,
...@@ -176,7 +175,7 @@ class StreamTracker(object): ...@@ -176,7 +175,7 @@ class StreamTracker(object):
visualization=True, visualization=True,
draw_threshold=0.5): draw_threshold=0.5):
assert video_stream is not None, \ assert video_stream is not None, \
"--video_file or --image_dir should be set." "--video_stream should be set."
if not os.path.exists(output_dir): os.makedirs(output_dir) if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results') result_root = os.path.join(output_dir, 'mot_results')
...@@ -214,7 +213,12 @@ class StreamTracker(object): ...@@ -214,7 +213,12 @@ class StreamTracker(object):
logger.info('No output images to save for video') logger.info('No output images to save for video')
return return
img = cv2.imread(os.path.join(save_dir, '00000.jpg')) img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path, fourcc=cv2.VideoWriter_fourcc('M','J','P','G'), fps=30, frameSize=[img.shape[1],img.shape[0]]) video_writer = cv2.VideoWriter(
output_video_path,
apiPreference=0,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=(img.shape[1], img.shape[0]))
for i in range(len(imgnames)): for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i)) imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath) img = cv2.imread(imgpath)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册