local_predict.py 8.5 KB
Newer Older
D
dongdaxiang 已提交
1 2
# -*- coding: utf-8 -*-
"""
D
dongdaxiang 已提交
3
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
D
dongdaxiang 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import os
import google.protobuf.text_format
import numpy as np
import argparse
import paddle.fluid as fluid
23
import paddle.inference as inference
D
dongdaxiang 已提交
24 25 26 27 28 29 30 31 32 33 34
from .proto import general_model_config_pb2 as m_config
from paddle.fluid.core import PaddleTensor
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
import logging

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)


W
wangjiawei04 已提交
35
class LocalPredictor(object):
36 37 38 39 40 41
    """
    Prediction in the current process of the local environment, in process
    call, Compared with RPC/HTTP, LocalPredictor has better performance, 
    because of no network and packaging load.
    """

D
dongdaxiang 已提交
42 43 44 45 46 47 48 49 50 51
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.feed_types_ = {}
        self.fetch_types_ = {}
        self.feed_shapes_ = {}
        self.feed_names_to_idx_ = {}
        self.fetch_names_to_idx_ = {}
        self.fetch_names_to_type_ = {}

52 53 54 55 56 57 58 59 60
    def load_model_config(self,
                          model_path,
                          use_gpu=False,
                          gpu_id=0,
                          use_profile=False,
                          thread_num=1,
                          mem_optim=True,
                          ir_optim=False,
                          use_trt=False,
Z
zhangjun 已提交
61 62
                          use_lite=False,
                          use_xpu=False,
63 64 65 66 67 68 69 70 71 72 73 74 75
                          use_feed_fetch_ops=False):
        """
        Load model config and set the engine config for the paddle predictor
   
        Args:
            model_path: model config path.
            use_gpu: calculating with gpu, False default.
            gpu_id: gpu id, 0 default.
            use_profile: use predictor profiles, False default.
            thread_num: thread nums, default 1. 
            mem_optim: memory optimization, True default.
            ir_optim: open calculation chart optimization, False default.
            use_trt: use nvidia TensorRT optimization, False default
Z
zhangjun 已提交
76 77
            use_lite: use Paddle-Lite Engint, False default
            use_xpu: run predict on Baidu Kunlun, False default
78 79
            use_feed_fetch_ops: use feed/fetch ops, False default.
        """
D
dongdaxiang 已提交
80 81 82 83 84 85
        client_config = "{}/serving_server_conf.prototxt".format(model_path)
        model_conf = m_config.GeneralModelConfig()
        f = open(client_config, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)
        config = AnalysisConfig(model_path)
86 87
        logger.info("load_model_config params: model_path:{}, use_gpu:{},\
            gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\
Z
zhangjun 已提交
88
            use_trt:{}, use_lite:{}, use_xpu: {}, use_feed_fetch_ops:{}".format(
89
            model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim,
Z
zhangjun 已提交
90
            ir_optim, use_trt, use_lite, use_xpu, use_feed_fetch_ops))
D
dongdaxiang 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105

        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
        self.feed_names_to_idx_ = {}
        self.fetch_names_to_idx_ = {}

        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape

        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type

106
        if use_profile:
D
dongdaxiang 已提交
107
            config.enable_profile()
108 109 110 111 112
        if mem_optim:
            config.enable_memory_optim()
        config.switch_ir_optim(ir_optim)
        config.set_cpu_math_library_num_threads(thread_num)
        config.switch_use_feed_fetch_ops(use_feed_fetch_ops)
W
wangjiawei04 已提交
113
        config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
114 115 116 117 118 119 120 121 122 123 124 125 126

        if not use_gpu:
            config.disable_gpu()
        else:
            config.enable_use_gpu(100, gpu_id)
            if use_trt:
                config.enable_tensorrt_engine(
                    workspace_size=1 << 20,
                    max_batch_size=32,
                    min_subgraph_size=3,
                    use_static=False,
                    use_calib_mode=False)

Z
zhangjun 已提交
127 128
        if use_lite:
            config.enable_lite_engine(
129 130 131 132
                precision_mode=inference.PrecisionType.Float32,
                zero_copy=True,
                passes_filter=[],
                ops_filter=[])
Z
zhangjun 已提交
133 134

        if use_xpu:
T
TeslaZhao 已提交
135
            config.enable_xpu(8 * 1024 * 1024)
Z
zhangjun 已提交
136

D
dongdaxiang 已提交
137 138
        self.predictor = create_paddle_predictor(config)

W
wangjiawei04 已提交
139
    def predict(self, feed=None, fetch=None, batch=False, log_id=0):
140 141 142 143 144 145 146 147 148 149 150 151 152
        """
        Predict locally

        Args:
            feed: feed var
            fetch: fetch var
            batch: batch data or not, False default.If batch is False, a new
                   dimension is added to header of the shape[np.newaxis].
            log_id: for logging

        Returns:
            fetch_map: dict 
        """
D
dongdaxiang 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
            raise ValueError("Fetch only accepts string and list of string")

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
            raise ValueError("Feed only accepts dict and list of dict")

        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
        int_shape = []
        float_shape = []
        fetch_names = []
        counter = 0
        batch_size = len(feed_batch)

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
                "Fetch names should not be empty or out of saved fetch list.")
            return {}

190 191
        input_names = self.predictor.get_input_names()
        for name in input_names:
M
MRXLT 已提交
192 193 194
            if isinstance(feed[name], list):
                feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[
                    name])
195 196
            if self.feed_types_[name] == 0:
                feed[name] = feed[name].astype("int64")
W
wangjiawei04 已提交
197
            elif self.feed_types_[name] == 1:
198
                feed[name] = feed[name].astype("float32")
W
wangjiawei04 已提交
199 200 201 202
            elif self.feed_types_[name] == 2:
                feed[name] = feed[name].astype("int32")
            else:
                raise ValueError("local predictor receives wrong data type")
203
            input_tensor = self.predictor.get_input_tensor(name)
W
wangjiawei04 已提交
204 205 206 207 208 209
            if "{}.lod".format(name) in feed:
                input_tensor.set_lod([feed["{}.lod".format(name)]])
            if batch == False:
                input_tensor.copy_from_cpu(feed[name][np.newaxis, :])
            else:
                input_tensor.copy_from_cpu(feed[name])
210 211 212 213 214 215 216 217 218 219
        output_tensors = []
        output_names = self.predictor.get_output_names()
        for output_name in output_names:
            output_tensor = self.predictor.get_output_tensor(output_name)
            output_tensors.append(output_tensor)
        outputs = []
        self.predictor.zero_copy_run()
        for output_tensor in output_tensors:
            output = output_tensor.copy_to_cpu()
            outputs.append(output)
D
dongdaxiang 已提交
220
        fetch_map = {}
221 222 223
        for i, name in enumerate(fetch):
            fetch_map[name] = outputs[i]
            if len(output_tensors[i].lod()) > 0:
W
wangjiawei04 已提交
224 225
                fetch_map[name + ".lod"] = np.array(output_tensors[i].lod()[
                    0]).astype('int32')
D
dongdaxiang 已提交
226
        return fetch_map