predictor.py 4.8 KB
Newer Older
littletomatodonkey's avatar
littletomatodonkey 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
G
gaotingquan 已提交
14
import platform
littletomatodonkey's avatar
littletomatodonkey 已提交
15 16 17 18 19 20 21 22 23 24 25 26
import os
import argparse
import base64
import shutil
import cv2
import numpy as np

from paddle.inference import Config
from paddle.inference import create_predictor


class Predictor(object):
littletomatodonkey's avatar
littletomatodonkey 已提交
27
    def __init__(self, args, inference_model_dir=None):
littletomatodonkey's avatar
littletomatodonkey 已提交
28 29 30 31
        # HALF precission predict only work when using tensorrt
        if args.use_fp16 is True:
            assert args.use_tensorrt is True
        self.args = args
32 33 34 35 36 37
        if self.args.get("use_onnx", False):
            self.predictor, self.config = self.create_onnx_predictor(
                args, inference_model_dir)
        else:
            self.predictor, self.config = self.create_paddle_predictor(
                args, inference_model_dir)
littletomatodonkey's avatar
littletomatodonkey 已提交
38

littletomatodonkey's avatar
littletomatodonkey 已提交
39 40
    def predict(self, image):
        raise NotImplementedError
littletomatodonkey's avatar
littletomatodonkey 已提交
41

littletomatodonkey's avatar
littletomatodonkey 已提交
42 43 44
    def create_paddle_predictor(self, args, inference_model_dir=None):
        if inference_model_dir is None:
            inference_model_dir = args.inference_model_dir
littletomatodonkey's avatar
littletomatodonkey 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
        if "inference_int8.pdiparams" in os.listdir(inference_model_dir):
            params_file = os.path.join(inference_model_dir,
                                       "inference_int8.pdiparams")
            model_file = os.path.join(inference_model_dir,
                                      "inference_int8.pdmodel")
            assert args.get(
                "use_fp16", False
            ) is False, "fp16 mode is not supported for int8 model inference, please set use_fp16 as False during inference."
        else:
            params_file = os.path.join(inference_model_dir,
                                       "inference.pdiparams")
            model_file = os.path.join(inference_model_dir, "inference.pdmodel")
            assert args.get(
                "use_int8", False
            ) is False, "int8 mode is not supported for fp32 model inference, please set use_int8 as False during inference."

littletomatodonkey's avatar
littletomatodonkey 已提交
61 62
        config = Config(model_file, params_file)

63
        if args.get("use_gpu", False):
littletomatodonkey's avatar
littletomatodonkey 已提交
64
            config.enable_use_gpu(args.gpu_mem, 0)
65 66 67 68
        elif args.get("use_npu", False):
            config.enable_npu()
        elif args.get("use_xpu", False):
            config.enable_xpu()
littletomatodonkey's avatar
littletomatodonkey 已提交
69 70 71
        else:
            config.disable_gpu()
            if args.enable_mkldnn:
G
gaotingquan 已提交
72 73 74 75
                # there is no set_mkldnn_cache_capatity() on macOS
                if platform.system() != "Darwin":
                    # cache 10 different shapes for mkldnn to avoid memory leak
                    config.set_mkldnn_cache_capacity(10)
littletomatodonkey's avatar
littletomatodonkey 已提交
76 77 78 79 80 81 82 83
                config.enable_mkldnn()
        config.set_cpu_math_library_num_threads(args.cpu_num_threads)

        if args.enable_profile:
            config.enable_profile()
        config.disable_glog_info()
        config.switch_ir_optim(args.ir_optim)  # default true
        if args.use_tensorrt:
littletomatodonkey's avatar
littletomatodonkey 已提交
84 85 86 87 88 89
            precision = Config.Precision.Float32
            if args.get("use_int8", False):
                precision = Config.Precision.Int8
            elif args.get("use_fp16", False):
                precision = Config.Precision.Half

littletomatodonkey's avatar
littletomatodonkey 已提交
90
            config.enable_tensorrt_engine(
littletomatodonkey's avatar
littletomatodonkey 已提交
91
                precision_mode=precision,
D
dongshuilong 已提交
92
                max_batch_size=args.batch_size,
C
cuicheng01 已提交
93
                workspace_size=1 << 30,
littletomatodonkey's avatar
littletomatodonkey 已提交
94 95
                min_subgraph_size=30,
                use_calib_mode=False)
littletomatodonkey's avatar
littletomatodonkey 已提交
96 97 98 99 100 101

        config.enable_memory_optim()
        # use zero copy
        config.switch_use_feed_fetch_ops(False)
        predictor = create_predictor(config)

D
dongshuilong 已提交
102
        return predictor, config
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119

    def create_onnx_predictor(self, args, inference_model_dir=None):
        import onnxruntime as ort
        if inference_model_dir is None:
            inference_model_dir = args.inference_model_dir
        model_file = os.path.join(inference_model_dir, "inference.onnx")
        config = ort.SessionOptions()
        if args.use_gpu:
            raise ValueError(
                "onnx inference now only supports cpu! please specify use_gpu false."
            )
        else:
            config.intra_op_num_threads = args.cpu_num_threads
            if args.ir_optim:
                config.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        predictor = ort.InferenceSession(model_file, sess_options=config)
        return predictor, config