提交 3cfe7bc3 编写于 作者: C chenguowei01

update video_infer.py

上级 f10970d7
......@@ -33,6 +33,33 @@ def parse_args():
return parser.parse_args()
def predict(img, model, test_transforms):
model.arrange_transform(transforms=test_transforms, mode='test')
img, im_info = test_transforms(img)
img = np.expand_dims(img, axis=0)
result = model.exe.run(
model.test_prog,
feed={'image': img},
fetch_list=list(model.test_outputs.values()))
score_map = result[1]
print(score_map)
score_map = np.squeeze(score_map, axis=0)
score_map = np.transpose(score_map, (1, 2, 0))
return score_map, im_info
def recover(img, im_info):
keys = list(im_info.keys())
for k in keys[::-1]:
if k == 'shape_before_resize':
h, w = im_info[k][0], im_info[k][1]
img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
elif k == 'shape_before_padding':
h, w = im_info[k][0], im_info[k][1]
img = img[0:h, 0:w]
return img
def video_infer(args):
test_transforms = transforms.Compose(
[transforms.Resize((192, 192)),
......@@ -52,6 +79,8 @@ def video_infer(args):
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
# 用于保存预测结果视频
if not osp.exists(args.save_dir):
os.makedirs(args.save_dir)
out = cv2.VideoWriter(
osp.join(args.save_dir, 'result.avi'),
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (width, height))
......@@ -59,8 +88,10 @@ def video_infer(args):
while cap.isOpened():
ret, frame = cap.read()
if ret:
results = model.predict(frame, test_transforms)
img_mat = postprocess(frame, results['score_map'])
score_map, im_info = predict(frame, model, test_transforms)
img = cv2.resize(frame, (192, 192))
img_mat = postprocess(img, score_map)
img_mat = recover(img_mat, im_info)
out.write(img_mat)
else:
break
......@@ -71,8 +102,11 @@ def video_infer(args):
while cap.isOpened():
ret, frame = cap.read()
if ret:
results = model.predict(frame, test_transforms)
img_mat = postprocess(frame, results['score_map'])
score_map, im_info = predict(frame, model, test_transforms)
img = cv2.resize(frame, (192, 192))
img_mat = postprocess(img, score_map)
img_mat = recover(img_mat, im_info)
print(img_mat.shape)
cv2.imshow('HumanSegmentation', img_mat)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册