From d4650341b093941424e30064a72353bd34747578 Mon Sep 17 00:00:00 2001 From: chenjian Date: Fri, 17 Dec 2021 14:29:54 +0800 Subject: [PATCH] add wav2lip module (#1733) * add wav2lip module * fix a bug when output dir not exists * fix readme * fix readme * fix readme * modify demo image Co-authored-by: wuzewu Co-authored-by: KP <109694228@qq.com> --- modules/image/Image_gan/gan/wav2lip/README.md | 94 +++++++ modules/image/Image_gan/gan/wav2lip/model.py | 259 ++++++++++++++++++ modules/image/Image_gan/gan/wav2lip/module.py | 101 +++++++ .../Image_gan/gan/wav2lip/requirements.txt | 1 + 4 files changed, 455 insertions(+) create mode 100644 modules/image/Image_gan/gan/wav2lip/README.md create mode 100644 modules/image/Image_gan/gan/wav2lip/model.py create mode 100644 modules/image/Image_gan/gan/wav2lip/module.py create mode 100644 modules/image/Image_gan/gan/wav2lip/requirements.txt diff --git a/modules/image/Image_gan/gan/wav2lip/README.md b/modules/image/Image_gan/gan/wav2lip/README.md new file mode 100644 index 00000000..5305725a --- /dev/null +++ b/modules/image/Image_gan/gan/wav2lip/README.md @@ -0,0 +1,94 @@ +# wav2lip + +|模型名称|wav2lip| +| :--- | :---: | +|类别|图像 - 视频生成| +|网络|Wav2Lip| +|数据集|LRS2| +|是否支持Fine-tuning|否| +|模型大小|139MB| +|最新更新日期|2021-12-14| +|数据指标|-| + + +## 一、模型基本信息 + +- ### 应用效果展示 + - 样例结果示例: +

+ +
+ 输入图像 +
+ +
+ 输出视频 +
+

+ + +- ### 模型介绍 + + - Wav2Lip实现的是视频人物根据输入音频生成与语音同步的人物唇形,使得生成的视频人物口型与输入语音同步。Wav2Lip不仅可以基于静态图像来输出与目标语音匹配的唇形同步视频,还可以直接将动态的视频进行唇形转换,输出与目标语音匹配的视频。Wav2Lip实现唇形与语音精准同步突破的关键在于,它采用了唇形同步判别器,以强制生成器持续产生准确而逼真的唇部运动。此外,它通过在鉴别器中使用多个连续帧而不是单个帧,并使用视觉质量损失(而不仅仅是对比损失)来考虑时间相关性,从而改善了视觉质量。Wav2Lip适用于任何人脸、任何语言,对任意视频都能达到很高都准确率,可以无缝地与原始视频融合,还可以用于转换动画人脸。 + + + +## 二、安装 + +- ### 1、环境依赖 + - ffmpeg + - libsndfile +- ### 2、安装 + + - ```shell + $ hub install wav2lip + ``` + - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) + | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) + +## 三、模型API预测 + +- ### 1、命令行预测 + + - ```shell + # Read from a file + $ hub run wav2lip --face "/PATH/TO/VIDEO or IMAGE" --audio "/PATH/TO/AUDIO" + ``` + - 通过命令行方式人物唇形生成模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) + +- ### 2、预测代码示例 + + - ```python + import paddlehub as hub + + module = hub.Module(name="wav2lip") + face_input_path = "/PATH/TO/VIDEO or IMAGE" + audio_input_path = "/PATH/TO/AUDIO" + module.wav2lip_transfer(face=face_input_path, audio=audio_input_path, output_dir='./transfer_result/', use_gpu=True) + ``` + +- ### 3、API + + - ```python + def wav2lip_transfer(face, audio, output_dir ='./output_result/', use_gpu=False, visualization=True): + ``` + - 人脸唇形生成API。 + + - **参数** + + - face (str): 视频或图像文件的路径
+ - audio (str): 音频文件的路径
+ - output\_dir (str): 结果保存的路径;
+ - use\_gpu (bool): 是否使用 GPU;
+ - visualization(bool): 是否保存结果到本地文件夹 + + +## 四、更新历史 + +* 1.0.0 + + 初始发布 + + - ```shell + $ hub install wav2lip==1.0.0 + ``` diff --git a/modules/image/Image_gan/gan/wav2lip/model.py b/modules/image/Image_gan/gan/wav2lip/model.py new file mode 100644 index 00000000..3fa32ed9 --- /dev/null +++ b/modules/image/Image_gan/gan/wav2lip/model.py @@ -0,0 +1,259 @@ +from os import listdir, path, makedirs +import platform +import numpy as np +import scipy, cv2, os, sys, argparse +import json, subprocess, random, string +from tqdm import tqdm +from glob import glob +import paddle +from paddle.utils.download import get_weights_path_from_url +from ppgan.faceutils import face_detection +from ppgan.utils import audio +from ppgan.models.generators.wav2lip import Wav2Lip + +WAV2LIP_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams' +mel_step_size = 16 + + +class Wav2LipPredictor: + def __init__(self, + checkpoint_path=None, + static=False, + fps=25, + pads=[0, 10, 0, 0], + face_det_batch_size=16, + wav2lip_batch_size=128, + resize_factor=1, + crop=[0, -1, 0, -1], + box=[-1, -1, -1, -1], + rotate=False, + nosmooth=False, + face_detector='sfd', + face_enhancement=False): + self.img_size = 96 + self.checkpoint_path = checkpoint_path + self.static = static + self.fps = fps + self.pads = pads + self.face_det_batch_size = face_det_batch_size + self.wav2lip_batch_size = wav2lip_batch_size + self.resize_factor = resize_factor + self.crop = crop + self.box = box + self.rotate = rotate + self.nosmooth = nosmooth + self.face_detector = face_detector + self.face_enhancement = face_enhancement + if face_enhancement: + from ppgan.faceutils.face_enhancement import FaceEnhancement + self.faceenhancer = FaceEnhancement() + makedirs('./temp', exist_ok=True) + + def get_smoothened_boxes(self, boxes, T): + for i in range(len(boxes)): + if i + T > len(boxes): + window = boxes[len(boxes) - T:] + else: + window = boxes[i:i + T] + boxes[i] = np.mean(window, axis=0) + return boxes + + def face_detect(self, images): + detector = face_detection.FaceAlignment( + face_detection.LandmarksType._2D, flip_input=False, face_detector=self.face_detector) + + batch_size = self.face_det_batch_size + + while 1: + predictions = [] + try: + for i in tqdm(range(0, len(images), batch_size)): + predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) + except RuntimeError: + if batch_size == 1: + raise RuntimeError( + 'Image too big to run face detection on GPU. Please use the --resize_factor argument') + batch_size //= 2 + print('Recovering from OOM error; New batch size: {}'.format(batch_size)) + continue + break + + results = [] + pady1, pady2, padx1, padx2 = self.pads + for rect, image in zip(predictions, images): + if rect is None: + cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. + raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') + + y1 = max(0, rect[1] - pady1) + y2 = min(image.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image.shape[1], rect[2] + padx2) + + results.append([x1, y1, x2, y2]) + + boxes = np.array(results) + if not self.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5) + results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] + + del detector + return results + + def datagen(self, frames, mels): + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if self.box[0] == -1: + if not self.static: + face_det_results = self.face_detect(frames) # BGR2RGB for CNN face detection + else: + face_det_results = self.face_detect([frames[0]]) + else: + print('Using the specified bounding box instead of face detection...') + y1, y2, x1, x2 = self.box + face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)] for f in frames] + + for i, m in enumerate(mels): + idx = 0 if self.static else i % len(frames) + frame_to_save = frames[idx].copy() + face, coords = face_det_results[idx].copy() + + face = cv2.resize(face, (self.img_size, self.img_size)) + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= self.wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, self.img_size // 2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, self.img_size // 2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + + def run(self, face, audio_seq, output_dir, visualization=True): + if os.path.isfile(face) and path.basename(face).split('.')[1] in ['jpg', 'png', 'jpeg']: + self.static = True + + if not os.path.isfile(face): + raise ValueError('--face argument must be a valid path to video/image file') + + elif path.basename(face).split('.')[1] in ['jpg', 'png', 'jpeg']: + full_frames = [cv2.imread(face)] + fps = self.fps + + else: + video_stream = cv2.VideoCapture(face) + fps = video_stream.get(cv2.CAP_PROP_FPS) + + print('Reading video frames...') + + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + if self.resize_factor > 1: + frame = cv2.resize(frame, + (frame.shape[1] // self.resize_factor, frame.shape[0] // self.resize_factor)) + + if self.rotate: + frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) + + y1, y2, x1, x2 = self.crop + if x2 == -1: x2 = frame.shape[1] + if y2 == -1: y2 = frame.shape[0] + + frame = frame[y1:y2, x1:x2] + + full_frames.append(frame) + + print("Number of frames available for inference: " + str(len(full_frames))) + + if not audio_seq.endswith('.wav'): + print('Extracting raw audio...') + command = 'ffmpeg -y -i {} -strict -2 {}'.format(audio_seq, 'temp/temp.wav') + + subprocess.call(command, shell=True) + audio_seq = 'temp/temp.wav' + + wav = audio.load_wav(audio_seq, 16000) + mel = audio.melspectrogram(wav) + if np.isnan(mel.reshape(-1)).sum() > 0: + raise ValueError( + 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') + + mel_chunks = [] + mel_idx_multiplier = 80. / fps + i = 0 + while 1: + start_idx = int(i * mel_idx_multiplier) + if start_idx + mel_step_size > len(mel[0]): + mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) + break + mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size]) + i += 1 + + print("Length of mel chunks: {}".format(len(mel_chunks))) + + full_frames = full_frames[:len(mel_chunks)] + + batch_size = self.wav2lip_batch_size + gen = self.datagen(full_frames.copy(), mel_chunks) + + model = Wav2Lip() + if self.checkpoint_path is None: + model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL) + weights = paddle.load(model_weights_path) + else: + weights = paddle.load(self.checkpoint_path) + model.load_dict(weights) + model.eval() + print("Model loaded") + for i, (img_batch, mel_batch, frames, coords) in enumerate( + tqdm(gen, total=int(np.ceil(float(len(mel_chunks)) / batch_size)))): + if i == 0: + + frame_h, frame_w = full_frames[0].shape[:-1] + out = cv2.VideoWriter('temp/result.avi', cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) + + img_batch = paddle.to_tensor(np.transpose(img_batch, (0, 3, 1, 2))).astype('float32') + mel_batch = paddle.to_tensor(np.transpose(mel_batch, (0, 3, 1, 2))).astype('float32') + + with paddle.no_grad(): + pred = model(mel_batch, img_batch) + + pred = pred.numpy().transpose(0, 2, 3, 1) * 255. + + for p, f, c in zip(pred, frames, coords): + y1, y2, x1, x2 = c + if self.face_enhancement: + p = self.faceenhancer.enhance_from_image(p) + p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) + + f[y1:y2, x1:x2] = p + out.write(f) + + out.release() + os.makedirs(output_dir, exist_ok=True) + if visualization: + command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_seq, 'temp/result.avi', + os.path.join(output_dir, 'result.avi')) + subprocess.call(command, shell=platform.system() != 'Windows') diff --git a/modules/image/Image_gan/gan/wav2lip/module.py b/modules/image/Image_gan/gan/wav2lip/module.py new file mode 100644 index 00000000..f16191d8 --- /dev/null +++ b/modules/image/Image_gan/gan/wav2lip/module.py @@ -0,0 +1,101 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import copy + +import paddle +import paddlehub as hub +from paddlehub.module.module import moduleinfo, runnable, serving +import numpy as np +import cv2 + +from .model import Wav2LipPredictor + + +@moduleinfo(name="wav2lip", type="CV/generation", author="paddlepaddle", author_email="", summary="", version="1.0.0") +class wav2lip: + def __init__(self): + self.pretrained_model = os.path.join(self.directory, "wav2lip_hq.pdparams") + + self.network = Wav2LipPredictor( + checkpoint_path=self.pretrained_model, + static=False, + fps=25, + pads=[0, 10, 0, 0], + face_det_batch_size=16, + wav2lip_batch_size=128, + resize_factor=1, + crop=[0, -1, 0, -1], + box=[-1, -1, -1, -1], + rotate=False, + nosmooth=False, + face_detector='sfd', + face_enhancement=True) + + def wav2lip_transfer(self, face, audio, output_dir='./output_result/', use_gpu=False, visualization=True): + ''' + face (str): path to video/image that contains faces to use. + audio (str): path to input audio. + output_dir: the dir to save the results + use_gpu: if True, use gpu to perform the computation, otherwise cpu. + visualization: if True, save results in output_dir. + ''' + paddle.disable_static() + place = 'gpu:0' if use_gpu else 'cpu' + place = paddle.set_device(place) + self.network.run(face, audio, output_dir, visualization) + + @runnable + def run_cmd(self, argvs: list): + """ + Run as a command. + """ + self.parser = argparse.ArgumentParser( + description="Run the {} module.".format(self.name), + prog='hub run {}'.format(self.name), + usage='%(prog)s', + add_help=True) + + self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") + self.arg_config_group = self.parser.add_argument_group( + title="Config options", description="Run configuration for controlling module behavior, not required.") + self.add_module_config_arg() + self.add_module_input_arg() + self.args = self.parser.parse_args(argvs) + self.wav2lip_transfer( + face=self.args.face, + audio=self.args.audio, + output_dir=self.args.output_dir, + use_gpu=self.args.use_gpu, + visualization=self.args.visualization) + return + + def add_module_config_arg(self): + """ + Add the command config options. + """ + self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not") + + self.arg_config_group.add_argument( + '--output_dir', type=str, default='output_result', help='output directory for saving result.') + self.arg_config_group.add_argument('--visualization', type=bool, default=False, help='save results or not.') + + def add_module_input_arg(self): + """ + Add the command input options. + """ + self.arg_input_group.add_argument('--audio', type=str, help="path to input audio.") + self.arg_input_group.add_argument('--face', type=str, help="path to video/image that contains faces to use.") diff --git a/modules/image/Image_gan/gan/wav2lip/requirements.txt b/modules/image/Image_gan/gan/wav2lip/requirements.txt new file mode 100644 index 00000000..67e9bb6f --- /dev/null +++ b/modules/image/Image_gan/gan/wav2lip/requirements.txt @@ -0,0 +1 @@ +ppgan -- GitLab