提交 41183ea7 编写于 作者: F FlyingQianMM

fix prediction in bg_replace

上级 cd8e9ced
......@@ -19,7 +19,7 @@ import os.path as osp
import cv2
import numpy as np
from utils.humanseg_postprocess import postprocess, threshold_mask
from postprocess import postprocess, threshold_mask
import paddlex as pdx
import paddlex.utils.logging as logging
from paddlex.seg import transforms
......@@ -74,44 +74,29 @@ def parse_args():
return parser.parse_args()
def predict(img, model, test_transforms):
model.arrange_transforms(transforms=test_transforms, mode='test')
img, im_info = test_transforms(img.astype('float32'))
img = np.expand_dims(img, axis=0)
result = model.exe.run(model.test_prog,
feed={'image': img},
fetch_list=list(model.test_outputs.values()))
score_map = result[1]
score_map = np.squeeze(score_map, axis=0)
score_map = np.transpose(score_map, (1, 2, 0))
return score_map, im_info
def bg_replace(label_map, img, bg):
h, w, _ = img.shape
bg = cv2.resize(bg, (w, h))
label_map = np.repeat(label_map[:, :, np.newaxis], 3, axis=2)
comb = (label_map * img + (1 - label_map) * bg).astype(np.uint8)
return comb
def recover(img, im_info):
for info in im_info[::-1]:
if info[0] == 'resize':
w, h = info[1][1], info[1][0]
if im_info[0] == 'resize':
w, h = im_info[1][1], im_info[1][0]
img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
elif info[0] == 'padding':
w, h = info[1][0], info[1][0]
elif im_info[0] == 'padding':
w, h = im_info[1][0], im_info[1][0]
img = img[0:h, 0:w, :]
return img
def bg_replace(score_map, img, bg):
h, w, _ = img.shape
bg = cv2.resize(bg, (w, h))
score_map = np.repeat(score_map[:, :, np.newaxis], 3, axis=2)
comb = (score_map * img + (1 - score_map) * bg).astype(np.uint8)
return comb
def infer(args):
resize_h = args.image_shape[1]
resize_w = args.image_shape[0]
test_transforms = transforms.Compose(
[transforms.Resize((resize_w, resize_h)), transforms.Normalize()])
test_transforms = transforms.Compose([transforms.Normalize()])
model = pdx.load_model(args.model_dir)
if not osp.exists(args.save_dir):
......@@ -130,14 +115,27 @@ def infer(args):
raise Exception(
'The --background_image_path is not existed: {}'.format(
args.background_image_path))
img = cv2.imread(args.image_path)
score_map, im_info = predict(img, model, test_transforms)
score_map = score_map[:, :, 1]
score_map = recover(score_map, im_info)
im_shape = img.shape
im_scale_x = float(resize_w) / float(im_shape[1])
im_scale_y = float(resize_h) / float(im_shape[0])
im = cv2.resize(
img,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_LINEAR)
image = im.astype('float32')
im_info = ('resize', im_shape[0:2])
pred = model.predict(image, test_transforms)
label_map = pred['label_map']
label_map = recover(label_map, im_info)
bg = cv2.imread(args.background_image_path)
save_name = osp.basename(args.image_path)
save_path = osp.join(args.save_dir, save_name)
result = bg_replace(score_map, img, bg)
result = bg_replace(label_map, img, bg)
cv2.imwrite(save_path, result)
# 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片
......@@ -192,8 +190,21 @@ def infer(args):
while cap_video.isOpened():
ret, frame = cap_video.read()
if ret:
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
im_shape = frame.shape
im_scale_x = float(resize_w) / float(im_shape[1])
im_scale_y = float(resize_h) / float(im_shape[0])
im = cv2.resize(
frame,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_LINEAR)
image = im.astype('float32')
im_info = ('resize', im_shape[0:2])
pred = model.predict(image, test_transforms)
score_map = pred['score_map']
cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
......@@ -248,8 +259,21 @@ def infer(args):
while cap_video.isOpened():
ret, frame = cap_video.read()
if ret:
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
im_shape = frame.shape
im_scale_x = float(resize_w) / float(im_shape[1])
im_scale_y = float(resize_h) / float(im_shape[0])
im = cv2.resize(
frame,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_LINEAR)
image = im.astype('float32')
im_info = ('resize', im_shape[0:2])
pred = model.predict(image, test_transforms)
score_map = pred['score_map']
cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
......
......@@ -70,8 +70,8 @@ def video_infer(args):
resize_h = args.image_shape[1]
resize_w = args.image_shape[0]
test_transforms = transforms.Compose([transforms.Normalize()])
model = pdx.load_model(args.model_dir)
test_transforms = transforms.Compose([transforms.Normalize()])
if not args.video_path:
cap = cv2.VideoCapture(0)
else:
......@@ -115,7 +115,7 @@ def video_infer(args):
interpolation=cv2.INTER_LINEAR)
image = im.astype('float32')
im_info = ('resize', im_shape[0:2])
pred = model.predict(image)
pred = model.predict(image, test_transforms)
score_map = pred['score_map']
cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
score_map = 255 * score_map[:, :, 1]
......@@ -155,7 +155,7 @@ def video_infer(args):
interpolation=cv2.INTER_LINEAR)
image = im.astype('float32')
im_info = ('resize', im_shape[0:2])
pred = model.predict(image)
pred = model.predict(image, test_transforms)
score_map = pred['score_map']
cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册