diff --git a/benchmark/benchmark.yaml b/benchmark/benchmark.yaml index 84274e1010345dcc92951f0bd39e64885bd00eee..6df2140b5944d705b1556401363c10dedc9f6ffb 100644 --- a/benchmark/benchmark.yaml +++ b/benchmark/benchmark.yaml @@ -12,7 +12,7 @@ FOMM: fp_item: fp32 bs_item: 8 16 epochs: 1 - log_interval: 11 + log_interval: 1 esrgan: dataset_web: https://paddlegan.bj.bcebos.com/datasets/DIV2KandSet14.tar diff --git a/configs/basicvsr_reds.yaml b/configs/basicvsr_reds.yaml index 566a4b26e3258da01438cab532fcf2f3a950769b..9e034f009708edc4b5f42eb22a148b41899cc650 100644 --- a/configs/basicvsr_reds.yaml +++ b/configs/basicvsr_reds.yaml @@ -38,6 +38,7 @@ dataset: use_rot: True scale: 4 val_partition: REDS4 + num_clips: 270 test: name: SRREDSMultipleGTDataset @@ -90,3 +91,6 @@ log_config: snapshot_config: interval: 5000 + +export_model: + - {name: 'generator', inputs_num: 1} diff --git a/configs/cyclegan_horse2zebra.yaml b/configs/cyclegan_horse2zebra.yaml index 77a28610e1b46b6529380770b08ecc897839d8eb..f5c80547d60ea8949b30b8c3bb49d359c7ca3f5b 100644 --- a/configs/cyclegan_horse2zebra.yaml +++ b/configs/cyclegan_horse2zebra.yaml @@ -119,3 +119,7 @@ log_config: snapshot_config: interval: 5 + +export_model: + - {name: 'netG_A', inputs_num: 1} + - {name: 'netG_B', inputs_num: 1} diff --git a/configs/firstorder_vox_256.yaml b/configs/firstorder_vox_256.yaml index 4fd12581d62b6d8b3f097d0cf189b3f48f2582dc..dc9be729bc59587494e85f92de850703f51f804d 100755 --- a/configs/firstorder_vox_256.yaml +++ b/configs/firstorder_vox_256.yaml @@ -123,3 +123,6 @@ snapshot_config: optimizer: name: Adam + +export_model: + - {} diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index 1654ac213325bad448f98e35f17bbb485cbf7588..ffe127f41097932b7947bd9a356b3211bb246a13 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -115,3 +115,6 @@ validate: fid: # metric name, can be arbitrary name: FID batch_size: 8 + +export_model: + - {name: 'netG', inputs_num: 1} diff --git a/ppgan/datasets/sr_reds_multiple_gt_dataset.py b/ppgan/datasets/sr_reds_multiple_gt_dataset.py index bec816052fbfd5e789d6fb487119f9d9712988ad..cf5940930efb2752b2a0a849188efa7cbce8f011 100644 --- a/ppgan/datasets/sr_reds_multiple_gt_dataset.py +++ b/ppgan/datasets/sr_reds_multiple_gt_dataset.py @@ -55,7 +55,8 @@ class SRREDSMultipleGTDataset(Dataset): use_rot=False, scale=4, val_partition='REDS4', - batch_size=4): + batch_size=4, + num_clips=270): super(SRREDSMultipleGTDataset, self).__init__() self.mode = mode self.fileroot = str(lq_folder) @@ -69,6 +70,7 @@ class SRREDSMultipleGTDataset(Dataset): self.scale = scale self.val_partition = val_partition self.batch_size = batch_size + self.num_clips = num_clips # training num of LQ and GT pairs self.data_infos = self.load_annotations() def __getitem__(self, idx): @@ -93,7 +95,7 @@ class SRREDSMultipleGTDataset(Dataset): dict: Returned dict for LQ and GT pairs. """ # generate keys - keys = [f'{i:03d}' for i in range(0, 270)] + keys = [f'{i:03d}' for i in range(0, self.num_clips)] if self.val_partition == 'REDS4': val_partition = ['000', '011', '015', '020'] @@ -170,7 +172,9 @@ class SRREDSMultipleGTDataset(Dataset): gt_list = rlt[number_frames:] # stack LQ images to NHWC, N is the frame number - frame_list = [v.transpose(2, 0, 1).astype('float32') for v in frame_list] + frame_list = [ + v.transpose(2, 0, 1).astype('float32') for v in frame_list + ] gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list] img_LQs = np.stack(frame_list, axis=0) diff --git a/ppgan/models/generators/basicvsr.py b/ppgan/models/generators/basicvsr.py index 48bc7cc54c37d5f8645c52f60cf6aa800c6eec9b..d7ccbc8b427c7e4f0b7829e9550e0939660a2854 100644 --- a/ppgan/models/generators/basicvsr.py +++ b/ppgan/models/generators/basicvsr.py @@ -160,7 +160,9 @@ def flow_warp(x, Returns: Tensor: Warped image or feature map. """ - if x.shape[-2:] != flow.shape[1:3]: + x_h, x_w = x.shape[-2:] + flow_h, flow_w = flow.shape[1:3] + if x_h != flow_h or x_w != flow_w: raise ValueError(f'The spatial sizes of input ({x.shape[-2:]}) and ' f'flow ({flow.shape[1:3]}) are not the same.') _, _, h, w = x.shape @@ -293,7 +295,7 @@ class SPyNet(nn.Layer): supp = supp[::-1] # flow computation - flow = paddle.to_tensor(np.zeros([n, 2, h // 32, w // 32], 'float32')) + flow = paddle.zeros([n, 2, h // 32, w // 32]) # level=0 flow_up = flow @@ -555,6 +557,7 @@ class BasicVSRNet(nn.Layer): """ n, t, c, h, w = lrs.shape + t = paddle.to_tensor(t) assert h >= 64 and w >= 64, ( 'The height and width of inputs should be at least 64, ' f'but got {h} and {w}.') @@ -567,19 +570,18 @@ class BasicVSRNet(nn.Layer): # backward-time propgation outputs = [] - feat_prop = paddle.to_tensor( - np.zeros([n, self.mid_channels, h, w], 'float32')) + feat_prop = paddle.zeros([n, self.mid_channels, h, w]) for i in range(t - 1, -1, -1): if i < t - 1: # no warping required for the last timestep - flow = flows_backward[:, i, :, :, :] - feat_prop = flow_warp(feat_prop, flow.transpose([0, 2, 3, 1])) + flow1 = flows_backward[:, i, :, :, :] + feat_prop = flow_warp(feat_prop, flow1.transpose([0, 2, 3, 1])) feat_prop = paddle.concat([lrs[:, i, :, :, :], feat_prop], axis=1) feat_prop = self.backward_resblocks(feat_prop) outputs.append(feat_prop) outputs = outputs[::-1] - + # forward-time propagation and upsampling feat_prop = paddle.zeros_like(feat_prop) for i in range(0, t): @@ -610,6 +612,7 @@ class BasicVSRNet(nn.Layer): class SecondOrderDeformableAlignment(nn.Layer): """Second-order deformable alignment module. + Args: in_channels (int): Same as nn.Conv2d. out_channels (int): Same as nn.Conv2d. diff --git a/ppgan/models/generators/generator_styleganv2.py b/ppgan/models/generators/generator_styleganv2.py index 6297a3ea8f949dc806bae25aef8984c983d096bb..64c2c09335d049baa0becac2885879a896dff2b1 100644 --- a/ppgan/models/generators/generator_styleganv2.py +++ b/ppgan/models/generators/generator_styleganv2.py @@ -32,9 +32,9 @@ class PixelNorm(nn.Layer): def __init__(self): super().__init__() - def forward(self, input): - return input * paddle.rsqrt( - paddle.mean(input * input, 1, keepdim=True) + 1e-8) + def forward(self, inputs): + return inputs * paddle.rsqrt( + paddle.mean(inputs * inputs, 1, keepdim=True) + 1e-8) class ModulatedConv2D(nn.Layer): @@ -93,8 +93,8 @@ class ModulatedConv2D(nn.Layer): f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " f"upsample={self.upsample}, downsample={self.downsample})") - def forward(self, input, style): - batch, in_channel, height, width = input.shape + def forward(self, inputs, style): + batch, in_channel, height, width = inputs.shape style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1)) weight = self.scale * self.weight * style @@ -107,13 +107,13 @@ class ModulatedConv2D(nn.Layer): self.kernel_size, self.kernel_size)) if self.upsample: - input = input.reshape((1, batch * in_channel, height, width)) + inputs = inputs.reshape((1, batch * in_channel, height, width)) weight = weight.reshape((batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size)) weight = weight.transpose((0, 2, 1, 3, 4)).reshape( (batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size)) - out = F.conv2d_transpose(input, + out = F.conv2d_transpose(inputs, weight, padding=0, stride=2, @@ -123,16 +123,16 @@ class ModulatedConv2D(nn.Layer): out = self.blur(out) elif self.downsample: - input = self.blur(input) - _, _, height, width = input.shape - input = input.reshape((1, batch * in_channel, height, width)) - out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + inputs = self.blur(inputs) + _, _, height, width = inputs.shape + inputs = inputs.reshape((1, batch * in_channel, height, width)) + out = F.conv2d(inputs, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.reshape((batch, self.out_channel, height, width)) else: - input = input.reshape((1, batch * in_channel, height, width)) - out = F.conv2d(input, weight, padding=self.padding, groups=batch) + inputs = inputs.reshape((1, batch * in_channel, height, width)) + out = F.conv2d(inputs, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.reshape((batch, self.out_channel, height, width)) @@ -165,8 +165,8 @@ class ConstantInput(nn.Layer): (1, channel, size, size), default_initializer=nn.initializer.Normal()) - def forward(self, input): - batch = input.shape[0] + def forward(self, inputs): + batch = inputs.shape[0] out = self.input.tile((batch, 1, 1, 1)) return out @@ -198,8 +198,8 @@ class StyledConv(nn.Layer): self.activate = FusedLeakyReLU(out_channel * 2 if is_concat else out_channel) - def forward(self, input, style, noise=None): - out = self.conv(input, style) + def forward(self, inputs, style, noise=None): + out = self.conv(inputs, style) out = self.noise(out, noise=noise) out = self.activate(out) @@ -225,8 +225,8 @@ class ToRGB(nn.Layer): self.bias = self.create_parameter((1, 3, 1, 1), nn.initializer.Constant(0.0)) - def forward(self, input, style, skip=None): - out = self.conv(input, style) + def forward(self, inputs, style, skip=None): + out = self.conv(inputs, style) out = out + self.bias if skip is not None: @@ -349,15 +349,28 @@ class StyleGANv2Generator(nn.Layer): return latent - def get_latent(self, input): - return self.style(input) + def get_latent(self, inputs): + return self.style(inputs) + + def get_mean_style(self): + mean_style = None + with paddle.no_grad(): + for i in range(10): + style = self.mean_latent(1024) + if mean_style is None: + mean_style = style + else: + mean_style += style + + mean_style /= 10 + return mean_style def forward( self, styles, return_latents=False, inject_index=None, - truncation=1, + truncation=1.0, truncation_latent=None, input_is_latent=False, noise=None, @@ -375,9 +388,10 @@ class StyleGANv2Generator(nn.Layer): for i in range(self.num_layers) ] - if truncation < 1: + if truncation < 1.0: style_t = [] - + if truncation_latent is None: + truncation_latent = self.get_mean_style() for style in styles: style_t.append(truncation_latent + truncation * (style - truncation_latent)) diff --git a/ppgan/models/styleganv2_model.py b/ppgan/models/styleganv2_model.py index e41c55216182cee7cba216050a94fa4493b24e3e..73ac5c9e0bb8be3171db48eca265d2b6682007ea 100644 --- a/ppgan/models/styleganv2_model.py +++ b/ppgan/models/styleganv2_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import math import random import paddle @@ -24,6 +25,7 @@ from .discriminators.builder import build_discriminator from ..solver import build_lr_scheduler, build_optimizer + def r1_penalty(real_pred, real_img): """ R1 regularization for discriminator. The core idea is to @@ -195,7 +197,6 @@ class StyleGAN2Model(BaseModel): noises = [] for _ in range(num_noise): noises.append(paddle.randn([batch, self.num_style_feat])) - return noises def mixing_noise(self, batch, prob): @@ -294,3 +295,25 @@ class StyleGAN2Model(BaseModel): metric.update(fake_img, self.real_img) self.nets['gen_ema'].train() + class InferGenerator(paddle.nn.Layer): + def set_generator(self, generator): + self.generator = generator + + def forward(self, style, truncation): + truncation_latent = self.generator.get_mean_style() + out = self.generator(styles=style, + truncation=truncation, + truncation_latent=truncation_latent) + return out[0] + + def export_model(self, + export_model=None, + output_dir=None, + inputs_size=[[1, 1, 512], [1, 1]]): + infer_generator = self.InferGenerator() + infer_generator.set_generator(self.nets['gen']) + style = paddle.rand(shape=inputs_size[0], dtype='float32') + truncation = paddle.rand(shape=inputs_size[1], dtype='float32') + paddle.jit.save(infer_generator, + os.path.join(output_dir, "stylegan2model_gen"), + input_spec=[style, truncation]) diff --git a/tools/fom_infer.py b/tools/fom_infer.py index 4896b71e1c766633a3e41455eefd4ae5a425343b..73fd7ad9efd6c289778fb48cc2d2e5104d793c33 100644 --- a/tools/fom_infer.py +++ b/tools/fom_infer.py @@ -9,15 +9,19 @@ import paddle.fluid as fluid import os from functools import reduce import paddle +from ppgan.utils.filesystem import makedirs +from pathlib import Path + def read_img(path): - img = imageio.imread(path) - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - # som images have 4 channels - if img.shape[2] > 3: - img = img[:,:,:3] - return img + img = imageio.imread(path) + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # som images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + def read_video(path): reader = imageio.get_reader(path) @@ -32,10 +36,10 @@ def read_video(path): reader.close() return driving_video, fps + def face_detection(img_ori, weight_path): - config = paddle_infer.Config( - os.path.join(weight_path, '__model__'), - os.path.join(weight_path, '__params__')) + config = paddle_infer.Config(os.path.join(weight_path, '__model__'), + os.path.join(weight_path, '__params__')) config.disable_gpu() # disable print log when predict config.disable_glog_info() @@ -44,10 +48,11 @@ def face_detection(img_ori, weight_path): # disable feed, fetch OP, needed by zero_copy_run config.switch_use_feed_fetch_ops(False) predictor = paddle_infer.create_predictor(config) - + img = img_ori.astype(np.float32) mean = np.array([123, 117, 104])[np.newaxis, np.newaxis, :] - std = np.array([127.502231, 127.502231, 127.502231])[np.newaxis, np.newaxis, :] + std = np.array([127.502231, 127.502231, 127.502231])[np.newaxis, + np.newaxis, :] img -= mean img /= std img = img[:, :, [2, 1, 0]] @@ -82,101 +87,146 @@ def face_detection(img_ori, weight_path): return int(y1), int(y2), int(x1), int(x2) - def main(): args = parse_args() - - source_path = args.source_path - driving_path = args.driving_path - source_img = read_img(source_path) - - #Todo:add blazeface static model - #left, right, up, bottom = face_detection(source_img, "/workspace/PaddleDetection/static/inference_model/blazeface/") - source = source_img #[left:right, up:bottom] - - source = cv2.resize(source, (256, 256)) / 255.0 - source = source[np.newaxis].astype(np.float32).transpose([0, 3, 1, 2]) - - - driving_video, fps = read_video(driving_path) - driving_video = [cv2.resize(frame, (256, 256)) / 255.0 for frame in driving_video] - driving_len = len(driving_video) - driving_video = np.array(driving_video).astype(np.float32).transpose([0, 3, 1, 2]) + source_path = args.source_path + driving_path = Path(args.driving_path) + makedirs(args.output_path) + if driving_path.is_dir(): + driving_paths = list(driving_path.iterdir()) + else: + driving_paths = [driving_path] # 创建 config - kp_detector_config = paddle_infer.Config(args.model_profix+"/kp_detector.pdmodel", args.model_profix+"/kp_detector.pdiparams") - generator_config = paddle_infer.Config(args.model_profix+"/generator.pdmodel", args.model_profix+"/generator.pdiparams") - kp_detector_config.set_mkldnn_cache_capacity(10) - kp_detector_config.enable_mkldnn() - generator_config.set_mkldnn_cache_capacity(10) - generator_config.enable_mkldnn() - kp_detector_config.disable_gpu() - kp_detector_config.set_cpu_math_library_num_threads(6) - generator_config.disable_gpu() - generator_config.set_cpu_math_library_num_threads(6) - + kp_detector_config = paddle_infer.Config(os.path.join( + args.model_path, "/kp_detector.pdmodel"), + os.path.join(args.model_path, "/kp_detector.pdiparams")) + generator_config = paddle_infer.Config(os.path.join( + args.model_path, "/generator.pdmodel"), + os.path.join(args.model_path, "/generator.pdiparams")) + if args.device == "gpu": + kp_detector_config.enable_use_gpu(100, 0) + generator_config.enable_use_gpu(100, 0) + else: + kp_detector_config.set_mkldnn_cache_capacity(10) + kp_detector_config.enable_mkldnn() + generator_config.set_mkldnn_cache_capacity(10) + generator_config.enable_mkldnn() + kp_detector_config.disable_gpu() + kp_detector_config.set_cpu_math_library_num_threads(6) + generator_config.disable_gpu() + generator_config.set_cpu_math_library_num_threads(6) # 根据 config 创建 predictor kp_detector_predictor = paddle_infer.create_predictor(kp_detector_config) generator_predictor = paddle_infer.create_predictor(generator_config) - # 获取输入的名称 - kp_detector_input_names = kp_detector_predictor.get_input_names() - kp_detector_input_handle = kp_detector_predictor.get_input_handle(kp_detector_input_names[0]) - - kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256]) - kp_detector_input_handle.copy_from_cpu(source) - kp_detector_predictor.run() - kp_detector_output_names = kp_detector_predictor.get_output_names() - kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[0]) - source_j = kp_detector_output_handle.copy_to_cpu() - kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[1]) - source_v = kp_detector_output_handle.copy_to_cpu() - - kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256]) - kp_detector_input_handle.copy_from_cpu(driving_video[0:1]) - kp_detector_predictor.run() - kp_detector_output_names = kp_detector_predictor.get_output_names() - kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[0]) - driving_init_j = kp_detector_output_handle.copy_to_cpu() - kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[1]) - driving_init_v = kp_detector_output_handle.copy_to_cpu() - start_time = time.time() - results = [] - for i in tqdm(range(0, driving_len)): - kp_detector_input_handle.copy_from_cpu(driving_video[i:i+1]) + for k in range(len(driving_paths)): + driving_path = driving_paths[k] + driving_video, fps = read_video(driving_path) + driving_video = [ + cv2.resize(frame, (256, 256)) / 255.0 for frame in driving_video + ] + driving_len = len(driving_video) + driving_video = np.array(driving_video).astype(np.float32).transpose( + [0, 3, 1, 2]) + + if source_path == None: + source = driving_video[0:1] + else: + source_img = read_img(source_path) + #Todo:add blazeface static model + #left, right, up, bottom = face_detection(source_img, "/workspace/PaddleDetection/static/inference_model/blazeface/") + source = source_img #[left:right, up:bottom] + source = cv2.resize(source, (256, 256)) / 255.0 + source = source[np.newaxis].astype(np.float32).transpose( + [0, 3, 1, 2]) + + # 获取输入的名称 + kp_detector_input_names = kp_detector_predictor.get_input_names() + kp_detector_input_handle = kp_detector_predictor.get_input_handle( + kp_detector_input_names[0]) + + kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256]) + kp_detector_input_handle.copy_from_cpu(source) + kp_detector_predictor.run() + kp_detector_output_names = kp_detector_predictor.get_output_names() + kp_detector_output_handle = kp_detector_predictor.get_output_handle( + kp_detector_output_names[0]) + source_j = kp_detector_output_handle.copy_to_cpu() + kp_detector_output_handle = kp_detector_predictor.get_output_handle( + kp_detector_output_names[1]) + source_v = kp_detector_output_handle.copy_to_cpu() + + kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256]) + kp_detector_input_handle.copy_from_cpu(driving_video[0:1]) kp_detector_predictor.run() kp_detector_output_names = kp_detector_predictor.get_output_names() - kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[0]) - driving_j = kp_detector_output_handle.copy_to_cpu() - kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[1]) - driving_v = kp_detector_output_handle.copy_to_cpu() - generator_inputs = [source, source_j, source_v, driving_j, driving_v, driving_init_j, driving_init_v] - generator_input_names = generator_predictor.get_input_names() - for i in range(len(generator_input_names)): - generator_input_handle = generator_predictor.get_input_handle(generator_input_names[i]) - generator_input_handle.copy_from_cpu(generator_inputs[i]) - generator_predictor.run() - generator_output_names = generator_predictor.get_output_names() - generator_output_handle = generator_predictor.get_output_handle(generator_output_names[0]) - output_data = generator_output_handle.copy_to_cpu() - output_data = np.transpose(output_data, [0, 2, 3, 1])[0] * 255.0 - - #Todo:add blazeface static model - #frame = source_img.copy() - #frame[left:right, up:bottom] = cv2.resize(output_data.astype(np.uint8), (bottom - up, right - left), cv2.INTER_AREA) - results.append(output_data.astype(np.uint8)) - print(time.time() - start_time) - imageio.mimsave(args.output_path, [frame for frame in results], fps=fps) + kp_detector_output_handle = kp_detector_predictor.get_output_handle( + kp_detector_output_names[0]) + driving_init_j = kp_detector_output_handle.copy_to_cpu() + kp_detector_output_handle = kp_detector_predictor.get_output_handle( + kp_detector_output_names[1]) + driving_init_v = kp_detector_output_handle.copy_to_cpu() + start_time = time.time() + results = [] + for i in tqdm(range(0, driving_len)): + kp_detector_input_handle.copy_from_cpu(driving_video[i:i + 1]) + kp_detector_predictor.run() + kp_detector_output_names = kp_detector_predictor.get_output_names() + kp_detector_output_handle = kp_detector_predictor.get_output_handle( + kp_detector_output_names[0]) + driving_j = kp_detector_output_handle.copy_to_cpu() + kp_detector_output_handle = kp_detector_predictor.get_output_handle( + kp_detector_output_names[1]) + driving_v = kp_detector_output_handle.copy_to_cpu() + generator_inputs = [ + source, source_j, source_v, driving_j, driving_v, + driving_init_j, driving_init_v + ] + generator_input_names = generator_predictor.get_input_names() + for i in range(len(generator_input_names)): + generator_input_handle = generator_predictor.get_input_handle( + generator_input_names[i]) + generator_input_handle.copy_from_cpu(generator_inputs[i]) + generator_predictor.run() + generator_output_names = generator_predictor.get_output_names() + generator_output_handle = generator_predictor.get_output_handle( + generator_output_names[0]) + output_data = generator_output_handle.copy_to_cpu() + output_data = np.transpose(output_data, [0, 2, 3, 1])[0] * 255.0 + + #Todo:add blazeface static model + #frame = source_img.copy() + #frame[left:right, up:bottom] = cv2.resize(output_data.astype(np.uint8), (bottom - up, right - left), cv2.INTER_AREA) + results.append(output_data.astype(np.uint8)) + print(time.time() - start_time) + imageio.mimsave(os.path.join(args.output_path, + "result_" + str(k) + ".mp4"), + [frame for frame in results], + fps=fps) + def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--model_profix", type=str, help="model filename profix") + parser.add_argument("--model_path", type=str, help="model filename profix") parser.add_argument("--batch_size", type=int, default=1, help="batch size") - parser.add_argument("--source_path", type=str, default=1, help="source_path") - parser.add_argument("--driving_path", type=str, default=1, help="driving_path") - parser.add_argument("--output_path", type=str, default=1, help="output_path") + parser.add_argument("--source_path", + type=str, + default=None, + help="source_path") + parser.add_argument("--driving_path", + type=str, + default=None, + help="driving_path") + parser.add_argument("--output_path", + type=str, + default="infer_output/fom/", + help="output_path") + parser.add_argument("--device", type=str, default="gpu", help="device") + return parser.parse_args() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tools/inference.py b/tools/inference.py index 01fccb878a8c1886cb21cbaf6184a7d2754f0f3d..b4d0200d153c54de5aacca98bfff041037b31b87 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -19,13 +19,14 @@ def parse_args(): default=None, type=str, required=True, - help="The path prefix of inference model to be used.", ) - parser.add_argument( - "--model_type", - default=None, - type=str, - required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES)) + help="The path prefix of inference model to be used.", + ) + parser.add_argument("--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES)) parser.add_argument( "--device", default="gpu", @@ -65,12 +66,14 @@ def main(): args = parse_args() cfg = get_config(args.config_file, args.opt) predictor = create_predictor(args.model_path, args.device) - input_handles = [predictor.get_input_handle( - name) for name in predictor.get_input_names()] - output_handle = predictor.get_output_handle( - predictor.get_output_names()[0]) - test_dataloader = build_dataloader( - cfg.dataset.test, is_train=False, distributed=False) + input_handles = [ + predictor.get_input_handle(name) + for name in predictor.get_input_names() + ] + output_handle = predictor.get_output_handle(predictor.get_output_names()[0]) + test_dataloader = build_dataloader(cfg.dataset.test, + is_train=False, + distributed=False) max_eval_steps = len(test_dataloader) iter_loader = IterLoader(test_dataloader) @@ -110,8 +113,8 @@ def main(): prediction[j] = prediction[j][::-1, :, :] image_numpy = paddle.to_tensor(prediction[j]) image_numpy = tensor2img(image_numpy, (0, 1)) - save_image( - image_numpy, "infer_output/wav2lip/{}_{}.png".format(i, j)) + save_image(image_numpy, + "infer_output/wav2lip/{}_{}.png".format(i, j)) elif model_type == "esrgan": lq = data['lq'].numpy() input_handles[0].copy_from_cpu(lq) @@ -128,6 +131,23 @@ def main(): prediction = paddle.to_tensor(prediction[0]) image_numpy = tensor2img(prediction, min_max) save_image(image_numpy, "infer_output/edvr/{}.png".format(i)) + elif model_type == "stylegan2": + noise = paddle.randn([1, 1, 512]).cpu().numpy() + input_handles[0].copy_from_cpu(noise) + input_handles[1].copy_from_cpu(np.array([0.7]).astype('float32')) + predictor.run() + prediction = output_handle.copy_to_cpu() + prediction = paddle.to_tensor(prediction[0]) + image_numpy = tensor2img(prediction, min_max) + save_image(image_numpy, "infer_output/stylegan2/{}.png".format(i)) + elif model_type == "basicvsr": + lq = data['lq'].numpy() + input_handles[0].copy_from_cpu(lq) + predictor.run() + prediction = output_handle.copy_to_cpu() + prediction = paddle.to_tensor(prediction[0]) + image_numpy = tensor2img(prediction, min_max) + save_image(image_numpy, "infer_output/basicvsr/{}.png".format(i)) if __name__ == '__main__':