提交 31a69d80 编写于 作者: L lzzyzlbb

1.modify wav2lip api; 2.add 512*512 of fom

上级 870c902a
...@@ -68,6 +68,11 @@ parser.add_argument("--multi_person", ...@@ -68,6 +68,11 @@ parser.add_argument("--multi_person",
action="store_true", action="store_true",
default=False, default=False,
help="whether there is only one person in the image or not") help="whether there is only one person in the image or not")
parser.add_argument("--image_size",
dest="image_size",
type=int,
default=256,
help="size of image")
parser.set_defaults(relative=False) parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False) parser.set_defaults(adapt_scale=False)
...@@ -87,5 +92,6 @@ if __name__ == "__main__": ...@@ -87,5 +92,6 @@ if __name__ == "__main__":
best_frame=args.best_frame, best_frame=args.best_frame,
ratio=args.ratio, ratio=args.ratio,
face_detector=args.face_detector, face_detector=args.face_detector,
multi_person=args.multi_person) multi_person=args.multi_person,
image_size=args.image_size)
predictor.run(args.source_image, args.driving_video) predictor.run(args.source_image, args.driving_video)
...@@ -23,7 +23,7 @@ parser.add_argument('--face', ...@@ -23,7 +23,7 @@ parser.add_argument('--face',
parser.add_argument('--outfile', parser.add_argument('--outfile',
type=str, type=str,
help='Video path to save result. See default for an e.g.', help='Video path to save result. See default for an e.g.',
default='results/result_voice.mp4') default='result_voice.mp4')
parser.add_argument( parser.add_argument(
'--static', '--static',
...@@ -109,5 +109,16 @@ if __name__ == "__main__": ...@@ -109,5 +109,16 @@ if __name__ == "__main__":
if args.cpu: if args.cpu:
paddle.set_device('cpu') paddle.set_device('cpu')
predictor = Wav2LipPredictor(args) predictor = Wav2LipPredictor(checkpoint_path = args.checkpoint_path,
predictor.run() static = args.static,
fps = args.fps,
pads = args.pads,
face_det_batch_size = args.face_det_batch_size,
wav2lip_batch_size = args.wav2lip_batch_size,
resize_factor = args.resize_factor,
crop = args.crop,
box = args.box,
rotate = args.rotate,
nosmooth = args.nosmooth,
face_detector = args.face_detector)
predictor.run(args.face, args.audio, args.outfile)
...@@ -439,7 +439,7 @@ ppgan.apps.MiDaSPredictor(output=None, weight_path=None) ...@@ -439,7 +439,7 @@ ppgan.apps.MiDaSPredictor(output=None, weight_path=None)
## ppgan.apps.Wav2lipPredictor ## ppgan.apps.Wav2lipPredictor
```python ```python
ppgan.apps.FirstOrderPredictor(args) ppgan.apps.FirstOrderPredictor()
``` ```
> 构建Wav2lip模型的实例,此模型用来做唇形合成,即给定一个人物视频和一个音频,实现人物口型与输入语音同步。论文是A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild,论文链接: http://arxiv.org/abs/2008.10010. > 构建Wav2lip模型的实例,此模型用来做唇形合成,即给定一个人物视频和一个音频,实现人物口型与输入语音同步。论文是A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild,论文链接: http://arxiv.org/abs/2008.10010.
...@@ -449,8 +449,8 @@ ppgan.apps.FirstOrderPredictor(args) ...@@ -449,8 +449,8 @@ ppgan.apps.FirstOrderPredictor(args)
> ``` > ```
> from ppgan.apps import Wav2LipPredictor > from ppgan.apps import Wav2LipPredictor
> # The args parameter should be specified by argparse > # The args parameter should be specified by argparse
> predictor = Wav2LipPredictor(args) > predictor = Wav2LipPredictor()
> predictor.run() > predictor.run(face, audio, outfile)
> ``` > ```
> **参数:** > **参数:**
......
...@@ -33,8 +33,6 @@ from ppgan.faceutils import face_detection ...@@ -33,8 +33,6 @@ from ppgan.faceutils import face_detection
from .base_predictor import BasePredictor from .base_predictor import BasePredictor
IMAGE_SIZE = 256
class FirstOrderPredictor(BasePredictor): class FirstOrderPredictor(BasePredictor):
def __init__(self, def __init__(self,
output='output', output='output',
...@@ -47,7 +45,8 @@ class FirstOrderPredictor(BasePredictor): ...@@ -47,7 +45,8 @@ class FirstOrderPredictor(BasePredictor):
ratio=1.0, ratio=1.0,
filename='result.mp4', filename='result.mp4',
face_detector='sfd', face_detector='sfd',
multi_person=False): multi_person=False,
image_size = 256):
if config is not None and isinstance(config, str): if config is not None and isinstance(config, str):
with open(config) as f: with open(config) as f:
self.cfg = yaml.load(f, Loader=yaml.SafeLoader) self.cfg = yaml.load(f, Loader=yaml.SafeLoader)
...@@ -85,7 +84,11 @@ class FirstOrderPredictor(BasePredictor): ...@@ -85,7 +84,11 @@ class FirstOrderPredictor(BasePredictor):
} }
} }
} }
self.image_size = image_size
if weight_path is None: if weight_path is None:
if self.image_size == 512:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk-512.pdparams'
else:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams' vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams'
weight_path = get_path_from_url(vox_cpk_weight_url) weight_path = get_path_from_url(vox_cpk_weight_url)
...@@ -104,6 +107,7 @@ class FirstOrderPredictor(BasePredictor): ...@@ -104,6 +107,7 @@ class FirstOrderPredictor(BasePredictor):
self.cfg, self.weight_path) self.cfg, self.weight_path)
self.multi_person = multi_person self.multi_person = multi_person
def read_img(self, path): def read_img(self, path):
img = imageio.imread(path) img = imageio.imread(path)
if img.ndim == 2: if img.ndim == 2:
...@@ -161,42 +165,23 @@ class FirstOrderPredictor(BasePredictor): ...@@ -161,42 +165,23 @@ class FirstOrderPredictor(BasePredictor):
reader.close() reader.close()
driving_video = [ driving_video = [
cv2.resize(frame, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0 for frame in driving_video cv2.resize(frame, (self.image_size, self.image_size)) / 255.0 for frame in driving_video
] ]
results = [] results = []
# for single person
if not self.multi_person:
h, w, _ = source_image.shape
source_image = cv2.resize(source_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0
predictions = get_prediction(source_image)
imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w))
for frame in predictions
],
fps=fps)
return
bboxes = self.extract_bbox(source_image.copy()) bboxes = self.extract_bbox(source_image.copy())
print(str(len(bboxes)) + " persons have been detected") print(str(len(bboxes)) + " persons have been detected")
if len(bboxes) <= 1:
h, w, _ = source_image.shape
source_image = cv2.resize(source_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0
predictions = get_prediction(source_image)
imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w))
for frame in predictions
],
fps=fps)
return
# for multi person # for multi person
for rec in bboxes: for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]] face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = cv2.resize(face_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0 face_image = cv2.resize(face_image, (self.image_size, self.image_size)) / 255.0
predictions = get_prediction(face_image) predictions = get_prediction(face_image)
results.append({'rec': rec, 'predict': predictions}) results.append({'rec': rec, 'predict': predictions})
if len(bboxes) == 1 or not self.multi_person:
break
out_frame = [] out_frame = []
for i in range(len(driving_video)): for i in range(len(driving_video)):
...@@ -206,9 +191,19 @@ class FirstOrderPredictor(BasePredictor): ...@@ -206,9 +191,19 @@ class FirstOrderPredictor(BasePredictor):
h = y2 - y1 h = y2 - y1
w = x2 - x1 w = x2 - x1
out = result['predict'][i] * 255.0 out = result['predict'][i] * 255.0
#from ppgan.apps import RealSRPredictor
#sr = RealSRPredictor()
#sr_img = sr.run(out.astype(np.uint8))
out = cv2.resize(out.astype(np.uint8), (x2 - x1, y2 - y1)) out = cv2.resize(out.astype(np.uint8), (x2 - x1, y2 - y1))
#out = cv2.resize(np.array(sr_img).astype(np.uint8), (x2 - x1, y2 - y1))
if len(results) == 1: if len(results) == 1:
#imageio.imwrite(os.path.join(self.output, "blending_512_realsr","source"+str(i) + ".png"), frame)
frame[y1:y2, x1:x2] = out frame[y1:y2, x1:x2] = out
#imageio.imwrite(os.path.join(self.output, "blending_512_realsr","target"+str(i) + ".png"), frame)
#mask = np.ones(frame.shape).astype('uint8') * 255
#mask[y1:y2, x1:x2] = (0,0,0)
#imageio.imwrite(os.path.join(self.output, "blending_512_realsr","mask"+str(i) + ".png"), mask)
else: else:
patch = np.zeros(frame.shape).astype('uint8') patch = np.zeros(frame.shape).astype('uint8')
patch[y1:y2, x1:x2] = out patch[y1:y2, x1:x2] = out
......
...@@ -17,12 +17,31 @@ mel_step_size = 16 ...@@ -17,12 +17,31 @@ mel_step_size = 16
class Wav2LipPredictor(BasePredictor): class Wav2LipPredictor(BasePredictor):
def __init__(self, args): def __init__(self, checkpoint_path = None,
self.args = args static = False,
if os.path.isfile(self.args.face) and path.basename( fps = 25,
self.args.face).split('.')[1] in ['jpg', 'png', 'jpeg']: pads = [0, 10, 0, 0],
self.args.static = True face_det_batch_size = 16,
wav2lip_batch_size = 128,
resize_factor = 1,
crop = [0, -1, 0, -1],
box = [-1, -1, -1, -1],
rotate = False,
nosmooth = False,
face_detector = 'sfd'):
self.img_size = 96 self.img_size = 96
self.checkpoint_path = checkpoint_path
self.static = static
self.fps = fps,
self.pads = pads
self.face_det_batch_size = face_det_batch_size
self.wav2lip_batch_size = wav2lip_batch_size
self.resize_factor = resize_factor
self.crop = crop
self.box = box
self.rotate = rotate
self.nosmooth = nosmooth
self.face_detector = face_detector
makedirs('./temp', exist_ok=True) makedirs('./temp', exist_ok=True)
def get_smoothened_boxes(self, boxes, T): def get_smoothened_boxes(self, boxes, T):
...@@ -38,9 +57,9 @@ class Wav2LipPredictor(BasePredictor): ...@@ -38,9 +57,9 @@ class Wav2LipPredictor(BasePredictor):
detector = face_detection.FaceAlignment( detector = face_detection.FaceAlignment(
face_detection.LandmarksType._2D, face_detection.LandmarksType._2D,
flip_input=False, flip_input=False,
face_detector=self.args.face_detector) face_detector=self.face_detector)
batch_size = self.args.face_det_batch_size batch_size = self.face_det_batch_size
while 1: while 1:
predictions = [] predictions = []
...@@ -61,7 +80,7 @@ class Wav2LipPredictor(BasePredictor): ...@@ -61,7 +80,7 @@ class Wav2LipPredictor(BasePredictor):
break break
results = [] results = []
pady1, pady2, padx1, padx2 = self.args.pads pady1, pady2, padx1, padx2 = self.pads
for rect, image in zip(predictions, images): for rect, image in zip(predictions, images):
if rect is None: if rect is None:
cv2.imwrite( cv2.imwrite(
...@@ -79,7 +98,7 @@ class Wav2LipPredictor(BasePredictor): ...@@ -79,7 +98,7 @@ class Wav2LipPredictor(BasePredictor):
results.append([x1, y1, x2, y2]) results.append([x1, y1, x2, y2])
boxes = np.array(results) boxes = np.array(results)
if not self.args.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5) if not self.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5)
results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)]
for image, (x1, y1, x2, y2) in zip(images, boxes)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
...@@ -89,8 +108,8 @@ class Wav2LipPredictor(BasePredictor): ...@@ -89,8 +108,8 @@ class Wav2LipPredictor(BasePredictor):
def datagen(self, frames, mels): def datagen(self, frames, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
if self.args.box[0] == -1: if self.box[0] == -1:
if not self.args.static: if not self.static:
face_det_results = self.face_detect( face_det_results = self.face_detect(
frames) # BGR2RGB for CNN face detection frames) # BGR2RGB for CNN face detection
else: else:
...@@ -98,12 +117,12 @@ class Wav2LipPredictor(BasePredictor): ...@@ -98,12 +117,12 @@ class Wav2LipPredictor(BasePredictor):
else: else:
print( print(
'Using the specified bounding box instead of face detection...') 'Using the specified bounding box instead of face detection...')
y1, y2, x1, x2 = self.args.box y1, y2, x1, x2 = self.box
face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)] face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)]
for f in frames] for f in frames]
for i, m in enumerate(mels): for i, m in enumerate(mels):
idx = 0 if self.args.static else i % len(frames) idx = 0 if self.static else i % len(frames)
frame_to_save = frames[idx].copy() frame_to_save = frames[idx].copy()
face, coords = face_det_results[idx].copy() face, coords = face_det_results[idx].copy()
...@@ -114,7 +133,7 @@ class Wav2LipPredictor(BasePredictor): ...@@ -114,7 +133,7 @@ class Wav2LipPredictor(BasePredictor):
frame_batch.append(frame_to_save) frame_batch.append(frame_to_save)
coords_batch.append(coords) coords_batch.append(coords)
if len(img_batch) >= self.args.wav2lip_batch_size: if len(img_batch) >= self.wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray( img_batch, mel_batch = np.asarray(img_batch), np.asarray(
mel_batch) mel_batch)
...@@ -143,18 +162,22 @@ class Wav2LipPredictor(BasePredictor): ...@@ -143,18 +162,22 @@ class Wav2LipPredictor(BasePredictor):
yield img_batch, mel_batch, frame_batch, coords_batch yield img_batch, mel_batch, frame_batch, coords_batch
def run(self): def run(self, face, audio_seq, outfile):
if not os.path.isfile(self.args.face): if os.path.isfile(face) and path.basename(
face).split('.')[1] in ['jpg', 'png', 'jpeg']:
self.static = True
if not os.path.isfile(face):
raise ValueError( raise ValueError(
'--face argument must be a valid path to video/image file') '--face argument must be a valid path to video/image file')
elif path.basename( elif path.basename(
self.args.face).split('.')[1] in ['jpg', 'png', 'jpeg']: face).split('.')[1] in ['jpg', 'png', 'jpeg']:
full_frames = [cv2.imread(self.args.face)] full_frames = [cv2.imread(face)]
fps = self.args.fps fps = self.fps
else: else:
video_stream = cv2.VideoCapture(self.args.face) video_stream = cv2.VideoCapture(face)
fps = video_stream.get(cv2.CAP_PROP_FPS) fps = video_stream.get(cv2.CAP_PROP_FPS)
print('Reading video frames...') print('Reading video frames...')
...@@ -165,15 +188,15 @@ class Wav2LipPredictor(BasePredictor): ...@@ -165,15 +188,15 @@ class Wav2LipPredictor(BasePredictor):
if not still_reading: if not still_reading:
video_stream.release() video_stream.release()
break break
if self.args.resize_factor > 1: if self.resize_factor > 1:
frame = cv2.resize( frame = cv2.resize(
frame, (frame.shape[1] // self.args.resize_factor, frame, (frame.shape[1] // self.resize_factor,
frame.shape[0] // self.args.resize_factor)) frame.shape[0] // self.resize_factor))
if self.args.rotate: if self.rotate:
frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
y1, y2, x1, x2 = self.args.crop y1, y2, x1, x2 = self.crop
if x2 == -1: x2 = frame.shape[1] if x2 == -1: x2 = frame.shape[1]
if y2 == -1: y2 = frame.shape[0] if y2 == -1: y2 = frame.shape[0]
...@@ -184,18 +207,16 @@ class Wav2LipPredictor(BasePredictor): ...@@ -184,18 +207,16 @@ class Wav2LipPredictor(BasePredictor):
print("Number of frames available for inference: " + print("Number of frames available for inference: " +
str(len(full_frames))) str(len(full_frames)))
if not self.args.audio.endswith('.wav'): if not audio_seq.endswith('.wav'):
print('Extracting raw audio...') print('Extracting raw audio...')
command = 'ffmpeg -y -i {} -strict -2 {}'.format( command = 'ffmpeg -y -i {} -strict -2 {}'.format(
self.args.audio, 'temp/temp.wav') audio_seq, 'temp/temp.wav')
subprocess.call(command, shell=True) subprocess.call(command, shell=True)
self.args.audio = 'temp/temp.wav' audio_seq = 'temp/temp.wav'
wav = audio.load_wav(self.args.audio, 16000) wav = audio.load_wav(audio_seq, 16000)
mel = audio.melspectrogram(wav) mel = audio.melspectrogram(wav)
print(mel.shape)
if np.isnan(mel.reshape(-1)).sum() > 0: if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError( raise ValueError(
'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again' 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again'
...@@ -216,15 +237,15 @@ class Wav2LipPredictor(BasePredictor): ...@@ -216,15 +237,15 @@ class Wav2LipPredictor(BasePredictor):
full_frames = full_frames[:len(mel_chunks)] full_frames = full_frames[:len(mel_chunks)]
batch_size = self.args.wav2lip_batch_size batch_size = self.wav2lip_batch_size
gen = self.datagen(full_frames.copy(), mel_chunks) gen = self.datagen(full_frames.copy(), mel_chunks)
model = Wav2Lip() model = Wav2Lip()
if self.args.checkpoint_path is None: if self.checkpoint_path is None:
model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL) model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL)
weights = paddle.load(model_weights_path) weights = paddle.load(model_weights_path)
else: else:
weights = paddle.load(self.args.checkpoint_path) weights = paddle.load(self.checkpoint_path)
model.load_dict(weights) model.load_dict(weights)
model.eval() model.eval()
print("Model loaded") print("Model loaded")
...@@ -258,5 +279,5 @@ class Wav2LipPredictor(BasePredictor): ...@@ -258,5 +279,5 @@ class Wav2LipPredictor(BasePredictor):
out.release() out.release()
command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format( command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(
self.args.audio, 'temp/result.avi', self.args.outfile) audio_seq, 'temp/result.avi', outfile)
subprocess.call(command, shell=platform.system() != 'Windows') subprocess.call(command, shell=platform.system() != 'Windows')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册