local_predict.py 9.7 KB
Newer Older
D
dongdaxiang 已提交
1 2
# -*- coding: utf-8 -*-
"""
D
dongdaxiang 已提交
3
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
D
dongdaxiang 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import os
import google.protobuf.text_format
import numpy as np
import argparse
from .proto import general_model_config_pb2 as m_config
23
import paddle.inference as paddle_infer
D
dongdaxiang 已提交
24 25 26
import logging

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
27
logger = logging.getLogger("LocalPredictor")
D
dongdaxiang 已提交
28 29
logger.setLevel(logging.INFO)

Z
zhangjun 已提交
30 31 32 33 34 35
precision_map = {
    'int8': paddle_infer.PrecisionType.Int8,
    'fp32': paddle_infer.PrecisionType.Float32,
    'fp16': paddle_infer.PrecisionType.Half,
}

D
dongdaxiang 已提交
36

W
wangjiawei04 已提交
37
class LocalPredictor(object):
38 39 40 41 42 43
    """
    Prediction in the current process of the local environment, in process
    call, Compared with RPC/HTTP, LocalPredictor has better performance, 
    because of no network and packaging load.
    """

D
dongdaxiang 已提交
44 45 46 47 48 49 50 51 52 53
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.feed_types_ = {}
        self.fetch_types_ = {}
        self.feed_shapes_ = {}
        self.feed_names_to_idx_ = {}
        self.fetch_names_to_idx_ = {}
        self.fetch_names_to_type_ = {}

54 55 56 57 58 59 60 61 62
    def load_model_config(self,
                          model_path,
                          use_gpu=False,
                          gpu_id=0,
                          use_profile=False,
                          thread_num=1,
                          mem_optim=True,
                          ir_optim=False,
                          use_trt=False,
Z
zhangjun 已提交
63 64
                          use_lite=False,
                          use_xpu=False,
Z
zhangjun 已提交
65 66
                          precision="fp32",
                          use_calib=False,
67 68
                          use_feed_fetch_ops=False):
        """
69
        Load model configs and create the paddle predictor by Paddle Inference API.
70 71 72 73 74 75 76 77 78 79
   
        Args:
            model_path: model config path.
            use_gpu: calculating with gpu, False default.
            gpu_id: gpu id, 0 default.
            use_profile: use predictor profiles, False default.
            thread_num: thread nums, default 1. 
            mem_optim: memory optimization, True default.
            ir_optim: open calculation chart optimization, False default.
            use_trt: use nvidia TensorRT optimization, False default
Z
zhangjun 已提交
80 81
            use_lite: use Paddle-Lite Engint, False default
            use_xpu: run predict on Baidu Kunlun, False default
Z
zhangjun 已提交
82 83
            precision: precision mode, "fp32" default
            use_calib: use TensorRT calibration, False default
84 85
            use_feed_fetch_ops: use feed/fetch ops, False default.
        """
D
dongdaxiang 已提交
86 87 88 89 90
        client_config = "{}/serving_server_conf.prototxt".format(model_path)
        model_conf = m_config.GeneralModelConfig()
        f = open(client_config, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)
W
wangjiawei04 已提交
91
        if os.path.exists(os.path.join(model_path, "__params__")):
92 93 94
            config = paddle_infer.Config(
                os.path.join(model_path, "__model__"),
                os.path.join(model_path, "__params__"))
W
wangjiawei04 已提交
95
        else:
96 97 98 99
            config = paddle_infer.Config(model_path)

        logger.info(
            "LocalPredictor load_model_config params: model_path:{}, use_gpu:{},\
100
            gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\
Z
zhangjun 已提交
101 102 103 104 105
            use_trt:{}, use_lite:{}, use_xpu: {}, precision: {}, use_calib: {},\
            use_feed_fetch_ops:{}"
            .format(model_path, use_gpu, gpu_id, use_profile, thread_num,
                    mem_optim, ir_optim, use_trt, use_lite, use_xpu, precision,
                    use_calib, use_feed_fetch_ops))
D
dongdaxiang 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120

        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
        self.feed_names_to_idx_ = {}
        self.fetch_names_to_idx_ = {}

        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape

        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type

Z
zhangjun 已提交
121 122 123
        precision_type = paddle_infer.PrecisionType.Float32
        if precision in precision_map:
            precision_type = precision_map[precision]
124
        if use_profile:
D
dongdaxiang 已提交
125
            config.enable_profile()
126 127 128 129 130
        if mem_optim:
            config.enable_memory_optim()
        config.switch_ir_optim(ir_optim)
        config.set_cpu_math_library_num_threads(thread_num)
        config.switch_use_feed_fetch_ops(use_feed_fetch_ops)
W
wangjiawei04 已提交
131
        config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
132 133 134 135 136 137 138

        if not use_gpu:
            config.disable_gpu()
        else:
            config.enable_use_gpu(100, gpu_id)
            if use_trt:
                config.enable_tensorrt_engine(
Z
zhangjun 已提交
139
                    precision_mode=precision_type,
140 141 142 143 144 145
                    workspace_size=1 << 20,
                    max_batch_size=32,
                    min_subgraph_size=3,
                    use_static=False,
                    use_calib_mode=False)

Z
zhangjun 已提交
146 147
        if use_lite:
            config.enable_lite_engine(
Z
zhangjun 已提交
148
                precision_mode=precision_type,
149 150 151
                zero_copy=True,
                passes_filter=[],
                ops_filter=[])
Z
zhangjun 已提交
152 153

        if use_xpu:
Z
zhangjun 已提交
154
            # 2MB l3 cache
T
TeslaZhao 已提交
155
            config.enable_xpu(8 * 1024 * 1024)
Z
zhangjun 已提交
156

Z
zhangjun 已提交
157 158 159 160 161
        if not use_gpu and not use_lite:
            if precision_type == paddle_infer.PrecisionType.Int8:
                config.enable_quantizer()
            if precision == "bf16":
                config.enable_mkldnn_bfloat16()
162
        self.predictor = paddle_infer.create_predictor(config)
D
dongdaxiang 已提交
163

W
wangjiawei04 已提交
164
    def predict(self, feed=None, fetch=None, batch=False, log_id=0):
165
        """
166
        Run model inference by Paddle Inference API.
167 168 169 170 171 172 173 174 175 176 177

        Args:
            feed: feed var
            fetch: fetch var
            batch: batch data or not, False default.If batch is False, a new
                   dimension is added to header of the shape[np.newaxis].
            log_id: for logging

        Returns:
            fetch_map: dict 
        """
D
dongdaxiang 已提交
178
        if feed is None or fetch is None:
179 180
            raise ValueError("You should specify feed and fetch for prediction.\
                log_id:{}".format(log_id))
D
dongdaxiang 已提交
181 182 183 184 185 186
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
187 188
            raise ValueError("Fetch only accepts string and list of string.\
                log_id:{}".format(log_id))
D
dongdaxiang 已提交
189 190 191 192 193 194 195

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
196 197
            raise ValueError("Feed only accepts dict and list of dict.\
                log_id:{}".format(log_id))
D
dongdaxiang 已提交
198

199 200
        fetch_names = []
        # Filter invalid fetch names
D
dongdaxiang 已提交
201 202 203 204 205 206
        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
207 208
                "Fetch names should not be empty or out of saved fetch list.\
                    log_id:{}".format(log_id))
D
dongdaxiang 已提交
209

210
        # Assemble the input data of paddle predictor 
211 212
        input_names = self.predictor.get_input_names()
        for name in input_names:
M
MRXLT 已提交
213 214 215
            if isinstance(feed[name], list):
                feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[
                    name])
216 217
            if self.feed_types_[name] == 0:
                feed[name] = feed[name].astype("int64")
W
wangjiawei04 已提交
218
            elif self.feed_types_[name] == 1:
219
                feed[name] = feed[name].astype("float32")
W
wangjiawei04 已提交
220 221 222 223
            elif self.feed_types_[name] == 2:
                feed[name] = feed[name].astype("int32")
            else:
                raise ValueError("local predictor receives wrong data type")
224
            input_tensor_handle = self.predictor.get_input_handle(name)
W
wangjiawei04 已提交
225
            if "{}.lod".format(name) in feed:
226
                input_tensor_handle.set_lod([feed["{}.lod".format(name)]])
W
wangjiawei04 已提交
227
            if batch == False:
228
                input_tensor_handle.copy_from_cpu(feed[name][np.newaxis, :])
W
wangjiawei04 已提交
229
            else:
230 231
                input_tensor_handle.copy_from_cpu(feed[name])
        output_tensor_handles = []
232 233
        output_names = self.predictor.get_output_names()
        for output_name in output_names:
234 235 236 237 238 239 240
            output_tensor_handle = self.predictor.get_output_handle(output_name)
            output_tensor_handles.append(output_tensor_handle)

        # Run inference 
        self.predictor.run()

        # Assemble output data of predict results
241
        outputs = []
242 243
        for output_tensor_handle in output_tensor_handles:
            output = output_tensor_handle.copy_to_cpu()
244
            outputs.append(output)
D
dongdaxiang 已提交
245
        fetch_map = {}
246 247
        for i, name in enumerate(fetch):
            fetch_map[name] = outputs[i]
248 249 250
            if len(output_tensor_handles[i].lod()) > 0:
                fetch_map[name + ".lod"] = np.array(output_tensor_handles[i]
                                                    .lod()[0]).astype('int32')
D
dongdaxiang 已提交
251
        return fetch_map