inference.py 8.6 KB
Newer Older
1 2 3
import paddle
import argparse
import numpy as np
4 5 6
import random
import os
from collections import OrderedDict
7 8 9 10 11 12 13

from ppgan.utils.config import get_config
from ppgan.datasets.builder import build_dataloader
from ppgan.engine.trainer import IterLoader
from ppgan.utils.visual import save_image
from ppgan.utils.visual import tensor2img
from ppgan.utils.filesystem import makedirs
14
from ppgan.metrics import build_metric
15

16 17

MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", "edvr", "fom", "stylegan2", "basicvsr"]
18 19 20 21 22 23 24 25 26


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path",
        default=None,
        type=str,
        required=True,
27 28 29 30 31 32 33 34
        help="The path prefix of inference model to be used.",
    )
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES))
35 36 37 38 39 40 41 42 43 44
    parser.add_argument(
        "--device",
        default="gpu",
        type=str,
        choices=["cpu", "gpu", "xpu"],
        help="The device to select to train the model, is must be cpu/gpu/xpu.")
    parser.add_argument('-c',
                        '--config-file',
                        metavar="FILE",
                        help='config file path')
45 46 47 48
    parser.add_argument("--output_path",
                        type=str,
                        default="infer_output",
                        help="output_path")
49 50 51 52 53
    # config options
    parser.add_argument("-o",
                        "--opt",
                        nargs='+',
                        help="set configuration options")
54 55 56 57 58 59
    # fix random numbers by setting seed
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='fix random numbers by setting seed\".'
    )
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    args = parser.parse_args()
    return args


def create_predictor(model_path, device="gpu"):
    config = paddle.inference.Config(model_path + ".pdmodel",
                                     model_path + ".pdiparams")
    if device == "gpu":
        config.enable_use_gpu(100, 0)
    elif device == "cpu":
        config.disable_gpu()
    elif device == "xpu":
        config.enable_xpu(100)
    else:
        config.disable_gpu()

    predictor = paddle.inference.create_predictor(config)
    return predictor

79 80 81 82 83 84 85 86 87 88
def setup_metrics(cfg):
    metrics = OrderedDict()
    if isinstance(list(cfg.values())[0], dict):
        for metric_name, cfg_ in cfg.items():
            metrics[metric_name] = build_metric(cfg_)
    else:
        metric = build_metric(cfg)
        metrics[metric.__class__.__name__] = metric

    return metrics
89 90 91

def main():
    args = parse_args()
92 93 94 95
    if args.seed:
        paddle.seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)    
96 97
    cfg = get_config(args.config_file, args.opt)
    predictor = create_predictor(args.model_path, args.device)
98 99 100 101 102 103 104 105
    input_handles = [
        predictor.get_input_handle(name)
        for name in predictor.get_input_names()
    ]
    output_handle = predictor.get_output_handle(predictor.get_output_names()[0])
    test_dataloader = build_dataloader(cfg.dataset.test,
                                       is_train=False,
                                       distributed=False)
106 107 108 109 110 111 112 113 114

    max_eval_steps = len(test_dataloader)
    iter_loader = IterLoader(test_dataloader)

    min_max = cfg.get('min_max', None)
    if min_max is None:
        min_max = (-1., 1.)

    model_type = args.model_type
115 116 117 118 119 120 121 122 123
    makedirs(os.path.join(args.output_path, model_type))

    validate_cfg = cfg.get('validate', None)
    metrics = None
    if validate_cfg and 'metrics' in validate_cfg:
        metrics = setup_metrics(validate_cfg['metrics'])
        for metric in metrics.values():
            metric.reset()

124 125 126 127 128 129 130
    for i in range(max_eval_steps):
        data = next(iter_loader)
        if model_type == "pix2pix":
            real_A = data['B'].numpy()
            input_handles[0].copy_from_cpu(real_A)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
131 132 133 134 135 136 137 138
            prediction = paddle.to_tensor(prediction)
            image_numpy = tensor2img(prediction[0], min_max)
            save_image(image_numpy, os.path.join(args.output_path, "pix2pix/{}.png".format(i)))
            metric_file = os.path.join(args.output_path, "pix2pix/metric.txt")
            real_B = paddle.to_tensor(data['A'])
            for metric in metrics.values():
                metric.update(prediction, real_B)

139 140 141 142 143
        elif model_type == "cyclegan":
            real_A = data['A'].numpy()
            input_handles[0].copy_from_cpu(real_A)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
144 145 146 147 148 149 150 151
            prediction = paddle.to_tensor(prediction)
            image_numpy = tensor2img(prediction[0], min_max)
            save_image(image_numpy, os.path.join(args.output_path, "cyclegan/{}.png".format(i)))
            metric_file = os.path.join(args.output_path, "cyclegan/metric.txt")
            real_B = paddle.to_tensor(data['B'])
            for metric in metrics.values():
                metric.update(prediction, real_B)

152 153 154 155 156 157 158 159 160 161 162
        elif model_type == "wav2lip":
            indiv_mels, x = data['indiv_mels'].numpy()[0], data['x'].numpy()[0]
            x = x.transpose([1, 0, 2, 3])
            input_handles[0].copy_from_cpu(indiv_mels)
            input_handles[1].copy_from_cpu(x)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
            for j in range(prediction.shape[0]):
                prediction[j] = prediction[j][::-1, :, :]
                image_numpy = paddle.to_tensor(prediction[j])
                image_numpy = tensor2img(image_numpy, (0, 1))
163 164
                save_image(image_numpy,
                           "infer_output/wav2lip/{}_{}.png".format(i, j))
165

166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
        elif model_type == "esrgan":
            lq = data['lq'].numpy()
            input_handles[0].copy_from_cpu(lq)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
            prediction = paddle.to_tensor(prediction[0])
            image_numpy = tensor2img(prediction, min_max)
            save_image(image_numpy, "infer_output/esrgan/{}.png".format(i))
        elif model_type == "edvr":
            lq = data['lq'].numpy()
            input_handles[0].copy_from_cpu(lq)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
            prediction = paddle.to_tensor(prediction[0])
            image_numpy = tensor2img(prediction, min_max)
            save_image(image_numpy, "infer_output/edvr/{}.png".format(i))
182 183 184 185 186 187
        elif model_type == "stylegan2":
            noise = paddle.randn([1, 1, 512]).cpu().numpy()
            input_handles[0].copy_from_cpu(noise)
            input_handles[1].copy_from_cpu(np.array([0.7]).astype('float32'))
            predictor.run()
            prediction = output_handle.copy_to_cpu()
188 189 190 191 192 193 194
            prediction = paddle.to_tensor(prediction)
            image_numpy = tensor2img(prediction[0], min_max)
            save_image(image_numpy, os.path.join(args.output_path, "stylegan2/{}.png".format(i)))
            metric_file = os.path.join(args.output_path, "stylegan2/metric.txt")
            real_img = paddle.to_tensor(data['A'])
            for metric in metrics.values():
                metric.update(prediction, real_img)
195 196 197 198 199
        elif model_type == "basicvsr":
            lq = data['lq'].numpy()
            input_handles[0].copy_from_cpu(lq)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
            prediction = paddle.to_tensor(prediction)
            _, t, _, _, _ = prediction.shape
            out_img = []
            gt_img = []
            for ti in range(t):
                out_tensor = prediction[0, ti]
                gt_tensor = data['gt'][0, ti]
                out_img.append(tensor2img(out_tensor, (0.,1.)))
                gt_img.append(tensor2img(gt_tensor, (0.,1.)))
                
            image_numpy = tensor2img(prediction[0], min_max)
            save_image(image_numpy, os.path.join(args.output_path, "basicvsr/{}.png".format(i)))

            metric_file = os.path.join(args.output_path, "basicvsr/metric.txt")
            for metric in metrics.values():
                metric.update(out_img, gt_img, is_seq=True)
216

217 218 219 220 221 222 223
    if metrics:
        log_file = open(metric_file, 'a')
        for metric_name, metric in metrics.items():
            loss_string = "Metric {}: {:.4f}".format(
                    metric_name, metric.accumulate())
            print(loss_string, file=log_file)
        log_file.close()
224 225 226

if __name__ == '__main__':
    main()