未验证 提交 74410ff9 编写于 作者: G Guanghua Yu 提交者: GitHub

support depoly infer for fcos (#1330)

上级 419671a8
...@@ -74,7 +74,7 @@ class Resize(object): ...@@ -74,7 +74,7 @@ class Resize(object):
self.arch = arch self.arch = arch
self.use_cv2 = use_cv2 self.use_cv2 = use_cv2
self.interp = interp self.interp = interp
self.scale_set = {'RCNN', 'RetinaNet'} self.scale_set = {'RCNN', 'RetinaNet', 'FCOS'}
def __call__(self, im, im_info): def __call__(self, im, im_info):
""" """
...@@ -259,7 +259,7 @@ def create_inputs(im, im_info, model_arch='YOLO'): ...@@ -259,7 +259,7 @@ def create_inputs(im, im_info, model_arch='YOLO'):
scale = scale_x scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32') im_info = np.array([resize_shape + [scale]]).astype('float32')
inputs['im_info'] = im_info inputs['im_info'] = im_info
elif 'RCNN' in model_arch: elif ('RCNN' in model_arch) or ('FCOS' in model_arch):
scale = scale_x scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32') im_info = np.array([resize_shape + [scale]]).astype('float32')
im_shape = np.array([origin_shape + [1.]]).astype('float32') im_shape = np.array([origin_shape + [1.]]).astype('float32')
...@@ -276,7 +276,15 @@ class Config(): ...@@ -276,7 +276,15 @@ class Config():
Args: Args:
model_dir (str): root path of model.yml model_dir (str): root path of model.yml
""" """
support_models = ['YOLO', 'SSD', 'RetinaNet', 'RCNN', 'Face', 'TTF'] support_models = [
'YOLO',
'SSD',
'RetinaNet',
'RCNN',
'Face',
'TTF',
'FCOS',
]
def __init__(self, model_dir): def __init__(self, model_dir):
# parsing Yaml config for Preprocess # parsing Yaml config for Preprocess
...@@ -566,15 +574,19 @@ def predict_image(): ...@@ -566,15 +574,19 @@ def predict_image():
output_dir=FLAGS.output_dir) output_dir=FLAGS.output_dir)
def predict_video(): def predict_video(camera_id):
detector = Detector( detector = Detector(
FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode) FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode)
capture = cv2.VideoCapture(FLAGS.video_file) if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
fps = 30 fps = 30
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v') fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_name = os.path.split(FLAGS.video_file)[-1]
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(FLAGS.output_dir, video_name)
...@@ -594,6 +606,10 @@ def predict_video(): ...@@ -594,6 +606,10 @@ def predict_video():
mask_resolution=detector.config.mask_resolution) mask_resolution=detector.config.mask_resolution)
im = np.array(im) im = np.array(im)
writer.write(im) writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release() writer.release()
...@@ -617,6 +633,11 @@ if __name__ == '__main__': ...@@ -617,6 +633,11 @@ if __name__ == '__main__':
"--image_file", type=str, default='', help="Path of image file.") "--image_file", type=str, default='', help="Path of image file.")
parser.add_argument( parser.add_argument(
"--video_file", type=str, default='', help="Path of video file.") "--video_file", type=str, default='', help="Path of video file.")
parser.add_argument(
"--camera_id",
type=int,
default=-1,
help="device id of camera to predict.")
parser.add_argument( parser.add_argument(
"--run_mode", "--run_mode",
type=str, type=str,
...@@ -647,5 +668,5 @@ if __name__ == '__main__': ...@@ -647,5 +668,5 @@ if __name__ == '__main__':
assert "Cannot predict image and video at the same time" assert "Cannot predict image and video at the same time"
if FLAGS.image_file != '': if FLAGS.image_file != '':
predict_image() predict_image()
if FLAGS.video_file != '': if FLAGS.video_file != '' or FLAGS.camera_id != -1:
predict_video() predict_video(FLAGS.camera_id)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册