提交 d401a15e 编写于 作者: F FNRE 提交者: LielinJiang

fix bug of first order multi person (#329)

上级 4a538cc4
...@@ -33,6 +33,7 @@ from ppgan.faceutils import face_detection ...@@ -33,6 +33,7 @@ from ppgan.faceutils import face_detection
from .base_predictor import BasePredictor from .base_predictor import BasePredictor
IMAGE_SIZE = 256
class FirstOrderPredictor(BasePredictor): class FirstOrderPredictor(BasePredictor):
def __init__(self, def __init__(self,
...@@ -105,7 +106,6 @@ class FirstOrderPredictor(BasePredictor): ...@@ -105,7 +106,6 @@ class FirstOrderPredictor(BasePredictor):
def read_img(self, path): def read_img(self, path):
img = imageio.imread(path) img = imageio.imread(path)
img = img.astype(np.float32)
if img.ndim == 2: if img.ndim == 2:
img = np.expand_dims(img, axis=2) img = np.expand_dims(img, axis=2)
# som images have 4 channels # som images have 4 channels
...@@ -161,14 +161,14 @@ class FirstOrderPredictor(BasePredictor): ...@@ -161,14 +161,14 @@ class FirstOrderPredictor(BasePredictor):
reader.close() reader.close()
driving_video = [ driving_video = [
cv2.resize(frame, (256, 256)) / 255.0 for frame in driving_video cv2.resize(frame, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0 for frame in driving_video
] ]
results = [] results = []
# for single person # for single person
if not self.multi_person: if not self.multi_person:
h, w, _ = source_image.shape h, w, _ = source_image.shape
source_image = cv2.resize(source_image, (256, 256)) / 255.0 source_image = cv2.resize(source_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0
predictions = get_prediction(source_image) predictions = get_prediction(source_image)
imageio.mimsave(os.path.join(self.output, self.filename), [ imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w)) cv2.resize((frame * 255.0).astype('uint8'), (h, w))
...@@ -181,7 +181,7 @@ class FirstOrderPredictor(BasePredictor): ...@@ -181,7 +181,7 @@ class FirstOrderPredictor(BasePredictor):
print(str(len(bboxes)) + " persons have been detected") print(str(len(bboxes)) + " persons have been detected")
if len(bboxes) <= 1: if len(bboxes) <= 1:
h, w, _ = source_image.shape h, w, _ = source_image.shape
source_image = cv2.resize(source_image, (256, 256)) / 255.0 source_image = cv2.resize(source_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0
predictions = get_prediction(source_image) predictions = get_prediction(source_image)
imageio.mimsave(os.path.join(self.output, self.filename), [ imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w)) cv2.resize((frame * 255.0).astype('uint8'), (h, w))
...@@ -193,7 +193,7 @@ class FirstOrderPredictor(BasePredictor): ...@@ -193,7 +193,7 @@ class FirstOrderPredictor(BasePredictor):
# for multi person # for multi person
for rec in bboxes: for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]] face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = cv2.resize(face_image, (256, 256)) / 255.0 face_image = cv2.resize(face_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0
predictions = get_prediction(face_image) predictions = get_prediction(face_image)
results.append({'rec': rec, 'predict': predictions}) results.append({'rec': rec, 'predict': predictions})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册