From edd6211310c9e5beedb9816a22159499df418ae2 Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Fri, 15 Jan 2021 23:26:01 +0800 Subject: [PATCH] add wav2lip training code (#142) * add wav2lip trainning code --- applications/tools/wav2lip.py | 108 ++++++++ configs/wav2lip.yaml | 63 +++++ configs/wav2lip_hq.yaml | 71 +++++ lsr2_preprocess.py | 129 +++++++++ ppgan/apps/__init__.py | 1 + ppgan/apps/wav2lip_predictor.py | 254 ++++++++++++++++++ ppgan/datasets/__init__.py | 1 + ppgan/datasets/builder.py | 16 +- ppgan/datasets/wav2lip_dataset.py | 201 ++++++++++++++ ppgan/engine/trainer.py | 15 +- ppgan/faceutils/__init__.py | 1 + ppgan/models/__init__.py | 2 + ppgan/models/discriminators/__init__.py | 2 + ppgan/models/discriminators/syncnet.py | 4 + .../discriminators/wav2lip_disc_qual.py | 83 ++++++ ppgan/models/generators/wav2lip.py | 86 +----- ppgan/models/wav2lip_hq_model.py | 212 +++++++++++++++ ppgan/models/wav2lip_model.py | 148 ++++++++++ ppgan/modules/conv.py | 2 +- ppgan/modules/init.py | 16 +- ppgan/utils/audio.py | 184 +++++++++++++ ppgan/utils/audio_config.py | 28 ++ tools/main.py | 1 - 23 files changed, 1531 insertions(+), 97 deletions(-) create mode 100644 applications/tools/wav2lip.py create mode 100644 configs/wav2lip.yaml create mode 100644 configs/wav2lip_hq.yaml create mode 100644 lsr2_preprocess.py create mode 100644 ppgan/apps/wav2lip_predictor.py create mode 100644 ppgan/datasets/wav2lip_dataset.py create mode 100644 ppgan/models/discriminators/wav2lip_disc_qual.py create mode 100644 ppgan/models/wav2lip_hq_model.py create mode 100644 ppgan/models/wav2lip_model.py create mode 100644 ppgan/utils/audio.py create mode 100644 ppgan/utils/audio_config.py diff --git a/applications/tools/wav2lip.py b/applications/tools/wav2lip.py new file mode 100644 index 0000000..97def54 --- /dev/null +++ b/applications/tools/wav2lip.py @@ -0,0 +1,108 @@ +import argparse + +import paddle +from ppgan.apps.wav2lip_predictor import Wav2LipPredictor + +parser = argparse.ArgumentParser( + description= + 'Inference code to lip-sync videos in the wild using Wav2Lip models') + +parser.add_argument('--checkpoint_path', + type=str, + help='Name of saved checkpoint to load weights from', + required=True) + +parser.add_argument('--face', + type=str, + help='Filepath of video/image that contains faces to use', + required=True) +parser.add_argument( + '--audio', + type=str, + help='Filepath of video/audio file to use as raw audio source', + required=True) +parser.add_argument('--outfile', + type=str, + help='Video path to save result. See default for an e.g.', + default='results/result_voice.mp4') + +parser.add_argument( + '--static', + type=bool, + help='If True, then use only first video frame for inference', + default=False) +parser.add_argument( + '--fps', + type=float, + help='Can be specified only if input is a static image (default: 25)', + default=25., + required=False) + +parser.add_argument( + '--pads', + nargs='+', + type=int, + default=[0, 10, 0, 0], + help= + 'Padding (top, bottom, left, right). Please adjust to include chin at least' +) + +parser.add_argument('--face_det_batch_size', + type=int, + help='Batch size for face detection', + default=16) +parser.add_argument('--wav2lip_batch_size', + type=int, + help='Batch size for Wav2Lip model(s)', + default=128) + +parser.add_argument( + '--resize_factor', + default=1, + type=int, + help= + 'Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p' +) + +parser.add_argument( + '--crop', + nargs='+', + type=int, + default=[0, -1, 0, -1], + help= + 'Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. ' + 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width' +) + +parser.add_argument( + '--box', + nargs='+', + type=int, + default=[-1, -1, -1, -1], + help= + 'Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.' + 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).' +) + +parser.add_argument( + '--rotate', + default=False, + action='store_true', + help= + 'Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.' + 'Use if you get a flipped result, despite feeding a normal looking video') + +parser.add_argument( + '--nosmooth', + default=False, + action='store_true', + help='Prevent smoothing face detections over a short temporal window') +parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") + +if __name__ == "__main__": + args = parser.parse_args() + if args.cpu: + paddle.set_device('cpu') + + predictor = Wav2LipPredictor(args) + predictor.run() diff --git a/configs/wav2lip.yaml b/configs/wav2lip.yaml new file mode 100644 index 0000000..fe180be --- /dev/null +++ b/configs/wav2lip.yaml @@ -0,0 +1,63 @@ +total_iters: 200000000 +output_dir: output +checkpoints_dir: checkpoints + +model: + name: Wav2LipModel + syncnet_wt: 0.0 + max_eval_steps: 700 + generator: + name: Wav2Lip + discriminator: + name: SyncNetColor + +dataset: + train: + name: Wav2LipDataset + dataroot: data/lrs2_preprocessed + filelists_path: ./ + img_size: 96 + split: train + batch_size: 8 + num_workers: 4 + use_shared_memory: False + test: + name: Wav2LipDataset + dataroot: data/lrs2_preprocessed + filelists_path: ./ + img_size: 96 + split: val + batch_size: 16 + num_workers: 4 + use_shared_memory: False + +optimizer: + optimizer_G: + name: Adam + net_names: + - netG + beta1: 0.5 + optimizer_D: + name: Adam + net_names: + - netD + beta1: 0.5 + +validate: + interval: 3000 + save_img: false + +lr_scheduler: + name: LinearDecay + learning_rate: 0.0001 + start_epoch: 2000000 + decay_epochs: 2000000 + # will get from real dataset + iters_per_epoch: 1 + +log_config: + interval: 10 + visiual_interval: 500 + +snapshot_config: + interval: 3000 diff --git a/configs/wav2lip_hq.yaml b/configs/wav2lip_hq.yaml new file mode 100644 index 0000000..a6a4f1c --- /dev/null +++ b/configs/wav2lip_hq.yaml @@ -0,0 +1,71 @@ +total_iters: 200000000 +output_dir: output_hq +checkpoints_dir: checkpoints_hq + +model: + name: Wav2LipModelHq + syncnet_wt: 0. + disc_wt: 0.07 + max_eval_steps: 700 + generator: + name: Wav2Lip + discriminator_sync: + name: SyncNetColor + discriminator_hq: + name: Wav2LipDiscQual + +dataset: + train: + name: Wav2LipDataset + dataroot: data/lrs2_preprocessed + filelists_path: ./ + img_size: 96 + split: train + batch_size: 8 + num_workers: 0 + use_shared_memory: False + test: + name: Wav2LipDataset + dataroot: data/lrs2_preprocessed + filelists_path: ./ + img_size: 96 + split: val + batch_size: 16 + num_workers: 0 + use_shared_memory: False + +optimizer: + optimizer_G: + name: Adam + net_names: + - netG + beta1: 0.5 + optimizer_DS: + name: Adam + net_names: + - netDS + beta1: 0.5 + optimizer_DH: + name: Adam + net_names: + - netDH + beta1: 0.5 + +validate: + interval: 3000 + save_img: false + +lr_scheduler: + name: LinearDecay + learning_rate: 0.0001 + start_epoch: 2000000 + decay_epochs: 2000000 + # will get from real dataset + iters_per_epoch: 1 + +log_config: + interval: 10 + visiual_interval: 500 + +snapshot_config: + interval: 3000 diff --git a/lsr2_preprocess.py b/lsr2_preprocess.py new file mode 100644 index 0000000..aa6e1a1 --- /dev/null +++ b/lsr2_preprocess.py @@ -0,0 +1,129 @@ +import sys + +if sys.version_info[0] < 3 and sys.version_info[1] < 2: + raise Exception("Must be using >= Python 3.2") + +from os import listdir, path + +import multiprocessing as mp +from concurrent.futures import ThreadPoolExecutor, as_completed +import numpy as np +import argparse, os, cv2, traceback, subprocess +from tqdm import tqdm +from glob import glob + +from ppgan.utils import audio +from ppgan.faceutils import face_detection + +parser = argparse.ArgumentParser() + +parser.add_argument('--ngpu', + help='Number of GPUs across which to run in parallel', + default=1, + type=int) +parser.add_argument('--batch_size', + help='Single GPU Face detection batch size', + default=32, + type=int) +parser.add_argument("--data_root", + help="Root folder of the LRS2 dataset", + required=True) +parser.add_argument("--preprocessed_root", + help="Root folder of the preprocessed dataset", + required=True) + +args = parser.parse_args() + +fa = [ + face_detection.FaceAlignment(face_detection.LandmarksType._2D, + flip_input=False) +] + +template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}' +# template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}' + + +def process_video_file(vfile, args, gpu_id): + video_stream = cv2.VideoCapture(vfile) + + frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + frames.append(frame) + + vidname = os.path.basename(vfile).split('.')[0] + dirname = vfile.split('/')[-2] + + fulldir = path.join(args.preprocessed_root, dirname, vidname) + os.makedirs(fulldir, exist_ok=True) + + batches = [ + frames[i:i + args.batch_size] + for i in range(0, len(frames), args.batch_size) + ] + + i = -1 + for fb in batches: + preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb)) + + for j, f in enumerate(preds): + i += 1 + if f is None: + continue + + x1, y1, x2, y2 = f + cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, + x1:x2]) + + +def process_audio_file(vfile, args): + vidname = os.path.basename(vfile).split('.')[0] + dirname = vfile.split('/')[-2] + + fulldir = path.join(args.preprocessed_root, dirname, vidname) + os.makedirs(fulldir, exist_ok=True) + + wavpath = path.join(fulldir, 'audio.wav') + + command = template.format(vfile, wavpath) + subprocess.call(command, shell=True) + + +def mp_handler(job): + vfile, args, gpu_id = job + try: + process_video_file(vfile, args, gpu_id) + except KeyboardInterrupt: + exit(0) + except: + traceback.print_exc() + + +def main(args): + print('Started processing for {} with {} GPUs'.format( + args.data_root, args.ngpu)) + + filelist = glob(path.join(args.data_root, '*/*.mp4')) + + jobs = [(vfile, args, i % args.ngpu) for i, vfile in enumerate(filelist)] + p = ThreadPoolExecutor(args.ngpu) + futures = [p.submit(mp_handler, j) for j in jobs] + _ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))] + + print('Dumping audios...') + + for vfile in tqdm(filelist): + try: + process_audio_file(vfile, args) + except KeyboardInterrupt: + exit(0) + except: + traceback.print_exc() + continue + + +if __name__ == '__main__': + main(args) diff --git a/ppgan/apps/__init__.py b/ppgan/apps/__init__.py index 486851a..647a212 100644 --- a/ppgan/apps/__init__.py +++ b/ppgan/apps/__init__.py @@ -24,3 +24,4 @@ from .midas_predictor import MiDaSPredictor from .photo2cartoon_predictor import Photo2CartoonPredictor from .styleganv2_predictor import StyleGANv2Predictor from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor +from .wav2lip_predictor import Wav2LipPredictor diff --git a/ppgan/apps/wav2lip_predictor.py b/ppgan/apps/wav2lip_predictor.py new file mode 100644 index 0000000..063c230 --- /dev/null +++ b/ppgan/apps/wav2lip_predictor.py @@ -0,0 +1,254 @@ +from os import listdir, path, makedirs +import platform +import numpy as np +import scipy, cv2, os, sys, argparse +import json, subprocess, random, string +from tqdm import tqdm +from glob import glob +import paddle +from ppgan.faceutils import face_detection +from ppgan.utils import audio +from ppgan.models.generators.wav2lip import Wav2Lip +from .base_predictor import BasePredictor + +mel_step_size = 16 + + +class Wav2LipPredictor(BasePredictor): + def __init__(self, args): + self.args = args + if os.path.isfile(self.args.face) and self.args.face.split('.')[1] in [ + 'jpg', 'png', 'jpeg' + ]: + self.args.static = True + self.img_size = 96 + makedirs('./temp', exist_ok=True) + + def get_smoothened_boxes(self, boxes, T): + for i in range(len(boxes)): + if i + T > len(boxes): + window = boxes[len(boxes) - T:] + else: + window = boxes[i:i + T] + boxes[i] = np.mean(window, axis=0) + return boxes + + def face_detect(self, images): + detector = face_detection.FaceAlignment( + face_detection.LandmarksType._2D, flip_input=False) + + batch_size = self.args.face_det_batch_size + + while 1: + predictions = [] + try: + for i in tqdm(range(0, len(images), batch_size)): + predictions.extend( + detector.get_detections_for_batch( + np.array(images[i:i + batch_size]))) + except RuntimeError: + if batch_size == 1: + raise RuntimeError( + 'Image too big to run face detection on GPU. Please use the --resize_factor argument' + ) + batch_size //= 2 + print('Recovering from OOM error; New batch size: {}'.format( + batch_size)) + continue + break + + results = [] + pady1, pady2, padx1, padx2 = self.args.pads + for rect, image in zip(predictions, images): + if rect is None: + cv2.imwrite( + 'temp/faulty_frame.jpg', + image) # check this frame where the face was not detected. + raise ValueError( + 'Face not detected! Ensure the video contains a face in all the frames.' + ) + + y1 = max(0, rect[1] - pady1) + y2 = min(image.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image.shape[1], rect[2] + padx2) + + results.append([x1, y1, x2, y2]) + + boxes = np.array(results) + if not self.args.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5) + results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] + for image, (x1, y1, x2, y2) in zip(images, boxes)] + + del detector + return results + + def datagen(self, frames, mels): + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if self.args.box[0] == -1: + if not self.args.static: + face_det_results = self.face_detect( + frames) # BGR2RGB for CNN face detection + else: + face_det_results = self.face_detect([frames[0]]) + else: + print( + 'Using the specified bounding box instead of face detection...') + y1, y2, x1, x2 = self.args.box + face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)] + for f in frames] + + for i, m in enumerate(mels): + idx = 0 if self.args.static else i % len(frames) + frame_to_save = frames[idx].copy() + face, coords = face_det_results[idx].copy() + + face = cv2.resize(face, (self.img_size, self.img_size)) + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= self.args.wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray( + mel_batch) + + img_masked = img_batch.copy() + img_masked[:, self.img_size // 2:] = 0 + + img_batch = np.concatenate( + (img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape( + mel_batch, + [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, self.img_size // 2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape( + mel_batch, + [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + + def run(self): + if not os.path.isfile(self.args.face): + raise ValueError( + '--face argument must be a valid path to video/image file') + + elif self.args.face.split('.')[1] in ['jpg', 'png', 'jpeg']: + full_frames = [cv2.imread(self.args.face)] + fps = self.args.fps + + else: + video_stream = cv2.VideoCapture(self.args.face) + fps = video_stream.get(cv2.CAP_PROP_FPS) + + print('Reading video frames...') + + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + if self.args.resize_factor > 1: + frame = cv2.resize( + frame, (frame.shape[1] // self.args.resize_factor, + frame.shape[0] // self.args.resize_factor)) + + if self.args.rotate: + frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) + + y1, y2, x1, x2 = self.args.crop + if x2 == -1: x2 = frame.shape[1] + if y2 == -1: y2 = frame.shape[0] + + frame = frame[y1:y2, x1:x2] + + full_frames.append(frame) + + print("Number of frames available for inference: " + + str(len(full_frames))) + + if not self.args.audio.endswith('.wav'): + print('Extracting raw audio...') + command = 'ffmpeg -y -i {} -strict -2 {}'.format( + self.args.audio, 'temp/temp.wav') + + subprocess.call(command, shell=True) + self.args.audio = 'temp/temp.wav' + + wav = audio.load_wav(self.args.audio, 16000) + mel = audio.melspectrogram(wav) + print(mel.shape) + + if np.isnan(mel.reshape(-1)).sum() > 0: + raise ValueError( + 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again' + ) + + mel_chunks = [] + mel_idx_multiplier = 80. / fps + i = 0 + while 1: + start_idx = int(i * mel_idx_multiplier) + if start_idx + mel_step_size > len(mel[0]): + mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) + break + mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size]) + i += 1 + + print("Length of mel chunks: {}".format(len(mel_chunks))) + + full_frames = full_frames[:len(mel_chunks)] + + batch_size = self.args.wav2lip_batch_size + gen = self.datagen(full_frames.copy(), mel_chunks) + + model = Wav2Lip() + weights = paddle.load(self.args.checkpoint_path) + model.load_dict(weights) + model.eval() + print("Model loaded") + for i, (img_batch, mel_batch, frames, coords) in enumerate( + tqdm(gen, + total=int(np.ceil(float(len(mel_chunks)) / batch_size)))): + if i == 0: + + frame_h, frame_w = full_frames[0].shape[:-1] + out = cv2.VideoWriter('temp/result.avi', + cv2.VideoWriter_fourcc(*'DIVX'), fps, + (frame_w, frame_h)) + + img_batch = paddle.to_tensor(np.transpose( + img_batch, (0, 3, 1, 2))).astype('float32') + mel_batch = paddle.to_tensor(np.transpose( + mel_batch, (0, 3, 1, 2))).astype('float32') + + with paddle.no_grad(): + pred = model(mel_batch, img_batch) + + pred = pred.numpy().transpose(0, 2, 3, 1) * 255. + + for p, f, c in zip(pred, frames, coords): + y1, y2, x1, x2 = c + p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) + + f[y1:y2, x1:x2] = p + out.write(f) + + out.release() + + command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format( + self.args.audio, 'temp/result.avi', self.args.outfile) + subprocess.call(command, shell=platform.system() != 'Windows') diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index 59d6d4c..4761233 100644 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -19,3 +19,4 @@ from .base_sr_dataset import SRDataset from .makeup_dataset import MakeupDataset from .common_vision_dataset import CommonVisionDataset from .animeganv2_dataset import AnimeGANV2Dataset +from .wav2lip_dataset import Wav2LipDataset diff --git a/ppgan/datasets/builder.py b/ppgan/datasets/builder.py index 8fe6c7e..da582bb 100644 --- a/ppgan/datasets/builder.py +++ b/ppgan/datasets/builder.py @@ -31,7 +31,6 @@ class DictDataset(paddle.io.Dataset): self.tensor_keys_set = set() self.non_tensor_keys_set = set() self.non_tensor_dict = Manager().dict() - single_item = dataset[0] self.keys = single_item.keys() @@ -71,6 +70,7 @@ class DictDataLoader(): batch_size, is_train, num_workers=4, + use_shared_memory=True, distributed=True): self.dataset = DictDataset(dataset) @@ -85,10 +85,12 @@ class DictDataLoader(): shuffle=True if is_train else False, drop_last=True if is_train else False) - self.dataloader = paddle.io.DataLoader(self.dataset, - batch_sampler=sampler, - places=place, - num_workers=num_workers) + self.dataloader = paddle.io.DataLoader( + self.dataset, + batch_sampler=sampler, + places=place, + num_workers=num_workers, + use_shared_memory=use_shared_memory) else: self.dataloader = paddle.io.DataLoader( self.dataset, @@ -96,6 +98,7 @@ class DictDataLoader(): shuffle=True if is_train else False, drop_last=True if is_train else False, places=place, + use_shared_memory=False, num_workers=num_workers) self.batch_size = batch_size @@ -137,15 +140,16 @@ def build_dataloader(cfg, is_train=True, distributed=True): batch_size = cfg_.pop('batch_size', 1) num_workers = cfg_.pop('num_workers', 0) + use_shared_memory = cfg_.pop('use_shared_memory', True) name = cfg_.pop('name') dataset = DATASETS.get(name)(**cfg_) - dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers, + use_shared_memory=use_shared_memory, distributed=distributed) return dataloader diff --git a/ppgan/datasets/wav2lip_dataset.py b/ppgan/datasets/wav2lip_dataset.py new file mode 100644 index 0000000..fde1fb6 --- /dev/null +++ b/ppgan/datasets/wav2lip_dataset.py @@ -0,0 +1,201 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import random +import os.path +import numpy as np +from PIL import Image +from glob import glob +from os.path import dirname, join, basename, isfile +from ppgan.utils import audio +from ppgan.utils.audio_config import get_audio_config +import numpy as np + +import paddle +from .builder import DATASETS + + +def get_image_list(data_root, split): + filelist = [] + + with open('filelists/{}.txt'.format(split)) as f: + for line in f: + line = line.strip() + if ' ' in line: line = line.split()[0] + filelist.append(os.path.join(data_root, line)) + + return filelist + + +syncnet_T = 5 +syncnet_mel_step_size = 16 +audio_cfg = get_audio_config() + + +@DATASETS.register() +class Wav2LipDataset(paddle.io.Dataset): + def __init__(self, dataroot, img_size, filelists_path, split): + """Initialize Wav2Lip dataset class. + + Args: + dataroot (str): Directory of dataset. + """ + self.image_path = dataroot + self.img_size = img_size + self.split = split + self.all_videos = get_image_list(self.image_path, self.split) + + def get_frame_id(self, frame): + return int(basename(frame).split('.')[0]) + + def get_window(self, start_frame): + start_id = self.get_frame_id(start_frame) + vidname = dirname(start_frame) + + window_fnames = [] + for frame_id in range(start_id, start_id + syncnet_T): + frame = join(vidname, '{}.jpg'.format(frame_id)) + if not isfile(frame): + return None + window_fnames.append(frame) + return window_fnames + + def read_window(self, window_fnames): + if window_fnames is None: return None + window = [] + for fname in window_fnames: + img = cv2.imread(fname) + if img is None: + return None + try: + img = cv2.resize(img, (self.img_size, self.img_size)) + except Exception as e: + return None + + window.append(img) + + return window + + def crop_audio_window(self, spec, start_frame): + if type(start_frame) == int: + start_frame_num = start_frame + else: + start_frame_num = self.get_frame_id( + start_frame) # 0-indexing ---> 1-indexing + start_idx = int(80. * (start_frame_num / float(audio_cfg["fps"]))) + + end_idx = start_idx + syncnet_mel_step_size + + return spec[start_idx:end_idx, :] + + def get_segmented_mels(self, spec, start_frame): + mels = [] + assert syncnet_T == 5 + start_frame_num = self.get_frame_id( + start_frame) + 1 # 0-indexing ---> 1-indexing + if start_frame_num - 2 < 0: return None + for i in range(start_frame_num, start_frame_num + syncnet_T): + m = self.crop_audio_window(spec, i - 2) + if m.shape[0] != syncnet_mel_step_size: + return None + mels.append(m.T) + + mels = np.asarray(mels) + + return mels + + def prepare_window(self, window): + # 3 x T x H x W + x = np.asarray(window) / 255. + x = np.transpose(x, (3, 0, 1, 2)) + + return x + + def __len__(self): + return len(self.all_videos) + + def __getitem__(self, idx): + while 1: + idx = random.randint(0, len(self.all_videos) - 1) + vidname = self.all_videos[idx] + img_names = list(glob(join(vidname, '*.jpg'))) + if len(img_names) <= 3 * syncnet_T: + continue + + img_name = random.choice(img_names) + wrong_img_name = random.choice(img_names) + while wrong_img_name == img_name: + wrong_img_name = random.choice(img_names) + + window_fnames = self.get_window(img_name) + wrong_window_fnames = self.get_window(wrong_img_name) + if window_fnames is None or wrong_window_fnames is None: + continue + + window = self.read_window(window_fnames) + if window is None: + continue + + wrong_window = self.read_window(wrong_window_fnames) + if wrong_window is None: + continue + + try: + wavpath = join(vidname, "audio.wav") + wav = audio.load_wav(wavpath, audio_cfg["sample_rate"]) + + orig_mel = audio.melspectrogram(wav).T + except Exception as e: + continue + + mel = self.crop_audio_window(orig_mel.copy(), img_name) + + if (mel.shape[0] != syncnet_mel_step_size): + continue + + indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name) + if indiv_mels is None: continue + + window = self.prepare_window(window) + y = window.copy() + window[:, :, window.shape[2] // 2:] = 0. + + wrong_window = self.prepare_window(wrong_window) + x = np.concatenate([window, wrong_window], axis=0) + + x = np.float32(x) + mel = np.transpose(mel) + mel = np.expand_dims(mel, 0) + indiv_mels = np.expand_dims(indiv_mels, 1) + #np.random.seed(200) + #x = np.random.rand(*x.shape).astype('float32') + #np.random.seed(200) + #mel = np.random.rand(*mel.shape) + #np.random.seed(200) + #indiv_mels = np.random.rand(*indiv_mels.shape) + #np.random.seed(200) + #y = np.random.rand(*y.shape) + + return { + 'x': x, + 'indiv_mels': np.float32(indiv_mels), + 'mel': np.float32(mel), + 'y': np.float32(y) + } + + def __len__(self): + """Return the total number of images in the dataset. + """ + return len(self.all_videos) diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index ab75cb0..18c2772 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -107,6 +107,7 @@ class Trainer: # base config self.output_dir = cfg.output_dir + self.max_eval_steps = cfg.model.get('max_eval_steps', None) self.epochs = cfg.get('epochs', None) if self.epochs: self.total_iters = self.epochs * self.iters_per_epoch @@ -194,15 +195,16 @@ class Trainer: self.test_dataloader = build_dataloader(self.cfg.dataset.test, is_train=False, distributed=False) + iter_loader = IterLoader(self.test_dataloader) + if self.max_eval_steps is None: + self.max_eval_steps = len(self.test_dataloader) if self.metrics: for metric in self.metrics.values(): metric.reset() - # data[0]: img, data[1]: img path index - # test batch size must be 1 - for i, data in enumerate(self.test_dataloader): - + for i in range(self.max_eval_steps): + data = next(iter_loader) self.model.setup_input(data) self.model.test_iter(metrics=self.metrics) @@ -236,7 +238,7 @@ class Trainer: if i % self.log_interval == 0: self.logger.info('Test iter: [%d/%d]' % - (i, len(self.test_dataloader))) + (i, self.max_eval_steps)) if self.metrics: for metric_name, metric in self.metrics.items(): @@ -340,6 +342,7 @@ class Trainer: else: save_filename = 'iter_%s_%s.pdparams' % (epoch, name) + os.makedirs(self.output_dir, exist_ok=True) save_path = os.path.join(self.output_dir, save_filename) for net_name, net in self.model.nets.items(): state_dicts[net_name] = net.state_dict() @@ -379,6 +382,8 @@ class Trainer: self.start_epoch = state_dicts['epoch'] + 1 self.global_steps = self.iters_per_epoch * state_dicts['epoch'] + self.current_iter = state_dicts['epoch'] + 1 + for net_name, net in self.model.nets.items(): net.set_state_dict(state_dicts[net_name]) diff --git a/ppgan/faceutils/__init__.py b/ppgan/faceutils/__init__.py index 898b224..326ddeb 100644 --- a/ppgan/faceutils/__init__.py +++ b/ppgan/faceutils/__init__.py @@ -16,3 +16,4 @@ from . import dlibutils as dlib from . import mask from . import image from . import face_segmentation +from . import face_detection diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index a675130..bb1f278 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -22,3 +22,5 @@ from .esrgan_model import ESRGAN from .ugatit_model import UGATITModel from .dc_gan_model import DCGANModel from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel +from .wav2lip_model import Wav2LipModel +from .wav2lip_hq_model import Wav2LipModelHq diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py index f7af297..cbdbc5e 100644 --- a/ppgan/models/discriminators/__init__.py +++ b/ppgan/models/discriminators/__init__.py @@ -18,3 +18,5 @@ from .discriminator_ugatit import UGATITDiscriminator from .dcdiscriminator import DCDiscriminator from .discriminator_animegan import AnimeDiscriminator from .discriminator_styleganv2 import StyleGANv2Discriminator +from .syncnet import SyncNetColor +from .wav2lip_disc_qual import Wav2LipDiscQual diff --git a/ppgan/models/discriminators/syncnet.py b/ppgan/models/discriminators/syncnet.py index d07b4fd..e597844 100644 --- a/ppgan/models/discriminators/syncnet.py +++ b/ppgan/models/discriminators/syncnet.py @@ -1,9 +1,13 @@ import paddle from paddle import nn from paddle.nn import functional as F +import sys from ...modules.conv import ConvBNRelu +#from conv import ConvBNRelu +from .builder import DISCRIMINATORS +@DISCRIMINATORS.register() class SyncNetColor(nn.Layer): def __init__(self): super(SyncNetColor, self).__init__() diff --git a/ppgan/models/discriminators/wav2lip_disc_qual.py b/ppgan/models/discriminators/wav2lip_disc_qual.py new file mode 100644 index 0000000..30dfa5d --- /dev/null +++ b/ppgan/models/discriminators/wav2lip_disc_qual.py @@ -0,0 +1,83 @@ +import paddle +from paddle import nn +from paddle.nn import functional as F + +from ...modules.conv import ConvBNRelu, NonNormConv2d, Conv2dTransposeRelu +from .builder import DISCRIMINATORS + + +@DISCRIMINATORS.register() +class Wav2LipDiscQual(nn.Layer): + def __init__(self): + super(Wav2LipDiscQual, self).__init__() + + self.face_encoder_blocks = nn.LayerList([ + nn.Sequential( + NonNormConv2d(3, 32, kernel_size=7, stride=1, + padding=3)), # 48,96 + nn.Sequential( + NonNormConv2d(32, 64, kernel_size=5, stride=(1, 2), + padding=2), # 48,48 + NonNormConv2d(64, 64, kernel_size=5, stride=1, padding=2)), + nn.Sequential( + NonNormConv2d(64, 128, kernel_size=5, stride=2, + padding=2), # 24,24 + NonNormConv2d(128, 128, kernel_size=5, stride=1, padding=2)), + nn.Sequential( + NonNormConv2d(128, 256, kernel_size=5, stride=2, + padding=2), # 12,12 + NonNormConv2d(256, 256, kernel_size=5, stride=1, padding=2)), + nn.Sequential( + NonNormConv2d(256, 512, kernel_size=3, stride=2, + padding=1), # 6,6 + NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1)), + nn.Sequential( + NonNormConv2d(512, 512, kernel_size=3, stride=2, + padding=1), # 3,3 + NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1), + ), + nn.Sequential( + NonNormConv2d(512, 512, kernel_size=3, stride=1, + padding=0), # 1, 1 + NonNormConv2d(512, 512, kernel_size=1, stride=1, padding=0)), + ]) + + self.binary_pred = nn.Sequential( + nn.Conv2D(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) + self.label_noise = .0 + + def get_lower_half(self, face_sequences): + return face_sequences[:, :, face_sequences.shape[2] // 2:] + + def to_2d(self, face_sequences): + B = face_sequences.shape[0] + face_sequences = paddle.concat( + [face_sequences[:, :, i] for i in range(face_sequences.shape[2])], + axis=0) + return face_sequences + + def perceptual_forward(self, false_face_sequences): + false_face_sequences = self.to_2d(false_face_sequences) + false_face_sequences = self.get_lower_half(false_face_sequences) + + false_feats = false_face_sequences + for f in self.face_encoder_blocks: + false_feats = f(false_feats) + + binary_pred = self.binary_pred(false_feats).reshape( + (len(false_feats), -1)) + + false_pred_loss = F.binary_cross_entropy( + binary_pred, paddle.ones((len(false_feats), 1))) + + return false_pred_loss + + def forward(self, face_sequences): + face_sequences = self.to_2d(face_sequences) + face_sequences = self.get_lower_half(face_sequences) + + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + + return paddle.reshape(self.binary_pred(x), (len(x), -1)) diff --git a/ppgan/models/generators/wav2lip.py b/ppgan/models/generators/wav2lip.py index d419d89..c4b9fc3 100644 --- a/ppgan/models/generators/wav2lip.py +++ b/ppgan/models/generators/wav2lip.py @@ -18,7 +18,6 @@ from paddle.nn import functional as F from .builder import GENERATORS from ...modules.conv import ConvBNRelu -from ...modules.conv import NonNormConv2d from ...modules.conv import Conv2dTransposeRelu @@ -27,7 +26,7 @@ class Wav2Lip(nn.Layer): def __init__(self): super(Wav2Lip, self).__init__() - self.face_encoder_blocks = [ + self.face_encoder_blocks = nn.LayerList([ nn.Sequential(ConvBNRelu(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 nn.Sequential( @@ -106,7 +105,7 @@ class Wav2Lip(nn.Layer): ConvBNRelu(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0)), - ] + ]) self.audio_encoder = nn.Sequential( ConvBNRelu(1, 32, kernel_size=3, stride=1, padding=1), @@ -159,7 +158,7 @@ class Wav2Lip(nn.Layer): ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0), ) - self.face_decoder_blocks = [ + self.face_decoder_blocks = nn.LayerList([ nn.Sequential( ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0), ), nn.Sequential( @@ -275,7 +274,7 @@ class Wav2Lip(nn.Layer): padding=1, residual=True), ), - ] # 96,96 + ]) # 96,96 self.output_block = nn.Sequential( ConvBNRelu(80, 32, kernel_size=3, stride=1, padding=1), @@ -319,85 +318,10 @@ class Wav2Lip(nn.Layer): x = self.output_block(x) if input_dim_size > 4: - x = paddle.split(x, B, axis=0) # [(B, C, H, W)] + x = paddle.split(x, int(x.shape[0] / B), axis=0) # [(B, C, H, W)] outputs = paddle.stack(x, axis=2) # (B, C, T, H, W) else: outputs = x return outputs - - -class Wav2LipDiscQual(nn.Layer): - def __init__(self): - super(Wav2LipDiscQual, self).__init__() - - self.face_encoder_blocks = [ - nn.Sequential( - NonNormConv2d(3, 32, kernel_size=7, stride=1, - padding=3)), # 48,96 - nn.Sequential( - NonNormConv2d(32, 64, kernel_size=5, stride=(1, 2), - padding=2), # 48,48 - NonNormConv2d(64, 64, kernel_size=5, stride=1, padding=2)), - nn.Sequential( - NonNormConv2d(64, 128, kernel_size=5, stride=2, - padding=2), # 24,24 - NonNormConv2d(128, 128, kernel_size=5, stride=1, padding=2)), - nn.Sequential( - NonNormConv2d(128, 256, kernel_size=5, stride=2, - padding=2), # 12,12 - NonNormConv2d(256, 256, kernel_size=5, stride=1, padding=2)), - nn.Sequential( - NonNormConv2d(256, 512, kernel_size=3, stride=2, - padding=1), # 6,6 - NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1)), - nn.Sequential( - NonNormConv2d(512, 512, kernel_size=3, stride=2, - padding=1), # 3,3 - NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1), - ), - nn.Sequential( - NonNormConv2d(512, 512, kernel_size=3, stride=1, - padding=0), # 1, 1 - NonNormConv2d(512, 512, kernel_size=1, stride=1, padding=0)), - ] - - self.binary_pred = nn.Sequential( - nn.Conv2D(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) - self.label_noise = .0 - - def get_lower_half(self, face_sequences): - return face_sequences[:, :, face_sequences.shape[2] // 2:] - - def to_2d(self, face_sequences): - B = face_sequences.shape[0] - face_sequences = paddle.concat( - [face_sequences[:, :, i] for i in range(face_sequences.shape[2])], - axis=0) - return face_sequences - - def perceptual_forward(self, false_face_sequences): - false_face_sequences = self.to_2d(false_face_sequences) - false_face_sequences = self.get_lower_half(false_face_sequences) - - false_feats = false_face_sequences - for f in self.face_encoder_blocks: - false_feats = f(false_feats) - - false_pred_loss = F.binary_cross_entropy( - paddle.reshape(self.binary_pred(false_feats), - (len(false_feats), -1)), - paddle.ones((len(false_feats), 1))) - - return false_pred_loss - - def forward(self, face_sequences): - face_sequences = self.to_2d(face_sequences) - face_sequences = self.get_lower_half(face_sequences) - - x = face_sequences - for f in self.face_encoder_blocks: - x = f(x) - - return paddle.reshape(self.binary_pred(x), (len(x), -1)) diff --git a/ppgan/models/wav2lip_hq_model.py b/ppgan/models/wav2lip_hq_model.py new file mode 100644 index 0000000..07ffdbc --- /dev/null +++ b/ppgan/models/wav2lip_hq_model.py @@ -0,0 +1,212 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn.functional as F +from .base_model import BaseModel + +from .builder import MODELS +from .generators.builder import build_generator +from .discriminators.builder import build_discriminator +from .criterions import build_criterion +from .wav2lip_model import cosine_loss, get_sync_loss + +from ..solver import build_optimizer +from ..modules.init import init_weights + +lipsync_weight_path = '/workspace/PaddleGAN/lipsync_expert.pdparams' + + +@MODELS.register() +class Wav2LipModelHq(BaseModel): + """ This class implements the Wav2lip model, Wav2lip paper: https://arxiv.org/abs/2008.10010. + + The model training requires dataset. + By default, it uses a '--netG Wav2lip' generator, + a '--netD SyncNetColor' discriminator. + """ + def __init__(self, + generator, + discriminator_sync=None, + discriminator_hq=None, + syncnet_wt=1.0, + disc_wt=0.07, + max_eval_steps=700, + is_train=True): + """Initialize the Wav2lip class. + + Parameters: + opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict + """ + super(Wav2LipModelHq, self).__init__() + self.syncnet_wt = syncnet_wt + self.disc_wt = disc_wt + self.is_train = is_train + self.eval_step = 0 + self.max_eval_steps = max_eval_steps + self.eval_sync_losses, self.eval_recon_losses = [], [] + self.eval_disc_real_losses, self.eval_disc_fake_losses = [], [] + self.eval_perceptual_losses = [] + # define networks (both generator and discriminator) + self.nets['netG'] = build_generator(generator) + init_weights(self.nets['netG'], + init_type='kaiming', + distribution='uniform') + if self.is_train: + self.nets['netDS'] = build_discriminator(discriminator_sync) + params = paddle.load(lipsync_weight_path) + self.nets['netDS'].load_dict(params) + + self.nets['netDH'] = build_discriminator(discriminator_hq) + init_weights(self.nets['netDH'], + init_type='kaiming', + distribution='uniform') + + if self.is_train: + self.recon_loss = paddle.nn.L1Loss() + + def setup_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): include the data itself and its metadata information. + """ + self.x = paddle.to_tensor(input['x']) + self.indiv_mels = paddle.to_tensor(input['indiv_mels']) + self.mel = paddle.to_tensor(input['mel']) + self.y = paddle.to_tensor(input['y']) + + def forward(self): + """Run forward pass; called by both functions and .""" + + self.g = self.nets['netG'](self.indiv_mels, self.x) + + def backward_G(self): + """Calculate GAN loss for the generator""" + if self.syncnet_wt > 0.: + self.sync_loss = get_sync_loss(self.mel, self.g, self.nets['netDS']) + else: + self.sync_loss = 0. + self.l1_loss = self.recon_loss(self.g, self.y) + + if self.disc_wt > 0.: + if isinstance(self.nets['netDH'], paddle.DataParallel + ): #paddle.fluid.dygraph.parallel.DataParallel) + self.perceptual_loss = self.nets[ + 'netDH']._layers.perceptual_forward(self.g) + else: + self.perceptual_loss = self.nets['netDH'].perceptual_forward( + self.g) + else: + self.perceptual_loss = 0. + + self.losses['sync_loss'] = self.sync_loss + self.losses['l1_loss'] = self.l1_loss + self.losses['perceptual_loss'] = self.perceptual_loss + + self.loss_G = self.syncnet_wt * self.sync_loss + self.disc_wt * self.perceptual_loss + ( + 1 - self.syncnet_wt - self.disc_wt) * self.l1_loss + self.loss_G.backward() + + def backward_D(self): + self.pred_real = self.nets['netDH'](self.y) + self.disc_real_loss = F.binary_cross_entropy( + self.pred_real, paddle.ones((len(self.pred_real), 1))) + self.losses['disc_real_loss'] = self.disc_real_loss + self.disc_real_loss.backward() + + self.pred_fake = self.nets['netDH'](self.g.detach()) + self.disc_fake_loss = F.binary_cross_entropy( + self.pred_fake, paddle.zeros((len(self.pred_fake), 1))) + self.losses['disc_fake_loss'] = self.disc_fake_loss + self.disc_fake_loss.backward() + + def train_iter(self, optimizers=None): + # forward + self.forward() + + # update G + self.set_requires_grad(self.nets['netDS'], False) + self.set_requires_grad(self.nets['netG'], True) + self.set_requires_grad(self.nets['netDH'], True) + + self.optimizers['optimizer_G'].clear_grad() + self.optimizers['optimizer_DH'].clear_grad() + self.backward_G() + self.optimizers['optimizer_G'].step() + + self.optimizers['optimizer_DH'].clear_grad() + self.backward_D() + self.optimizers['optimizer_DH'].step() + + def test_iter(self, metrics=None): + self.eval_step += 1 + self.nets['netG'].eval() + self.nets['netDH'].eval() + with paddle.no_grad(): + self.forward() + sync_loss = get_sync_loss(self.mel, self.g, self.nets['netDS']) + l1loss = self.recon_loss(self.g, self.y) + + pred_real = self.nets['netDH'](self.y) + pred_fake = self.nets['netDH'](self.g) + disc_real_loss = F.binary_cross_entropy( + pred_real, paddle.ones((len(pred_real), 1))) + disc_fake_loss = F.binary_cross_entropy( + pred_fake, paddle.zeros((len(pred_fake), 1))) + + self.eval_disc_fake_losses.append(disc_fake_loss.numpy().item()) + self.eval_disc_real_losses.append(disc_real_loss.numpy().item()) + + self.eval_sync_losses.append(sync_loss.numpy().item()) + self.eval_recon_losses.append(l1loss.numpy().item()) + + if self.disc_wt > 0.: + if isinstance(self.nets['netDH'], paddle.DataParallel + ): #paddle.fluid.dygraph.parallel.DataParallel) + perceptual_loss = self.nets[ + 'netDH']._layers.perceptual_forward( + self.g).numpy().item() + else: + perceptual_loss = self.nets['netDH'].perceptual_forward( + self.g).numpy().item() + else: + perceptual_loss = 0. + self.eval_perceptual_losses.append(perceptual_loss) + + if self.eval_step == self.max_eval_steps: + averaged_sync_loss = sum(self.eval_sync_losses) / len( + self.eval_sync_losses) + averaged_recon_loss = sum(self.eval_recon_losses) / len( + self.eval_recon_losses) + averaged_perceptual_loss = sum(self.eval_perceptual_losses) / len( + self.eval_perceptual_losses) + averaged_disc_fake_loss = sum(self.eval_disc_fake_losses) / len( + self.eval_disc_fake_losses) + averaged_disc_real_loss = sum(self.eval_disc_real_losses) / len( + self.eval_disc_real_losses) + if averaged_sync_loss < .75: + self.syncnet_wt = 0.01 + + print( + 'L1: {}, Sync loss: {}, Percep: {}, Fake: {}, Real: {}'.format( + averaged_recon_loss, averaged_sync_loss, + averaged_perceptual_loss, averaged_disc_fake_loss, + averaged_disc_real_loss)) + self.eval_sync_losses, self.eval_recon_losses = [], [] + self.eval_disc_real_losses, self.eval_disc_fake_losses = [], [] + self.eval_perceptual_losses = [] + self.eval_step = 0 + self.nets['netG'].train() + self.nets['netDH'].train() diff --git a/ppgan/models/wav2lip_model.py b/ppgan/models/wav2lip_model.py new file mode 100644 index 0000000..d5a2c36 --- /dev/null +++ b/ppgan/models/wav2lip_model.py @@ -0,0 +1,148 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from .base_model import BaseModel + +from .builder import MODELS +from .generators.builder import build_generator +from .discriminators.builder import build_discriminator + +from ..solver import build_optimizer +from ..modules.init import init_weights + +syncnet_T = 5 +syncnet_mel_step_size = 16 + + +def cosine_loss(a, v, y): + logloss = paddle.nn.BCELoss() + d = paddle.nn.functional.cosine_similarity(a, v) + loss = logloss(d.unsqueeze(1), y) + return loss + + +def get_sync_loss(mel, g, netD): + g = g[:, :, :, g.shape[3] // 2:] + g = paddle.concat([g[:, :, i] for i in range(syncnet_T)], axis=1) + a, v = netD(mel, g) + y = paddle.ones((g.shape[0], 1)).astype('float32') + return cosine_loss(a, v, y) + + +lipsync_weight_path = '/workspace/PaddleGAN/lipsync_expert.pdparams' + + +@MODELS.register() +class Wav2LipModel(BaseModel): + """ This class implements the Wav2lip model, Wav2lip paper: https://arxiv.org/abs/2008.10010. + + The model training requires dataset. + By default, it uses a '--netG Wav2lip' generator, + a '--netD SyncNetColor' discriminator. + """ + def __init__(self, + generator, + discriminator=None, + syncnet_wt=1.0, + max_eval_steps=700, + is_train=True): + """Initialize the Wav2lip class. + + Parameters: + opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict + """ + super(Wav2LipModel, self).__init__() + self.syncnet_wt = syncnet_wt + self.is_train = is_train + self.eval_step = 0 + self.max_eval_steps = max_eval_steps + self.eval_sync_losses, self.eval_recon_losses = [], [] + # define networks (both generator and discriminator) + self.nets['netG'] = build_generator(generator) + init_weights(self.nets['netG'], distribution='uniform') + if self.is_train: + self.nets['netD'] = build_discriminator(discriminator) + params = paddle.load(lipsync_weight_path) + self.nets['netD'].load_dict(params) + + if self.is_train: + self.recon_loss = paddle.nn.L1Loss() + + def setup_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): include the data itself and its metadata information. + """ + self.x = paddle.to_tensor(input['x']) + self.indiv_mels = paddle.to_tensor(input['indiv_mels']) + self.mel = paddle.to_tensor(input['mel']) + self.y = paddle.to_tensor(input['y']) + + def forward(self): + """Run forward pass; called by both functions and .""" + + self.g = self.nets['netG'](self.indiv_mels, self.x) + + def backward_G(self): + """Calculate GAN loss for the generator""" + if self.syncnet_wt > 0.: + self.sync_loss = get_sync_loss(self.mel, self.g, self.nets['netD']) + else: + self.sync_loss = 0. + self.l1_loss = self.recon_loss(self.g, self.y) + + self.losses['sync_loss'] = self.sync_loss + self.losses['l1_loss'] = self.l1_loss + + self.loss_G = self.syncnet_wt * self.sync_loss + ( + 1 - self.syncnet_wt) * self.l1_loss + self.loss_G.backward() + + def train_iter(self, optimizers=None): + # forward + self.forward() + + # update G + self.set_requires_grad(self.nets['netD'], False) + self.set_requires_grad(self.nets['netG'], True) + self.optimizers['optimizer_G'].clear_grad() + self.backward_G() + self.optimizers['optimizer_G'].step() + + def test_iter(self, metrics=None): + self.eval_step += 1 + self.nets['netG'].eval() + with paddle.no_grad(): + self.forward() + + sync_loss = get_sync_loss(self.mel, self.g, self.nets['netD']) + l1loss = self.recon_loss(self.g, self.y) + + self.eval_sync_losses.append(sync_loss.numpy().item()) + self.eval_recon_losses.append(l1loss.numpy().item()) + if self.eval_step == self.max_eval_steps: + averaged_sync_loss = sum(self.eval_sync_losses) / len( + self.eval_sync_losses) + averaged_recon_loss = sum(self.eval_recon_losses) / len( + self.eval_recon_losses) + if averaged_sync_loss < .75: + self.syncnet_wt = 0.01 + + print('L1: {}, Sync loss: {}'.format(averaged_recon_loss, + averaged_sync_loss)) + self.eval_step = 0 + self.eval_sync_losses, self.eval_recon_losses = [], [] + self.nets['netG'].train() diff --git a/ppgan/modules/conv.py b/ppgan/modules/conv.py index 2fc09c5..ec28614 100644 --- a/ppgan/modules/conv.py +++ b/ppgan/modules/conv.py @@ -40,7 +40,7 @@ class NonNormConv2d(nn.Layer): super().__init__(*args, **kwargs) self.conv_block = nn.Sequential( nn.Conv2D(cin, cout, kernel_size, stride, padding), ) - self.act = nn.LeakyReLU(0.01, inplace=True) + self.act = nn.LeakyReLU(0.01) def forward(self, x): out = self.conv_block(x) diff --git a/ppgan/modules/init.py b/ppgan/modules/init.py index 4a4cc16..91dfd06 100644 --- a/ppgan/modules/init.py +++ b/ppgan/modules/init.py @@ -281,7 +281,10 @@ def kaiming_init(layer, constant_(layer.bias, bias) -def init_weights(net, init_type='normal', init_gain=0.02): +def init_weights(net, + init_type='normal', + init_gain=0.02, + distribution='normal'): """Initialize network weights. Args: net (nn.Layer): network to be initialized @@ -297,9 +300,16 @@ def init_weights(net, init_type='normal', init_gain=0.02): if init_type == 'normal': normal_(m.weight, 0.0, init_gain) elif init_type == 'xavier': - xavier_normal_(m.weight, gain=init_gain) + if distribution == 'normal': + xavier_normal_(m.weight, gain=init_gain) + else: + xavier_uniform_(m.weight, gain=init_gain) + elif init_type == 'kaiming': - kaiming_normal_(m.weight, a=0, mode='fan_in') + if distribution == 'normal': + kaiming_normal_(m.weight, a=0, mode='fan_in') + else: + kaiming_uniform_(m.weight, a=0, mode='fan_in') else: raise NotImplementedError( 'initialization method [%s] is not implemented' % init_type) diff --git a/ppgan/utils/audio.py b/ppgan/utils/audio.py new file mode 100644 index 0000000..2863418 --- /dev/null +++ b/ppgan/utils/audio.py @@ -0,0 +1,184 @@ +import librosa +import librosa.filters +import numpy as np +from scipy import signal +from scipy.io import wavfile +from .audio_config import get_audio_config + +audio_config = get_audio_config() + + +def load_wav(path, sr): + return librosa.core.load(path, sr=sr)[0] + + +def save_wav(wav, path, sr): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + #proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + + +def save_wavenet_wav(wav, path, sr): + librosa.output.write_wav(path, wav, sr=sr) + + +def preemphasis(wav, k, preemphasize=True): + if preemphasize: + return signal.lfilter([1, -k], [1], wav) + return wav + + +def inv_preemphasis(wav, k, inv_preemphasize=True): + if inv_preemphasize: + return signal.lfilter([1], [1, -k], wav) + return wav + + +def get_hop_size(): + hop_size = audio_config.hop_size + if hop_size is None: + assert audio_config.frame_shift_ms is not None + hop_size = int(audio_config.frame_shift_ms / 1000 * + audio_config.sample_rate) + return hop_size + + +def linearspectrogram(wav): + D = _stft( + preemphasis(wav, audio_config.preemphasis, audio_config.preemphasize)) + S = _amp_to_db(np.abs(D)) - audio_config.ref_level_db + + if audio_config.signal_normalization: + return _normalize(S) + return S + + +def melspectrogram(wav): + D = _stft( + preemphasis(wav, audio_config.preemphasis, audio_config.preemphasize)) + S = _amp_to_db(_linear_to_mel(np.abs(D))) - audio_config.ref_level_db + + if audio_config.signal_normalization: + return _normalize(S) + return S + + +def _lws_processor(): + import lws + return lws.lws(audio_config.n_fft, + get_hop_size(), + fftsize=audio_config.win_size, + mode="speech") + + +def _stft(y): + if audio_config.use_lws: + return _lws_processor(audio_config).stft(y).T + else: + return librosa.stft(y=y, + n_fft=audio_config.n_fft, + hop_length=get_hop_size(), + win_length=audio_config.win_size) + + +########################################################## +#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) +def num_frames(length, fsize, fshift): + """Compute number of time frames of spectrogram + """ + pad = (fsize - fshift) + if length % fshift == 0: + M = (length + pad * 2 - fsize) // fshift + 1 + else: + M = (length + pad * 2 - fsize) // fshift + 2 + return M + + +def pad_lr(x, fsize, fshift): + """Compute left and right padding + """ + M = num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r + + +########################################################## +#Librosa correct padding +def librosa_pad_lr(x, fsize, fshift): + return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + + +# Conversions +_mel_basis = None + + +def _linear_to_mel(spectogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis() + return np.dot(_mel_basis, spectogram) + + +def _build_mel_basis(): + assert audio_config.fmax <= audio_config.sample_rate // 2 + return librosa.filters.mel(audio_config.sample_rate, + audio_config.n_fft, + n_mels=audio_config.num_mels, + fmin=audio_config.fmin, + fmax=audio_config.fmax) + + +def _amp_to_db(x): + min_level = np.exp(audio_config.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + + +def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + + +def _normalize(S): + if audio_config.allow_clipping_in_normalization: + if audio_config.symmetric_mels: + return np.clip( + (2 * audio_config.max_abs_value) * + ((S - audio_config.min_level_db) / + (-audio_config.min_level_db)) - audio_config.max_abs_value, + -audio_config.max_abs_value, audio_config.max_abs_value) + else: + return np.clip( + audio_config.max_abs_value * ((S - audio_config.min_level_db) / + (-audio_config.min_level_db)), 0, + audio_config.max_abs_value) + + assert S.max() <= 0 and S.min() - audio_config.min_level_db >= 0 + if audio_config.symmetric_mels: + return (2 * audio_config.max_abs_value) * ( + (S - audio_config.min_level_db) / + (-audio_config.min_level_db)) - audio_config.max_abs_value + else: + return audio_config.max_abs_value * ((S - audio_config.min_level_db) / + (-audio_config.min_level_db)) + + +def _denormalize(D): + if audio_config.allow_clipping_in_normalization: + if audio_config.symmetric_mels: + return (((np.clip(D, -audio_config.max_abs_value, + audio_config.max_abs_value) + + audio_config.max_abs_value) * -audio_config.min_level_db / + (2 * audio_config.max_abs_value)) + + audio_config.min_level_db) + else: + return ((np.clip(D, 0, audio_config.max_abs_value) * + -audio_config.min_level_db / audio_config.max_abs_value) + + audio_config.min_level_db) + + if audio_config.symmetric_mels: + return (((D + audio_config.max_abs_value) * -audio_config.min_level_db / + (2 * audio_config.max_abs_value)) + audio_config.min_level_db) + else: + return ((D * -audio_config.min_level_db / audio_config.max_abs_value) + + audio_config.min_level_db) diff --git a/ppgan/utils/audio_config.py b/ppgan/utils/audio_config.py new file mode 100644 index 0000000..a104499 --- /dev/null +++ b/ppgan/utils/audio_config.py @@ -0,0 +1,28 @@ +from easydict import EasyDict as edict + +_C = edict() + +_C.num_mels = 80 +_C.rescale = True +_C.rescaling_max = 0.9 +_C.use_lws = False +_C.n_fft = 800 +_C.hop_size = 200 +_C.win_size = 800 +_C.sample_rate = 16000 +_C.frame_shift_ms = None +_C.signal_normalization = True +_C.allow_clipping_in_normalization = True +_C.symmetric_mels = True +_C.max_abs_value = 4. +_C.preemphasize = True +_C.preemphasis = 0.97 +_C.min_level_db = -100 +_C.ref_level_db = 20 +_C.fmin = 55 +_C.fmax = 7600 +_C.fps = 25 + + +def get_audio_config(): + return _C diff --git a/tools/main.py b/tools/main.py index 0918d52..7d40c11 100644 --- a/tools/main.py +++ b/tools/main.py @@ -28,7 +28,6 @@ from ppgan.engine.trainer import Trainer def main(args, cfg): # init environment, include logger, dynamic graph, seed, device, train or test mode... setup(args, cfg) - # build trainer trainer = Trainer(cfg) -- GitLab