From a4a54e5180bbb5c500e7b83c427c00204954b76d Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Thu, 17 Jun 2021 03:07:19 -0500 Subject: [PATCH] add export_model and inference scripts (#340) * add export_model and inference scripts --- configs/cyclegan_horse2zebra.yaml | 4 + configs/edvr_m_w_tsa.yaml | 3 + configs/esrgan_x4_div2k.yaml | 3 + configs/pix2pix_cityscapes.yaml | 3 + configs/stylegan_v2_256_ffhq.yaml | 3 + configs/wav2lip_hq.yaml | 3 + ppgan/models/base_model.py | 13 +++ ppgan/modules/dcn.py | 20 +++++ ppgan/modules/fused_act.py | 2 +- tools/export_model.py | 72 ++++++++++++++++ tools/inference.py | 134 ++++++++++++++++++++++++++++++ 11 files changed, 259 insertions(+), 1 deletion(-) create mode 100644 tools/export_model.py create mode 100644 tools/inference.py diff --git a/configs/cyclegan_horse2zebra.yaml b/configs/cyclegan_horse2zebra.yaml index 7d831ff..77a2861 100644 --- a/configs/cyclegan_horse2zebra.yaml +++ b/configs/cyclegan_horse2zebra.yaml @@ -27,6 +27,10 @@ model: name: GANLoss gan_mode: lsgan +export_model: + - {name: 'netG_A', inputs_num: 1} + - {name: 'netG_B', inputs_num: 1} + dataset: train: name: UnpairedDataset diff --git a/configs/edvr_m_w_tsa.yaml b/configs/edvr_m_w_tsa.yaml index c9d5ea8..e40fa0c 100644 --- a/configs/edvr_m_w_tsa.yaml +++ b/configs/edvr_m_w_tsa.yaml @@ -26,6 +26,9 @@ model: pixel_criterion: name: CharbonnierLoss +export_model: + - {name: 'generator', inputs_num: 1} + dataset: train: name: REDSDataset diff --git a/configs/esrgan_x4_div2k.yaml b/configs/esrgan_x4_div2k.yaml index 53cabe9..5202389 100644 --- a/configs/esrgan_x4_div2k.yaml +++ b/configs/esrgan_x4_div2k.yaml @@ -32,6 +32,9 @@ model: gan_mode: vanilla loss_weight: !!float 5e-3 +export_model: + - {name: 'generator', inputs_num: 1} + dataset: train: name: SRDataset diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index 47f1716..11ad4fc 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -25,6 +25,9 @@ model: name: GANLoss gan_mode: vanilla +export_model: + - {name: 'netG', inputs_num: 1} + dataset: train: name: PairedDataset diff --git a/configs/stylegan_v2_256_ffhq.yaml b/configs/stylegan_v2_256_ffhq.yaml index d87268c..1265589 100644 --- a/configs/stylegan_v2_256_ffhq.yaml +++ b/configs/stylegan_v2_256_ffhq.yaml @@ -24,6 +24,9 @@ model: gen_iters: 4 disc_iters: 16 +export_model: + - {name: 'gen', inputs_num: 2} + dataset: train: name: SingleDataset diff --git a/configs/wav2lip_hq.yaml b/configs/wav2lip_hq.yaml index a6a4f1c..9e9dc51 100644 --- a/configs/wav2lip_hq.yaml +++ b/configs/wav2lip_hq.yaml @@ -14,6 +14,9 @@ model: discriminator_hq: name: Wav2LipDiscQual +export_model: + - {name: 'netG', inputs_num: 2} + dataset: train: name: Wav2LipDataset diff --git a/ppgan/models/base_model.py b/ppgan/models/base_model.py index dac434e..36d1454 100755 --- a/ppgan/models/base_model.py +++ b/ppgan/models/base_model.py @@ -182,3 +182,16 @@ class BaseModel(ABC): if net is not None: for param in net.parameters(): param.trainable = requires_grad + + def export_model(self, export_model, output_dir=None, inputs_size=[]): + inputs_num = 0 + for net in export_model: + input_spec = [paddle.static.InputSpec( + shape=inputs_size[inputs_num + i], dtype="float32") for i in range(net["inputs_num"])] + inputs_num = inputs_num + net["inputs_num"] + static_model = paddle.jit.to_static(self.nets[net["name"]], + input_spec=input_spec) + if output_dir is None: + output_dir = 'export_model' + paddle.jit.save(static_model, os.path.join( + output_dir, '{}_{}'.format(self.__class__.__name__.lower(), net["name"]))) diff --git a/ppgan/modules/dcn.py b/ppgan/modules/dcn.py index cf9a5a4..f5d60e8 100644 --- a/ppgan/modules/dcn.py +++ b/ppgan/modules/dcn.py @@ -120,6 +120,26 @@ def deform_conv2d(x, out = nn.elementwise_add(pre_bias, bias, axis=1) else: out = pre_bias + else: + helper = LayerHelper('deform_conv2d', **locals()) + attrs = {'strides': stride, 'paddings': padding, 'dilations': dilation, 'deformable_groups': deformable_groups, + 'groups': groups, 'im2col_step': 1} + if use_deform_conv2d_v1: + op_type = 'deformable_conv_v1' + inputs = {'Input': x, 'Offset': offset, 'Filter': weight} + else: + op_type = 'deformable_conv' + inputs = {'Input': x, 'Offset': offset, + 'Mask': mask, 'Filter': weight} + + pre_bias = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op(type=op_type, inputs=inputs, outputs={ + 'Output': pre_bias}, attrs=attrs) + + if bias is not None: + out = nn.elementwise_add(pre_bias, bias, axis=1) + else: + out = pre_bias return out diff --git a/ppgan/modules/fused_act.py b/ppgan/modules/fused_act.py index 8723af3..d1bc584 100644 --- a/ppgan/modules/fused_act.py +++ b/ppgan/modules/fused_act.py @@ -36,7 +36,7 @@ class FusedLeakyReLU(nn.Layer): def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): if bias is not None: - rest_dim = [1] * (input.ndim - bias.ndim - 1) + rest_dim = [1] * (len(input.shape) - len(bias.shape) - 1) return ( F.leaky_relu( input + bias.reshape((1, bias.shape[0], *rest_dim)), negative_slope=0.2 diff --git a/tools/export_model.py b/tools/export_model.py new file mode 100644 index 0000000..f589b15 --- /dev/null +++ b/tools/export_model.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse + +import ppgan +from ppgan.utils.config import get_config +from ppgan.utils.setup import setup +from ppgan.engine.trainer import Trainer + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--export_model", + default=None, + type=str, + help="The path prefix of inference model to be used.", ) + parser.add_argument('-c', + '--config-file', + metavar="FILE", + required=True, + help="config file path") + parser.add_argument("--load", + type=str, + default=None, + required=True, + help="put the path to resuming file if needed") + # config options + parser.add_argument("-o", + "--opt", + nargs="+", + help="set configuration options") + parser.add_argument("-s", + "--inputs_size", + type=str, + default=None, + required=True, + help="the inputs size") + args = parser.parse_args() + return args + + +def main(args, cfg): + inputs_size = [[int(size) for size in input_size.split(',')] + for input_size in args.inputs_size.split(';')] + model = ppgan.models.builder.build_model(cfg.model) + model.setup_train_mode(is_train=False) + state_dicts = ppgan.utils.filesystem.load(args.load) + for net_name, net in model.nets.items(): + if net_name in state_dicts: + net.set_state_dict(state_dicts[net_name]) + model.export_model(cfg.export_model, args.export_model, inputs_size) + + +if __name__ == "__main__": + args = parse_args() + cfg = get_config(args.config_file, args.opt) + main(args, cfg) diff --git a/tools/inference.py b/tools/inference.py new file mode 100644 index 0000000..01fccb8 --- /dev/null +++ b/tools/inference.py @@ -0,0 +1,134 @@ +import paddle +import argparse +import numpy as np + +from ppgan.utils.config import get_config +from ppgan.datasets.builder import build_dataloader +from ppgan.engine.trainer import IterLoader +from ppgan.utils.visual import save_image +from ppgan.utils.visual import tensor2img +from ppgan.utils.filesystem import makedirs + +MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", "edvr"] + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_path", + default=None, + type=str, + required=True, + help="The path prefix of inference model to be used.", ) + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES)) + parser.add_argument( + "--device", + default="gpu", + type=str, + choices=["cpu", "gpu", "xpu"], + help="The device to select to train the model, is must be cpu/gpu/xpu.") + parser.add_argument('-c', + '--config-file', + metavar="FILE", + help='config file path') + # config options + parser.add_argument("-o", + "--opt", + nargs='+', + help="set configuration options") + args = parser.parse_args() + return args + + +def create_predictor(model_path, device="gpu"): + config = paddle.inference.Config(model_path + ".pdmodel", + model_path + ".pdiparams") + if device == "gpu": + config.enable_use_gpu(100, 0) + elif device == "cpu": + config.disable_gpu() + elif device == "xpu": + config.enable_xpu(100) + else: + config.disable_gpu() + + predictor = paddle.inference.create_predictor(config) + return predictor + + +def main(): + args = parse_args() + cfg = get_config(args.config_file, args.opt) + predictor = create_predictor(args.model_path, args.device) + input_handles = [predictor.get_input_handle( + name) for name in predictor.get_input_names()] + output_handle = predictor.get_output_handle( + predictor.get_output_names()[0]) + test_dataloader = build_dataloader( + cfg.dataset.test, is_train=False, distributed=False) + + max_eval_steps = len(test_dataloader) + iter_loader = IterLoader(test_dataloader) + + min_max = cfg.get('min_max', None) + if min_max is None: + min_max = (-1., 1.) + + model_type = args.model_type + makedirs("infer_output/" + model_type) + for i in range(max_eval_steps): + data = next(iter_loader) + if model_type == "pix2pix": + real_A = data['B'].numpy() + input_handles[0].copy_from_cpu(real_A) + predictor.run() + prediction = output_handle.copy_to_cpu() + prediction = paddle.to_tensor(prediction[0]) + image_numpy = tensor2img(prediction, min_max) + save_image(image_numpy, "infer_output/pix2pix/{}.png".format(i)) + elif model_type == "cyclegan": + real_A = data['A'].numpy() + input_handles[0].copy_from_cpu(real_A) + predictor.run() + prediction = output_handle.copy_to_cpu() + prediction = paddle.to_tensor(prediction[0]) + image_numpy = tensor2img(prediction, min_max) + save_image(image_numpy, "infer_output/cyclegan/{}.png".format(i)) + elif model_type == "wav2lip": + indiv_mels, x = data['indiv_mels'].numpy()[0], data['x'].numpy()[0] + x = x.transpose([1, 0, 2, 3]) + input_handles[0].copy_from_cpu(indiv_mels) + input_handles[1].copy_from_cpu(x) + predictor.run() + prediction = output_handle.copy_to_cpu() + for j in range(prediction.shape[0]): + prediction[j] = prediction[j][::-1, :, :] + image_numpy = paddle.to_tensor(prediction[j]) + image_numpy = tensor2img(image_numpy, (0, 1)) + save_image( + image_numpy, "infer_output/wav2lip/{}_{}.png".format(i, j)) + elif model_type == "esrgan": + lq = data['lq'].numpy() + input_handles[0].copy_from_cpu(lq) + predictor.run() + prediction = output_handle.copy_to_cpu() + prediction = paddle.to_tensor(prediction[0]) + image_numpy = tensor2img(prediction, min_max) + save_image(image_numpy, "infer_output/esrgan/{}.png".format(i)) + elif model_type == "edvr": + lq = data['lq'].numpy() + input_handles[0].copy_from_cpu(lq) + predictor.run() + prediction = output_handle.copy_to_cpu() + prediction = paddle.to_tensor(prediction[0]) + image_numpy = tensor2img(prediction, min_max) + save_image(image_numpy, "infer_output/edvr/{}.png".format(i)) + + +if __name__ == '__main__': + main() -- GitLab