未验证 提交 4baea348 编写于 作者: L leesusu 提交者: GitHub

Add Wav2Lip generator. (#105)

* Add Wav2Lip generator.
上级 2a092607
......@@ -18,3 +18,4 @@ from .rrdb_net import RRDBNet
from .makeup import GeneratorPSGANAttention
from .resnet_ugatit import ResnetUGATITGenerator
from .dcgenerator import DCGenerator
from .wav2lip import Wav2Lip
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from paddle.nn import functional as F
from .builder import GENERATORS
from ...modules.conv import ConvBNRelu
from ...modules.conv import NonNormConv2d
from ...modules.conv import Conv2dTransposeRelu
@GENERATORS.register()
class Wav2Lip(nn.Layer):
def __init__(self):
super(Wav2Lip, self).__init__()
self.face_encoder_blocks = [
nn.Sequential(ConvBNRelu(6, 16, kernel_size=7, stride=1,
padding=3)), # 96,96
nn.Sequential(
ConvBNRelu(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
ConvBNRelu(32,
32,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(32,
32,
kernel_size=3,
stride=1,
padding=1,
residual=True)),
nn.Sequential(
ConvBNRelu(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
ConvBNRelu(64,
64,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(64,
64,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(64,
64,
kernel_size=3,
stride=1,
padding=1,
residual=True)),
nn.Sequential(
ConvBNRelu(64, 128, kernel_size=3, stride=2,
padding=1), # 12,12
ConvBNRelu(128,
128,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(128,
128,
kernel_size=3,
stride=1,
padding=1,
residual=True)),
nn.Sequential(
ConvBNRelu(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
ConvBNRelu(256,
256,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(256,
256,
kernel_size=3,
stride=1,
padding=1,
residual=True)),
nn.Sequential(
ConvBNRelu(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
ConvBNRelu(512,
512,
kernel_size=3,
stride=1,
padding=1,
residual=True),
),
nn.Sequential(
ConvBNRelu(512, 512, kernel_size=3, stride=1,
padding=0), # 1, 1
ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0)),
]
self.audio_encoder = nn.Sequential(
ConvBNRelu(1, 32, kernel_size=3, stride=1, padding=1),
ConvBNRelu(32,
32,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(32,
32,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(32, 64, kernel_size=3, stride=(3, 1), padding=1),
ConvBNRelu(64,
64,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(64,
64,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(64, 128, kernel_size=3, stride=3, padding=1),
ConvBNRelu(128,
128,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(128,
128,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(128, 256, kernel_size=3, stride=(3, 2), padding=1),
ConvBNRelu(256,
256,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(256, 512, kernel_size=3, stride=1, padding=0),
ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0),
)
self.face_decoder_blocks = [
nn.Sequential(
ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0), ),
nn.Sequential(
Conv2dTransposeRelu(1024,
512,
kernel_size=3,
stride=1,
padding=0), # 3,3
ConvBNRelu(512,
512,
kernel_size=3,
stride=1,
padding=1,
residual=True),
),
nn.Sequential(
Conv2dTransposeRelu(1024,
512,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
ConvBNRelu(512,
512,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(512,
512,
kernel_size=3,
stride=1,
padding=1,
residual=True),
), # 6, 6
nn.Sequential(
Conv2dTransposeRelu(768,
384,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
ConvBNRelu(384,
384,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(384,
384,
kernel_size=3,
stride=1,
padding=1,
residual=True),
), # 12, 12
nn.Sequential(
Conv2dTransposeRelu(512,
256,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
ConvBNRelu(256,
256,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(256,
256,
kernel_size=3,
stride=1,
padding=1,
residual=True),
), # 24, 24
nn.Sequential(
Conv2dTransposeRelu(320,
128,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
ConvBNRelu(128,
128,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(128,
128,
kernel_size=3,
stride=1,
padding=1,
residual=True),
), # 48, 48
nn.Sequential(
Conv2dTransposeRelu(160,
64,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
ConvBNRelu(64,
64,
kernel_size=3,
stride=1,
padding=1,
residual=True),
ConvBNRelu(64,
64,
kernel_size=3,
stride=1,
padding=1,
residual=True),
),
] # 96,96
self.output_block = nn.Sequential(
ConvBNRelu(80, 32, kernel_size=3, stride=1, padding=1),
nn.Conv2D(32, 3, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
def forward(self, audio_sequences, face_sequences):
# audio_sequences = (B, T, 1, 80, 16)
B = audio_sequences.shape[0]
input_dim_size = len(face_sequences.shape)
if input_dim_size > 4:
audio_sequences = paddle.concat([
audio_sequences[:, i] for i in range(audio_sequences.shape[1])
],
axis=0)
face_sequences = paddle.concat([
face_sequences[:, :, i] for i in range(face_sequences.shape[2])
],
axis=0)
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
feats = []
x = face_sequences
for f in self.face_encoder_blocks:
x = f(x)
feats.append(x)
x = audio_embedding
for f in self.face_decoder_blocks:
x = f(x)
try:
x = paddle.concat((x, feats[-1]), axis=1)
except Exception as e:
print(x.shape)
print(feats[-1].shape)
raise e
feats.pop()
x = self.output_block(x)
if input_dim_size > 4:
x = paddle.split(x, B, axis=0) # [(B, C, H, W)]
outputs = paddle.stack(x, axis=2) # (B, C, T, H, W)
else:
outputs = x
return outputs
class Wav2LipDiscQual(nn.Layer):
def __init__(self):
super(Wav2LipDiscQual, self).__init__()
self.face_encoder_blocks = [
nn.Sequential(
NonNormConv2d(3, 32, kernel_size=7, stride=1,
padding=3)), # 48,96
nn.Sequential(
NonNormConv2d(32, 64, kernel_size=5, stride=(1, 2),
padding=2), # 48,48
NonNormConv2d(64, 64, kernel_size=5, stride=1, padding=2)),
nn.Sequential(
NonNormConv2d(64, 128, kernel_size=5, stride=2,
padding=2), # 24,24
NonNormConv2d(128, 128, kernel_size=5, stride=1, padding=2)),
nn.Sequential(
NonNormConv2d(128, 256, kernel_size=5, stride=2,
padding=2), # 12,12
NonNormConv2d(256, 256, kernel_size=5, stride=1, padding=2)),
nn.Sequential(
NonNormConv2d(256, 512, kernel_size=3, stride=2,
padding=1), # 6,6
NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1)),
nn.Sequential(
NonNormConv2d(512, 512, kernel_size=3, stride=2,
padding=1), # 3,3
NonNormConv2d(512, 512, kernel_size=3, stride=1, padding=1),
),
nn.Sequential(
NonNormConv2d(512, 512, kernel_size=3, stride=1,
padding=0), # 1, 1
NonNormConv2d(512, 512, kernel_size=1, stride=1, padding=0)),
]
self.binary_pred = nn.Sequential(
nn.Conv2D(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
self.label_noise = .0
def get_lower_half(self, face_sequences):
return face_sequences[:, :, face_sequences.shape[2] // 2:]
def to_2d(self, face_sequences):
B = face_sequences.shape[0]
face_sequences = paddle.concat(
[face_sequences[:, :, i] for i in range(face_sequences.shape[2])],
axis=0)
return face_sequences
def perceptual_forward(self, false_face_sequences):
false_face_sequences = self.to_2d(false_face_sequences)
false_face_sequences = self.get_lower_half(false_face_sequences)
false_feats = false_face_sequences
for f in self.face_encoder_blocks:
false_feats = f(false_feats)
false_pred_loss = F.binary_cross_entropy(
paddle.reshape(self.binary_pred(false_feats),
(len(false_feats), -1)),
paddle.ones((len(false_feats), 1)))
return false_pred_loss
def forward(self, face_sequences):
face_sequences = self.to_2d(face_sequences)
face_sequences = self.get_lower_half(face_sequences)
x = face_sequences
for f in self.face_encoder_blocks:
x = f(x)
return paddle.reshape(self.binary_pred(x), (len(x), -1))
......@@ -47,7 +47,7 @@ class NonNormConv2d(nn.Layer):
return self.act(out)
class Conv2dTranspseRelu(nn.Layer):
class Conv2dTransposeRelu(nn.Layer):
def __init__(self,
cin,
cout,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册