diff --git a/contrib/RealTimeHumanSeg/python/infer.py b/contrib/RealTimeHumanSeg/python/infer.py index 280091ac8b94e40b71379c96cb1462961ca63521..8f638c1af2519b69a6513c2b1ae975adcca7c09a 100644 --- a/contrib/RealTimeHumanSeg/python/infer.py +++ b/contrib/RealTimeHumanSeg/python/infer.py @@ -89,7 +89,28 @@ def PredictImage(seg, image_path): # Do Predicting on a video def PredictVideo(seg, video_path): - + cap = cv2.VideoCapture(video_path) + if cap.isOpened() == False: + print("Error opening video stream or file") + return + w = cap.get(cv2.CAP_PROP_FRAME_WIDTH) + h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) + fps = cap.get(cv2.CAP_PROP_FPS) + # Result Video Writer + out = cv2.VideoWriter('result.avi', + cv2.VideoWriter_fourcc('M','J','P','G'), + fps, + (int(w), int(h))) + # Start capturing from video + while(cap.isOpened()): + ret, frame = cap.read() + if ret == True: + im = seg.Predict(frame) + out.write(im); + else: + break + cap.release() + out.release() if __name__ == "__main__": if len(sys.argv) < 3: @@ -104,3 +125,4 @@ if __name__ == "__main__": scale = [1.0, 1.0, 1.0] eval_size = (192, 192) seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu) + PredictVideo(seg, input_path)