# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import cv2 import random import os.path import numpy as np from PIL import Image from glob import glob from os.path import dirname, join, basename, isfile from ppgan.utils import audio from ppgan.utils.audio_config import get_audio_config import numpy as np import paddle from .builder import DATASETS def get_image_list(data_root, split): filelist = [] with open('filelists/{}.txt'.format(split)) as f: for line in f: line = line.strip() if ' ' in line: line = line.split()[0] filelist.append(os.path.join(data_root, line)) return filelist syncnet_T = 5 syncnet_mel_step_size = 16 audio_cfg = get_audio_config() @DATASETS.register() class Wav2LipDataset(paddle.io.Dataset): def __init__(self, dataroot, img_size, filelists_path, split): """Initialize Wav2Lip dataset class. Args: dataroot (str): Directory of dataset. """ self.image_path = dataroot self.img_size = img_size self.split = split self.all_videos = get_image_list(self.image_path, self.split) def get_frame_id(self, frame): return int(basename(frame).split('.')[0]) def get_window(self, start_frame): start_id = self.get_frame_id(start_frame) vidname = dirname(start_frame) window_fnames = [] for frame_id in range(start_id, start_id + syncnet_T): frame = join(vidname, '{}.jpg'.format(frame_id)) if not isfile(frame): return None window_fnames.append(frame) return window_fnames def read_window(self, window_fnames): if window_fnames is None: return None window = [] for fname in window_fnames: img = cv2.imread(fname) if img is None: return None try: img = cv2.resize(img, (self.img_size, self.img_size)) except Exception as e: return None window.append(img) return window def crop_audio_window(self, spec, start_frame): if type(start_frame) == int: start_frame_num = start_frame else: start_frame_num = self.get_frame_id( start_frame) # 0-indexing ---> 1-indexing start_idx = int(80. * (start_frame_num / float(audio_cfg["fps"]))) end_idx = start_idx + syncnet_mel_step_size return spec[start_idx:end_idx, :] def get_segmented_mels(self, spec, start_frame): mels = [] assert syncnet_T == 5 start_frame_num = self.get_frame_id( start_frame) + 1 # 0-indexing ---> 1-indexing if start_frame_num - 2 < 0: return None for i in range(start_frame_num, start_frame_num + syncnet_T): m = self.crop_audio_window(spec, i - 2) if m.shape[0] != syncnet_mel_step_size: return None mels.append(m.T) mels = np.asarray(mels) return mels def prepare_window(self, window): # 3 x T x H x W x = np.asarray(window) / 255. x = np.transpose(x, (3, 0, 1, 2)) return x def __len__(self): return len(self.all_videos) def __getitem__(self, idx): while 1: idx = random.randint(0, len(self.all_videos) - 1) vidname = self.all_videos[idx] img_names = list(glob(join(vidname, '*.jpg'))) if len(img_names) <= 3 * syncnet_T: continue img_name = random.choice(img_names) wrong_img_name = random.choice(img_names) while wrong_img_name == img_name: wrong_img_name = random.choice(img_names) window_fnames = self.get_window(img_name) wrong_window_fnames = self.get_window(wrong_img_name) if window_fnames is None or wrong_window_fnames is None: continue window = self.read_window(window_fnames) if window is None: continue wrong_window = self.read_window(wrong_window_fnames) if wrong_window is None: continue try: wavpath = join(vidname, "audio.wav") wav = audio.load_wav(wavpath, audio_cfg["sample_rate"]) orig_mel = audio.melspectrogram(wav).T except Exception as e: continue mel = self.crop_audio_window(orig_mel.copy(), img_name) if (mel.shape[0] != syncnet_mel_step_size): continue indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name) if indiv_mels is None: continue window = self.prepare_window(window) y = window.copy() window[:, :, window.shape[2] // 2:] = 0. wrong_window = self.prepare_window(wrong_window) x = np.concatenate([window, wrong_window], axis=0) x = np.float32(x) mel = np.transpose(mel) mel = np.expand_dims(mel, 0) indiv_mels = np.expand_dims(indiv_mels, 1) #np.random.seed(200) #x = np.random.rand(*x.shape).astype('float32') #np.random.seed(200) #mel = np.random.rand(*mel.shape) #np.random.seed(200) #indiv_mels = np.random.rand(*indiv_mels.shape) #np.random.seed(200) #y = np.random.rand(*y.shape) return { 'x': x, 'indiv_mels': np.float32(indiv_mels), 'mel': np.float32(mel), 'y': np.float32(y) } def __len__(self): """Return the total number of images in the dataset. """ return len(self.all_videos)