local_predict.py 16.4 KB
Newer Older
D
dongdaxiang 已提交
1 2
# -*- coding: utf-8 -*-
"""
D
dongdaxiang 已提交
3
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
D
dongdaxiang 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import os
import google.protobuf.text_format
import numpy as np
import argparse
from .proto import general_model_config_pb2 as m_config
23
import paddle.inference as paddle_infer
D
dongdaxiang 已提交
24
import logging
25
import glob
D
dongdaxiang 已提交
26 27

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
28
logger = logging.getLogger("LocalPredictor")
D
dongdaxiang 已提交
29 30
logger.setLevel(logging.INFO)

Z
zhangjun 已提交
31 32 33 34 35 36
precision_map = {
    'int8': paddle_infer.PrecisionType.Int8,
    'fp32': paddle_infer.PrecisionType.Float32,
    'fp16': paddle_infer.PrecisionType.Half,
}

D
dongdaxiang 已提交
37

W
wangjiawei04 已提交
38
class LocalPredictor(object):
39 40 41 42 43 44
    """
    Prediction in the current process of the local environment, in process
    call, Compared with RPC/HTTP, LocalPredictor has better performance, 
    because of no network and packaging load.
    """

D
dongdaxiang 已提交
45 46 47 48 49 50 51 52 53 54
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.feed_types_ = {}
        self.fetch_types_ = {}
        self.feed_shapes_ = {}
        self.feed_names_to_idx_ = {}
        self.fetch_names_to_idx_ = {}
        self.fetch_names_to_type_ = {}

55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    def search_suffix_files(self, model_path, target_suffix):
        """
        Find all files with the suffix xxx in the specified directory.

        Args:
            model_path: model directory, not None.
            target_suffix: filenames with target suffix, not None. e.g: *.pdmodel

        Returns:
            file_list, None, [] or [path, ] . 
        """
        if model_path is None or target_suffix is None:
            return None

        file_list = glob.glob(os.path.join(model_path, target_suffix))
        return file_list

72 73 74 75 76 77 78 79 80
    def load_model_config(self,
                          model_path,
                          use_gpu=False,
                          gpu_id=0,
                          use_profile=False,
                          thread_num=1,
                          mem_optim=True,
                          ir_optim=False,
                          use_trt=False,
Z
zhangjun 已提交
81 82
                          use_lite=False,
                          use_xpu=False,
Z
zhangjun 已提交
83 84
                          precision="fp32",
                          use_calib=False,
T
TeslaZhao 已提交
85 86 87 88
                          use_mkldnn=False,
                          mkldnn_cache_capacity=0,
                          mkldnn_op_list=None,
                          mkldnn_bf16_op_list=None,
89 90
                          use_feed_fetch_ops=False,
                          use_ascend_cl=False):
91
        """
92
        Load model configs and create the paddle predictor by Paddle Inference API.
93 94 95 96 97 98
   
        Args:
            model_path: model config path.
            use_gpu: calculating with gpu, False default.
            gpu_id: gpu id, 0 default.
            use_profile: use predictor profiles, False default.
T
TeslaZhao 已提交
99
            thread_num: thread nums of cpu math library, default 1. 
100 101 102
            mem_optim: memory optimization, True default.
            ir_optim: open calculation chart optimization, False default.
            use_trt: use nvidia TensorRT optimization, False default
Z
zhangjun 已提交
103 104
            use_lite: use Paddle-Lite Engint, False default
            use_xpu: run predict on Baidu Kunlun, False default
Z
zhangjun 已提交
105 106
            precision: precision mode, "fp32" default
            use_calib: use TensorRT calibration, False default
T
TeslaZhao 已提交
107 108 109 110
            use_mkldnn: use MKLDNN, False default.
            mkldnn_cache_capacity: cache capacity for input shapes, 0 default.
            mkldnn_op_list: op list accelerated using MKLDNN, None default.
            mkldnn_bf16_op_list: op list accelerated using MKLDNN bf16, None default.
111
            use_feed_fetch_ops: use feed/fetch ops, False default.
112
            use_ascend_cl: run predict on Huawei Ascend, False default
113
        """
H
HexToString 已提交
114
        gpu_id = int(gpu_id)
D
dongdaxiang 已提交
115 116 117 118 119
        client_config = "{}/serving_server_conf.prototxt".format(model_path)
        model_conf = m_config.GeneralModelConfig()
        f = open(client_config, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)
120 121 122 123 124 125 126 127 128

        # Init paddle_infer config
        # Paddle's model files and parameter files have multiple naming rules:
        #   1) __model__, __params__
        #   2) *.pdmodel, *.pdiparams
        #   3) __model__, conv2d_1.w_0, conv2d_2.w_0, fc_1.w_0, conv2d_1.b_0, ... 
        pdmodel_file_list = self.search_suffix_files(model_path, "*.pdmodel")
        pdiparams_file_list = self.search_suffix_files(model_path,
                                                       "*.pdiparams")
W
wangjiawei04 已提交
129
        if os.path.exists(os.path.join(model_path, "__params__")):
130
            # case 1) initializing
131 132 133
            config = paddle_infer.Config(
                os.path.join(model_path, "__model__"),
                os.path.join(model_path, "__params__"))
134 135 136 137 138 139 140 141
        elif pdmodel_file_list and len(
                pdmodel_file_list) > 0 and pdiparams_file_list and len(
                    pdiparams_file_list) > 0:
            # case 2) initializing
            logger.info("pdmodel_file_list:{}, pdiparams_file_list:{}".format(
                pdmodel_file_list, pdiparams_file_list))
            config = paddle_infer.Config(pdmodel_file_list[0],
                                         pdiparams_file_list[0])
W
wangjiawei04 已提交
142
        else:
143
            # case 3) initializing.
144 145 146
            config = paddle_infer.Config(model_path)

        logger.info(
T
TeslaZhao 已提交
147 148 149 150
            "LocalPredictor load_model_config params: model_path:{}, use_gpu:{}, "
            "gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{}, "
            "use_trt:{}, use_lite:{}, use_xpu:{}, precision:{}, use_calib:{}, "
            "use_mkldnn:{}, mkldnn_cache_capacity:{}, mkldnn_op_list:{}, "
151 152
            "mkldnn_bf16_op_list:{}, use_feed_fetch_ops:{}, "
            "use_ascend_cl:{} ".format(
T
TeslaZhao 已提交
153 154 155
                model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim,
                ir_optim, use_trt, use_lite, use_xpu, precision, use_calib,
                use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
156
                mkldnn_bf16_op_list, use_feed_fetch_ops, use_ascend_cl))
D
dongdaxiang 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169

        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
        self.feed_names_to_idx_ = {}
        self.fetch_names_to_idx_ = {}

        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape

        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
170 171
            self.fetch_types_[var.alias_name] = var.fetch_type
            self.fetch_names_to_type_[var.alias_name] = var.shape
D
dongdaxiang 已提交
172

T
TeslaZhao 已提交
173
        # set precision of inference.
Z
zhangjun 已提交
174
        precision_type = paddle_infer.PrecisionType.Float32
175
        if precision is not None and precision.lower() in precision_map:
Z
zhangjun 已提交
176
            precision_type = precision_map[precision.lower()]
177 178 179
        else:
            logger.warning("precision error!!! Please check precision:{}".
                           format(precision))
T
TeslaZhao 已提交
180
        # set profile
181
        if use_profile:
D
dongdaxiang 已提交
182
            config.enable_profile()
T
TeslaZhao 已提交
183
        # set memory optimization
184 185
        if mem_optim:
            config.enable_memory_optim()
T
TeslaZhao 已提交
186
        # set ir optimization, threads of cpu math library
187
        config.switch_ir_optim(ir_optim)
T
TeslaZhao 已提交
188
        # use feed & fetch ops
189
        config.switch_use_feed_fetch_ops(use_feed_fetch_ops)
T
TeslaZhao 已提交
190
        # pass optim
W
wangjiawei04 已提交
191
        config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
192

T
TeslaZhao 已提交
193 194 195 196 197 198 199 200 201
        # set cpu & mkldnn
        config.set_cpu_math_library_num_threads(thread_num)
        if use_mkldnn:
            config.enable_mkldnn()
            if mkldnn_cache_capacity > 0:
                config.set_mkldnn_cache_capacity(mkldnn_cache_capacity)
            if mkldnn_op_list is not None:
                config.set_mkldnn_op(mkldnn_op_list)
        # set gpu
202 203 204 205 206 207
        if not use_gpu:
            config.disable_gpu()
        else:
            config.enable_use_gpu(100, gpu_id)
            if use_trt:
                config.enable_tensorrt_engine(
Z
zhangjun 已提交
208
                    precision_mode=precision_type,
209 210 211 212 213
                    workspace_size=1 << 20,
                    max_batch_size=32,
                    min_subgraph_size=3,
                    use_static=False,
                    use_calib_mode=False)
T
TeslaZhao 已提交
214
        # set lite
Z
zhangjun 已提交
215 216
        if use_lite:
            config.enable_lite_engine(
Z
zhangjun 已提交
217
                precision_mode=precision_type,
218 219 220
                zero_copy=True,
                passes_filter=[],
                ops_filter=[])
221
            config.switch_ir_optim(True)
T
TeslaZhao 已提交
222
        # set xpu
Z
zhangjun 已提交
223
        if use_xpu:
Z
zhangjun 已提交
224
            # 2MB l3 cache
T
TeslaZhao 已提交
225
            config.enable_xpu(8 * 1024 * 1024)
S
ShiningZhang 已提交
226
            config.set_xpu_device_id(gpu_id)
227 228 229 230 231 232 233 234 235 236 237 238
        # set ascend cl
        if use_ascend_cl:
            if use_lite:
                nnadapter_device_names = "huawei_ascend_npu"
                nnadapter_context_properties = \
                    "HUAWEI_ASCEND_NPU_SELECTED_DEVICE_IDS={}".format(gpu_id)
                nnadapter_model_cache_dir = ""
                config.nnadapter() \
                .enable() \
                .set_device_names([nnadapter_device_names]) \
                .set_context_properties(nnadapter_context_properties) \
                .set_model_cache_dir(nnadapter_model_cache_dir)
T
TeslaZhao 已提交
239
        # set cpu low precision
Z
zhangjun 已提交
240 241
        if not use_gpu and not use_lite:
            if precision_type == paddle_infer.PrecisionType.Int8:
242 243 244 245 246
                logger.warning(
                    "PRECISION INT8 is not supported in CPU right now! Please use fp16 or bf16."
                )
                #config.enable_quantizer()
            if precision is not None and precision.lower() == "bf16":
Z
zhangjun 已提交
247
                config.enable_mkldnn_bfloat16()
T
TeslaZhao 已提交
248 249 250
                if mkldnn_bf16_op_list is not None:
                    config.set_bfloat16_op(mkldnn_bf16_op_list)

251
        self.predictor = paddle_infer.create_predictor(config)
D
dongdaxiang 已提交
252

W
wangjiawei04 已提交
253
    def predict(self, feed=None, fetch=None, batch=False, log_id=0):
254
        """
255
        Run model inference by Paddle Inference API.
256 257

        Args:
258 259 260
            feed: feed var list, None is not allowed.
            fetch: fetch var list, None allowed. when it is None, all fetch 
                   vars are returned. Otherwise, return fetch specified result.
261 262 263 264 265 266 267
            batch: batch data or not, False default.If batch is False, a new
                   dimension is added to header of the shape[np.newaxis].
            log_id: for logging

        Returns:
            fetch_map: dict 
        """
268 269
        if feed is None:
            raise ValueError("You should specify feed vars for prediction.\
270
                log_id:{}".format(log_id))
D
dongdaxiang 已提交
271 272 273 274 275 276 277

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
278 279
            raise ValueError("Feed only accepts dict and list of dict.\
                log_id:{}".format(log_id))
D
dongdaxiang 已提交
280

281 282 283 284 285 286 287
        fetch_list = []
        if fetch is not None:
            if isinstance(fetch, str):
                fetch_list = [fetch]
            elif isinstance(fetch, list):
                fetch_list = fetch

288
        # Filter invalid fetch names
289
        fetch_names = []
D
dongdaxiang 已提交
290 291 292 293
        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

294
        # Assemble the input data of paddle predictor, and filter invalid inputs. 
295 296
        input_names = self.predictor.get_input_names()
        for name in input_names:
M
MRXLT 已提交
297 298 299
            if isinstance(feed[name], list):
                feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[
                    name])
300 301
            if self.feed_types_[name] == 0:
                feed[name] = feed[name].astype("int64")
W
wangjiawei04 已提交
302
            elif self.feed_types_[name] == 1:
303
                feed[name] = feed[name].astype("float32")
W
wangjiawei04 已提交
304 305
            elif self.feed_types_[name] == 2:
                feed[name] = feed[name].astype("int32")
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
            elif self.feed_types_[name] == 3:
                feed[name] = feed[name].astype("float64")
            elif self.feed_types_[name] == 4:
                feed[name] = feed[name].astype("int16")
            elif self.feed_types_[name] == 5:
                feed[name] = feed[name].astype("float16")
            elif self.feed_types_[name] == 6:
                feed[name] = feed[name].astype("uint16")
            elif self.feed_types_[name] == 7:
                feed[name] = feed[name].astype("uint8")
            elif self.feed_types_[name] == 8:
                feed[name] = feed[name].astype("int8")
            elif self.feed_types_[name] == 9:
                feed[name] = feed[name].astype("bool")
            elif self.feed_types_[name] == 10:
                feed[name] = feed[name].astype("complex64")
            elif self.feed_types_[name] == 11:
                feed[name] = feed[name].astype("complex128")
W
wangjiawei04 已提交
324 325
            else:
                raise ValueError("local predictor receives wrong data type")
326

327
            input_tensor_handle = self.predictor.get_input_handle(name)
W
wangjiawei04 已提交
328
            if "{}.lod".format(name) in feed:
329
                input_tensor_handle.set_lod([feed["{}.lod".format(name)]])
W
wangjiawei04 已提交
330
            if batch == False:
331
                input_tensor_handle.copy_from_cpu(feed[name][np.newaxis, :])
W
wangjiawei04 已提交
332
            else:
333
                input_tensor_handle.copy_from_cpu(feed[name])
334 335

        # set output tensor handlers
336
        output_tensor_handles = []
337
        output_name_to_index_dict = {}
338
        output_names = self.predictor.get_output_names()
339
        for i, output_name in enumerate(output_names):
340 341
            output_tensor_handle = self.predictor.get_output_handle(output_name)
            output_tensor_handles.append(output_tensor_handle)
342
            output_name_to_index_dict[output_name] = i
343 344 345 346 347

        # Run inference 
        self.predictor.run()

        # Assemble output data of predict results
348
        outputs = []
349 350
        for output_tensor_handle in output_tensor_handles:
            output = output_tensor_handle.copy_to_cpu()
351
            outputs.append(output)
352 353 354 355
        outputs_len = len(outputs)

        # Copy fetch vars. If fetch is None, it will copy all results from output_tensor_handles. 
        # Otherwise, it will copy the fields specified from output_tensor_handles.
D
dongdaxiang 已提交
356
        fetch_map = {}
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
        if fetch is None:
            for i, name in enumerate(output_names):
                fetch_map[name] = outputs[i]
                if len(output_tensor_handles[i].lod()) > 0:
                    fetch_map[name + ".lod"] = np.array(output_tensor_handles[
                        i].lod()[0]).astype('int32')
        else:
            # Because the save_inference_model interface will increase the scale op 
            # in the network, the name of fetch_var is different from that in prototxt. 
            # Therefore, it is compatible with v0.6.x and the previous model save format,
            # and here is compatible with the results that do not match.
            fetch_match_num = 0
            for i, name in enumerate(fetch):
                output_index = output_name_to_index_dict.get(name)
                if output_index is None:
                    continue

                fetch_map[name] = outputs[output_index]
                fetch_match_num += 1
                if len(output_tensor_handles[output_index].lod()) > 0:
                    fetch_map[name + ".lod"] = np.array(output_tensor_handles[
                        output_index].lod()[0]).astype('int32')

            # Compatible with v0.6.x and lower versions model saving formats.
            if fetch_match_num == 0:
                logger.debug("fetch match num is 0. Retrain the model please!")
                for i, name in enumerate(fetch):
                    if i >= outputs_len:
                        break
                    fetch_map[name] = outputs[i]
                    if len(output_tensor_handles[i].lod()) > 0:
                        fetch_map[name + ".lod"] = np.array(
                            output_tensor_handles[i].lod()[0]).astype('int32')

D
dongdaxiang 已提交
391
        return fetch_map