提交 31a69d80 编写于 作者: L lzzyzlbb

1.modify wav2lip api; 2.add 512*512 of fom

上级 870c902a
......@@ -68,6 +68,11 @@ parser.add_argument("--multi_person",
action="store_true",
default=False,
help="whether there is only one person in the image or not")
parser.add_argument("--image_size",
dest="image_size",
type=int,
default=256,
help="size of image")
parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)
......@@ -87,5 +92,6 @@ if __name__ == "__main__":
best_frame=args.best_frame,
ratio=args.ratio,
face_detector=args.face_detector,
multi_person=args.multi_person)
multi_person=args.multi_person,
image_size=args.image_size)
predictor.run(args.source_image, args.driving_video)
......@@ -23,7 +23,7 @@ parser.add_argument('--face',
parser.add_argument('--outfile',
type=str,
help='Video path to save result. See default for an e.g.',
default='results/result_voice.mp4')
default='result_voice.mp4')
parser.add_argument(
'--static',
......@@ -109,5 +109,16 @@ if __name__ == "__main__":
if args.cpu:
paddle.set_device('cpu')
predictor = Wav2LipPredictor(args)
predictor.run()
predictor = Wav2LipPredictor(checkpoint_path = args.checkpoint_path,
static = args.static,
fps = args.fps,
pads = args.pads,
face_det_batch_size = args.face_det_batch_size,
wav2lip_batch_size = args.wav2lip_batch_size,
resize_factor = args.resize_factor,
crop = args.crop,
box = args.box,
rotate = args.rotate,
nosmooth = args.nosmooth,
face_detector = args.face_detector)
predictor.run(args.face, args.audio, args.outfile)
......@@ -439,7 +439,7 @@ ppgan.apps.MiDaSPredictor(output=None, weight_path=None)
## ppgan.apps.Wav2lipPredictor
```python
ppgan.apps.FirstOrderPredictor(args)
ppgan.apps.FirstOrderPredictor()
```
> 构建Wav2lip模型的实例,此模型用来做唇形合成,即给定一个人物视频和一个音频,实现人物口型与输入语音同步。论文是A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild,论文链接: http://arxiv.org/abs/2008.10010.
......@@ -449,8 +449,8 @@ ppgan.apps.FirstOrderPredictor(args)
> ```
> from ppgan.apps import Wav2LipPredictor
> # The args parameter should be specified by argparse
> predictor = Wav2LipPredictor(args)
> predictor.run()
> predictor = Wav2LipPredictor()
> predictor.run(face, audio, outfile)
> ```
> **参数:**
......
......@@ -33,8 +33,6 @@ from ppgan.faceutils import face_detection
from .base_predictor import BasePredictor
IMAGE_SIZE = 256
class FirstOrderPredictor(BasePredictor):
def __init__(self,
output='output',
......@@ -47,7 +45,8 @@ class FirstOrderPredictor(BasePredictor):
ratio=1.0,
filename='result.mp4',
face_detector='sfd',
multi_person=False):
multi_person=False,
image_size = 256):
if config is not None and isinstance(config, str):
with open(config) as f:
self.cfg = yaml.load(f, Loader=yaml.SafeLoader)
......@@ -85,8 +84,12 @@ class FirstOrderPredictor(BasePredictor):
}
}
}
self.image_size = image_size
if weight_path is None:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams'
if self.image_size == 512:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk-512.pdparams'
else:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams'
weight_path = get_path_from_url(vox_cpk_weight_url)
self.weight_path = weight_path
......@@ -103,6 +106,7 @@ class FirstOrderPredictor(BasePredictor):
self.generator, self.kp_detector = self.load_checkpoints(
self.cfg, self.weight_path)
self.multi_person = multi_person
def read_img(self, path):
img = imageio.imread(path)
......@@ -161,42 +165,23 @@ class FirstOrderPredictor(BasePredictor):
reader.close()
driving_video = [
cv2.resize(frame, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0 for frame in driving_video
cv2.resize(frame, (self.image_size, self.image_size)) / 255.0 for frame in driving_video
]
results = []
# for single person
if not self.multi_person:
h, w, _ = source_image.shape
source_image = cv2.resize(source_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0
predictions = get_prediction(source_image)
imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w))
for frame in predictions
],
fps=fps)
return
bboxes = self.extract_bbox(source_image.copy())
print(str(len(bboxes)) + " persons have been detected")
if len(bboxes) <= 1:
h, w, _ = source_image.shape
source_image = cv2.resize(source_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0
predictions = get_prediction(source_image)
imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w))
for frame in predictions
],
fps=fps)
return
# for multi person
for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = cv2.resize(face_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0
face_image = cv2.resize(face_image, (self.image_size, self.image_size)) / 255.0
predictions = get_prediction(face_image)
results.append({'rec': rec, 'predict': predictions})
if len(bboxes) == 1 or not self.multi_person:
break
out_frame = []
for i in range(len(driving_video)):
......@@ -206,9 +191,19 @@ class FirstOrderPredictor(BasePredictor):
h = y2 - y1
w = x2 - x1
out = result['predict'][i] * 255.0
#from ppgan.apps import RealSRPredictor
#sr = RealSRPredictor()
#sr_img = sr.run(out.astype(np.uint8))
out = cv2.resize(out.astype(np.uint8), (x2 - x1, y2 - y1))
#out = cv2.resize(np.array(sr_img).astype(np.uint8), (x2 - x1, y2 - y1))
if len(results) == 1:
#imageio.imwrite(os.path.join(self.output, "blending_512_realsr","source"+str(i) + ".png"), frame)
frame[y1:y2, x1:x2] = out
#imageio.imwrite(os.path.join(self.output, "blending_512_realsr","target"+str(i) + ".png"), frame)
#mask = np.ones(frame.shape).astype('uint8') * 255
#mask[y1:y2, x1:x2] = (0,0,0)
#imageio.imwrite(os.path.join(self.output, "blending_512_realsr","mask"+str(i) + ".png"), mask)
else:
patch = np.zeros(frame.shape).astype('uint8')
patch[y1:y2, x1:x2] = out
......@@ -218,7 +213,7 @@ class FirstOrderPredictor(BasePredictor):
cv2.circle(mask, (cx, cy), math.ceil(h * self.ratio),
(255, 255, 255), -1, 8, 0)
frame = cv2.copyTo(patch, mask, frame)
out_frame.append(frame)
imageio.mimsave(os.path.join(self.output, self.filename),
[frame for frame in out_frame],
......
......@@ -17,12 +17,31 @@ mel_step_size = 16
class Wav2LipPredictor(BasePredictor):
def __init__(self, args):
self.args = args
if os.path.isfile(self.args.face) and path.basename(
self.args.face).split('.')[1] in ['jpg', 'png', 'jpeg']:
self.args.static = True
def __init__(self, checkpoint_path = None,
static = False,
fps = 25,
pads = [0, 10, 0, 0],
face_det_batch_size = 16,
wav2lip_batch_size = 128,
resize_factor = 1,
crop = [0, -1, 0, -1],
box = [-1, -1, -1, -1],
rotate = False,
nosmooth = False,
face_detector = 'sfd'):
self.img_size = 96
self.checkpoint_path = checkpoint_path
self.static = static
self.fps = fps,
self.pads = pads
self.face_det_batch_size = face_det_batch_size
self.wav2lip_batch_size = wav2lip_batch_size
self.resize_factor = resize_factor
self.crop = crop
self.box = box
self.rotate = rotate
self.nosmooth = nosmooth
self.face_detector = face_detector
makedirs('./temp', exist_ok=True)
def get_smoothened_boxes(self, boxes, T):
......@@ -38,9 +57,9 @@ class Wav2LipPredictor(BasePredictor):
detector = face_detection.FaceAlignment(
face_detection.LandmarksType._2D,
flip_input=False,
face_detector=self.args.face_detector)
face_detector=self.face_detector)
batch_size = self.args.face_det_batch_size
batch_size = self.face_det_batch_size
while 1:
predictions = []
......@@ -61,7 +80,7 @@ class Wav2LipPredictor(BasePredictor):
break
results = []
pady1, pady2, padx1, padx2 = self.args.pads
pady1, pady2, padx1, padx2 = self.pads
for rect, image in zip(predictions, images):
if rect is None:
cv2.imwrite(
......@@ -79,7 +98,7 @@ class Wav2LipPredictor(BasePredictor):
results.append([x1, y1, x2, y2])
boxes = np.array(results)
if not self.args.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5)
if not self.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5)
results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)]
for image, (x1, y1, x2, y2) in zip(images, boxes)]
......@@ -89,8 +108,8 @@ class Wav2LipPredictor(BasePredictor):
def datagen(self, frames, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
if self.args.box[0] == -1:
if not self.args.static:
if self.box[0] == -1:
if not self.static:
face_det_results = self.face_detect(
frames) # BGR2RGB for CNN face detection
else:
......@@ -98,12 +117,12 @@ class Wav2LipPredictor(BasePredictor):
else:
print(
'Using the specified bounding box instead of face detection...')
y1, y2, x1, x2 = self.args.box
y1, y2, x1, x2 = self.box
face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)]
for f in frames]
for i, m in enumerate(mels):
idx = 0 if self.args.static else i % len(frames)
idx = 0 if self.static else i % len(frames)
frame_to_save = frames[idx].copy()
face, coords = face_det_results[idx].copy()
......@@ -114,7 +133,7 @@ class Wav2LipPredictor(BasePredictor):
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= self.args.wav2lip_batch_size:
if len(img_batch) >= self.wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(
mel_batch)
......@@ -143,18 +162,22 @@ class Wav2LipPredictor(BasePredictor):
yield img_batch, mel_batch, frame_batch, coords_batch
def run(self):
if not os.path.isfile(self.args.face):
def run(self, face, audio_seq, outfile):
if os.path.isfile(face) and path.basename(
face).split('.')[1] in ['jpg', 'png', 'jpeg']:
self.static = True
if not os.path.isfile(face):
raise ValueError(
'--face argument must be a valid path to video/image file')
elif path.basename(
self.args.face).split('.')[1] in ['jpg', 'png', 'jpeg']:
full_frames = [cv2.imread(self.args.face)]
fps = self.args.fps
face).split('.')[1] in ['jpg', 'png', 'jpeg']:
full_frames = [cv2.imread(face)]
fps = self.fps
else:
video_stream = cv2.VideoCapture(self.args.face)
video_stream = cv2.VideoCapture(face)
fps = video_stream.get(cv2.CAP_PROP_FPS)
print('Reading video frames...')
......@@ -165,15 +188,15 @@ class Wav2LipPredictor(BasePredictor):
if not still_reading:
video_stream.release()
break
if self.args.resize_factor > 1:
if self.resize_factor > 1:
frame = cv2.resize(
frame, (frame.shape[1] // self.args.resize_factor,
frame.shape[0] // self.args.resize_factor))
frame, (frame.shape[1] // self.resize_factor,
frame.shape[0] // self.resize_factor))
if self.args.rotate:
if self.rotate:
frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
y1, y2, x1, x2 = self.args.crop
y1, y2, x1, x2 = self.crop
if x2 == -1: x2 = frame.shape[1]
if y2 == -1: y2 = frame.shape[0]
......@@ -184,18 +207,16 @@ class Wav2LipPredictor(BasePredictor):
print("Number of frames available for inference: " +
str(len(full_frames)))
if not self.args.audio.endswith('.wav'):
if not audio_seq.endswith('.wav'):
print('Extracting raw audio...')
command = 'ffmpeg -y -i {} -strict -2 {}'.format(
self.args.audio, 'temp/temp.wav')
audio_seq, 'temp/temp.wav')
subprocess.call(command, shell=True)
self.args.audio = 'temp/temp.wav'
audio_seq = 'temp/temp.wav'
wav = audio.load_wav(self.args.audio, 16000)
wav = audio.load_wav(audio_seq, 16000)
mel = audio.melspectrogram(wav)
print(mel.shape)
if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError(
'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again'
......@@ -216,15 +237,15 @@ class Wav2LipPredictor(BasePredictor):
full_frames = full_frames[:len(mel_chunks)]
batch_size = self.args.wav2lip_batch_size
batch_size = self.wav2lip_batch_size
gen = self.datagen(full_frames.copy(), mel_chunks)
model = Wav2Lip()
if self.args.checkpoint_path is None:
if self.checkpoint_path is None:
model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL)
weights = paddle.load(model_weights_path)
else:
weights = paddle.load(self.args.checkpoint_path)
weights = paddle.load(self.checkpoint_path)
model.load_dict(weights)
model.eval()
print("Model loaded")
......@@ -258,5 +279,5 @@ class Wav2LipPredictor(BasePredictor):
out.release()
command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(
self.args.audio, 'temp/result.avi', self.args.outfile)
audio_seq, 'temp/result.avi', outfile)
subprocess.call(command, shell=platform.system() != 'Windows')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册