diff --git a/applications/tools/first-order-demo.py b/applications/tools/first-order-demo.py index 7a6d64db8db97fecb3358e5bb77e08109d1d1746..2183b7130f3f3034d6ee535453f2ffea6b4264a6 100644 --- a/applications/tools/first-order-demo.py +++ b/applications/tools/first-order-demo.py @@ -68,6 +68,11 @@ parser.add_argument("--multi_person", action="store_true", default=False, help="whether there is only one person in the image or not") +parser.add_argument("--image_size", + dest="image_size", + type=int, + default=256, + help="size of image") parser.set_defaults(relative=False) parser.set_defaults(adapt_scale=False) @@ -87,5 +92,6 @@ if __name__ == "__main__": best_frame=args.best_frame, ratio=args.ratio, face_detector=args.face_detector, - multi_person=args.multi_person) + multi_person=args.multi_person, + image_size=args.image_size) predictor.run(args.source_image, args.driving_video) diff --git a/applications/tools/wav2lip.py b/applications/tools/wav2lip.py index 74f40f6082f4caa9f5b39a44f88567f8f7c14986..e78841cd5c594ed6705e6a8bafa1cf0b1974fa37 100644 --- a/applications/tools/wav2lip.py +++ b/applications/tools/wav2lip.py @@ -23,7 +23,7 @@ parser.add_argument('--face', parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.', - default='results/result_voice.mp4') + default='result_voice.mp4') parser.add_argument( '--static', @@ -109,5 +109,16 @@ if __name__ == "__main__": if args.cpu: paddle.set_device('cpu') - predictor = Wav2LipPredictor(args) - predictor.run() + predictor = Wav2LipPredictor(checkpoint_path = args.checkpoint_path, + static = args.static, + fps = args.fps, + pads = args.pads, + face_det_batch_size = args.face_det_batch_size, + wav2lip_batch_size = args.wav2lip_batch_size, + resize_factor = args.resize_factor, + crop = args.crop, + box = args.box, + rotate = args.rotate, + nosmooth = args.nosmooth, + face_detector = args.face_detector) + predictor.run(args.face, args.audio, args.outfile) diff --git a/docs/zh_CN/apis/apps.md b/docs/zh_CN/apis/apps.md index c206d8ac0dfd41c17d70f2ecf21692279d4797d3..73979f39a4018970c88f9779bbcbf7aa337a053a 100644 --- a/docs/zh_CN/apis/apps.md +++ b/docs/zh_CN/apis/apps.md @@ -439,7 +439,7 @@ ppgan.apps.MiDaSPredictor(output=None, weight_path=None) ## ppgan.apps.Wav2lipPredictor ```python -ppgan.apps.FirstOrderPredictor(args) +ppgan.apps.FirstOrderPredictor() ``` > 构建Wav2lip模型的实例,此模型用来做唇形合成,即给定一个人物视频和一个音频,实现人物口型与输入语音同步。论文是A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild,论文链接: http://arxiv.org/abs/2008.10010. @@ -449,8 +449,8 @@ ppgan.apps.FirstOrderPredictor(args) > ``` > from ppgan.apps import Wav2LipPredictor > # The args parameter should be specified by argparse -> predictor = Wav2LipPredictor(args) -> predictor.run() +> predictor = Wav2LipPredictor() +> predictor.run(face, audio, outfile) > ``` > **参数:** diff --git a/ppgan/apps/first_order_predictor.py b/ppgan/apps/first_order_predictor.py index 397ede7b48faa81ca608ed662008e3277e0b7d67..e08e3a2d21ca571d3064501798c70ee22f0f9800 100644 --- a/ppgan/apps/first_order_predictor.py +++ b/ppgan/apps/first_order_predictor.py @@ -33,8 +33,6 @@ from ppgan.faceutils import face_detection from .base_predictor import BasePredictor -IMAGE_SIZE = 256 - class FirstOrderPredictor(BasePredictor): def __init__(self, output='output', @@ -47,7 +45,8 @@ class FirstOrderPredictor(BasePredictor): ratio=1.0, filename='result.mp4', face_detector='sfd', - multi_person=False): + multi_person=False, + image_size = 256): if config is not None and isinstance(config, str): with open(config) as f: self.cfg = yaml.load(f, Loader=yaml.SafeLoader) @@ -85,8 +84,12 @@ class FirstOrderPredictor(BasePredictor): } } } + self.image_size = image_size if weight_path is None: - vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams' + if self.image_size == 512: + vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk-512.pdparams' + else: + vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams' weight_path = get_path_from_url(vox_cpk_weight_url) self.weight_path = weight_path @@ -103,6 +106,7 @@ class FirstOrderPredictor(BasePredictor): self.generator, self.kp_detector = self.load_checkpoints( self.cfg, self.weight_path) self.multi_person = multi_person + def read_img(self, path): img = imageio.imread(path) @@ -161,42 +165,23 @@ class FirstOrderPredictor(BasePredictor): reader.close() driving_video = [ - cv2.resize(frame, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0 for frame in driving_video + cv2.resize(frame, (self.image_size, self.image_size)) / 255.0 for frame in driving_video ] results = [] - # for single person - if not self.multi_person: - h, w, _ = source_image.shape - source_image = cv2.resize(source_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0 - predictions = get_prediction(source_image) - imageio.mimsave(os.path.join(self.output, self.filename), [ - cv2.resize((frame * 255.0).astype('uint8'), (h, w)) - for frame in predictions - ], - fps=fps) - return - + bboxes = self.extract_bbox(source_image.copy()) print(str(len(bboxes)) + " persons have been detected") - if len(bboxes) <= 1: - h, w, _ = source_image.shape - source_image = cv2.resize(source_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0 - predictions = get_prediction(source_image) - imageio.mimsave(os.path.join(self.output, self.filename), [ - cv2.resize((frame * 255.0).astype('uint8'), (h, w)) - for frame in predictions - ], - fps=fps) - return + # for multi person for rec in bboxes: face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]] - face_image = cv2.resize(face_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0 + face_image = cv2.resize(face_image, (self.image_size, self.image_size)) / 255.0 predictions = get_prediction(face_image) results.append({'rec': rec, 'predict': predictions}) - + if len(bboxes) == 1 or not self.multi_person: + break out_frame = [] for i in range(len(driving_video)): @@ -206,9 +191,19 @@ class FirstOrderPredictor(BasePredictor): h = y2 - y1 w = x2 - x1 out = result['predict'][i] * 255.0 + #from ppgan.apps import RealSRPredictor + #sr = RealSRPredictor() + #sr_img = sr.run(out.astype(np.uint8)) out = cv2.resize(out.astype(np.uint8), (x2 - x1, y2 - y1)) + #out = cv2.resize(np.array(sr_img).astype(np.uint8), (x2 - x1, y2 - y1)) if len(results) == 1: + #imageio.imwrite(os.path.join(self.output, "blending_512_realsr","source"+str(i) + ".png"), frame) frame[y1:y2, x1:x2] = out + #imageio.imwrite(os.path.join(self.output, "blending_512_realsr","target"+str(i) + ".png"), frame) + #mask = np.ones(frame.shape).astype('uint8') * 255 + #mask[y1:y2, x1:x2] = (0,0,0) + #imageio.imwrite(os.path.join(self.output, "blending_512_realsr","mask"+str(i) + ".png"), mask) + else: patch = np.zeros(frame.shape).astype('uint8') patch[y1:y2, x1:x2] = out @@ -218,7 +213,7 @@ class FirstOrderPredictor(BasePredictor): cv2.circle(mask, (cx, cy), math.ceil(h * self.ratio), (255, 255, 255), -1, 8, 0) frame = cv2.copyTo(patch, mask, frame) - + out_frame.append(frame) imageio.mimsave(os.path.join(self.output, self.filename), [frame for frame in out_frame], diff --git a/ppgan/apps/wav2lip_predictor.py b/ppgan/apps/wav2lip_predictor.py index a8014bb97150838c454279274462dea2070bb325..b067243a619dbd5aa1f7792ed91815b49d52426a 100644 --- a/ppgan/apps/wav2lip_predictor.py +++ b/ppgan/apps/wav2lip_predictor.py @@ -17,12 +17,31 @@ mel_step_size = 16 class Wav2LipPredictor(BasePredictor): - def __init__(self, args): - self.args = args - if os.path.isfile(self.args.face) and path.basename( - self.args.face).split('.')[1] in ['jpg', 'png', 'jpeg']: - self.args.static = True + def __init__(self, checkpoint_path = None, + static = False, + fps = 25, + pads = [0, 10, 0, 0], + face_det_batch_size = 16, + wav2lip_batch_size = 128, + resize_factor = 1, + crop = [0, -1, 0, -1], + box = [-1, -1, -1, -1], + rotate = False, + nosmooth = False, + face_detector = 'sfd'): self.img_size = 96 + self.checkpoint_path = checkpoint_path + self.static = static + self.fps = fps, + self.pads = pads + self.face_det_batch_size = face_det_batch_size + self.wav2lip_batch_size = wav2lip_batch_size + self.resize_factor = resize_factor + self.crop = crop + self.box = box + self.rotate = rotate + self.nosmooth = nosmooth + self.face_detector = face_detector makedirs('./temp', exist_ok=True) def get_smoothened_boxes(self, boxes, T): @@ -38,9 +57,9 @@ class Wav2LipPredictor(BasePredictor): detector = face_detection.FaceAlignment( face_detection.LandmarksType._2D, flip_input=False, - face_detector=self.args.face_detector) + face_detector=self.face_detector) - batch_size = self.args.face_det_batch_size + batch_size = self.face_det_batch_size while 1: predictions = [] @@ -61,7 +80,7 @@ class Wav2LipPredictor(BasePredictor): break results = [] - pady1, pady2, padx1, padx2 = self.args.pads + pady1, pady2, padx1, padx2 = self.pads for rect, image in zip(predictions, images): if rect is None: cv2.imwrite( @@ -79,7 +98,7 @@ class Wav2LipPredictor(BasePredictor): results.append([x1, y1, x2, y2]) boxes = np.array(results) - if not self.args.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5) + if not self.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5) results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] @@ -89,8 +108,8 @@ class Wav2LipPredictor(BasePredictor): def datagen(self, frames, mels): img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] - if self.args.box[0] == -1: - if not self.args.static: + if self.box[0] == -1: + if not self.static: face_det_results = self.face_detect( frames) # BGR2RGB for CNN face detection else: @@ -98,12 +117,12 @@ class Wav2LipPredictor(BasePredictor): else: print( 'Using the specified bounding box instead of face detection...') - y1, y2, x1, x2 = self.args.box + y1, y2, x1, x2 = self.box face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)] for f in frames] for i, m in enumerate(mels): - idx = 0 if self.args.static else i % len(frames) + idx = 0 if self.static else i % len(frames) frame_to_save = frames[idx].copy() face, coords = face_det_results[idx].copy() @@ -114,7 +133,7 @@ class Wav2LipPredictor(BasePredictor): frame_batch.append(frame_to_save) coords_batch.append(coords) - if len(img_batch) >= self.args.wav2lip_batch_size: + if len(img_batch) >= self.wav2lip_batch_size: img_batch, mel_batch = np.asarray(img_batch), np.asarray( mel_batch) @@ -143,18 +162,22 @@ class Wav2LipPredictor(BasePredictor): yield img_batch, mel_batch, frame_batch, coords_batch - def run(self): - if not os.path.isfile(self.args.face): + def run(self, face, audio_seq, outfile): + if os.path.isfile(face) and path.basename( + face).split('.')[1] in ['jpg', 'png', 'jpeg']: + self.static = True + + if not os.path.isfile(face): raise ValueError( '--face argument must be a valid path to video/image file') elif path.basename( - self.args.face).split('.')[1] in ['jpg', 'png', 'jpeg']: - full_frames = [cv2.imread(self.args.face)] - fps = self.args.fps + face).split('.')[1] in ['jpg', 'png', 'jpeg']: + full_frames = [cv2.imread(face)] + fps = self.fps else: - video_stream = cv2.VideoCapture(self.args.face) + video_stream = cv2.VideoCapture(face) fps = video_stream.get(cv2.CAP_PROP_FPS) print('Reading video frames...') @@ -165,15 +188,15 @@ class Wav2LipPredictor(BasePredictor): if not still_reading: video_stream.release() break - if self.args.resize_factor > 1: + if self.resize_factor > 1: frame = cv2.resize( - frame, (frame.shape[1] // self.args.resize_factor, - frame.shape[0] // self.args.resize_factor)) + frame, (frame.shape[1] // self.resize_factor, + frame.shape[0] // self.resize_factor)) - if self.args.rotate: + if self.rotate: frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) - y1, y2, x1, x2 = self.args.crop + y1, y2, x1, x2 = self.crop if x2 == -1: x2 = frame.shape[1] if y2 == -1: y2 = frame.shape[0] @@ -184,18 +207,16 @@ class Wav2LipPredictor(BasePredictor): print("Number of frames available for inference: " + str(len(full_frames))) - if not self.args.audio.endswith('.wav'): + if not audio_seq.endswith('.wav'): print('Extracting raw audio...') command = 'ffmpeg -y -i {} -strict -2 {}'.format( - self.args.audio, 'temp/temp.wav') + audio_seq, 'temp/temp.wav') subprocess.call(command, shell=True) - self.args.audio = 'temp/temp.wav' + audio_seq = 'temp/temp.wav' - wav = audio.load_wav(self.args.audio, 16000) + wav = audio.load_wav(audio_seq, 16000) mel = audio.melspectrogram(wav) - print(mel.shape) - if np.isnan(mel.reshape(-1)).sum() > 0: raise ValueError( 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again' @@ -216,15 +237,15 @@ class Wav2LipPredictor(BasePredictor): full_frames = full_frames[:len(mel_chunks)] - batch_size = self.args.wav2lip_batch_size + batch_size = self.wav2lip_batch_size gen = self.datagen(full_frames.copy(), mel_chunks) model = Wav2Lip() - if self.args.checkpoint_path is None: + if self.checkpoint_path is None: model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL) weights = paddle.load(model_weights_path) else: - weights = paddle.load(self.args.checkpoint_path) + weights = paddle.load(self.checkpoint_path) model.load_dict(weights) model.eval() print("Model loaded") @@ -258,5 +279,5 @@ class Wav2LipPredictor(BasePredictor): out.release() command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format( - self.args.audio, 'temp/result.avi', self.args.outfile) + audio_seq, 'temp/result.avi', outfile) subprocess.call(command, shell=platform.system() != 'Windows')