未验证 提交 a4a54e51 编写于 作者: R ronnywang 提交者: GitHub

add export_model and inference scripts (#340)

* add export_model and inference scripts
上级 059d52e3
......@@ -27,6 +27,10 @@ model:
name: GANLoss
gan_mode: lsgan
export_model:
- {name: 'netG_A', inputs_num: 1}
- {name: 'netG_B', inputs_num: 1}
dataset:
train:
name: UnpairedDataset
......
......@@ -26,6 +26,9 @@ model:
pixel_criterion:
name: CharbonnierLoss
export_model:
- {name: 'generator', inputs_num: 1}
dataset:
train:
name: REDSDataset
......
......@@ -32,6 +32,9 @@ model:
gan_mode: vanilla
loss_weight: !!float 5e-3
export_model:
- {name: 'generator', inputs_num: 1}
dataset:
train:
name: SRDataset
......
......@@ -25,6 +25,9 @@ model:
name: GANLoss
gan_mode: vanilla
export_model:
- {name: 'netG', inputs_num: 1}
dataset:
train:
name: PairedDataset
......
......@@ -24,6 +24,9 @@ model:
gen_iters: 4
disc_iters: 16
export_model:
- {name: 'gen', inputs_num: 2}
dataset:
train:
name: SingleDataset
......
......@@ -14,6 +14,9 @@ model:
discriminator_hq:
name: Wav2LipDiscQual
export_model:
- {name: 'netG', inputs_num: 2}
dataset:
train:
name: Wav2LipDataset
......
......@@ -182,3 +182,16 @@ class BaseModel(ABC):
if net is not None:
for param in net.parameters():
param.trainable = requires_grad
def export_model(self, export_model, output_dir=None, inputs_size=[]):
inputs_num = 0
for net in export_model:
input_spec = [paddle.static.InputSpec(
shape=inputs_size[inputs_num + i], dtype="float32") for i in range(net["inputs_num"])]
inputs_num = inputs_num + net["inputs_num"]
static_model = paddle.jit.to_static(self.nets[net["name"]],
input_spec=input_spec)
if output_dir is None:
output_dir = 'export_model'
paddle.jit.save(static_model, os.path.join(
output_dir, '{}_{}'.format(self.__class__.__name__.lower(), net["name"])))
......@@ -120,6 +120,26 @@ def deform_conv2d(x,
out = nn.elementwise_add(pre_bias, bias, axis=1)
else:
out = pre_bias
else:
helper = LayerHelper('deform_conv2d', **locals())
attrs = {'strides': stride, 'paddings': padding, 'dilations': dilation, 'deformable_groups': deformable_groups,
'groups': groups, 'im2col_step': 1}
if use_deform_conv2d_v1:
op_type = 'deformable_conv_v1'
inputs = {'Input': x, 'Offset': offset, 'Filter': weight}
else:
op_type = 'deformable_conv'
inputs = {'Input': x, 'Offset': offset,
'Mask': mask, 'Filter': weight}
pre_bias = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type=op_type, inputs=inputs, outputs={
'Output': pre_bias}, attrs=attrs)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=1)
else:
out = pre_bias
return out
......
......@@ -36,7 +36,7 @@ class FusedLeakyReLU(nn.Layer):
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
rest_dim = [1] * (len(input.shape) - len(bias.shape) - 1)
return (
F.leaky_relu(
input + bias.reshape((1, bias.shape[0], *rest_dim)), negative_slope=0.2
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
import ppgan
from ppgan.utils.config import get_config
from ppgan.utils.setup import setup
from ppgan.engine.trainer import Trainer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--export_model",
default=None,
type=str,
help="The path prefix of inference model to be used.", )
parser.add_argument('-c',
'--config-file',
metavar="FILE",
required=True,
help="config file path")
parser.add_argument("--load",
type=str,
default=None,
required=True,
help="put the path to resuming file if needed")
# config options
parser.add_argument("-o",
"--opt",
nargs="+",
help="set configuration options")
parser.add_argument("-s",
"--inputs_size",
type=str,
default=None,
required=True,
help="the inputs size")
args = parser.parse_args()
return args
def main(args, cfg):
inputs_size = [[int(size) for size in input_size.split(',')]
for input_size in args.inputs_size.split(';')]
model = ppgan.models.builder.build_model(cfg.model)
model.setup_train_mode(is_train=False)
state_dicts = ppgan.utils.filesystem.load(args.load)
for net_name, net in model.nets.items():
if net_name in state_dicts:
net.set_state_dict(state_dicts[net_name])
model.export_model(cfg.export_model, args.export_model, inputs_size)
if __name__ == "__main__":
args = parse_args()
cfg = get_config(args.config_file, args.opt)
main(args, cfg)
import paddle
import argparse
import numpy as np
from ppgan.utils.config import get_config
from ppgan.datasets.builder import build_dataloader
from ppgan.engine.trainer import IterLoader
from ppgan.utils.visual import save_image
from ppgan.utils.visual import tensor2img
from ppgan.utils.filesystem import makedirs
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", "edvr"]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
default=None,
type=str,
required=True,
help="The path prefix of inference model to be used.", )
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES))
parser.add_argument(
"--device",
default="gpu",
type=str,
choices=["cpu", "gpu", "xpu"],
help="The device to select to train the model, is must be cpu/gpu/xpu.")
parser.add_argument('-c',
'--config-file',
metavar="FILE",
help='config file path')
# config options
parser.add_argument("-o",
"--opt",
nargs='+',
help="set configuration options")
args = parser.parse_args()
return args
def create_predictor(model_path, device="gpu"):
config = paddle.inference.Config(model_path + ".pdmodel",
model_path + ".pdiparams")
if device == "gpu":
config.enable_use_gpu(100, 0)
elif device == "cpu":
config.disable_gpu()
elif device == "xpu":
config.enable_xpu(100)
else:
config.disable_gpu()
predictor = paddle.inference.create_predictor(config)
return predictor
def main():
args = parse_args()
cfg = get_config(args.config_file, args.opt)
predictor = create_predictor(args.model_path, args.device)
input_handles = [predictor.get_input_handle(
name) for name in predictor.get_input_names()]
output_handle = predictor.get_output_handle(
predictor.get_output_names()[0])
test_dataloader = build_dataloader(
cfg.dataset.test, is_train=False, distributed=False)
max_eval_steps = len(test_dataloader)
iter_loader = IterLoader(test_dataloader)
min_max = cfg.get('min_max', None)
if min_max is None:
min_max = (-1., 1.)
model_type = args.model_type
makedirs("infer_output/" + model_type)
for i in range(max_eval_steps):
data = next(iter_loader)
if model_type == "pix2pix":
real_A = data['B'].numpy()
input_handles[0].copy_from_cpu(real_A)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/pix2pix/{}.png".format(i))
elif model_type == "cyclegan":
real_A = data['A'].numpy()
input_handles[0].copy_from_cpu(real_A)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/cyclegan/{}.png".format(i))
elif model_type == "wav2lip":
indiv_mels, x = data['indiv_mels'].numpy()[0], data['x'].numpy()[0]
x = x.transpose([1, 0, 2, 3])
input_handles[0].copy_from_cpu(indiv_mels)
input_handles[1].copy_from_cpu(x)
predictor.run()
prediction = output_handle.copy_to_cpu()
for j in range(prediction.shape[0]):
prediction[j] = prediction[j][::-1, :, :]
image_numpy = paddle.to_tensor(prediction[j])
image_numpy = tensor2img(image_numpy, (0, 1))
save_image(
image_numpy, "infer_output/wav2lip/{}_{}.png".format(i, j))
elif model_type == "esrgan":
lq = data['lq'].numpy()
input_handles[0].copy_from_cpu(lq)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/esrgan/{}.png".format(i))
elif model_type == "edvr":
lq = data['lq'].numpy()
input_handles[0].copy_from_cpu(lq)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/edvr/{}.png".format(i))
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册