提交 41183ea7 编写于 作者: F FlyingQianMM

fix prediction in bg_replace

上级 cd8e9ced
...@@ -19,7 +19,7 @@ import os.path as osp ...@@ -19,7 +19,7 @@ import os.path as osp
import cv2 import cv2
import numpy as np import numpy as np
from utils.humanseg_postprocess import postprocess, threshold_mask from postprocess import postprocess, threshold_mask
import paddlex as pdx import paddlex as pdx
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
from paddlex.seg import transforms from paddlex.seg import transforms
...@@ -74,44 +74,29 @@ def parse_args(): ...@@ -74,44 +74,29 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def predict(img, model, test_transforms): def bg_replace(label_map, img, bg):
model.arrange_transforms(transforms=test_transforms, mode='test') h, w, _ = img.shape
img, im_info = test_transforms(img.astype('float32')) bg = cv2.resize(bg, (w, h))
img = np.expand_dims(img, axis=0) label_map = np.repeat(label_map[:, :, np.newaxis], 3, axis=2)
result = model.exe.run(model.test_prog, comb = (label_map * img + (1 - label_map) * bg).astype(np.uint8)
feed={'image': img}, return comb
fetch_list=list(model.test_outputs.values()))
score_map = result[1]
score_map = np.squeeze(score_map, axis=0)
score_map = np.transpose(score_map, (1, 2, 0))
return score_map, im_info
def recover(img, im_info): def recover(img, im_info):
for info in im_info[::-1]: if im_info[0] == 'resize':
if info[0] == 'resize': w, h = im_info[1][1], im_info[1][0]
w, h = info[1][1], info[1][0]
img = cv2.resize(img, (w, h), cv2.INTER_LINEAR) img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
elif info[0] == 'padding': elif im_info[0] == 'padding':
w, h = info[1][0], info[1][0] w, h = im_info[1][0], im_info[1][0]
img = img[0:h, 0:w, :] img = img[0:h, 0:w, :]
return img return img
def bg_replace(score_map, img, bg):
h, w, _ = img.shape
bg = cv2.resize(bg, (w, h))
score_map = np.repeat(score_map[:, :, np.newaxis], 3, axis=2)
comb = (score_map * img + (1 - score_map) * bg).astype(np.uint8)
return comb
def infer(args): def infer(args):
resize_h = args.image_shape[1] resize_h = args.image_shape[1]
resize_w = args.image_shape[0] resize_w = args.image_shape[0]
test_transforms = transforms.Compose( test_transforms = transforms.Compose([transforms.Normalize()])
[transforms.Resize((resize_w, resize_h)), transforms.Normalize()])
model = pdx.load_model(args.model_dir) model = pdx.load_model(args.model_dir)
if not osp.exists(args.save_dir): if not osp.exists(args.save_dir):
...@@ -130,14 +115,27 @@ def infer(args): ...@@ -130,14 +115,27 @@ def infer(args):
raise Exception( raise Exception(
'The --background_image_path is not existed: {}'.format( 'The --background_image_path is not existed: {}'.format(
args.background_image_path)) args.background_image_path))
img = cv2.imread(args.image_path) img = cv2.imread(args.image_path)
score_map, im_info = predict(img, model, test_transforms) im_shape = img.shape
score_map = score_map[:, :, 1] im_scale_x = float(resize_w) / float(im_shape[1])
score_map = recover(score_map, im_info) im_scale_y = float(resize_h) / float(im_shape[0])
im = cv2.resize(
img,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_LINEAR)
image = im.astype('float32')
im_info = ('resize', im_shape[0:2])
pred = model.predict(image, test_transforms)
label_map = pred['label_map']
label_map = recover(label_map, im_info)
bg = cv2.imread(args.background_image_path) bg = cv2.imread(args.background_image_path)
save_name = osp.basename(args.image_path) save_name = osp.basename(args.image_path)
save_path = osp.join(args.save_dir, save_name) save_path = osp.join(args.save_dir, save_name)
result = bg_replace(score_map, img, bg) result = bg_replace(label_map, img, bg)
cv2.imwrite(save_path, result) cv2.imwrite(save_path, result)
# 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片 # 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片
...@@ -192,8 +190,21 @@ def infer(args): ...@@ -192,8 +190,21 @@ def infer(args):
while cap_video.isOpened(): while cap_video.isOpened():
ret, frame = cap_video.read() ret, frame = cap_video.read()
if ret: if ret:
score_map, im_info = predict(frame, model, test_transforms) im_shape = frame.shape
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) im_scale_x = float(resize_w) / float(im_shape[1])
im_scale_y = float(resize_h) / float(im_shape[0])
im = cv2.resize(
frame,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_LINEAR)
image = im.astype('float32')
im_info = ('resize', im_shape[0:2])
pred = model.predict(image, test_transforms)
score_map = pred['score_map']
cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h)) cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
score_map = 255 * score_map[:, :, 1] score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \ optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
...@@ -248,8 +259,21 @@ def infer(args): ...@@ -248,8 +259,21 @@ def infer(args):
while cap_video.isOpened(): while cap_video.isOpened():
ret, frame = cap_video.read() ret, frame = cap_video.read()
if ret: if ret:
score_map, im_info = predict(frame, model, test_transforms) im_shape = frame.shape
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) im_scale_x = float(resize_w) / float(im_shape[1])
im_scale_y = float(resize_h) / float(im_shape[0])
im = cv2.resize(
frame,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_LINEAR)
image = im.astype('float32')
im_info = ('resize', im_shape[0:2])
pred = model.predict(image, test_transforms)
score_map = pred['score_map']
cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h)) cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
score_map = 255 * score_map[:, :, 1] score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \ optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
......
...@@ -70,8 +70,8 @@ def video_infer(args): ...@@ -70,8 +70,8 @@ def video_infer(args):
resize_h = args.image_shape[1] resize_h = args.image_shape[1]
resize_w = args.image_shape[0] resize_w = args.image_shape[0]
test_transforms = transforms.Compose([transforms.Normalize()])
model = pdx.load_model(args.model_dir) model = pdx.load_model(args.model_dir)
test_transforms = transforms.Compose([transforms.Normalize()])
if not args.video_path: if not args.video_path:
cap = cv2.VideoCapture(0) cap = cv2.VideoCapture(0)
else: else:
...@@ -115,7 +115,7 @@ def video_infer(args): ...@@ -115,7 +115,7 @@ def video_infer(args):
interpolation=cv2.INTER_LINEAR) interpolation=cv2.INTER_LINEAR)
image = im.astype('float32') image = im.astype('float32')
im_info = ('resize', im_shape[0:2]) im_info = ('resize', im_shape[0:2])
pred = model.predict(image) pred = model.predict(image, test_transforms)
score_map = pred['score_map'] score_map = pred['score_map']
cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
score_map = 255 * score_map[:, :, 1] score_map = 255 * score_map[:, :, 1]
...@@ -155,7 +155,7 @@ def video_infer(args): ...@@ -155,7 +155,7 @@ def video_infer(args):
interpolation=cv2.INTER_LINEAR) interpolation=cv2.INTER_LINEAR)
image = im.astype('float32') image = im.astype('float32')
im_info = ('resize', im_shape[0:2]) im_info = ('resize', im_shape[0:2])
pred = model.predict(image) pred = model.predict(image, test_transforms)
score_map = pred['score_map'] score_map = pred['score_map']
cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) cur_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h)) cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册