pipeline_client.py 8.1 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import grpc
B
barrierye 已提交
16
import sys
17
import time
B
barrierye 已提交
18
import numpy as np
B
barrierye 已提交
19 20
from numpy import *
import logging
B
barrierye 已提交
21
import functools
T
TeslaZhao 已提交
22
import json
23
import socket
T
TeslaZhao 已提交
24
from .channel import ChannelDataErrcode
B
barrierye 已提交
25 26
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
W
wangjiawei04 已提交
27
import six
28
from io import BytesIO
29
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
30

B
barrierye 已提交
31 32

class PipelineClient(object):
33 34 35 36
    """
    PipelineClient provides the basic capabilities of the pipeline SDK
    """

B
barrierye 已提交
37 38
    def __init__(self):
        self._channel = None
B
barrierye 已提交
39 40
        self._profile_key = "pipeline.profile"
        self._profile_value = "1"
B
barrierye 已提交
41

B
barrierye 已提交
42 43 44 45 46 47
    def connect(self, endpoints):
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
                   ('grpc.lb_policy_name', 'round_robin')]
        g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
        self._channel = grpc.insecure_channel(g_endpoint, options=options)
B
barrierye 已提交
48 49 50
        self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
            self._channel)

51 52
    def _pack_request_package(self, feed_dict, pack_tensor_format,
                              use_tensor_bytes, profile):
B
barrierye 已提交
53
        req = pipeline_service_pb2.Request()
54 55 56 57 58

        logid = feed_dict.get("logid")
        if logid is None:
            req.logid = 0
        else:
T
TeslaZhao 已提交
59
            if sys.version_info.major == 2:
W
wangjiawei04 已提交
60
                req.logid = long(logid)
T
TeslaZhao 已提交
61 62
            elif sys.version_info.major == 3:
                req.logid = int(logid)
63 64 65 66 67 68 69 70 71 72 73
            feed_dict.pop("logid")

        clientip = feed_dict.get("clientip")
        if clientip is None:
            hostname = socket.gethostname()
            ip = socket.gethostbyname(hostname)
            req.clientip = ip
        else:
            req.clientip = clientip
            feed_dict.pop("clientip")

B
barriery 已提交
74
        np.set_printoptions(threshold=sys.maxsize)
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
        if pack_tensor_format is False:
            # pack string key/val format
            for key, value in feed_dict.items():
                req.key.append(key)

                if (sys.version_info.major == 2 and
                        isinstance(value, (str, unicode)) or
                    ((sys.version_info.major == 3) and isinstance(value, str))):
                    req.value.append(value)
                    continue

                if isinstance(value, np.ndarray):
                    req.value.append(value.__repr__())
                elif isinstance(value, list):
                    req.value.append(np.array(value).__repr__())
                else:
                    raise TypeError(
                        "only str and np.ndarray type is supported: {}".format(
                            type(value)))

            if profile:
                req.key.append(self._profile_key)
                req.value.append(self._profile_value)
        else:
            # pack tensor format
            for key, value in feed_dict.items():
                one_tensor = req.tensors.add()
                one_tensor.name = key

104
                if isinstance(value, str):
105
                    one_tensor.string_data.add(value)
106
                    one_tensor.elem_type = 12  #12 => string in proto
107 108 109 110
                    continue

                if isinstance(value, np.ndarray):
                    # copy shape
T
TeslaZhao 已提交
111 112 113
                    _LOGGER.debug(
                        "key:{}, use_tensor_bytes:{}, value.shape:{}, value.dtype:{}".
                        format(key, use_tensor_bytes, value.shape, value.dtype))
114 115 116
                    for one_dim in value.shape:
                        one_tensor.shape.append(one_dim)

117 118 119 120 121 122
                    # packed into bytes
                    if use_tensor_bytes is True:
                        np_bytes = BytesIO()
                        np.save(np_bytes, value, allow_pickle=True)
                        one_tensor.byte_data = np_bytes.getvalue()
                        one_tensor.elem_type = 13  #13 => bytes in proto
T
TeslaZhao 已提交
123
                        continue
124

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
                    flat_value = value.flatten().tolist()
                    # copy data
                    if value.dtype == "int64":
                        one_tensor.int64_data.extend(flat_value)
                        one_tensor.elem_type = 0
                    elif value.dtype == "float32":
                        one_tensor.float_data.extend(flat_value)
                        one_tensor.elem_type = 1
                    elif value.dtype == "int32":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 2
                    elif value.dtype == "float64":
                        one_tensor.float64_data.extend(flat_value)
                        one_tensor.elem_type = 3
                    elif value.dtype == "int16":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 4
                    elif value.dtype == "float16":
                        one_tensor.float_data.extend(flat_value)
                        one_tensor.elem_type = 5
                    elif value.dtype == "uint16":
                        one_tensor.uint32_data.extend(flat_value)
                        one_tensor.elem_type = 6
                    elif value.dtype == "uint8":
                        one_tensor.uint32_data.extend(flat_value)
                        one_tensor.elem_type = 7
                    elif value.dtype == "int8":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 8
                    elif value.dtype == "bool":
                        one_tensor.bool_data.extend(flat_value)
                        one_tensor.elem_type = 9
                    else:
                        _LOGGER.error(
                            "value type {} of tensor {} is not supported.".
                            format(value.dtype, key))
                else:
                    raise TypeError(
                        "only str and np.ndarray type is supported: {}".format(
                            type(value)))
B
barrierye 已提交
165 166
        return req

B
barrierye 已提交
167
    def _unpack_response_package(self, resp, fetch):
T
TeslaZhao 已提交
168
        return resp
B
barrierye 已提交
169

170 171 172 173
    def predict(self,
                feed_dict,
                fetch=None,
                asyn=False,
174
                pack_tensor_format=False,
175
                use_tensor_bytes=False,
176 177
                profile=False,
                log_id=0):
B
barrierye 已提交
178 179 180
        if not isinstance(feed_dict, dict):
            raise TypeError(
                "feed must be dict type with format: {name: value}.")
W
wangjiawei04 已提交
181
        if fetch is not None and not isinstance(fetch, list):
B
barrierye 已提交
182
            raise TypeError("fetch must be list type with format: [name].")
183
        print("PipelineClient::predict pack_data time:{}".format(time.time()))
184 185
        req = self._pack_request_package(feed_dict, pack_tensor_format,
                                         use_tensor_bytes, profile)
186
        req.logid = log_id
B
barrierye 已提交
187
        if not asyn:
188
            print("PipelineClient::predict before time:{}".format(time.time()))
B
barrierye 已提交
189
            resp = self._stub.inference(req)
W
wangjiawei04 已提交
190
            return self._unpack_response_package(resp, fetch)
B
barrierye 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
        else:
            call_future = self._stub.inference.future(req)
            return PipelinePredictFuture(
                call_future,
                functools.partial(
                    self._unpack_response_package, fetch=fetch))


class PipelinePredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
        resp = self.call_future_.result()
        return self.callback_func_(resp)