local_predict.py 8.7 KB
Newer Older
D
dongdaxiang 已提交
1 2
# -*- coding: utf-8 -*-
"""
D
dongdaxiang 已提交
3
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
D
dongdaxiang 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import os
import google.protobuf.text_format
import numpy as np
import argparse
import paddle.fluid as fluid
23
import paddle.inference as inference
D
dongdaxiang 已提交
24 25 26 27 28 29 30 31 32 33 34
from .proto import general_model_config_pb2 as m_config
from paddle.fluid.core import PaddleTensor
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
import logging

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)


W
wangjiawei04 已提交
35
class LocalPredictor(object):
36 37 38 39 40 41
    """
    Prediction in the current process of the local environment, in process
    call, Compared with RPC/HTTP, LocalPredictor has better performance, 
    because of no network and packaging load.
    """

D
dongdaxiang 已提交
42 43 44 45 46 47 48 49 50 51
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.feed_types_ = {}
        self.fetch_types_ = {}
        self.feed_shapes_ = {}
        self.feed_names_to_idx_ = {}
        self.fetch_names_to_idx_ = {}
        self.fetch_names_to_type_ = {}

52 53 54 55 56 57 58 59 60
    def load_model_config(self,
                          model_path,
                          use_gpu=False,
                          gpu_id=0,
                          use_profile=False,
                          thread_num=1,
                          mem_optim=True,
                          ir_optim=False,
                          use_trt=False,
Z
zhangjun 已提交
61 62
                          use_lite=False,
                          use_xpu=False,
63 64 65 66 67 68 69 70 71 72 73 74 75
                          use_feed_fetch_ops=False):
        """
        Load model config and set the engine config for the paddle predictor
   
        Args:
            model_path: model config path.
            use_gpu: calculating with gpu, False default.
            gpu_id: gpu id, 0 default.
            use_profile: use predictor profiles, False default.
            thread_num: thread nums, default 1. 
            mem_optim: memory optimization, True default.
            ir_optim: open calculation chart optimization, False default.
            use_trt: use nvidia TensorRT optimization, False default
Z
zhangjun 已提交
76 77
            use_lite: use Paddle-Lite Engint, False default
            use_xpu: run predict on Baidu Kunlun, False default
78 79
            use_feed_fetch_ops: use feed/fetch ops, False default.
        """
D
dongdaxiang 已提交
80 81 82 83 84
        client_config = "{}/serving_server_conf.prototxt".format(model_path)
        model_conf = m_config.GeneralModelConfig()
        f = open(client_config, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)
W
wangjiawei04 已提交
85 86 87 88
        if os.path.exists(os.path.join(model_path, "__params__")):
            config = AnalysisConfig(os.path.join(model_path, "__model__"), os.path.join(model_path, "__params__")) 
        else:
            config = AnalysisConfig(model_path) 
89 90
        logger.info("load_model_config params: model_path:{}, use_gpu:{},\
            gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\
Z
zhangjun 已提交
91
            use_trt:{}, use_lite:{}, use_xpu: {}, use_feed_fetch_ops:{}".format(
92
            model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim,
Z
zhangjun 已提交
93
            ir_optim, use_trt, use_lite, use_xpu, use_feed_fetch_ops))
D
dongdaxiang 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108

        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
        self.feed_names_to_idx_ = {}
        self.fetch_names_to_idx_ = {}

        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
            self.feed_shapes_[var.alias_name] = var.shape

        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type

109
        if use_profile:
D
dongdaxiang 已提交
110
            config.enable_profile()
111 112 113 114 115
        if mem_optim:
            config.enable_memory_optim()
        config.switch_ir_optim(ir_optim)
        config.set_cpu_math_library_num_threads(thread_num)
        config.switch_use_feed_fetch_ops(use_feed_fetch_ops)
W
wangjiawei04 已提交
116
        config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
117 118 119 120 121 122 123 124 125 126 127 128 129

        if not use_gpu:
            config.disable_gpu()
        else:
            config.enable_use_gpu(100, gpu_id)
            if use_trt:
                config.enable_tensorrt_engine(
                    workspace_size=1 << 20,
                    max_batch_size=32,
                    min_subgraph_size=3,
                    use_static=False,
                    use_calib_mode=False)

Z
zhangjun 已提交
130 131
        if use_lite:
            config.enable_lite_engine(
132 133 134 135
                precision_mode=inference.PrecisionType.Float32,
                zero_copy=True,
                passes_filter=[],
                ops_filter=[])
Z
zhangjun 已提交
136 137

        if use_xpu:
Z
zhangjun 已提交
138
            # 2MB l3 cache
T
TeslaZhao 已提交
139
            config.enable_xpu(8 * 1024 * 1024)
Z
zhangjun 已提交
140

D
dongdaxiang 已提交
141 142
        self.predictor = create_paddle_predictor(config)

W
wangjiawei04 已提交
143
    def predict(self, feed=None, fetch=None, batch=False, log_id=0):
144 145 146 147 148 149 150 151 152 153 154 155 156
        """
        Predict locally

        Args:
            feed: feed var
            fetch: fetch var
            batch: batch data or not, False default.If batch is False, a new
                   dimension is added to header of the shape[np.newaxis].
            log_id: for logging

        Returns:
            fetch_map: dict 
        """
D
dongdaxiang 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
            raise ValueError("Fetch only accepts string and list of string")

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
            raise ValueError("Feed only accepts dict and list of dict")

        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
        int_shape = []
        float_shape = []
        fetch_names = []
        counter = 0
        batch_size = len(feed_batch)

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
                "Fetch names should not be empty or out of saved fetch list.")
            return {}

194 195
        input_names = self.predictor.get_input_names()
        for name in input_names:
M
MRXLT 已提交
196 197 198
            if isinstance(feed[name], list):
                feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[
                    name])
199 200
            if self.feed_types_[name] == 0:
                feed[name] = feed[name].astype("int64")
W
wangjiawei04 已提交
201
            elif self.feed_types_[name] == 1:
202
                feed[name] = feed[name].astype("float32")
W
wangjiawei04 已提交
203 204 205 206
            elif self.feed_types_[name] == 2:
                feed[name] = feed[name].astype("int32")
            else:
                raise ValueError("local predictor receives wrong data type")
207
            input_tensor = self.predictor.get_input_tensor(name)
W
wangjiawei04 已提交
208 209 210 211 212 213
            if "{}.lod".format(name) in feed:
                input_tensor.set_lod([feed["{}.lod".format(name)]])
            if batch == False:
                input_tensor.copy_from_cpu(feed[name][np.newaxis, :])
            else:
                input_tensor.copy_from_cpu(feed[name])
214 215 216 217 218 219 220 221 222 223
        output_tensors = []
        output_names = self.predictor.get_output_names()
        for output_name in output_names:
            output_tensor = self.predictor.get_output_tensor(output_name)
            output_tensors.append(output_tensor)
        outputs = []
        self.predictor.zero_copy_run()
        for output_tensor in output_tensors:
            output = output_tensor.copy_to_cpu()
            outputs.append(output)
D
dongdaxiang 已提交
224
        fetch_map = {}
225 226 227
        for i, name in enumerate(fetch):
            fetch_map[name] = outputs[i]
            if len(output_tensors[i].lod()) > 0:
W
wangjiawei04 已提交
228 229
                fetch_map[name + ".lod"] = np.array(output_tensors[i].lod()[
                    0]).astype('int32')
D
dongdaxiang 已提交
230
        return fetch_map