diff --git a/modules/image/Image_gan/gan/wav2lip/README.md b/modules/image/Image_gan/gan/wav2lip/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5305725a65bb12a8d4cf4c0f18c655b4c07c2841
--- /dev/null
+++ b/modules/image/Image_gan/gan/wav2lip/README.md
@@ -0,0 +1,94 @@
+# wav2lip
+
+|模型名称|wav2lip|
+| :--- | :---: |
+|类别|图像 - 视频生成|
+|网络|Wav2Lip|
+|数据集|LRS2|
+|是否支持Fine-tuning|否|
+|模型大小|139MB|
+|最新更新日期|2021-12-14|
+|数据指标|-|
+
+
+## 一、模型基本信息
+
+- ### 应用效果展示
+ - 样例结果示例:
+
+
+
+ 输入图像
+
+
+
+ 输出视频
+
+
+
+
+- ### 模型介绍
+
+ - Wav2Lip实现的是视频人物根据输入音频生成与语音同步的人物唇形,使得生成的视频人物口型与输入语音同步。Wav2Lip不仅可以基于静态图像来输出与目标语音匹配的唇形同步视频,还可以直接将动态的视频进行唇形转换,输出与目标语音匹配的视频。Wav2Lip实现唇形与语音精准同步突破的关键在于,它采用了唇形同步判别器,以强制生成器持续产生准确而逼真的唇部运动。此外,它通过在鉴别器中使用多个连续帧而不是单个帧,并使用视觉质量损失(而不仅仅是对比损失)来考虑时间相关性,从而改善了视觉质量。Wav2Lip适用于任何人脸、任何语言,对任意视频都能达到很高都准确率,可以无缝地与原始视频融合,还可以用于转换动画人脸。
+
+
+
+## 二、安装
+
+- ### 1、环境依赖
+ - ffmpeg
+ - libsndfile
+- ### 2、安装
+
+ - ```shell
+ $ hub install wav2lip
+ ```
+ - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
+ | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
+
+## 三、模型API预测
+
+- ### 1、命令行预测
+
+ - ```shell
+ # Read from a file
+ $ hub run wav2lip --face "/PATH/TO/VIDEO or IMAGE" --audio "/PATH/TO/AUDIO"
+ ```
+ - 通过命令行方式人物唇形生成模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
+
+- ### 2、预测代码示例
+
+ - ```python
+ import paddlehub as hub
+
+ module = hub.Module(name="wav2lip")
+ face_input_path = "/PATH/TO/VIDEO or IMAGE"
+ audio_input_path = "/PATH/TO/AUDIO"
+ module.wav2lip_transfer(face=face_input_path, audio=audio_input_path, output_dir='./transfer_result/', use_gpu=True)
+ ```
+
+- ### 3、API
+
+ - ```python
+ def wav2lip_transfer(face, audio, output_dir ='./output_result/', use_gpu=False, visualization=True):
+ ```
+ - 人脸唇形生成API。
+
+ - **参数**
+
+ - face (str): 视频或图像文件的路径
+ - audio (str): 音频文件的路径
+ - output\_dir (str): 结果保存的路径;
+ - use\_gpu (bool): 是否使用 GPU;
+ - visualization(bool): 是否保存结果到本地文件夹
+
+
+## 四、更新历史
+
+* 1.0.0
+
+ 初始发布
+
+ - ```shell
+ $ hub install wav2lip==1.0.0
+ ```
diff --git a/modules/image/Image_gan/gan/wav2lip/model.py b/modules/image/Image_gan/gan/wav2lip/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fa32ed9c384e74cf569ef0daa09215539355d8e
--- /dev/null
+++ b/modules/image/Image_gan/gan/wav2lip/model.py
@@ -0,0 +1,259 @@
+from os import listdir, path, makedirs
+import platform
+import numpy as np
+import scipy, cv2, os, sys, argparse
+import json, subprocess, random, string
+from tqdm import tqdm
+from glob import glob
+import paddle
+from paddle.utils.download import get_weights_path_from_url
+from ppgan.faceutils import face_detection
+from ppgan.utils import audio
+from ppgan.models.generators.wav2lip import Wav2Lip
+
+WAV2LIP_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams'
+mel_step_size = 16
+
+
+class Wav2LipPredictor:
+ def __init__(self,
+ checkpoint_path=None,
+ static=False,
+ fps=25,
+ pads=[0, 10, 0, 0],
+ face_det_batch_size=16,
+ wav2lip_batch_size=128,
+ resize_factor=1,
+ crop=[0, -1, 0, -1],
+ box=[-1, -1, -1, -1],
+ rotate=False,
+ nosmooth=False,
+ face_detector='sfd',
+ face_enhancement=False):
+ self.img_size = 96
+ self.checkpoint_path = checkpoint_path
+ self.static = static
+ self.fps = fps
+ self.pads = pads
+ self.face_det_batch_size = face_det_batch_size
+ self.wav2lip_batch_size = wav2lip_batch_size
+ self.resize_factor = resize_factor
+ self.crop = crop
+ self.box = box
+ self.rotate = rotate
+ self.nosmooth = nosmooth
+ self.face_detector = face_detector
+ self.face_enhancement = face_enhancement
+ if face_enhancement:
+ from ppgan.faceutils.face_enhancement import FaceEnhancement
+ self.faceenhancer = FaceEnhancement()
+ makedirs('./temp', exist_ok=True)
+
+ def get_smoothened_boxes(self, boxes, T):
+ for i in range(len(boxes)):
+ if i + T > len(boxes):
+ window = boxes[len(boxes) - T:]
+ else:
+ window = boxes[i:i + T]
+ boxes[i] = np.mean(window, axis=0)
+ return boxes
+
+ def face_detect(self, images):
+ detector = face_detection.FaceAlignment(
+ face_detection.LandmarksType._2D, flip_input=False, face_detector=self.face_detector)
+
+ batch_size = self.face_det_batch_size
+
+ while 1:
+ predictions = []
+ try:
+ for i in tqdm(range(0, len(images), batch_size)):
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
+ except RuntimeError:
+ if batch_size == 1:
+ raise RuntimeError(
+ 'Image too big to run face detection on GPU. Please use the --resize_factor argument')
+ batch_size //= 2
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
+ continue
+ break
+
+ results = []
+ pady1, pady2, padx1, padx2 = self.pads
+ for rect, image in zip(predictions, images):
+ if rect is None:
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
+
+ y1 = max(0, rect[1] - pady1)
+ y2 = min(image.shape[0], rect[3] + pady2)
+ x1 = max(0, rect[0] - padx1)
+ x2 = min(image.shape[1], rect[2] + padx2)
+
+ results.append([x1, y1, x2, y2])
+
+ boxes = np.array(results)
+ if not self.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5)
+ results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
+
+ del detector
+ return results
+
+ def datagen(self, frames, mels):
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
+
+ if self.box[0] == -1:
+ if not self.static:
+ face_det_results = self.face_detect(frames) # BGR2RGB for CNN face detection
+ else:
+ face_det_results = self.face_detect([frames[0]])
+ else:
+ print('Using the specified bounding box instead of face detection...')
+ y1, y2, x1, x2 = self.box
+ face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
+
+ for i, m in enumerate(mels):
+ idx = 0 if self.static else i % len(frames)
+ frame_to_save = frames[idx].copy()
+ face, coords = face_det_results[idx].copy()
+
+ face = cv2.resize(face, (self.img_size, self.img_size))
+
+ img_batch.append(face)
+ mel_batch.append(m)
+ frame_batch.append(frame_to_save)
+ coords_batch.append(coords)
+
+ if len(img_batch) >= self.wav2lip_batch_size:
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
+
+ img_masked = img_batch.copy()
+ img_masked[:, self.img_size // 2:] = 0
+
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
+
+ yield img_batch, mel_batch, frame_batch, coords_batch
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
+
+ if len(img_batch) > 0:
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
+
+ img_masked = img_batch.copy()
+ img_masked[:, self.img_size // 2:] = 0
+
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
+
+ yield img_batch, mel_batch, frame_batch, coords_batch
+
+ def run(self, face, audio_seq, output_dir, visualization=True):
+ if os.path.isfile(face) and path.basename(face).split('.')[1] in ['jpg', 'png', 'jpeg']:
+ self.static = True
+
+ if not os.path.isfile(face):
+ raise ValueError('--face argument must be a valid path to video/image file')
+
+ elif path.basename(face).split('.')[1] in ['jpg', 'png', 'jpeg']:
+ full_frames = [cv2.imread(face)]
+ fps = self.fps
+
+ else:
+ video_stream = cv2.VideoCapture(face)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+
+ print('Reading video frames...')
+
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ if self.resize_factor > 1:
+ frame = cv2.resize(frame,
+ (frame.shape[1] // self.resize_factor, frame.shape[0] // self.resize_factor))
+
+ if self.rotate:
+ frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
+
+ y1, y2, x1, x2 = self.crop
+ if x2 == -1: x2 = frame.shape[1]
+ if y2 == -1: y2 = frame.shape[0]
+
+ frame = frame[y1:y2, x1:x2]
+
+ full_frames.append(frame)
+
+ print("Number of frames available for inference: " + str(len(full_frames)))
+
+ if not audio_seq.endswith('.wav'):
+ print('Extracting raw audio...')
+ command = 'ffmpeg -y -i {} -strict -2 {}'.format(audio_seq, 'temp/temp.wav')
+
+ subprocess.call(command, shell=True)
+ audio_seq = 'temp/temp.wav'
+
+ wav = audio.load_wav(audio_seq, 16000)
+ mel = audio.melspectrogram(wav)
+ if np.isnan(mel.reshape(-1)).sum() > 0:
+ raise ValueError(
+ 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
+
+ mel_chunks = []
+ mel_idx_multiplier = 80. / fps
+ i = 0
+ while 1:
+ start_idx = int(i * mel_idx_multiplier)
+ if start_idx + mel_step_size > len(mel[0]):
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
+ break
+ mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size])
+ i += 1
+
+ print("Length of mel chunks: {}".format(len(mel_chunks)))
+
+ full_frames = full_frames[:len(mel_chunks)]
+
+ batch_size = self.wav2lip_batch_size
+ gen = self.datagen(full_frames.copy(), mel_chunks)
+
+ model = Wav2Lip()
+ if self.checkpoint_path is None:
+ model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL)
+ weights = paddle.load(model_weights_path)
+ else:
+ weights = paddle.load(self.checkpoint_path)
+ model.load_dict(weights)
+ model.eval()
+ print("Model loaded")
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(
+ tqdm(gen, total=int(np.ceil(float(len(mel_chunks)) / batch_size)))):
+ if i == 0:
+
+ frame_h, frame_w = full_frames[0].shape[:-1]
+ out = cv2.VideoWriter('temp/result.avi', cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
+
+ img_batch = paddle.to_tensor(np.transpose(img_batch, (0, 3, 1, 2))).astype('float32')
+ mel_batch = paddle.to_tensor(np.transpose(mel_batch, (0, 3, 1, 2))).astype('float32')
+
+ with paddle.no_grad():
+ pred = model(mel_batch, img_batch)
+
+ pred = pred.numpy().transpose(0, 2, 3, 1) * 255.
+
+ for p, f, c in zip(pred, frames, coords):
+ y1, y2, x1, x2 = c
+ if self.face_enhancement:
+ p = self.faceenhancer.enhance_from_image(p)
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
+
+ f[y1:y2, x1:x2] = p
+ out.write(f)
+
+ out.release()
+ os.makedirs(output_dir, exist_ok=True)
+ if visualization:
+ command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_seq, 'temp/result.avi',
+ os.path.join(output_dir, 'result.avi'))
+ subprocess.call(command, shell=platform.system() != 'Windows')
diff --git a/modules/image/Image_gan/gan/wav2lip/module.py b/modules/image/Image_gan/gan/wav2lip/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..f16191d8984e33f38246e7985a8bb3f7f2aa74b0
--- /dev/null
+++ b/modules/image/Image_gan/gan/wav2lip/module.py
@@ -0,0 +1,101 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import copy
+
+import paddle
+import paddlehub as hub
+from paddlehub.module.module import moduleinfo, runnable, serving
+import numpy as np
+import cv2
+
+from .model import Wav2LipPredictor
+
+
+@moduleinfo(name="wav2lip", type="CV/generation", author="paddlepaddle", author_email="", summary="", version="1.0.0")
+class wav2lip:
+ def __init__(self):
+ self.pretrained_model = os.path.join(self.directory, "wav2lip_hq.pdparams")
+
+ self.network = Wav2LipPredictor(
+ checkpoint_path=self.pretrained_model,
+ static=False,
+ fps=25,
+ pads=[0, 10, 0, 0],
+ face_det_batch_size=16,
+ wav2lip_batch_size=128,
+ resize_factor=1,
+ crop=[0, -1, 0, -1],
+ box=[-1, -1, -1, -1],
+ rotate=False,
+ nosmooth=False,
+ face_detector='sfd',
+ face_enhancement=True)
+
+ def wav2lip_transfer(self, face, audio, output_dir='./output_result/', use_gpu=False, visualization=True):
+ '''
+ face (str): path to video/image that contains faces to use.
+ audio (str): path to input audio.
+ output_dir: the dir to save the results
+ use_gpu: if True, use gpu to perform the computation, otherwise cpu.
+ visualization: if True, save results in output_dir.
+ '''
+ paddle.disable_static()
+ place = 'gpu:0' if use_gpu else 'cpu'
+ place = paddle.set_device(place)
+ self.network.run(face, audio, output_dir, visualization)
+
+ @runnable
+ def run_cmd(self, argvs: list):
+ """
+ Run as a command.
+ """
+ self.parser = argparse.ArgumentParser(
+ description="Run the {} module.".format(self.name),
+ prog='hub run {}'.format(self.name),
+ usage='%(prog)s',
+ add_help=True)
+
+ self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
+ self.arg_config_group = self.parser.add_argument_group(
+ title="Config options", description="Run configuration for controlling module behavior, not required.")
+ self.add_module_config_arg()
+ self.add_module_input_arg()
+ self.args = self.parser.parse_args(argvs)
+ self.wav2lip_transfer(
+ face=self.args.face,
+ audio=self.args.audio,
+ output_dir=self.args.output_dir,
+ use_gpu=self.args.use_gpu,
+ visualization=self.args.visualization)
+ return
+
+ def add_module_config_arg(self):
+ """
+ Add the command config options.
+ """
+ self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
+
+ self.arg_config_group.add_argument(
+ '--output_dir', type=str, default='output_result', help='output directory for saving result.')
+ self.arg_config_group.add_argument('--visualization', type=bool, default=False, help='save results or not.')
+
+ def add_module_input_arg(self):
+ """
+ Add the command input options.
+ """
+ self.arg_input_group.add_argument('--audio', type=str, help="path to input audio.")
+ self.arg_input_group.add_argument('--face', type=str, help="path to video/image that contains faces to use.")
diff --git a/modules/image/Image_gan/gan/wav2lip/requirements.txt b/modules/image/Image_gan/gan/wav2lip/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..67e9bb6fa840355e9ed0d44b7134850f1fe22fe1
--- /dev/null
+++ b/modules/image/Image_gan/gan/wav2lip/requirements.txt
@@ -0,0 +1 @@
+ppgan