inference.py 14.3 KB
Newer Older
1 2 3
import paddle
import argparse
import numpy as np
4 5 6
import random
import os
from collections import OrderedDict
7 8 9 10 11 12 13

from ppgan.utils.config import get_config
from ppgan.datasets.builder import build_dataloader
from ppgan.engine.trainer import IterLoader
from ppgan.utils.visual import save_image
from ppgan.utils.visual import tensor2img
from ppgan.utils.filesystem import makedirs
14
from ppgan.metrics import build_metric
15

16

L
lzzyzlbb 已提交
17
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \
S
simonsLiang 已提交
18
                 "edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan","prenet"]
19 20 21 22 23 24 25 26 27


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path",
        default=None,
        type=str,
        required=True,
28 29 30 31 32 33 34 35
        help="The path prefix of inference model to be used.",
    )
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES))
36 37 38 39 40 41 42 43 44 45
    parser.add_argument(
        "--device",
        default="gpu",
        type=str,
        choices=["cpu", "gpu", "xpu"],
        help="The device to select to train the model, is must be cpu/gpu/xpu.")
    parser.add_argument('-c',
                        '--config-file',
                        metavar="FILE",
                        help='config file path')
46 47 48 49
    parser.add_argument("--output_path",
                        type=str,
                        default="infer_output",
                        help="output_path")
50 51 52 53 54
    # config options
    parser.add_argument("-o",
                        "--opt",
                        nargs='+',
                        help="set configuration options")
55 56 57 58
    # fix random numbers by setting seed
    parser.add_argument('--seed',
                        type=int,
                        default=None,
S
simonsLiang 已提交
59
                        help='fix random numbers by setting seed\".')
L
lzzyzlbb 已提交
60
    # for tensorRT
S
simonsLiang 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    parser.add_argument("--run_mode",
                        default="fluid",
                        type=str,
                        choices=["fluid", "trt_fp32", "trt_fp16"],
                        help="mode of running(fluid/trt_fp32/trt_fp16)")
    parser.add_argument("--trt_min_shape",
                        default=1,
                        type=int,
                        help="trt_min_shape for tensorRT")
    parser.add_argument("--trt_max_shape",
                        default=1280,
                        type=int,
                        help="trt_max_shape for tensorRT")
    parser.add_argument("--trt_opt_shape",
                        default=640,
                        type=int,
                        help="trt_opt_shape for tensorRT")
    parser.add_argument("--min_subgraph_size",
                        default=3,
                        type=int,
                        help="trt_opt_shape for tensorRT")
    parser.add_argument("--batch_size",
                        default=1,
                        type=int,
                        help="batch_size for tensorRT")
    parser.add_argument("--use_dynamic_shape",
                        dest="use_dynamic_shape",
                        action="store_true",
                        help="use_dynamic_shape for tensorRT")
    parser.add_argument("--trt_calib_mode",
                        dest="trt_calib_mode",
                        action="store_true",
                        help="trt_calib_mode for tensorRT")
94 95 96 97
    args = parser.parse_args()
    return args


S
simonsLiang 已提交
98 99 100 101 102 103 104 105 106 107
def create_predictor(model_path,
                     device="gpu",
                     run_mode='fluid',
                     batch_size=1,
                     min_subgraph_size=3,
                     use_dynamic_shape=False,
                     trt_min_shape=1,
                     trt_max_shape=1280,
                     trt_opt_shape=640,
                     trt_calib_mode=False):
108 109 110 111 112 113 114 115 116 117
    config = paddle.inference.Config(model_path + ".pdmodel",
                                     model_path + ".pdiparams")
    if device == "gpu":
        config.enable_use_gpu(100, 0)
    elif device == "cpu":
        config.disable_gpu()
    elif device == "xpu":
        config.enable_xpu(100)
    else:
        config.disable_gpu()
S
simonsLiang 已提交
118

L
lzzyzlbb 已提交
119 120 121 122 123 124
    precision_map = {
        'trt_int8': paddle.inference.Config.Precision.Int8,
        'trt_fp32': paddle.inference.Config.Precision.Float32,
        'trt_fp16': paddle.inference.Config.Precision.Half
    }
    if run_mode in precision_map.keys():
S
simonsLiang 已提交
125 126 127 128 129 130
        config.enable_tensorrt_engine(workspace_size=1 << 25,
                                      max_batch_size=batch_size,
                                      min_subgraph_size=min_subgraph_size,
                                      precision_mode=precision_map[run_mode],
                                      use_static=False,
                                      use_calib_mode=trt_calib_mode)
L
lzzyzlbb 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144

        if use_dynamic_shape:
            min_input_shape = {
                'image': [batch_size, 3, trt_min_shape, trt_min_shape]
            }
            max_input_shape = {
                'image': [batch_size, 3, trt_max_shape, trt_max_shape]
            }
            opt_input_shape = {
                'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
            }
            config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
                                              opt_input_shape)
            print('trt set dynamic shape done!')
145 146 147 148

    predictor = paddle.inference.create_predictor(config)
    return predictor

S
simonsLiang 已提交
149

150 151 152 153 154 155 156 157 158 159
def setup_metrics(cfg):
    metrics = OrderedDict()
    if isinstance(list(cfg.values())[0], dict):
        for metric_name, cfg_ in cfg.items():
            metrics[metric_name] = build_metric(cfg_)
    else:
        metric = build_metric(cfg)
        metrics[metric.__class__.__name__] = metric

    return metrics
160

S
simonsLiang 已提交
161

162 163
def main():
    args = parse_args()
164 165 166
    if args.seed:
        paddle.seed(args.seed)
        random.seed(args.seed)
S
simonsLiang 已提交
167
        np.random.seed(args.seed)
168
    cfg = get_config(args.config_file, args.opt)
S
simonsLiang 已提交
169 170 171 172
    predictor = create_predictor(args.model_path, args.device, args.run_mode,
                                 args.batch_size, args.min_subgraph_size,
                                 args.use_dynamic_shape, args.trt_min_shape,
                                 args.trt_max_shape, args.trt_opt_shape,
L
lzzyzlbb 已提交
173
                                 args.trt_calib_mode)
174 175 176 177
    input_handles = [
        predictor.get_input_handle(name)
        for name in predictor.get_input_names()
    ]
L
lzzyzlbb 已提交
178

179 180 181 182
    output_handle = predictor.get_output_handle(predictor.get_output_names()[0])
    test_dataloader = build_dataloader(cfg.dataset.test,
                                       is_train=False,
                                       distributed=False)
183 184 185 186 187 188 189 190

    max_eval_steps = len(test_dataloader)
    iter_loader = IterLoader(test_dataloader)
    min_max = cfg.get('min_max', None)
    if min_max is None:
        min_max = (-1., 1.)

    model_type = args.model_type
191 192 193 194 195 196 197 198 199
    makedirs(os.path.join(args.output_path, model_type))

    validate_cfg = cfg.get('validate', None)
    metrics = None
    if validate_cfg and 'metrics' in validate_cfg:
        metrics = setup_metrics(validate_cfg['metrics'])
        for metric in metrics.values():
            metric.reset()

200 201 202 203 204 205 206
    for i in range(max_eval_steps):
        data = next(iter_loader)
        if model_type == "pix2pix":
            real_A = data['B'].numpy()
            input_handles[0].copy_from_cpu(real_A)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
207 208
            prediction = paddle.to_tensor(prediction)
            image_numpy = tensor2img(prediction[0], min_max)
S
simonsLiang 已提交
209 210 211
            save_image(
                image_numpy,
                os.path.join(args.output_path, "pix2pix/{}.png".format(i)))
212 213 214 215 216
            metric_file = os.path.join(args.output_path, "pix2pix/metric.txt")
            real_B = paddle.to_tensor(data['A'])
            for metric in metrics.values():
                metric.update(prediction, real_B)

217 218 219 220 221
        elif model_type == "cyclegan":
            real_A = data['A'].numpy()
            input_handles[0].copy_from_cpu(real_A)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
222 223
            prediction = paddle.to_tensor(prediction)
            image_numpy = tensor2img(prediction[0], min_max)
S
simonsLiang 已提交
224 225 226
            save_image(
                image_numpy,
                os.path.join(args.output_path, "cyclegan/{}.png".format(i)))
227 228 229 230 231
            metric_file = os.path.join(args.output_path, "cyclegan/metric.txt")
            real_B = paddle.to_tensor(data['B'])
            for metric in metrics.values():
                metric.update(prediction, real_B)

232 233 234 235 236 237 238 239 240 241 242
        elif model_type == "wav2lip":
            indiv_mels, x = data['indiv_mels'].numpy()[0], data['x'].numpy()[0]
            x = x.transpose([1, 0, 2, 3])
            input_handles[0].copy_from_cpu(indiv_mels)
            input_handles[1].copy_from_cpu(x)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
            for j in range(prediction.shape[0]):
                prediction[j] = prediction[j][::-1, :, :]
                image_numpy = paddle.to_tensor(prediction[j])
                image_numpy = tensor2img(image_numpy, (0, 1))
243 244
                save_image(image_numpy,
                           "infer_output/wav2lip/{}_{}.png".format(i, j))
245

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
        elif model_type == "esrgan":
            lq = data['lq'].numpy()
            input_handles[0].copy_from_cpu(lq)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
            prediction = paddle.to_tensor(prediction[0])
            image_numpy = tensor2img(prediction, min_max)
            save_image(image_numpy, "infer_output/esrgan/{}.png".format(i))
        elif model_type == "edvr":
            lq = data['lq'].numpy()
            input_handles[0].copy_from_cpu(lq)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
            prediction = paddle.to_tensor(prediction[0])
            image_numpy = tensor2img(prediction, min_max)
            save_image(image_numpy, "infer_output/edvr/{}.png".format(i))
262 263 264 265 266 267
        elif model_type == "stylegan2":
            noise = paddle.randn([1, 1, 512]).cpu().numpy()
            input_handles[0].copy_from_cpu(noise)
            input_handles[1].copy_from_cpu(np.array([0.7]).astype('float32'))
            predictor.run()
            prediction = output_handle.copy_to_cpu()
268 269
            prediction = paddle.to_tensor(prediction)
            image_numpy = tensor2img(prediction[0], min_max)
S
simonsLiang 已提交
270 271 272
            save_image(
                image_numpy,
                os.path.join(args.output_path, "stylegan2/{}.png".format(i)))
273 274 275 276
            metric_file = os.path.join(args.output_path, "stylegan2/metric.txt")
            real_img = paddle.to_tensor(data['A'])
            for metric in metrics.values():
                metric.update(prediction, real_img)
L
lzzyzlbb 已提交
277
        elif model_type in ["basicvsr", "msvsr"]:
278 279 280
            lq = data['lq'].numpy()
            input_handles[0].copy_from_cpu(lq)
            predictor.run()
L
lzzyzlbb 已提交
281
            if len(predictor.get_output_names()) > 1:
S
simonsLiang 已提交
282 283
                output_handle = predictor.get_output_handle(
                    predictor.get_output_names()[-1])
284
            prediction = output_handle.copy_to_cpu()
285 286
            prediction = paddle.to_tensor(prediction)
            _, t, _, _, _ = prediction.shape
L
lzzyzlbb 已提交
287

288 289 290 291 292
            out_img = []
            gt_img = []
            for ti in range(t):
                out_tensor = prediction[0, ti]
                gt_tensor = data['gt'][0, ti]
S
simonsLiang 已提交
293 294 295
                out_img.append(tensor2img(out_tensor, (0., 1.)))
                gt_img.append(tensor2img(gt_tensor, (0., 1.)))

296
            image_numpy = tensor2img(prediction[0], min_max)
S
simonsLiang 已提交
297 298 299
            save_image(
                image_numpy,
                os.path.join(args.output_path, model_type, "{}.png".format(i)))
300

S
simonsLiang 已提交
301 302
            metric_file = os.path.join(args.output_path, model_type,
                                       "metric.txt")
303 304
            for metric in metrics.values():
                metric.update(out_img, gt_img, is_seq=True)
B
BrilliantYuKaimin 已提交
305 306 307 308 309
        elif model_type == "singan":
            predictor.run()
            prediction = output_handle.copy_to_cpu()
            prediction = paddle.to_tensor(prediction)
            image_numpy = tensor2img(prediction, min_max)
S
simonsLiang 已提交
310 311 312
            save_image(
                image_numpy,
                os.path.join(args.output_path, "singan/{}.png".format(i)))
B
BrilliantYuKaimin 已提交
313 314 315
            metric_file = os.path.join(args.output_path, "singan/metric.txt")
            for metric in metrics.values():
                metric.update(prediction, data['A'])
S
simonsLiang 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
         elif model_type == "prenet":
            lq = data['lq'].numpy()
            gt = data['gt'].numpy()
            input_handles[0].copy_from_cpu(lq)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
            prediction = paddle.to_tensor(prediction)
            gt = paddle.to_tensor(gt)
            image_numpy = tensor2img(prediction, min_max)
            gt_img = tensor2img(gt, min_max)
            save_image(
                image_numpy,
                os.path.join(args.output_path, "prenet/{}.png".format(i)))
            metric_file = os.path.join(args.output_path, "prenet/metric.txt")
            for metric in metrics.values():
                metric.update(image_numpy, gt_img)
332

S
simonsLiang 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
        elif model_type == "prenet":
            lq = data['lq'].numpy()
            gt = data['gt'].numpy()
            input_handles[0].copy_from_cpu(lq)
            predictor.run()
            prediction = output_handle.copy_to_cpu()
            prediction = paddle.to_tensor(prediction)
            gt = paddle.to_tensor(gt)
            image_numpy = tensor2img(prediction, min_max)
            gt_img = tensor2img(gt, min_max)
            save_image(
                image_numpy,
                os.path.join(args.output_path, "prenet/{}.png".format(i)))
            metric_file = os.path.join(args.output_path, "prenet/metric.txt")
            for metric in metrics.values():
                metric.update(image_numpy, gt_img)

350 351 352
    if metrics:
        log_file = open(metric_file, 'a')
        for metric_name, metric in metrics.items():
S
simonsLiang 已提交
353 354
            loss_string = "Metric {}: {:.4f}".format(metric_name,
                                                     metric.accumulate())
355 356
            print(loss_string, file=log_file)
        log_file.close()
357

S
simonsLiang 已提交
358

359 360
if __name__ == '__main__':
    main()