未验证 提交 edd62113 编写于 作者: L lijianshe02 提交者: GitHub

add wav2lip training code (#142)

* add wav2lip trainning code
上级 776fe801
import argparse
import paddle
from ppgan.apps.wav2lip_predictor import Wav2LipPredictor
parser = argparse.ArgumentParser(
description=
'Inference code to lip-sync videos in the wild using Wav2Lip models')
parser.add_argument('--checkpoint_path',
type=str,
help='Name of saved checkpoint to load weights from',
required=True)
parser.add_argument('--face',
type=str,
help='Filepath of video/image that contains faces to use',
required=True)
parser.add_argument(
'--audio',
type=str,
help='Filepath of video/audio file to use as raw audio source',
required=True)
parser.add_argument('--outfile',
type=str,
help='Video path to save result. See default for an e.g.',
default='results/result_voice.mp4')
parser.add_argument(
'--static',
type=bool,
help='If True, then use only first video frame for inference',
default=False)
parser.add_argument(
'--fps',
type=float,
help='Can be specified only if input is a static image (default: 25)',
default=25.,
required=False)
parser.add_argument(
'--pads',
nargs='+',
type=int,
default=[0, 10, 0, 0],
help=
'Padding (top, bottom, left, right). Please adjust to include chin at least'
)
parser.add_argument('--face_det_batch_size',
type=int,
help='Batch size for face detection',
default=16)
parser.add_argument('--wav2lip_batch_size',
type=int,
help='Batch size for Wav2Lip model(s)',
default=128)
parser.add_argument(
'--resize_factor',
default=1,
type=int,
help=
'Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p'
)
parser.add_argument(
'--crop',
nargs='+',
type=int,
default=[0, -1, 0, -1],
help=
'Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width'
)
parser.add_argument(
'--box',
nargs='+',
type=int,
default=[-1, -1, -1, -1],
help=
'Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).'
)
parser.add_argument(
'--rotate',
default=False,
action='store_true',
help=
'Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
'Use if you get a flipped result, despite feeding a normal looking video')
parser.add_argument(
'--nosmooth',
default=False,
action='store_true',
help='Prevent smoothing face detections over a short temporal window')
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
if __name__ == "__main__":
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = Wav2LipPredictor(args)
predictor.run()
total_iters: 200000000
output_dir: output
checkpoints_dir: checkpoints
model:
name: Wav2LipModel
syncnet_wt: 0.0
max_eval_steps: 700
generator:
name: Wav2Lip
discriminator:
name: SyncNetColor
dataset:
train:
name: Wav2LipDataset
dataroot: data/lrs2_preprocessed
filelists_path: ./
img_size: 96
split: train
batch_size: 8
num_workers: 4
use_shared_memory: False
test:
name: Wav2LipDataset
dataroot: data/lrs2_preprocessed
filelists_path: ./
img_size: 96
split: val
batch_size: 16
num_workers: 4
use_shared_memory: False
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_D:
name: Adam
net_names:
- netD
beta1: 0.5
validate:
interval: 3000
save_img: false
lr_scheduler:
name: LinearDecay
learning_rate: 0.0001
start_epoch: 2000000
decay_epochs: 2000000
# will get from real dataset
iters_per_epoch: 1
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 3000
total_iters: 200000000
output_dir: output_hq
checkpoints_dir: checkpoints_hq
model:
name: Wav2LipModelHq
syncnet_wt: 0.
disc_wt: 0.07
max_eval_steps: 700
generator:
name: Wav2Lip
discriminator_sync:
name: SyncNetColor
discriminator_hq:
name: Wav2LipDiscQual
dataset:
train:
name: Wav2LipDataset
dataroot: data/lrs2_preprocessed
filelists_path: ./
img_size: 96
split: train
batch_size: 8
num_workers: 0
use_shared_memory: False
test:
name: Wav2LipDataset
dataroot: data/lrs2_preprocessed
filelists_path: ./
img_size: 96
split: val
batch_size: 16
num_workers: 0
use_shared_memory: False
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_DS:
name: Adam
net_names:
- netDS
beta1: 0.5
optimizer_DH:
name: Adam
net_names:
- netDH
beta1: 0.5
validate:
interval: 3000
save_img: false
lr_scheduler:
name: LinearDecay
learning_rate: 0.0001
start_epoch: 2000000
decay_epochs: 2000000
# will get from real dataset
iters_per_epoch: 1
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 3000
import sys
if sys.version_info[0] < 3 and sys.version_info[1] < 2:
raise Exception("Must be using >= Python 3.2")
from os import listdir, path
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import argparse, os, cv2, traceback, subprocess
from tqdm import tqdm
from glob import glob
from ppgan.utils import audio
from ppgan.faceutils import face_detection
parser = argparse.ArgumentParser()
parser.add_argument('--ngpu',
help='Number of GPUs across which to run in parallel',
default=1,
type=int)
parser.add_argument('--batch_size',
help='Single GPU Face detection batch size',
default=32,
type=int)
parser.add_argument("--data_root",
help="Root folder of the LRS2 dataset",
required=True)
parser.add_argument("--preprocessed_root",
help="Root folder of the preprocessed dataset",
required=True)
args = parser.parse_args()
fa = [
face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False)
]
template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
# template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}'
def process_video_file(vfile, args, gpu_id):
video_stream = cv2.VideoCapture(vfile)
frames = []
while 1:
still_reading, frame = video_stream.read()
if not still_reading:
video_stream.release()
break
frames.append(frame)
vidname = os.path.basename(vfile).split('.')[0]
dirname = vfile.split('/')[-2]
fulldir = path.join(args.preprocessed_root, dirname, vidname)
os.makedirs(fulldir, exist_ok=True)
batches = [
frames[i:i + args.batch_size]
for i in range(0, len(frames), args.batch_size)
]
i = -1
for fb in batches:
preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))
for j, f in enumerate(preds):
i += 1
if f is None:
continue
x1, y1, x2, y2 = f
cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2,
x1:x2])
def process_audio_file(vfile, args):
vidname = os.path.basename(vfile).split('.')[0]
dirname = vfile.split('/')[-2]
fulldir = path.join(args.preprocessed_root, dirname, vidname)
os.makedirs(fulldir, exist_ok=True)
wavpath = path.join(fulldir, 'audio.wav')
command = template.format(vfile, wavpath)
subprocess.call(command, shell=True)
def mp_handler(job):
vfile, args, gpu_id = job
try:
process_video_file(vfile, args, gpu_id)
except KeyboardInterrupt:
exit(0)
except:
traceback.print_exc()
def main(args):
print('Started processing for {} with {} GPUs'.format(
args.data_root, args.ngpu))
filelist = glob(path.join(args.data_root, '*/*.mp4'))
jobs = [(vfile, args, i % args.ngpu) for i, vfile in enumerate(filelist)]
p = ThreadPoolExecutor(args.ngpu)
futures = [p.submit(mp_handler, j) for j in jobs]
_ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]
print('Dumping audios...')
for vfile in tqdm(filelist):
try:
process_audio_file(vfile, args)
except KeyboardInterrupt:
exit(0)
except:
traceback.print_exc()
continue
if __name__ == '__main__':
main(args)
...@@ -24,3 +24,4 @@ from .midas_predictor import MiDaSPredictor ...@@ -24,3 +24,4 @@ from .midas_predictor import MiDaSPredictor
from .photo2cartoon_predictor import Photo2CartoonPredictor from .photo2cartoon_predictor import Photo2CartoonPredictor
from .styleganv2_predictor import StyleGANv2Predictor from .styleganv2_predictor import StyleGANv2Predictor
from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor
from .wav2lip_predictor import Wav2LipPredictor
from os import listdir, path, makedirs
import platform
import numpy as np
import scipy, cv2, os, sys, argparse
import json, subprocess, random, string
from tqdm import tqdm
from glob import glob
import paddle
from ppgan.faceutils import face_detection
from ppgan.utils import audio
from ppgan.models.generators.wav2lip import Wav2Lip
from .base_predictor import BasePredictor
mel_step_size = 16
class Wav2LipPredictor(BasePredictor):
def __init__(self, args):
self.args = args
if os.path.isfile(self.args.face) and self.args.face.split('.')[1] in [
'jpg', 'png', 'jpeg'
]:
self.args.static = True
self.img_size = 96
makedirs('./temp', exist_ok=True)
def get_smoothened_boxes(self, boxes, T):
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T:]
else:
window = boxes[i:i + T]
boxes[i] = np.mean(window, axis=0)
return boxes
def face_detect(self, images):
detector = face_detection.FaceAlignment(
face_detection.LandmarksType._2D, flip_input=False)
batch_size = self.args.face_det_batch_size
while 1:
predictions = []
try:
for i in tqdm(range(0, len(images), batch_size)):
predictions.extend(
detector.get_detections_for_batch(
np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError(
'Image too big to run face detection on GPU. Please use the --resize_factor argument'
)
batch_size //= 2
print('Recovering from OOM error; New batch size: {}'.format(
batch_size))
continue
break
results = []
pady1, pady2, padx1, padx2 = self.args.pads
for rect, image in zip(predictions, images):
if rect is None:
cv2.imwrite(
'temp/faulty_frame.jpg',
image) # check this frame where the face was not detected.
raise ValueError(
'Face not detected! Ensure the video contains a face in all the frames.'
)
y1 = max(0, rect[1] - pady1)
y2 = min(image.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image.shape[1], rect[2] + padx2)
results.append([x1, y1, x2, y2])
boxes = np.array(results)
if not self.args.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5)
results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)]
for image, (x1, y1, x2, y2) in zip(images, boxes)]
del detector
return results
def datagen(self, frames, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
if self.args.box[0] == -1:
if not self.args.static:
face_det_results = self.face_detect(
frames) # BGR2RGB for CNN face detection
else:
face_det_results = self.face_detect([frames[0]])
else:
print(
'Using the specified bounding box instead of face detection...')
y1, y2, x1, x2 = self.args.box
face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)]
for f in frames]
for i, m in enumerate(mels):
idx = 0 if self.args.static else i % len(frames)
frame_to_save = frames[idx].copy()
face, coords = face_det_results[idx].copy()
face = cv2.resize(face, (self.img_size, self.img_size))
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= self.args.wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(
mel_batch)
img_masked = img_batch.copy()
img_masked[:, self.img_size // 2:] = 0
img_batch = np.concatenate(
(img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(
mel_batch,
[len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, self.img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(
mel_batch,
[len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
def run(self):
if not os.path.isfile(self.args.face):
raise ValueError(
'--face argument must be a valid path to video/image file')
elif self.args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
full_frames = [cv2.imread(self.args.face)]
fps = self.args.fps
else:
video_stream = cv2.VideoCapture(self.args.face)
fps = video_stream.get(cv2.CAP_PROP_FPS)
print('Reading video frames...')
full_frames = []
while 1:
still_reading, frame = video_stream.read()
if not still_reading:
video_stream.release()
break
if self.args.resize_factor > 1:
frame = cv2.resize(
frame, (frame.shape[1] // self.args.resize_factor,
frame.shape[0] // self.args.resize_factor))
if self.args.rotate:
frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
y1, y2, x1, x2 = self.args.crop
if x2 == -1: x2 = frame.shape[1]
if y2 == -1: y2 = frame.shape[0]
frame = frame[y1:y2, x1:x2]
full_frames.append(frame)
print("Number of frames available for inference: " +
str(len(full_frames)))
if not self.args.audio.endswith('.wav'):
print('Extracting raw audio...')
command = 'ffmpeg -y -i {} -strict -2 {}'.format(
self.args.audio, 'temp/temp.wav')
subprocess.call(command, shell=True)
self.args.audio = 'temp/temp.wav'
wav = audio.load_wav(self.args.audio, 16000)
mel = audio.melspectrogram(wav)
print(mel.shape)
if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError(
'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again'
)
mel_chunks = []
mel_idx_multiplier = 80. / fps
i = 0
while 1:
start_idx = int(i * mel_idx_multiplier)
if start_idx + mel_step_size > len(mel[0]):
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
break
mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size])
i += 1
print("Length of mel chunks: {}".format(len(mel_chunks)))
full_frames = full_frames[:len(mel_chunks)]
batch_size = self.args.wav2lip_batch_size
gen = self.datagen(full_frames.copy(), mel_chunks)
model = Wav2Lip()
weights = paddle.load(self.args.checkpoint_path)
model.load_dict(weights)
model.eval()
print("Model loaded")
for i, (img_batch, mel_batch, frames, coords) in enumerate(
tqdm(gen,
total=int(np.ceil(float(len(mel_chunks)) / batch_size)))):
if i == 0:
frame_h, frame_w = full_frames[0].shape[:-1]
out = cv2.VideoWriter('temp/result.avi',
cv2.VideoWriter_fourcc(*'DIVX'), fps,
(frame_w, frame_h))
img_batch = paddle.to_tensor(np.transpose(
img_batch, (0, 3, 1, 2))).astype('float32')
mel_batch = paddle.to_tensor(np.transpose(
mel_batch, (0, 3, 1, 2))).astype('float32')
with paddle.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.numpy().transpose(0, 2, 3, 1) * 255.
for p, f, c in zip(pred, frames, coords):
y1, y2, x1, x2 = c
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
f[y1:y2, x1:x2] = p
out.write(f)
out.release()
command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(
self.args.audio, 'temp/result.avi', self.args.outfile)
subprocess.call(command, shell=platform.system() != 'Windows')
...@@ -19,3 +19,4 @@ from .base_sr_dataset import SRDataset ...@@ -19,3 +19,4 @@ from .base_sr_dataset import SRDataset
from .makeup_dataset import MakeupDataset from .makeup_dataset import MakeupDataset
from .common_vision_dataset import CommonVisionDataset from .common_vision_dataset import CommonVisionDataset
from .animeganv2_dataset import AnimeGANV2Dataset from .animeganv2_dataset import AnimeGANV2Dataset
from .wav2lip_dataset import Wav2LipDataset
...@@ -31,7 +31,6 @@ class DictDataset(paddle.io.Dataset): ...@@ -31,7 +31,6 @@ class DictDataset(paddle.io.Dataset):
self.tensor_keys_set = set() self.tensor_keys_set = set()
self.non_tensor_keys_set = set() self.non_tensor_keys_set = set()
self.non_tensor_dict = Manager().dict() self.non_tensor_dict = Manager().dict()
single_item = dataset[0] single_item = dataset[0]
self.keys = single_item.keys() self.keys = single_item.keys()
...@@ -71,6 +70,7 @@ class DictDataLoader(): ...@@ -71,6 +70,7 @@ class DictDataLoader():
batch_size, batch_size,
is_train, is_train,
num_workers=4, num_workers=4,
use_shared_memory=True,
distributed=True): distributed=True):
self.dataset = DictDataset(dataset) self.dataset = DictDataset(dataset)
...@@ -85,10 +85,12 @@ class DictDataLoader(): ...@@ -85,10 +85,12 @@ class DictDataLoader():
shuffle=True if is_train else False, shuffle=True if is_train else False,
drop_last=True if is_train else False) drop_last=True if is_train else False)
self.dataloader = paddle.io.DataLoader(self.dataset, self.dataloader = paddle.io.DataLoader(
batch_sampler=sampler, self.dataset,
places=place, batch_sampler=sampler,
num_workers=num_workers) places=place,
num_workers=num_workers,
use_shared_memory=use_shared_memory)
else: else:
self.dataloader = paddle.io.DataLoader( self.dataloader = paddle.io.DataLoader(
self.dataset, self.dataset,
...@@ -96,6 +98,7 @@ class DictDataLoader(): ...@@ -96,6 +98,7 @@ class DictDataLoader():
shuffle=True if is_train else False, shuffle=True if is_train else False,
drop_last=True if is_train else False, drop_last=True if is_train else False,
places=place, places=place,
use_shared_memory=False,
num_workers=num_workers) num_workers=num_workers)
self.batch_size = batch_size self.batch_size = batch_size
...@@ -137,15 +140,16 @@ def build_dataloader(cfg, is_train=True, distributed=True): ...@@ -137,15 +140,16 @@ def build_dataloader(cfg, is_train=True, distributed=True):
batch_size = cfg_.pop('batch_size', 1) batch_size = cfg_.pop('batch_size', 1)
num_workers = cfg_.pop('num_workers', 0) num_workers = cfg_.pop('num_workers', 0)
use_shared_memory = cfg_.pop('use_shared_memory', True)
name = cfg_.pop('name') name = cfg_.pop('name')
dataset = DATASETS.get(name)(**cfg_) dataset = DATASETS.get(name)(**cfg_)
dataloader = DictDataLoader(dataset, dataloader = DictDataLoader(dataset,
batch_size, batch_size,
is_train, is_train,
num_workers, num_workers,
use_shared_memory=use_shared_memory,
distributed=distributed) distributed=distributed)
return dataloader return dataloader
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import random
import os.path
import numpy as np
from PIL import Image
from glob import glob
from os.path import dirname, join, basename, isfile
from ppgan.utils import audio
from ppgan.utils.audio_config import get_audio_config
import numpy as np
import paddle
from .builder import DATASETS
def get_image_list(data_root, split):
filelist = []
with open('filelists/{}.txt'.format(split)) as f:
for line in f:
line = line.strip()
if ' ' in line: line = line.split()[0]
filelist.append(os.path.join(data_root, line))
return filelist
syncnet_T = 5
syncnet_mel_step_size = 16
audio_cfg = get_audio_config()
@DATASETS.register()
class Wav2LipDataset(paddle.io.Dataset):
def __init__(self, dataroot, img_size, filelists_path, split):
"""Initialize Wav2Lip dataset class.
Args:
dataroot (str): Directory of dataset.
"""
self.image_path = dataroot
self.img_size = img_size
self.split = split
self.all_videos = get_image_list(self.image_path, self.split)
def get_frame_id(self, frame):
return int(basename(frame).split('.')[0])
def get_window(self, start_frame):
start_id = self.get_frame_id(start_frame)
vidname = dirname(start_frame)
window_fnames = []
for frame_id in range(start_id, start_id + syncnet_T):
frame = join(vidname, '{}.jpg'.format(frame_id))
if not isfile(frame):
return None
window_fnames.append(frame)
return window_fnames
def read_window(self, window_fnames):
if window_fnames is None: return None
window = []
for fname in window_fnames:
img = cv2.imread(fname)
if img is None:
return None
try:
img = cv2.resize(img, (self.img_size, self.img_size))
except Exception as e:
return None
window.append(img)
return window
def crop_audio_window(self, spec, start_frame):
if type(start_frame) == int:
start_frame_num = start_frame
else:
start_frame_num = self.get_frame_id(
start_frame) # 0-indexing ---> 1-indexing
start_idx = int(80. * (start_frame_num / float(audio_cfg["fps"])))
end_idx = start_idx + syncnet_mel_step_size
return spec[start_idx:end_idx, :]
def get_segmented_mels(self, spec, start_frame):
mels = []
assert syncnet_T == 5
start_frame_num = self.get_frame_id(
start_frame) + 1 # 0-indexing ---> 1-indexing
if start_frame_num - 2 < 0: return None
for i in range(start_frame_num, start_frame_num + syncnet_T):
m = self.crop_audio_window(spec, i - 2)
if m.shape[0] != syncnet_mel_step_size:
return None
mels.append(m.T)
mels = np.asarray(mels)
return mels
def prepare_window(self, window):
# 3 x T x H x W
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
return x
def __len__(self):
return len(self.all_videos)
def __getitem__(self, idx):
while 1:
idx = random.randint(0, len(self.all_videos) - 1)
vidname = self.all_videos[idx]
img_names = list(glob(join(vidname, '*.jpg')))
if len(img_names) <= 3 * syncnet_T:
continue
img_name = random.choice(img_names)
wrong_img_name = random.choice(img_names)
while wrong_img_name == img_name:
wrong_img_name = random.choice(img_names)
window_fnames = self.get_window(img_name)
wrong_window_fnames = self.get_window(wrong_img_name)
if window_fnames is None or wrong_window_fnames is None:
continue
window = self.read_window(window_fnames)
if window is None:
continue
wrong_window = self.read_window(wrong_window_fnames)
if wrong_window is None:
continue
try:
wavpath = join(vidname, "audio.wav")
wav = audio.load_wav(wavpath, audio_cfg["sample_rate"])
orig_mel = audio.melspectrogram(wav).T
except Exception as e:
continue
mel = self.crop_audio_window(orig_mel.copy(), img_name)
if (mel.shape[0] != syncnet_mel_step_size):
continue
indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
if indiv_mels is None: continue
window = self.prepare_window(window)
y = window.copy()
window[:, :, window.shape[2] // 2:] = 0.
wrong_window = self.prepare_window(wrong_window)
x = np.concatenate([window, wrong_window], axis=0)
x = np.float32(x)
mel = np.transpose(mel)
mel = np.expand_dims(mel, 0)
indiv_mels = np.expand_dims(indiv_mels, 1)
#np.random.seed(200)
#x = np.random.rand(*x.shape).astype('float32')
#np.random.seed(200)
#mel = np.random.rand(*mel.shape)
#np.random.seed(200)
#indiv_mels = np.random.rand(*indiv_mels.shape)
#np.random.seed(200)
#y = np.random.rand(*y.shape)
return {
'x': x,
'indiv_mels': np.float32(indiv_mels),
'mel': np.float32(mel),
'y': np.float32(y)
}
def __len__(self):
"""Return the total number of images in the dataset.
"""
return len(self.all_videos)
...@@ -107,6 +107,7 @@ class Trainer: ...@@ -107,6 +107,7 @@ class Trainer:
# base config # base config
self.output_dir = cfg.output_dir self.output_dir = cfg.output_dir
self.max_eval_steps = cfg.model.get('max_eval_steps', None)
self.epochs = cfg.get('epochs', None) self.epochs = cfg.get('epochs', None)
if self.epochs: if self.epochs:
self.total_iters = self.epochs * self.iters_per_epoch self.total_iters = self.epochs * self.iters_per_epoch
...@@ -194,15 +195,16 @@ class Trainer: ...@@ -194,15 +195,16 @@ class Trainer:
self.test_dataloader = build_dataloader(self.cfg.dataset.test, self.test_dataloader = build_dataloader(self.cfg.dataset.test,
is_train=False, is_train=False,
distributed=False) distributed=False)
iter_loader = IterLoader(self.test_dataloader)
if self.max_eval_steps is None:
self.max_eval_steps = len(self.test_dataloader)
if self.metrics: if self.metrics:
for metric in self.metrics.values(): for metric in self.metrics.values():
metric.reset() metric.reset()
# data[0]: img, data[1]: img path index for i in range(self.max_eval_steps):
# test batch size must be 1 data = next(iter_loader)
for i, data in enumerate(self.test_dataloader):
self.model.setup_input(data) self.model.setup_input(data)
self.model.test_iter(metrics=self.metrics) self.model.test_iter(metrics=self.metrics)
...@@ -236,7 +238,7 @@ class Trainer: ...@@ -236,7 +238,7 @@ class Trainer:
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' % self.logger.info('Test iter: [%d/%d]' %
(i, len(self.test_dataloader))) (i, self.max_eval_steps))
if self.metrics: if self.metrics:
for metric_name, metric in self.metrics.items(): for metric_name, metric in self.metrics.items():
...@@ -340,6 +342,7 @@ class Trainer: ...@@ -340,6 +342,7 @@ class Trainer:
else: else:
save_filename = 'iter_%s_%s.pdparams' % (epoch, name) save_filename = 'iter_%s_%s.pdparams' % (epoch, name)
os.makedirs(self.output_dir, exist_ok=True)
save_path = os.path.join(self.output_dir, save_filename) save_path = os.path.join(self.output_dir, save_filename)
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
state_dicts[net_name] = net.state_dict() state_dicts[net_name] = net.state_dict()
...@@ -379,6 +382,8 @@ class Trainer: ...@@ -379,6 +382,8 @@ class Trainer:
self.start_epoch = state_dicts['epoch'] + 1 self.start_epoch = state_dicts['epoch'] + 1
self.global_steps = self.iters_per_epoch * state_dicts['epoch'] self.global_steps = self.iters_per_epoch * state_dicts['epoch']
self.current_iter = state_dicts['epoch'] + 1
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
net.set_state_dict(state_dicts[net_name]) net.set_state_dict(state_dicts[net_name])
......
...@@ -16,3 +16,4 @@ from . import dlibutils as dlib ...@@ -16,3 +16,4 @@ from . import dlibutils as dlib
from . import mask from . import mask
from . import image from . import image
from . import face_segmentation from . import face_segmentation
from . import face_detection
...@@ -22,3 +22,5 @@ from .esrgan_model import ESRGAN ...@@ -22,3 +22,5 @@ from .esrgan_model import ESRGAN
from .ugatit_model import UGATITModel from .ugatit_model import UGATITModel
from .dc_gan_model import DCGANModel from .dc_gan_model import DCGANModel
from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel
from .wav2lip_model import Wav2LipModel
from .wav2lip_hq_model import Wav2LipModelHq
...@@ -18,3 +18,5 @@ from .discriminator_ugatit import UGATITDiscriminator ...@@ -18,3 +18,5 @@ from .discriminator_ugatit import UGATITDiscriminator
from .dcdiscriminator import DCDiscriminator from .dcdiscriminator import DCDiscriminator
from .discriminator_animegan import AnimeDiscriminator from .discriminator_animegan import AnimeDiscriminator
from .discriminator_styleganv2 import StyleGANv2Discriminator from .discriminator_styleganv2 import StyleGANv2Discriminator
from .syncnet import SyncNetColor
from .wav2lip_disc_qual import Wav2LipDiscQual
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
import sys
from ...modules.conv import ConvBNRelu from ...modules.conv import ConvBNRelu
#from conv import ConvBNRelu
from .builder import DISCRIMINATORS
@DISCRIMINATORS.register()
class SyncNetColor(nn.Layer): class SyncNetColor(nn.Layer):
def __init__(self): def __init__(self):
super(SyncNetColor, self).__init__() super(SyncNetColor, self).__init__()
......
import paddle
from paddle import nn
from paddle.nn import functional as F
from ...modules.conv import ConvBNRelu, NonNormConv2d, Conv2dTransposeRelu
from .builder import DISCRIMINATORS
@DISCRIMINATORS.register()
class Wav2LipDiscQual(nn.Layer):
def __init__(self):
super(Wav2LipDiscQual, self).__init__()
self.face_encoder_blocks = nn.LayerList([
nn.Sequential(
NonNormConv2d(3, 32, kernel_size=7, stride=1,
padding=3)), # 48,96
nn.Sequential(
NonNormConv2d(32, 64, kernel_size=5, stride=(1, 2),
padding=2), # 48,48
NonNormConv2d(64, 64, kernel_size=5, stride=1, padding=2)),
nn.Sequential(
NonNormConv2d(64, 128, kernel_size=5, stride=2,
padding=2), # 24,24
NonNormConv2d(128, 128, kernel_size=5, stride=1, padding=2)),
nn.Sequential(
NonNormConv2d(128, 256, kernel_size=5, stride=2,
padding=2), # 12,12
NonNormConv2d(256, 256, kernel_size=5, stride=1, padding=2)),
nn.Sequential(
NonNormConv2d(256, 512, kernel_size=3, stride=2,
padding=1), # 6,6
NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1)),
nn.Sequential(
NonNormConv2d(512, 512, kernel_size=3, stride=2,
padding=1), # 3,3
NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1),
),
nn.Sequential(
NonNormConv2d(512, 512, kernel_size=3, stride=1,
padding=0), # 1, 1
NonNormConv2d(512, 512, kernel_size=1, stride=1, padding=0)),
])
self.binary_pred = nn.Sequential(
nn.Conv2D(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
self.label_noise = .0
def get_lower_half(self, face_sequences):
return face_sequences[:, :, face_sequences.shape[2] // 2:]
def to_2d(self, face_sequences):
B = face_sequences.shape[0]
face_sequences = paddle.concat(
[face_sequences[:, :, i] for i in range(face_sequences.shape[2])],
axis=0)
return face_sequences
def perceptual_forward(self, false_face_sequences):
false_face_sequences = self.to_2d(false_face_sequences)
false_face_sequences = self.get_lower_half(false_face_sequences)
false_feats = false_face_sequences
for f in self.face_encoder_blocks:
false_feats = f(false_feats)
binary_pred = self.binary_pred(false_feats).reshape(
(len(false_feats), -1))
false_pred_loss = F.binary_cross_entropy(
binary_pred, paddle.ones((len(false_feats), 1)))
return false_pred_loss
def forward(self, face_sequences):
face_sequences = self.to_2d(face_sequences)
face_sequences = self.get_lower_half(face_sequences)
x = face_sequences
for f in self.face_encoder_blocks:
x = f(x)
return paddle.reshape(self.binary_pred(x), (len(x), -1))
...@@ -18,7 +18,6 @@ from paddle.nn import functional as F ...@@ -18,7 +18,6 @@ from paddle.nn import functional as F
from .builder import GENERATORS from .builder import GENERATORS
from ...modules.conv import ConvBNRelu from ...modules.conv import ConvBNRelu
from ...modules.conv import NonNormConv2d
from ...modules.conv import Conv2dTransposeRelu from ...modules.conv import Conv2dTransposeRelu
...@@ -27,7 +26,7 @@ class Wav2Lip(nn.Layer): ...@@ -27,7 +26,7 @@ class Wav2Lip(nn.Layer):
def __init__(self): def __init__(self):
super(Wav2Lip, self).__init__() super(Wav2Lip, self).__init__()
self.face_encoder_blocks = [ self.face_encoder_blocks = nn.LayerList([
nn.Sequential(ConvBNRelu(6, 16, kernel_size=7, stride=1, nn.Sequential(ConvBNRelu(6, 16, kernel_size=7, stride=1,
padding=3)), # 96,96 padding=3)), # 96,96
nn.Sequential( nn.Sequential(
...@@ -106,7 +105,7 @@ class Wav2Lip(nn.Layer): ...@@ -106,7 +105,7 @@ class Wav2Lip(nn.Layer):
ConvBNRelu(512, 512, kernel_size=3, stride=1, ConvBNRelu(512, 512, kernel_size=3, stride=1,
padding=0), # 1, 1 padding=0), # 1, 1
ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0)), ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0)),
] ])
self.audio_encoder = nn.Sequential( self.audio_encoder = nn.Sequential(
ConvBNRelu(1, 32, kernel_size=3, stride=1, padding=1), ConvBNRelu(1, 32, kernel_size=3, stride=1, padding=1),
...@@ -159,7 +158,7 @@ class Wav2Lip(nn.Layer): ...@@ -159,7 +158,7 @@ class Wav2Lip(nn.Layer):
ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0), ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0),
) )
self.face_decoder_blocks = [ self.face_decoder_blocks = nn.LayerList([
nn.Sequential( nn.Sequential(
ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0), ), ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0), ),
nn.Sequential( nn.Sequential(
...@@ -275,7 +274,7 @@ class Wav2Lip(nn.Layer): ...@@ -275,7 +274,7 @@ class Wav2Lip(nn.Layer):
padding=1, padding=1,
residual=True), residual=True),
), ),
] # 96,96 ]) # 96,96
self.output_block = nn.Sequential( self.output_block = nn.Sequential(
ConvBNRelu(80, 32, kernel_size=3, stride=1, padding=1), ConvBNRelu(80, 32, kernel_size=3, stride=1, padding=1),
...@@ -319,85 +318,10 @@ class Wav2Lip(nn.Layer): ...@@ -319,85 +318,10 @@ class Wav2Lip(nn.Layer):
x = self.output_block(x) x = self.output_block(x)
if input_dim_size > 4: if input_dim_size > 4:
x = paddle.split(x, B, axis=0) # [(B, C, H, W)] x = paddle.split(x, int(x.shape[0] / B), axis=0) # [(B, C, H, W)]
outputs = paddle.stack(x, axis=2) # (B, C, T, H, W) outputs = paddle.stack(x, axis=2) # (B, C, T, H, W)
else: else:
outputs = x outputs = x
return outputs return outputs
class Wav2LipDiscQual(nn.Layer):
def __init__(self):
super(Wav2LipDiscQual, self).__init__()
self.face_encoder_blocks = [
nn.Sequential(
NonNormConv2d(3, 32, kernel_size=7, stride=1,
padding=3)), # 48,96
nn.Sequential(
NonNormConv2d(32, 64, kernel_size=5, stride=(1, 2),
padding=2), # 48,48
NonNormConv2d(64, 64, kernel_size=5, stride=1, padding=2)),
nn.Sequential(
NonNormConv2d(64, 128, kernel_size=5, stride=2,
padding=2), # 24,24
NonNormConv2d(128, 128, kernel_size=5, stride=1, padding=2)),
nn.Sequential(
NonNormConv2d(128, 256, kernel_size=5, stride=2,
padding=2), # 12,12
NonNormConv2d(256, 256, kernel_size=5, stride=1, padding=2)),
nn.Sequential(
NonNormConv2d(256, 512, kernel_size=3, stride=2,
padding=1), # 6,6
NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1)),
nn.Sequential(
NonNormConv2d(512, 512, kernel_size=3, stride=2,
padding=1), # 3,3
NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1),
),
nn.Sequential(
NonNormConv2d(512, 512, kernel_size=3, stride=1,
padding=0), # 1, 1
NonNormConv2d(512, 512, kernel_size=1, stride=1, padding=0)),
]
self.binary_pred = nn.Sequential(
nn.Conv2D(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
self.label_noise = .0
def get_lower_half(self, face_sequences):
return face_sequences[:, :, face_sequences.shape[2] // 2:]
def to_2d(self, face_sequences):
B = face_sequences.shape[0]
face_sequences = paddle.concat(
[face_sequences[:, :, i] for i in range(face_sequences.shape[2])],
axis=0)
return face_sequences
def perceptual_forward(self, false_face_sequences):
false_face_sequences = self.to_2d(false_face_sequences)
false_face_sequences = self.get_lower_half(false_face_sequences)
false_feats = false_face_sequences
for f in self.face_encoder_blocks:
false_feats = f(false_feats)
false_pred_loss = F.binary_cross_entropy(
paddle.reshape(self.binary_pred(false_feats),
(len(false_feats), -1)),
paddle.ones((len(false_feats), 1)))
return false_pred_loss
def forward(self, face_sequences):
face_sequences = self.to_2d(face_sequences)
face_sequences = self.get_lower_half(face_sequences)
x = face_sequences
for f in self.face_encoder_blocks:
x = f(x)
return paddle.reshape(self.binary_pred(x), (len(x), -1))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn.functional as F
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .criterions import build_criterion
from .wav2lip_model import cosine_loss, get_sync_loss
from ..solver import build_optimizer
from ..modules.init import init_weights
lipsync_weight_path = '/workspace/PaddleGAN/lipsync_expert.pdparams'
@MODELS.register()
class Wav2LipModelHq(BaseModel):
""" This class implements the Wav2lip model, Wav2lip paper: https://arxiv.org/abs/2008.10010.
The model training requires dataset.
By default, it uses a '--netG Wav2lip' generator,
a '--netD SyncNetColor' discriminator.
"""
def __init__(self,
generator,
discriminator_sync=None,
discriminator_hq=None,
syncnet_wt=1.0,
disc_wt=0.07,
max_eval_steps=700,
is_train=True):
"""Initialize the Wav2lip class.
Parameters:
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
"""
super(Wav2LipModelHq, self).__init__()
self.syncnet_wt = syncnet_wt
self.disc_wt = disc_wt
self.is_train = is_train
self.eval_step = 0
self.max_eval_steps = max_eval_steps
self.eval_sync_losses, self.eval_recon_losses = [], []
self.eval_disc_real_losses, self.eval_disc_fake_losses = [], []
self.eval_perceptual_losses = []
# define networks (both generator and discriminator)
self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG'],
init_type='kaiming',
distribution='uniform')
if self.is_train:
self.nets['netDS'] = build_discriminator(discriminator_sync)
params = paddle.load(lipsync_weight_path)
self.nets['netDS'].load_dict(params)
self.nets['netDH'] = build_discriminator(discriminator_hq)
init_weights(self.nets['netDH'],
init_type='kaiming',
distribution='uniform')
if self.is_train:
self.recon_loss = paddle.nn.L1Loss()
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
"""
self.x = paddle.to_tensor(input['x'])
self.indiv_mels = paddle.to_tensor(input['indiv_mels'])
self.mel = paddle.to_tensor(input['mel'])
self.y = paddle.to_tensor(input['y'])
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.g = self.nets['netG'](self.indiv_mels, self.x)
def backward_G(self):
"""Calculate GAN loss for the generator"""
if self.syncnet_wt > 0.:
self.sync_loss = get_sync_loss(self.mel, self.g, self.nets['netDS'])
else:
self.sync_loss = 0.
self.l1_loss = self.recon_loss(self.g, self.y)
if self.disc_wt > 0.:
if isinstance(self.nets['netDH'], paddle.DataParallel
): #paddle.fluid.dygraph.parallel.DataParallel)
self.perceptual_loss = self.nets[
'netDH']._layers.perceptual_forward(self.g)
else:
self.perceptual_loss = self.nets['netDH'].perceptual_forward(
self.g)
else:
self.perceptual_loss = 0.
self.losses['sync_loss'] = self.sync_loss
self.losses['l1_loss'] = self.l1_loss
self.losses['perceptual_loss'] = self.perceptual_loss
self.loss_G = self.syncnet_wt * self.sync_loss + self.disc_wt * self.perceptual_loss + (
1 - self.syncnet_wt - self.disc_wt) * self.l1_loss
self.loss_G.backward()
def backward_D(self):
self.pred_real = self.nets['netDH'](self.y)
self.disc_real_loss = F.binary_cross_entropy(
self.pred_real, paddle.ones((len(self.pred_real), 1)))
self.losses['disc_real_loss'] = self.disc_real_loss
self.disc_real_loss.backward()
self.pred_fake = self.nets['netDH'](self.g.detach())
self.disc_fake_loss = F.binary_cross_entropy(
self.pred_fake, paddle.zeros((len(self.pred_fake), 1)))
self.losses['disc_fake_loss'] = self.disc_fake_loss
self.disc_fake_loss.backward()
def train_iter(self, optimizers=None):
# forward
self.forward()
# update G
self.set_requires_grad(self.nets['netDS'], False)
self.set_requires_grad(self.nets['netG'], True)
self.set_requires_grad(self.nets['netDH'], True)
self.optimizers['optimizer_G'].clear_grad()
self.optimizers['optimizer_DH'].clear_grad()
self.backward_G()
self.optimizers['optimizer_G'].step()
self.optimizers['optimizer_DH'].clear_grad()
self.backward_D()
self.optimizers['optimizer_DH'].step()
def test_iter(self, metrics=None):
self.eval_step += 1
self.nets['netG'].eval()
self.nets['netDH'].eval()
with paddle.no_grad():
self.forward()
sync_loss = get_sync_loss(self.mel, self.g, self.nets['netDS'])
l1loss = self.recon_loss(self.g, self.y)
pred_real = self.nets['netDH'](self.y)
pred_fake = self.nets['netDH'](self.g)
disc_real_loss = F.binary_cross_entropy(
pred_real, paddle.ones((len(pred_real), 1)))
disc_fake_loss = F.binary_cross_entropy(
pred_fake, paddle.zeros((len(pred_fake), 1)))
self.eval_disc_fake_losses.append(disc_fake_loss.numpy().item())
self.eval_disc_real_losses.append(disc_real_loss.numpy().item())
self.eval_sync_losses.append(sync_loss.numpy().item())
self.eval_recon_losses.append(l1loss.numpy().item())
if self.disc_wt > 0.:
if isinstance(self.nets['netDH'], paddle.DataParallel
): #paddle.fluid.dygraph.parallel.DataParallel)
perceptual_loss = self.nets[
'netDH']._layers.perceptual_forward(
self.g).numpy().item()
else:
perceptual_loss = self.nets['netDH'].perceptual_forward(
self.g).numpy().item()
else:
perceptual_loss = 0.
self.eval_perceptual_losses.append(perceptual_loss)
if self.eval_step == self.max_eval_steps:
averaged_sync_loss = sum(self.eval_sync_losses) / len(
self.eval_sync_losses)
averaged_recon_loss = sum(self.eval_recon_losses) / len(
self.eval_recon_losses)
averaged_perceptual_loss = sum(self.eval_perceptual_losses) / len(
self.eval_perceptual_losses)
averaged_disc_fake_loss = sum(self.eval_disc_fake_losses) / len(
self.eval_disc_fake_losses)
averaged_disc_real_loss = sum(self.eval_disc_real_losses) / len(
self.eval_disc_real_losses)
if averaged_sync_loss < .75:
self.syncnet_wt = 0.01
print(
'L1: {}, Sync loss: {}, Percep: {}, Fake: {}, Real: {}'.format(
averaged_recon_loss, averaged_sync_loss,
averaged_perceptual_loss, averaged_disc_fake_loss,
averaged_disc_real_loss))
self.eval_sync_losses, self.eval_recon_losses = [], []
self.eval_disc_real_losses, self.eval_disc_fake_losses = [], []
self.eval_perceptual_losses = []
self.eval_step = 0
self.nets['netG'].train()
self.nets['netDH'].train()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from ..solver import build_optimizer
from ..modules.init import init_weights
syncnet_T = 5
syncnet_mel_step_size = 16
def cosine_loss(a, v, y):
logloss = paddle.nn.BCELoss()
d = paddle.nn.functional.cosine_similarity(a, v)
loss = logloss(d.unsqueeze(1), y)
return loss
def get_sync_loss(mel, g, netD):
g = g[:, :, :, g.shape[3] // 2:]
g = paddle.concat([g[:, :, i] for i in range(syncnet_T)], axis=1)
a, v = netD(mel, g)
y = paddle.ones((g.shape[0], 1)).astype('float32')
return cosine_loss(a, v, y)
lipsync_weight_path = '/workspace/PaddleGAN/lipsync_expert.pdparams'
@MODELS.register()
class Wav2LipModel(BaseModel):
""" This class implements the Wav2lip model, Wav2lip paper: https://arxiv.org/abs/2008.10010.
The model training requires dataset.
By default, it uses a '--netG Wav2lip' generator,
a '--netD SyncNetColor' discriminator.
"""
def __init__(self,
generator,
discriminator=None,
syncnet_wt=1.0,
max_eval_steps=700,
is_train=True):
"""Initialize the Wav2lip class.
Parameters:
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
"""
super(Wav2LipModel, self).__init__()
self.syncnet_wt = syncnet_wt
self.is_train = is_train
self.eval_step = 0
self.max_eval_steps = max_eval_steps
self.eval_sync_losses, self.eval_recon_losses = [], []
# define networks (both generator and discriminator)
self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG'], distribution='uniform')
if self.is_train:
self.nets['netD'] = build_discriminator(discriminator)
params = paddle.load(lipsync_weight_path)
self.nets['netD'].load_dict(params)
if self.is_train:
self.recon_loss = paddle.nn.L1Loss()
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
"""
self.x = paddle.to_tensor(input['x'])
self.indiv_mels = paddle.to_tensor(input['indiv_mels'])
self.mel = paddle.to_tensor(input['mel'])
self.y = paddle.to_tensor(input['y'])
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.g = self.nets['netG'](self.indiv_mels, self.x)
def backward_G(self):
"""Calculate GAN loss for the generator"""
if self.syncnet_wt > 0.:
self.sync_loss = get_sync_loss(self.mel, self.g, self.nets['netD'])
else:
self.sync_loss = 0.
self.l1_loss = self.recon_loss(self.g, self.y)
self.losses['sync_loss'] = self.sync_loss
self.losses['l1_loss'] = self.l1_loss
self.loss_G = self.syncnet_wt * self.sync_loss + (
1 - self.syncnet_wt) * self.l1_loss
self.loss_G.backward()
def train_iter(self, optimizers=None):
# forward
self.forward()
# update G
self.set_requires_grad(self.nets['netD'], False)
self.set_requires_grad(self.nets['netG'], True)
self.optimizers['optimizer_G'].clear_grad()
self.backward_G()
self.optimizers['optimizer_G'].step()
def test_iter(self, metrics=None):
self.eval_step += 1
self.nets['netG'].eval()
with paddle.no_grad():
self.forward()
sync_loss = get_sync_loss(self.mel, self.g, self.nets['netD'])
l1loss = self.recon_loss(self.g, self.y)
self.eval_sync_losses.append(sync_loss.numpy().item())
self.eval_recon_losses.append(l1loss.numpy().item())
if self.eval_step == self.max_eval_steps:
averaged_sync_loss = sum(self.eval_sync_losses) / len(
self.eval_sync_losses)
averaged_recon_loss = sum(self.eval_recon_losses) / len(
self.eval_recon_losses)
if averaged_sync_loss < .75:
self.syncnet_wt = 0.01
print('L1: {}, Sync loss: {}'.format(averaged_recon_loss,
averaged_sync_loss))
self.eval_step = 0
self.eval_sync_losses, self.eval_recon_losses = [], []
self.nets['netG'].train()
...@@ -40,7 +40,7 @@ class NonNormConv2d(nn.Layer): ...@@ -40,7 +40,7 @@ class NonNormConv2d(nn.Layer):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential( self.conv_block = nn.Sequential(
nn.Conv2D(cin, cout, kernel_size, stride, padding), ) nn.Conv2D(cin, cout, kernel_size, stride, padding), )
self.act = nn.LeakyReLU(0.01, inplace=True) self.act = nn.LeakyReLU(0.01)
def forward(self, x): def forward(self, x):
out = self.conv_block(x) out = self.conv_block(x)
......
...@@ -281,7 +281,10 @@ def kaiming_init(layer, ...@@ -281,7 +281,10 @@ def kaiming_init(layer,
constant_(layer.bias, bias) constant_(layer.bias, bias)
def init_weights(net, init_type='normal', init_gain=0.02): def init_weights(net,
init_type='normal',
init_gain=0.02,
distribution='normal'):
"""Initialize network weights. """Initialize network weights.
Args: Args:
net (nn.Layer): network to be initialized net (nn.Layer): network to be initialized
...@@ -297,9 +300,16 @@ def init_weights(net, init_type='normal', init_gain=0.02): ...@@ -297,9 +300,16 @@ def init_weights(net, init_type='normal', init_gain=0.02):
if init_type == 'normal': if init_type == 'normal':
normal_(m.weight, 0.0, init_gain) normal_(m.weight, 0.0, init_gain)
elif init_type == 'xavier': elif init_type == 'xavier':
xavier_normal_(m.weight, gain=init_gain) if distribution == 'normal':
xavier_normal_(m.weight, gain=init_gain)
else:
xavier_uniform_(m.weight, gain=init_gain)
elif init_type == 'kaiming': elif init_type == 'kaiming':
kaiming_normal_(m.weight, a=0, mode='fan_in') if distribution == 'normal':
kaiming_normal_(m.weight, a=0, mode='fan_in')
else:
kaiming_uniform_(m.weight, a=0, mode='fan_in')
else: else:
raise NotImplementedError( raise NotImplementedError(
'initialization method [%s] is not implemented' % init_type) 'initialization method [%s] is not implemented' % init_type)
......
import librosa
import librosa.filters
import numpy as np
from scipy import signal
from scipy.io import wavfile
from .audio_config import get_audio_config
audio_config = get_audio_config()
def load_wav(path, sr):
return librosa.core.load(path, sr=sr)[0]
def save_wav(wav, path, sr):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
#proposed by @dsmiller
wavfile.write(path, sr, wav.astype(np.int16))
def save_wavenet_wav(wav, path, sr):
librosa.output.write_wav(path, wav, sr=sr)
def preemphasis(wav, k, preemphasize=True):
if preemphasize:
return signal.lfilter([1, -k], [1], wav)
return wav
def inv_preemphasis(wav, k, inv_preemphasize=True):
if inv_preemphasize:
return signal.lfilter([1], [1, -k], wav)
return wav
def get_hop_size():
hop_size = audio_config.hop_size
if hop_size is None:
assert audio_config.frame_shift_ms is not None
hop_size = int(audio_config.frame_shift_ms / 1000 *
audio_config.sample_rate)
return hop_size
def linearspectrogram(wav):
D = _stft(
preemphasis(wav, audio_config.preemphasis, audio_config.preemphasize))
S = _amp_to_db(np.abs(D)) - audio_config.ref_level_db
if audio_config.signal_normalization:
return _normalize(S)
return S
def melspectrogram(wav):
D = _stft(
preemphasis(wav, audio_config.preemphasis, audio_config.preemphasize))
S = _amp_to_db(_linear_to_mel(np.abs(D))) - audio_config.ref_level_db
if audio_config.signal_normalization:
return _normalize(S)
return S
def _lws_processor():
import lws
return lws.lws(audio_config.n_fft,
get_hop_size(),
fftsize=audio_config.win_size,
mode="speech")
def _stft(y):
if audio_config.use_lws:
return _lws_processor(audio_config).stft(y).T
else:
return librosa.stft(y=y,
n_fft=audio_config.n_fft,
hop_length=get_hop_size(),
win_length=audio_config.win_size)
##########################################################
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
def num_frames(length, fsize, fshift):
"""Compute number of time frames of spectrogram
"""
pad = (fsize - fshift)
if length % fshift == 0:
M = (length + pad * 2 - fsize) // fshift + 1
else:
M = (length + pad * 2 - fsize) // fshift + 2
return M
def pad_lr(x, fsize, fshift):
"""Compute left and right padding
"""
M = num_frames(len(x), fsize, fshift)
pad = (fsize - fshift)
T = len(x) + 2 * pad
r = (M - 1) * fshift + fsize - T
return pad, pad + r
##########################################################
#Librosa correct padding
def librosa_pad_lr(x, fsize, fshift):
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
# Conversions
_mel_basis = None
def _linear_to_mel(spectogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectogram)
def _build_mel_basis():
assert audio_config.fmax <= audio_config.sample_rate // 2
return librosa.filters.mel(audio_config.sample_rate,
audio_config.n_fft,
n_mels=audio_config.num_mels,
fmin=audio_config.fmin,
fmax=audio_config.fmax)
def _amp_to_db(x):
min_level = np.exp(audio_config.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, (x) * 0.05)
def _normalize(S):
if audio_config.allow_clipping_in_normalization:
if audio_config.symmetric_mels:
return np.clip(
(2 * audio_config.max_abs_value) *
((S - audio_config.min_level_db) /
(-audio_config.min_level_db)) - audio_config.max_abs_value,
-audio_config.max_abs_value, audio_config.max_abs_value)
else:
return np.clip(
audio_config.max_abs_value * ((S - audio_config.min_level_db) /
(-audio_config.min_level_db)), 0,
audio_config.max_abs_value)
assert S.max() <= 0 and S.min() - audio_config.min_level_db >= 0
if audio_config.symmetric_mels:
return (2 * audio_config.max_abs_value) * (
(S - audio_config.min_level_db) /
(-audio_config.min_level_db)) - audio_config.max_abs_value
else:
return audio_config.max_abs_value * ((S - audio_config.min_level_db) /
(-audio_config.min_level_db))
def _denormalize(D):
if audio_config.allow_clipping_in_normalization:
if audio_config.symmetric_mels:
return (((np.clip(D, -audio_config.max_abs_value,
audio_config.max_abs_value) +
audio_config.max_abs_value) * -audio_config.min_level_db /
(2 * audio_config.max_abs_value)) +
audio_config.min_level_db)
else:
return ((np.clip(D, 0, audio_config.max_abs_value) *
-audio_config.min_level_db / audio_config.max_abs_value) +
audio_config.min_level_db)
if audio_config.symmetric_mels:
return (((D + audio_config.max_abs_value) * -audio_config.min_level_db /
(2 * audio_config.max_abs_value)) + audio_config.min_level_db)
else:
return ((D * -audio_config.min_level_db / audio_config.max_abs_value) +
audio_config.min_level_db)
from easydict import EasyDict as edict
_C = edict()
_C.num_mels = 80
_C.rescale = True
_C.rescaling_max = 0.9
_C.use_lws = False
_C.n_fft = 800
_C.hop_size = 200
_C.win_size = 800
_C.sample_rate = 16000
_C.frame_shift_ms = None
_C.signal_normalization = True
_C.allow_clipping_in_normalization = True
_C.symmetric_mels = True
_C.max_abs_value = 4.
_C.preemphasize = True
_C.preemphasis = 0.97
_C.min_level_db = -100
_C.ref_level_db = 20
_C.fmin = 55
_C.fmax = 7600
_C.fps = 25
def get_audio_config():
return _C
...@@ -28,7 +28,6 @@ from ppgan.engine.trainer import Trainer ...@@ -28,7 +28,6 @@ from ppgan.engine.trainer import Trainer
def main(args, cfg): def main(args, cfg):
# init environment, include logger, dynamic graph, seed, device, train or test mode... # init environment, include logger, dynamic graph, seed, device, train or test mode...
setup(args, cfg) setup(args, cfg)
# build trainer # build trainer
trainer = Trainer(cfg) trainer = Trainer(cfg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册