未验证 提交 183ffd18 编写于 作者: L lijianshe02 提交者: GitHub

add multipeople inference of fom (#202)

上级 8ece3c2d
...@@ -49,6 +49,11 @@ parser.add_argument("--best_frame", ...@@ -49,6 +49,11 @@ parser.add_argument("--best_frame",
default=None, default=None,
help="Set frame to start from.") help="Set frame to start from.")
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
parser.add_argument("--ratio",
dest="ratio",
type=float,
default=1.0,
help="margin ratio")
parser.set_defaults(relative=False) parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False) parser.set_defaults(adapt_scale=False)
...@@ -65,5 +70,6 @@ if __name__ == "__main__": ...@@ -65,5 +70,6 @@ if __name__ == "__main__":
relative=args.relative, relative=args.relative,
adapt_scale=args.adapt_scale, adapt_scale=args.adapt_scale,
find_best_frame=args.find_best_frame, find_best_frame=args.find_best_frame,
best_frame=args.best_frame) best_frame=args.best_frame,
ratio=args.ratio)
predictor.run(args.source_image, args.driving_video) predictor.run(args.source_image, args.driving_video)
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import os import os
import sys import sys
import cv2
import math
import yaml import yaml
import pickle import pickle
...@@ -29,6 +31,7 @@ from ppgan.utils.download import get_path_from_url ...@@ -29,6 +31,7 @@ from ppgan.utils.download import get_path_from_url
from ppgan.utils.animate import normalize_kp from ppgan.utils.animate import normalize_kp
from ppgan.modules.keypoint_detector import KPDetector from ppgan.modules.keypoint_detector import KPDetector
from ppgan.models.generators.occlusion_aware import OcclusionAwareGenerator from ppgan.models.generators.occlusion_aware import OcclusionAwareGenerator
from ppgan.faceutils import face_detection
from .base_predictor import BasePredictor from .base_predictor import BasePredictor
...@@ -41,7 +44,8 @@ class FirstOrderPredictor(BasePredictor): ...@@ -41,7 +44,8 @@ class FirstOrderPredictor(BasePredictor):
relative=False, relative=False,
adapt_scale=False, adapt_scale=False,
find_best_frame=False, find_best_frame=False,
best_frame=None): best_frame=None,
ratio=1.0):
if config is not None and isinstance(config, str): if config is not None and isinstance(config, str):
self.cfg = yaml.load(config, Loader=yaml.SafeLoader) self.cfg = yaml.load(config, Loader=yaml.SafeLoader)
elif isinstance(config, dict): elif isinstance(config, dict):
...@@ -88,11 +92,13 @@ class FirstOrderPredictor(BasePredictor): ...@@ -88,11 +92,13 @@ class FirstOrderPredictor(BasePredictor):
self.adapt_scale = adapt_scale self.adapt_scale = adapt_scale
self.find_best_frame = find_best_frame self.find_best_frame = find_best_frame
self.best_frame = best_frame self.best_frame = best_frame
self.ratio = ratio
self.generator, self.kp_detector = self.load_checkpoints( self.generator, self.kp_detector = self.load_checkpoints(
self.cfg, self.weight_path) self.cfg, self.weight_path)
def run(self, source_image, driving_video): def run(self, source_image, driving_video):
source_image = imageio.imread(source_image) source_image = imageio.imread(source_image)
bboxes = self.extract_bbox(source_image.copy())
reader = imageio.get_reader(driving_video) reader = imageio.get_reader(driving_video)
fps = reader.get_meta_data()['fps'] fps = reader.get_meta_data()['fps']
driving_video = [] driving_video = []
...@@ -103,44 +109,70 @@ class FirstOrderPredictor(BasePredictor): ...@@ -103,44 +109,70 @@ class FirstOrderPredictor(BasePredictor):
pass pass
reader.close() reader.close()
source_image = resize(source_image, (256, 256))[..., :3]
driving_video = [ driving_video = [
resize(frame, (256, 256))[..., :3] for frame in driving_video resize(frame, (256, 256))[..., :3] for frame in driving_video
] ]
results = []
for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = resize(face_image, (256, 256))
if self.find_best_frame or self.best_frame is not None: if self.find_best_frame or self.best_frame is not None:
i = self.best_frame if self.best_frame is not None else self.find_best_frame_func( i = self.best_frame if self.best_frame is not None else self.find_best_frame_func(
source_image, driving_video) source_image, driving_video)
print("Best frame: " + str(i)) print("Best frame: " + str(i))
driving_forward = driving_video[i:] driving_forward = driving_video[i:]
driving_backward = driving_video[:(i + 1)][::-1] driving_backward = driving_video[:(i + 1)][::-1]
predictions_forward = self.make_animation( predictions_forward = self.make_animation(
source_image, face_image,
driving_forward, driving_forward,
self.generator, self.generator,
self.kp_detector, self.kp_detector,
relative=self.relative, relative=self.relative,
adapt_movement_scale=self.adapt_scale) adapt_movement_scale=self.adapt_scale)
predictions_backward = self.make_animation( predictions_backward = self.make_animation(
source_image, face_image,
driving_backward, driving_backward,
self.generator, self.generator,
self.kp_detector, self.kp_detector,
relative=self.relative, relative=self.relative,
adapt_movement_scale=self.adapt_scale) adapt_movement_scale=self.adapt_scale)
predictions = predictions_backward[::-1] + predictions_forward[1:] predictions = predictions_backward[::-1] + predictions_forward[
else: 1:]
predictions = self.make_animation( else:
source_image, predictions = self.make_animation(
driving_video, face_image,
self.generator, driving_video,
self.kp_detector, self.generator,
relative=self.relative, self.kp_detector,
adapt_movement_scale=self.adapt_scale) relative=self.relative,
imageio.mimsave(os.path.join(self.output, 'result.mp4'), adapt_movement_scale=self.adapt_scale)
[img_as_ubyte(frame) for frame in predictions],
fps=fps) results.append({'rec': rec, 'predict': predictions})
out_frame = []
for i in range(len(driving_video)):
frame = source_image.copy()
for result in results:
x1, y1, x2, y2 = result['rec']
h = y2 - y1
w = x2 - x1
out = result['predict'][i] * 255.0
out = cv2.resize(out.astype(np.uint8), (x2 - x1, y2 - y1))
patch = np.zeros(frame.shape).astype('uint8')
patch[y1:y2, x1:x2] = out
mask = np.zeros(frame.shape[:2]).astype('uint8')
cx = int((x1 + x2) / 2)
cy = int((y1 + y2) / 2)
cv2.circle(mask, (cx, cy), math.ceil(h * self.ratio),
(255, 255, 255), -1, 8, 0)
frame = cv2.copyTo(patch, mask, frame)
out_frame.append(frame)
imageio.mimsave(os.path.join(self.output, 'result.mp4'),
[frame for frame in out_frame],
fps=fps)
def load_checkpoints(self, config, checkpoint_path): def load_checkpoints(self, config, checkpoint_path):
...@@ -220,3 +252,25 @@ class FirstOrderPredictor(BasePredictor): ...@@ -220,3 +252,25 @@ class FirstOrderPredictor(BasePredictor):
norm = new_norm norm = new_norm
frame_num = i frame_num = i
return frame_num return frame_num
def extract_bbox(self, image):
detector = face_detection.FaceAlignment(
face_detection.LandmarksType._2D, flip_input=False)
frame = [image]
predictions = detector.get_detections_for_image(np.array(frame))
results = []
h, w, _ = image.shape
for rect in predictions:
bh = rect[3] - rect[1]
bw = rect[2] - rect[0]
cy = rect[1] + int(bh / 2)
cx = rect[0] + int(bw / 2)
margin = max(bh, bw)
y1 = max(0, cy - margin)
x1 = max(0, cx - margin)
y2 = min(h, cy + margin)
x2 = min(w, cx + margin)
results.append([x1, y1, x2, y2])
boxes = np.array(results)
return boxes
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册