pipeline_client.py 8.6 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import grpc
B
barrierye 已提交
16
import sys
17
import time
B
barrierye 已提交
18
import numpy as np
B
barrierye 已提交
19 20
from numpy import *
import logging
B
barrierye 已提交
21
import functools
T
TeslaZhao 已提交
22
import json
23
import socket
T
TeslaZhao 已提交
24
from .channel import ChannelDataErrcode
B
barrierye 已提交
25 26
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
W
wangjiawei04 已提交
27
import six
28
from io import BytesIO
29
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
30

B
barrierye 已提交
31 32

class PipelineClient(object):
33 34 35 36
    """
    PipelineClient provides the basic capabilities of the pipeline SDK
    """

B
barrierye 已提交
37 38
    def __init__(self):
        self._channel = None
B
barrierye 已提交
39 40
        self._profile_key = "pipeline.profile"
        self._profile_value = "1"
B
barrierye 已提交
41

B
barrierye 已提交
42 43 44 45 46 47
    def connect(self, endpoints):
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
                   ('grpc.lb_policy_name', 'round_robin')]
        g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
        self._channel = grpc.insecure_channel(g_endpoint, options=options)
B
barrierye 已提交
48 49 50
        self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
            self._channel)

51 52
    def _pack_request_package(self, feed_dict, pack_tensor_format,
                              use_tensor_bytes, profile):
B
barrierye 已提交
53
        req = pipeline_service_pb2.Request()
54 55 56 57 58

        logid = feed_dict.get("logid")
        if logid is None:
            req.logid = 0
        else:
T
TeslaZhao 已提交
59
            if sys.version_info.major == 2:
W
wangjiawei04 已提交
60
                req.logid = long(logid)
T
TeslaZhao 已提交
61 62
            elif sys.version_info.major == 3:
                req.logid = int(logid)
63 64 65 66 67 68 69 70 71 72 73
            feed_dict.pop("logid")

        clientip = feed_dict.get("clientip")
        if clientip is None:
            hostname = socket.gethostname()
            ip = socket.gethostbyname(hostname)
            req.clientip = ip
        else:
            req.clientip = clientip
            feed_dict.pop("clientip")

B
barriery 已提交
74
        np.set_printoptions(threshold=sys.maxsize)
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        if pack_tensor_format is False:
            # pack string key/val format
            for key, value in feed_dict.items():
                req.key.append(key)

                if (sys.version_info.major == 2 and
                        isinstance(value, (str, unicode)) or
                    ((sys.version_info.major == 3) and isinstance(value, str))):
                    req.value.append(value)
                    continue

                if isinstance(value, np.ndarray):
                    req.value.append(value.__repr__())
                elif isinstance(value, list):
                    req.value.append(np.array(value).__repr__())
                else:
                    raise TypeError(
                        "only str and np.ndarray type is supported: {}".format(
                            type(value)))

            if profile:
                req.key.append(self._profile_key)
                req.value.append(self._profile_value)
        else:
            # pack tensor format
            for key, value in feed_dict.items():
T
TeslaZhao 已提交
101 102 103 104 105 106

                # skipping the lod feed_var.
                # The declare of lod feed_var must be hebind the feed_var.
                if ".lod" in key:
                    continue

107 108 109
                one_tensor = req.tensors.add()
                one_tensor.name = key

110
                if isinstance(value, str):
H
huangjianhui 已提交
111
                    one_tensor.str_data.append(value)
112
                    one_tensor.elem_type = 12  #12 => string in proto
113 114 115 116
                    continue

                if isinstance(value, np.ndarray):
                    # copy shape
T
TeslaZhao 已提交
117 118 119
                    _LOGGER.debug(
                        "key:{}, use_tensor_bytes:{}, value.shape:{}, value.dtype:{}".
                        format(key, use_tensor_bytes, value.shape, value.dtype))
120 121 122
                    for one_dim in value.shape:
                        one_tensor.shape.append(one_dim)

T
TeslaZhao 已提交
123 124 125 126 127 128 129
                    # set lod info, must be list type.
                    lod_key = key + ".lod"
                    if lod_key in feed_dict:
                        lod_list = feed_dict.get(lod_key)
                        if lod_list is not None:
                            one_tensor.lod.extend(lod_list)

130 131 132 133 134 135
                    # packed into bytes
                    if use_tensor_bytes is True:
                        np_bytes = BytesIO()
                        np.save(np_bytes, value, allow_pickle=True)
                        one_tensor.byte_data = np_bytes.getvalue()
                        one_tensor.elem_type = 13  #13 => bytes in proto
T
TeslaZhao 已提交
136
                        continue
137

138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
                    flat_value = value.flatten().tolist()
                    # copy data
                    if value.dtype == "int64":
                        one_tensor.int64_data.extend(flat_value)
                        one_tensor.elem_type = 0
                    elif value.dtype == "float32":
                        one_tensor.float_data.extend(flat_value)
                        one_tensor.elem_type = 1
                    elif value.dtype == "int32":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 2
                    elif value.dtype == "float64":
                        one_tensor.float64_data.extend(flat_value)
                        one_tensor.elem_type = 3
                    elif value.dtype == "int16":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 4
                    elif value.dtype == "float16":
                        one_tensor.float_data.extend(flat_value)
                        one_tensor.elem_type = 5
                    elif value.dtype == "uint16":
                        one_tensor.uint32_data.extend(flat_value)
                        one_tensor.elem_type = 6
                    elif value.dtype == "uint8":
                        one_tensor.uint32_data.extend(flat_value)
                        one_tensor.elem_type = 7
                    elif value.dtype == "int8":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 8
                    elif value.dtype == "bool":
                        one_tensor.bool_data.extend(flat_value)
                        one_tensor.elem_type = 9
                    else:
                        _LOGGER.error(
                            "value type {} of tensor {} is not supported.".
                            format(value.dtype, key))
                else:
                    raise TypeError(
                        "only str and np.ndarray type is supported: {}".format(
                            type(value)))
B
barrierye 已提交
178 179
        return req

B
barrierye 已提交
180
    def _unpack_response_package(self, resp, fetch):
T
TeslaZhao 已提交
181
        return resp
B
barrierye 已提交
182

183 184 185 186
    def predict(self,
                feed_dict,
                fetch=None,
                asyn=False,
187
                pack_tensor_format=False,
188
                use_tensor_bytes=False,
189 190
                profile=False,
                log_id=0):
B
barrierye 已提交
191 192 193
        if not isinstance(feed_dict, dict):
            raise TypeError(
                "feed must be dict type with format: {name: value}.")
W
wangjiawei04 已提交
194
        if fetch is not None and not isinstance(fetch, list):
B
barrierye 已提交
195
            raise TypeError("fetch must be list type with format: [name].")
196
        print("PipelineClient::predict pack_data time:{}".format(time.time()))
197 198
        req = self._pack_request_package(feed_dict, pack_tensor_format,
                                         use_tensor_bytes, profile)
199
        req.logid = log_id
B
barrierye 已提交
200
        if not asyn:
201
            print("PipelineClient::predict before time:{}".format(time.time()))
B
barrierye 已提交
202
            resp = self._stub.inference(req)
W
wangjiawei04 已提交
203
            return self._unpack_response_package(resp, fetch)
B
barrierye 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
        else:
            call_future = self._stub.inference.future(req)
            return PipelinePredictFuture(
                call_future,
                functools.partial(
                    self._unpack_response_package, fetch=fetch))


class PipelinePredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
        resp = self.call_future_.result()
        return self.callback_func_(resp)