未验证 提交 2ab96cb8 编写于 作者: L lzzyzlbb 提交者: GitHub

Add static model and inference of stylegan2,fom, basicvsr (#491)

* Update benchmark.yaml

* Update benchmark.yaml

* add static model and inference of fom, basicvsr, stylegan2

* add static model and inference of fom, basicvsr, stylegan2

* fix basicvsr dataset for small datasets

* fix basicvsr dataset for small datasets

* fix basicvsr dataset for small datasets

* fix basicvsr dataset for small datasets
上级 6e3dad37
...@@ -12,7 +12,7 @@ FOMM: ...@@ -12,7 +12,7 @@ FOMM:
fp_item: fp32 fp_item: fp32
bs_item: 8 16 bs_item: 8 16
epochs: 1 epochs: 1
log_interval: 11 log_interval: 1
esrgan: esrgan:
dataset_web: https://paddlegan.bj.bcebos.com/datasets/DIV2KandSet14.tar dataset_web: https://paddlegan.bj.bcebos.com/datasets/DIV2KandSet14.tar
......
...@@ -38,6 +38,7 @@ dataset: ...@@ -38,6 +38,7 @@ dataset:
use_rot: True use_rot: True
scale: 4 scale: 4
val_partition: REDS4 val_partition: REDS4
num_clips: 270
test: test:
name: SRREDSMultipleGTDataset name: SRREDSMultipleGTDataset
...@@ -90,3 +91,6 @@ log_config: ...@@ -90,3 +91,6 @@ log_config:
snapshot_config: snapshot_config:
interval: 5000 interval: 5000
export_model:
- {name: 'generator', inputs_num: 1}
...@@ -119,3 +119,7 @@ log_config: ...@@ -119,3 +119,7 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
export_model:
- {name: 'netG_A', inputs_num: 1}
- {name: 'netG_B', inputs_num: 1}
...@@ -123,3 +123,6 @@ snapshot_config: ...@@ -123,3 +123,6 @@ snapshot_config:
optimizer: optimizer:
name: Adam name: Adam
export_model:
- {}
...@@ -115,3 +115,6 @@ validate: ...@@ -115,3 +115,6 @@ validate:
fid: # metric name, can be arbitrary fid: # metric name, can be arbitrary
name: FID name: FID
batch_size: 8 batch_size: 8
export_model:
- {name: 'netG', inputs_num: 1}
...@@ -55,7 +55,8 @@ class SRREDSMultipleGTDataset(Dataset): ...@@ -55,7 +55,8 @@ class SRREDSMultipleGTDataset(Dataset):
use_rot=False, use_rot=False,
scale=4, scale=4,
val_partition='REDS4', val_partition='REDS4',
batch_size=4): batch_size=4,
num_clips=270):
super(SRREDSMultipleGTDataset, self).__init__() super(SRREDSMultipleGTDataset, self).__init__()
self.mode = mode self.mode = mode
self.fileroot = str(lq_folder) self.fileroot = str(lq_folder)
...@@ -69,6 +70,7 @@ class SRREDSMultipleGTDataset(Dataset): ...@@ -69,6 +70,7 @@ class SRREDSMultipleGTDataset(Dataset):
self.scale = scale self.scale = scale
self.val_partition = val_partition self.val_partition = val_partition
self.batch_size = batch_size self.batch_size = batch_size
self.num_clips = num_clips # training num of LQ and GT pairs
self.data_infos = self.load_annotations() self.data_infos = self.load_annotations()
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -93,7 +95,7 @@ class SRREDSMultipleGTDataset(Dataset): ...@@ -93,7 +95,7 @@ class SRREDSMultipleGTDataset(Dataset):
dict: Returned dict for LQ and GT pairs. dict: Returned dict for LQ and GT pairs.
""" """
# generate keys # generate keys
keys = [f'{i:03d}' for i in range(0, 270)] keys = [f'{i:03d}' for i in range(0, self.num_clips)]
if self.val_partition == 'REDS4': if self.val_partition == 'REDS4':
val_partition = ['000', '011', '015', '020'] val_partition = ['000', '011', '015', '020']
...@@ -170,7 +172,9 @@ class SRREDSMultipleGTDataset(Dataset): ...@@ -170,7 +172,9 @@ class SRREDSMultipleGTDataset(Dataset):
gt_list = rlt[number_frames:] gt_list = rlt[number_frames:]
# stack LQ images to NHWC, N is the frame number # stack LQ images to NHWC, N is the frame number
frame_list = [v.transpose(2, 0, 1).astype('float32') for v in frame_list] frame_list = [
v.transpose(2, 0, 1).astype('float32') for v in frame_list
]
gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list] gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list]
img_LQs = np.stack(frame_list, axis=0) img_LQs = np.stack(frame_list, axis=0)
......
...@@ -160,7 +160,9 @@ def flow_warp(x, ...@@ -160,7 +160,9 @@ def flow_warp(x,
Returns: Returns:
Tensor: Warped image or feature map. Tensor: Warped image or feature map.
""" """
if x.shape[-2:] != flow.shape[1:3]: x_h, x_w = x.shape[-2:]
flow_h, flow_w = flow.shape[1:3]
if x_h != flow_h or x_w != flow_w:
raise ValueError(f'The spatial sizes of input ({x.shape[-2:]}) and ' raise ValueError(f'The spatial sizes of input ({x.shape[-2:]}) and '
f'flow ({flow.shape[1:3]}) are not the same.') f'flow ({flow.shape[1:3]}) are not the same.')
_, _, h, w = x.shape _, _, h, w = x.shape
...@@ -293,7 +295,7 @@ class SPyNet(nn.Layer): ...@@ -293,7 +295,7 @@ class SPyNet(nn.Layer):
supp = supp[::-1] supp = supp[::-1]
# flow computation # flow computation
flow = paddle.to_tensor(np.zeros([n, 2, h // 32, w // 32], 'float32')) flow = paddle.zeros([n, 2, h // 32, w // 32])
# level=0 # level=0
flow_up = flow flow_up = flow
...@@ -555,6 +557,7 @@ class BasicVSRNet(nn.Layer): ...@@ -555,6 +557,7 @@ class BasicVSRNet(nn.Layer):
""" """
n, t, c, h, w = lrs.shape n, t, c, h, w = lrs.shape
t = paddle.to_tensor(t)
assert h >= 64 and w >= 64, ( assert h >= 64 and w >= 64, (
'The height and width of inputs should be at least 64, ' 'The height and width of inputs should be at least 64, '
f'but got {h} and {w}.') f'but got {h} and {w}.')
...@@ -567,12 +570,11 @@ class BasicVSRNet(nn.Layer): ...@@ -567,12 +570,11 @@ class BasicVSRNet(nn.Layer):
# backward-time propgation # backward-time propgation
outputs = [] outputs = []
feat_prop = paddle.to_tensor( feat_prop = paddle.zeros([n, self.mid_channels, h, w])
np.zeros([n, self.mid_channels, h, w], 'float32'))
for i in range(t - 1, -1, -1): for i in range(t - 1, -1, -1):
if i < t - 1: # no warping required for the last timestep if i < t - 1: # no warping required for the last timestep
flow = flows_backward[:, i, :, :, :] flow1 = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.transpose([0, 2, 3, 1])) feat_prop = flow_warp(feat_prop, flow1.transpose([0, 2, 3, 1]))
feat_prop = paddle.concat([lrs[:, i, :, :, :], feat_prop], axis=1) feat_prop = paddle.concat([lrs[:, i, :, :, :], feat_prop], axis=1)
feat_prop = self.backward_resblocks(feat_prop) feat_prop = self.backward_resblocks(feat_prop)
...@@ -610,6 +612,7 @@ class BasicVSRNet(nn.Layer): ...@@ -610,6 +612,7 @@ class BasicVSRNet(nn.Layer):
class SecondOrderDeformableAlignment(nn.Layer): class SecondOrderDeformableAlignment(nn.Layer):
"""Second-order deformable alignment module. """Second-order deformable alignment module.
Args: Args:
in_channels (int): Same as nn.Conv2d. in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d. out_channels (int): Same as nn.Conv2d.
......
...@@ -32,9 +32,9 @@ class PixelNorm(nn.Layer): ...@@ -32,9 +32,9 @@ class PixelNorm(nn.Layer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, input): def forward(self, inputs):
return input * paddle.rsqrt( return inputs * paddle.rsqrt(
paddle.mean(input * input, 1, keepdim=True) + 1e-8) paddle.mean(inputs * inputs, 1, keepdim=True) + 1e-8)
class ModulatedConv2D(nn.Layer): class ModulatedConv2D(nn.Layer):
...@@ -93,8 +93,8 @@ class ModulatedConv2D(nn.Layer): ...@@ -93,8 +93,8 @@ class ModulatedConv2D(nn.Layer):
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})") f"upsample={self.upsample}, downsample={self.downsample})")
def forward(self, input, style): def forward(self, inputs, style):
batch, in_channel, height, width = input.shape batch, in_channel, height, width = inputs.shape
style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1)) style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
weight = self.scale * self.weight * style weight = self.scale * self.weight * style
...@@ -107,13 +107,13 @@ class ModulatedConv2D(nn.Layer): ...@@ -107,13 +107,13 @@ class ModulatedConv2D(nn.Layer):
self.kernel_size, self.kernel_size)) self.kernel_size, self.kernel_size))
if self.upsample: if self.upsample:
input = input.reshape((1, batch * in_channel, height, width)) inputs = inputs.reshape((1, batch * in_channel, height, width))
weight = weight.reshape((batch, self.out_channel, in_channel, weight = weight.reshape((batch, self.out_channel, in_channel,
self.kernel_size, self.kernel_size)) self.kernel_size, self.kernel_size))
weight = weight.transpose((0, 2, 1, 3, 4)).reshape( weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
(batch * in_channel, self.out_channel, self.kernel_size, (batch * in_channel, self.out_channel, self.kernel_size,
self.kernel_size)) self.kernel_size))
out = F.conv2d_transpose(input, out = F.conv2d_transpose(inputs,
weight, weight,
padding=0, padding=0,
stride=2, stride=2,
...@@ -123,16 +123,16 @@ class ModulatedConv2D(nn.Layer): ...@@ -123,16 +123,16 @@ class ModulatedConv2D(nn.Layer):
out = self.blur(out) out = self.blur(out)
elif self.downsample: elif self.downsample:
input = self.blur(input) inputs = self.blur(inputs)
_, _, height, width = input.shape _, _, height, width = inputs.shape
input = input.reshape((1, batch * in_channel, height, width)) inputs = inputs.reshape((1, batch * in_channel, height, width))
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) out = F.conv2d(inputs, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape _, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width)) out = out.reshape((batch, self.out_channel, height, width))
else: else:
input = input.reshape((1, batch * in_channel, height, width)) inputs = inputs.reshape((1, batch * in_channel, height, width))
out = F.conv2d(input, weight, padding=self.padding, groups=batch) out = F.conv2d(inputs, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape _, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width)) out = out.reshape((batch, self.out_channel, height, width))
...@@ -165,8 +165,8 @@ class ConstantInput(nn.Layer): ...@@ -165,8 +165,8 @@ class ConstantInput(nn.Layer):
(1, channel, size, size), (1, channel, size, size),
default_initializer=nn.initializer.Normal()) default_initializer=nn.initializer.Normal())
def forward(self, input): def forward(self, inputs):
batch = input.shape[0] batch = inputs.shape[0]
out = self.input.tile((batch, 1, 1, 1)) out = self.input.tile((batch, 1, 1, 1))
return out return out
...@@ -198,8 +198,8 @@ class StyledConv(nn.Layer): ...@@ -198,8 +198,8 @@ class StyledConv(nn.Layer):
self.activate = FusedLeakyReLU(out_channel * self.activate = FusedLeakyReLU(out_channel *
2 if is_concat else out_channel) 2 if is_concat else out_channel)
def forward(self, input, style, noise=None): def forward(self, inputs, style, noise=None):
out = self.conv(input, style) out = self.conv(inputs, style)
out = self.noise(out, noise=noise) out = self.noise(out, noise=noise)
out = self.activate(out) out = self.activate(out)
...@@ -225,8 +225,8 @@ class ToRGB(nn.Layer): ...@@ -225,8 +225,8 @@ class ToRGB(nn.Layer):
self.bias = self.create_parameter((1, 3, 1, 1), self.bias = self.create_parameter((1, 3, 1, 1),
nn.initializer.Constant(0.0)) nn.initializer.Constant(0.0))
def forward(self, input, style, skip=None): def forward(self, inputs, style, skip=None):
out = self.conv(input, style) out = self.conv(inputs, style)
out = out + self.bias out = out + self.bias
if skip is not None: if skip is not None:
...@@ -349,15 +349,28 @@ class StyleGANv2Generator(nn.Layer): ...@@ -349,15 +349,28 @@ class StyleGANv2Generator(nn.Layer):
return latent return latent
def get_latent(self, input): def get_latent(self, inputs):
return self.style(input) return self.style(inputs)
def get_mean_style(self):
mean_style = None
with paddle.no_grad():
for i in range(10):
style = self.mean_latent(1024)
if mean_style is None:
mean_style = style
else:
mean_style += style
mean_style /= 10
return mean_style
def forward( def forward(
self, self,
styles, styles,
return_latents=False, return_latents=False,
inject_index=None, inject_index=None,
truncation=1, truncation=1.0,
truncation_latent=None, truncation_latent=None,
input_is_latent=False, input_is_latent=False,
noise=None, noise=None,
...@@ -375,9 +388,10 @@ class StyleGANv2Generator(nn.Layer): ...@@ -375,9 +388,10 @@ class StyleGANv2Generator(nn.Layer):
for i in range(self.num_layers) for i in range(self.num_layers)
] ]
if truncation < 1: if truncation < 1.0:
style_t = [] style_t = []
if truncation_latent is None:
truncation_latent = self.get_mean_style()
for style in styles: for style in styles:
style_t.append(truncation_latent + truncation * style_t.append(truncation_latent + truncation *
(style - truncation_latent)) (style - truncation_latent))
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import math import math
import random import random
import paddle import paddle
...@@ -24,6 +25,7 @@ from .discriminators.builder import build_discriminator ...@@ -24,6 +25,7 @@ from .discriminators.builder import build_discriminator
from ..solver import build_lr_scheduler, build_optimizer from ..solver import build_lr_scheduler, build_optimizer
def r1_penalty(real_pred, real_img): def r1_penalty(real_pred, real_img):
""" """
R1 regularization for discriminator. The core idea is to R1 regularization for discriminator. The core idea is to
...@@ -195,7 +197,6 @@ class StyleGAN2Model(BaseModel): ...@@ -195,7 +197,6 @@ class StyleGAN2Model(BaseModel):
noises = [] noises = []
for _ in range(num_noise): for _ in range(num_noise):
noises.append(paddle.randn([batch, self.num_style_feat])) noises.append(paddle.randn([batch, self.num_style_feat]))
return noises return noises
def mixing_noise(self, batch, prob): def mixing_noise(self, batch, prob):
...@@ -294,3 +295,25 @@ class StyleGAN2Model(BaseModel): ...@@ -294,3 +295,25 @@ class StyleGAN2Model(BaseModel):
metric.update(fake_img, self.real_img) metric.update(fake_img, self.real_img)
self.nets['gen_ema'].train() self.nets['gen_ema'].train()
class InferGenerator(paddle.nn.Layer):
def set_generator(self, generator):
self.generator = generator
def forward(self, style, truncation):
truncation_latent = self.generator.get_mean_style()
out = self.generator(styles=style,
truncation=truncation,
truncation_latent=truncation_latent)
return out[0]
def export_model(self,
export_model=None,
output_dir=None,
inputs_size=[[1, 1, 512], [1, 1]]):
infer_generator = self.InferGenerator()
infer_generator.set_generator(self.nets['gen'])
style = paddle.rand(shape=inputs_size[0], dtype='float32')
truncation = paddle.rand(shape=inputs_size[1], dtype='float32')
paddle.jit.save(infer_generator,
os.path.join(output_dir, "stylegan2model_gen"),
input_spec=[style, truncation])
...@@ -9,6 +9,9 @@ import paddle.fluid as fluid ...@@ -9,6 +9,9 @@ import paddle.fluid as fluid
import os import os
from functools import reduce from functools import reduce
import paddle import paddle
from ppgan.utils.filesystem import makedirs
from pathlib import Path
def read_img(path): def read_img(path):
img = imageio.imread(path) img = imageio.imread(path)
...@@ -16,9 +19,10 @@ def read_img(path): ...@@ -16,9 +19,10 @@ def read_img(path):
img = np.expand_dims(img, axis=2) img = np.expand_dims(img, axis=2)
# som images have 4 channels # som images have 4 channels
if img.shape[2] > 3: if img.shape[2] > 3:
img = img[:,:,:3] img = img[:, :, :3]
return img return img
def read_video(path): def read_video(path):
reader = imageio.get_reader(path) reader = imageio.get_reader(path)
fps = reader.get_meta_data()['fps'] fps = reader.get_meta_data()['fps']
...@@ -32,9 +36,9 @@ def read_video(path): ...@@ -32,9 +36,9 @@ def read_video(path):
reader.close() reader.close()
return driving_video, fps return driving_video, fps
def face_detection(img_ori, weight_path): def face_detection(img_ori, weight_path):
config = paddle_infer.Config( config = paddle_infer.Config(os.path.join(weight_path, '__model__'),
os.path.join(weight_path, '__model__'),
os.path.join(weight_path, '__params__')) os.path.join(weight_path, '__params__'))
config.disable_gpu() config.disable_gpu()
# disable print log when predict # disable print log when predict
...@@ -47,7 +51,8 @@ def face_detection(img_ori, weight_path): ...@@ -47,7 +51,8 @@ def face_detection(img_ori, weight_path):
img = img_ori.astype(np.float32) img = img_ori.astype(np.float32)
mean = np.array([123, 117, 104])[np.newaxis, np.newaxis, :] mean = np.array([123, 117, 104])[np.newaxis, np.newaxis, :]
std = np.array([127.502231, 127.502231, 127.502231])[np.newaxis, np.newaxis, :] std = np.array([127.502231, 127.502231, 127.502231])[np.newaxis,
np.newaxis, :]
img -= mean img -= mean
img /= std img /= std
img = img[:, :, [2, 1, 0]] img = img[:, :, [2, 1, 0]]
...@@ -82,31 +87,28 @@ def face_detection(img_ori, weight_path): ...@@ -82,31 +87,28 @@ def face_detection(img_ori, weight_path):
return int(y1), int(y2), int(x1), int(x2) return int(y1), int(y2), int(x1), int(x2)
def main(): def main():
args = parse_args() args = parse_args()
source_path = args.source_path source_path = args.source_path
driving_path = args.driving_path driving_path = Path(args.driving_path)
source_img = read_img(source_path) makedirs(args.output_path)
if driving_path.is_dir():
#Todo:add blazeface static model driving_paths = list(driving_path.iterdir())
#left, right, up, bottom = face_detection(source_img, "/workspace/PaddleDetection/static/inference_model/blazeface/") else:
source = source_img #[left:right, up:bottom] driving_paths = [driving_path]
source = cv2.resize(source, (256, 256)) / 255.0
source = source[np.newaxis].astype(np.float32).transpose([0, 3, 1, 2])
driving_video, fps = read_video(driving_path)
driving_video = [cv2.resize(frame, (256, 256)) / 255.0 for frame in driving_video]
driving_len = len(driving_video)
driving_video = np.array(driving_video).astype(np.float32).transpose([0, 3, 1, 2])
# 创建 config # 创建 config
kp_detector_config = paddle_infer.Config(args.model_profix+"/kp_detector.pdmodel", args.model_profix+"/kp_detector.pdiparams") kp_detector_config = paddle_infer.Config(os.path.join(
generator_config = paddle_infer.Config(args.model_profix+"/generator.pdmodel", args.model_profix+"/generator.pdiparams") args.model_path, "/kp_detector.pdmodel"),
os.path.join(args.model_path, "/kp_detector.pdiparams"))
generator_config = paddle_infer.Config(os.path.join(
args.model_path, "/generator.pdmodel"),
os.path.join(args.model_path, "/generator.pdiparams"))
if args.device == "gpu":
kp_detector_config.enable_use_gpu(100, 0)
generator_config.enable_use_gpu(100, 0)
else:
kp_detector_config.set_mkldnn_cache_capacity(10) kp_detector_config.set_mkldnn_cache_capacity(10)
kp_detector_config.enable_mkldnn() kp_detector_config.enable_mkldnn()
generator_config.set_mkldnn_cache_capacity(10) generator_config.set_mkldnn_cache_capacity(10)
...@@ -115,50 +117,82 @@ def main(): ...@@ -115,50 +117,82 @@ def main():
kp_detector_config.set_cpu_math_library_num_threads(6) kp_detector_config.set_cpu_math_library_num_threads(6)
generator_config.disable_gpu() generator_config.disable_gpu()
generator_config.set_cpu_math_library_num_threads(6) generator_config.set_cpu_math_library_num_threads(6)
# 根据 config 创建 predictor # 根据 config 创建 predictor
kp_detector_predictor = paddle_infer.create_predictor(kp_detector_config) kp_detector_predictor = paddle_infer.create_predictor(kp_detector_config)
generator_predictor = paddle_infer.create_predictor(generator_config) generator_predictor = paddle_infer.create_predictor(generator_config)
for k in range(len(driving_paths)):
driving_path = driving_paths[k]
driving_video, fps = read_video(driving_path)
driving_video = [
cv2.resize(frame, (256, 256)) / 255.0 for frame in driving_video
]
driving_len = len(driving_video)
driving_video = np.array(driving_video).astype(np.float32).transpose(
[0, 3, 1, 2])
if source_path == None:
source = driving_video[0:1]
else:
source_img = read_img(source_path)
#Todo:add blazeface static model
#left, right, up, bottom = face_detection(source_img, "/workspace/PaddleDetection/static/inference_model/blazeface/")
source = source_img #[left:right, up:bottom]
source = cv2.resize(source, (256, 256)) / 255.0
source = source[np.newaxis].astype(np.float32).transpose(
[0, 3, 1, 2])
# 获取输入的名称 # 获取输入的名称
kp_detector_input_names = kp_detector_predictor.get_input_names() kp_detector_input_names = kp_detector_predictor.get_input_names()
kp_detector_input_handle = kp_detector_predictor.get_input_handle(kp_detector_input_names[0]) kp_detector_input_handle = kp_detector_predictor.get_input_handle(
kp_detector_input_names[0])
kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256]) kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256])
kp_detector_input_handle.copy_from_cpu(source) kp_detector_input_handle.copy_from_cpu(source)
kp_detector_predictor.run() kp_detector_predictor.run()
kp_detector_output_names = kp_detector_predictor.get_output_names() kp_detector_output_names = kp_detector_predictor.get_output_names()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[0]) kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[0])
source_j = kp_detector_output_handle.copy_to_cpu() source_j = kp_detector_output_handle.copy_to_cpu()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[1]) kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[1])
source_v = kp_detector_output_handle.copy_to_cpu() source_v = kp_detector_output_handle.copy_to_cpu()
kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256]) kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256])
kp_detector_input_handle.copy_from_cpu(driving_video[0:1]) kp_detector_input_handle.copy_from_cpu(driving_video[0:1])
kp_detector_predictor.run() kp_detector_predictor.run()
kp_detector_output_names = kp_detector_predictor.get_output_names() kp_detector_output_names = kp_detector_predictor.get_output_names()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[0]) kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[0])
driving_init_j = kp_detector_output_handle.copy_to_cpu() driving_init_j = kp_detector_output_handle.copy_to_cpu()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[1]) kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[1])
driving_init_v = kp_detector_output_handle.copy_to_cpu() driving_init_v = kp_detector_output_handle.copy_to_cpu()
start_time = time.time() start_time = time.time()
results = [] results = []
for i in tqdm(range(0, driving_len)): for i in tqdm(range(0, driving_len)):
kp_detector_input_handle.copy_from_cpu(driving_video[i:i+1]) kp_detector_input_handle.copy_from_cpu(driving_video[i:i + 1])
kp_detector_predictor.run() kp_detector_predictor.run()
kp_detector_output_names = kp_detector_predictor.get_output_names() kp_detector_output_names = kp_detector_predictor.get_output_names()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[0]) kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[0])
driving_j = kp_detector_output_handle.copy_to_cpu() driving_j = kp_detector_output_handle.copy_to_cpu()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[1]) kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[1])
driving_v = kp_detector_output_handle.copy_to_cpu() driving_v = kp_detector_output_handle.copy_to_cpu()
generator_inputs = [source, source_j, source_v, driving_j, driving_v, driving_init_j, driving_init_v] generator_inputs = [
source, source_j, source_v, driving_j, driving_v,
driving_init_j, driving_init_v
]
generator_input_names = generator_predictor.get_input_names() generator_input_names = generator_predictor.get_input_names()
for i in range(len(generator_input_names)): for i in range(len(generator_input_names)):
generator_input_handle = generator_predictor.get_input_handle(generator_input_names[i]) generator_input_handle = generator_predictor.get_input_handle(
generator_input_names[i])
generator_input_handle.copy_from_cpu(generator_inputs[i]) generator_input_handle.copy_from_cpu(generator_inputs[i])
generator_predictor.run() generator_predictor.run()
generator_output_names = generator_predictor.get_output_names() generator_output_names = generator_predictor.get_output_names()
generator_output_handle = generator_predictor.get_output_handle(generator_output_names[0]) generator_output_handle = generator_predictor.get_output_handle(
generator_output_names[0])
output_data = generator_output_handle.copy_to_cpu() output_data = generator_output_handle.copy_to_cpu()
output_data = np.transpose(output_data, [0, 2, 3, 1])[0] * 255.0 output_data = np.transpose(output_data, [0, 2, 3, 1])[0] * 255.0
...@@ -167,16 +201,32 @@ def main(): ...@@ -167,16 +201,32 @@ def main():
#frame[left:right, up:bottom] = cv2.resize(output_data.astype(np.uint8), (bottom - up, right - left), cv2.INTER_AREA) #frame[left:right, up:bottom] = cv2.resize(output_data.astype(np.uint8), (bottom - up, right - left), cv2.INTER_AREA)
results.append(output_data.astype(np.uint8)) results.append(output_data.astype(np.uint8))
print(time.time() - start_time) print(time.time() - start_time)
imageio.mimsave(args.output_path, [frame for frame in results], fps=fps) imageio.mimsave(os.path.join(args.output_path,
"result_" + str(k) + ".mp4"),
[frame for frame in results],
fps=fps)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_profix", type=str, help="model filename profix") parser.add_argument("--model_path", type=str, help="model filename profix")
parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--source_path", type=str, default=1, help="source_path") parser.add_argument("--source_path",
parser.add_argument("--driving_path", type=str, default=1, help="driving_path") type=str,
parser.add_argument("--output_path", type=str, default=1, help="output_path") default=None,
help="source_path")
parser.add_argument("--driving_path",
type=str,
default=None,
help="driving_path")
parser.add_argument("--output_path",
type=str,
default="infer_output/fom/",
help="output_path")
parser.add_argument("--device", type=str, default="gpu", help="device")
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -19,13 +19,14 @@ def parse_args(): ...@@ -19,13 +19,14 @@ def parse_args():
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="The path prefix of inference model to be used.", ) help="The path prefix of inference model to be used.",
parser.add_argument( )
"--model_type", parser.add_argument("--model_type",
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES)) help="Model type selected in the list: " +
", ".join(MODEL_CLASSES))
parser.add_argument( parser.add_argument(
"--device", "--device",
default="gpu", default="gpu",
...@@ -65,12 +66,14 @@ def main(): ...@@ -65,12 +66,14 @@ def main():
args = parse_args() args = parse_args()
cfg = get_config(args.config_file, args.opt) cfg = get_config(args.config_file, args.opt)
predictor = create_predictor(args.model_path, args.device) predictor = create_predictor(args.model_path, args.device)
input_handles = [predictor.get_input_handle( input_handles = [
name) for name in predictor.get_input_names()] predictor.get_input_handle(name)
output_handle = predictor.get_output_handle( for name in predictor.get_input_names()
predictor.get_output_names()[0]) ]
test_dataloader = build_dataloader( output_handle = predictor.get_output_handle(predictor.get_output_names()[0])
cfg.dataset.test, is_train=False, distributed=False) test_dataloader = build_dataloader(cfg.dataset.test,
is_train=False,
distributed=False)
max_eval_steps = len(test_dataloader) max_eval_steps = len(test_dataloader)
iter_loader = IterLoader(test_dataloader) iter_loader = IterLoader(test_dataloader)
...@@ -110,8 +113,8 @@ def main(): ...@@ -110,8 +113,8 @@ def main():
prediction[j] = prediction[j][::-1, :, :] prediction[j] = prediction[j][::-1, :, :]
image_numpy = paddle.to_tensor(prediction[j]) image_numpy = paddle.to_tensor(prediction[j])
image_numpy = tensor2img(image_numpy, (0, 1)) image_numpy = tensor2img(image_numpy, (0, 1))
save_image( save_image(image_numpy,
image_numpy, "infer_output/wav2lip/{}_{}.png".format(i, j)) "infer_output/wav2lip/{}_{}.png".format(i, j))
elif model_type == "esrgan": elif model_type == "esrgan":
lq = data['lq'].numpy() lq = data['lq'].numpy()
input_handles[0].copy_from_cpu(lq) input_handles[0].copy_from_cpu(lq)
...@@ -128,6 +131,23 @@ def main(): ...@@ -128,6 +131,23 @@ def main():
prediction = paddle.to_tensor(prediction[0]) prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max) image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/edvr/{}.png".format(i)) save_image(image_numpy, "infer_output/edvr/{}.png".format(i))
elif model_type == "stylegan2":
noise = paddle.randn([1, 1, 512]).cpu().numpy()
input_handles[0].copy_from_cpu(noise)
input_handles[1].copy_from_cpu(np.array([0.7]).astype('float32'))
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/stylegan2/{}.png".format(i))
elif model_type == "basicvsr":
lq = data['lq'].numpy()
input_handles[0].copy_from_cpu(lq)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/basicvsr/{}.png".format(i))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册