pipeline_client.py 8.0 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import grpc
B
barrierye 已提交
16
import sys
17
import time
B
barrierye 已提交
18
import numpy as np
B
barrierye 已提交
19 20
from numpy import *
import logging
B
barrierye 已提交
21
import functools
T
TeslaZhao 已提交
22
import json
23
import socket
T
TeslaZhao 已提交
24
from .channel import ChannelDataErrcode
B
barrierye 已提交
25 26
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
W
wangjiawei04 已提交
27
import six
28
from io import BytesIO
29
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
30

B
barrierye 已提交
31 32

class PipelineClient(object):
33 34 35 36
    """
    PipelineClient provides the basic capabilities of the pipeline SDK
    """

B
barrierye 已提交
37 38
    def __init__(self):
        self._channel = None
B
barrierye 已提交
39 40
        self._profile_key = "pipeline.profile"
        self._profile_value = "1"
B
barrierye 已提交
41

B
barrierye 已提交
42 43 44 45 46 47
    def connect(self, endpoints):
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
                   ('grpc.lb_policy_name', 'round_robin')]
        g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
        self._channel = grpc.insecure_channel(g_endpoint, options=options)
B
barrierye 已提交
48 49 50
        self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
            self._channel)

51 52
    def _pack_request_package(self, feed_dict, pack_tensor_format,
                              use_tensor_bytes, profile):
B
barrierye 已提交
53
        req = pipeline_service_pb2.Request()
54 55 56 57 58

        logid = feed_dict.get("logid")
        if logid is None:
            req.logid = 0
        else:
T
TeslaZhao 已提交
59
            if sys.version_info.major == 2:
W
wangjiawei04 已提交
60
                req.logid = long(logid)
T
TeslaZhao 已提交
61 62
            elif sys.version_info.major == 3:
                req.logid = int(logid)
63 64 65 66 67 68 69 70 71 72 73
            feed_dict.pop("logid")

        clientip = feed_dict.get("clientip")
        if clientip is None:
            hostname = socket.gethostname()
            ip = socket.gethostbyname(hostname)
            req.clientip = ip
        else:
            req.clientip = clientip
            feed_dict.pop("clientip")

B
barriery 已提交
74
        np.set_printoptions(threshold=sys.maxsize)
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
        if pack_tensor_format is False:
            # pack string key/val format
            for key, value in feed_dict.items():
                req.key.append(key)

                if (sys.version_info.major == 2 and
                        isinstance(value, (str, unicode)) or
                    ((sys.version_info.major == 3) and isinstance(value, str))):
                    req.value.append(value)
                    continue

                if isinstance(value, np.ndarray):
                    req.value.append(value.__repr__())
                elif isinstance(value, list):
                    req.value.append(np.array(value).__repr__())
                else:
                    raise TypeError(
                        "only str and np.ndarray type is supported: {}".format(
                            type(value)))

            if profile:
                req.key.append(self._profile_key)
                req.value.append(self._profile_value)
        else:
            # pack tensor format
            for key, value in feed_dict.items():
                one_tensor = req.tensors.add()
                one_tensor.name = key

104
                if isinstance(value, str):
105
                    one_tensor.string_data.add(value)
106
                    one_tensor.elem_type = 12  #12 => string in proto
107 108 109 110 111 112 113 114
                    continue

                if isinstance(value, np.ndarray):
                    # copy shape
                    _LOGGER.info("value shape is {}".format(value.shape))
                    for one_dim in value.shape:
                        one_tensor.shape.append(one_dim)

115 116 117 118 119 120 121
                    # packed into bytes
                    if use_tensor_bytes is True:
                        np_bytes = BytesIO()
                        np.save(np_bytes, value, allow_pickle=True)
                        one_tensor.byte_data = np_bytes.getvalue()
                        one_tensor.elem_type = 13  #13 => bytes in proto

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
                    flat_value = value.flatten().tolist()
                    # copy data
                    if value.dtype == "int64":
                        one_tensor.int64_data.extend(flat_value)
                        one_tensor.elem_type = 0
                    elif value.dtype == "float32":
                        one_tensor.float_data.extend(flat_value)
                        one_tensor.elem_type = 1
                    elif value.dtype == "int32":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 2
                    elif value.dtype == "float64":
                        one_tensor.float64_data.extend(flat_value)
                        one_tensor.elem_type = 3
                    elif value.dtype == "int16":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 4
                    elif value.dtype == "float16":
                        one_tensor.float_data.extend(flat_value)
                        one_tensor.elem_type = 5
                    elif value.dtype == "uint16":
                        one_tensor.uint32_data.extend(flat_value)
                        one_tensor.elem_type = 6
                    elif value.dtype == "uint8":
                        one_tensor.uint32_data.extend(flat_value)
                        one_tensor.elem_type = 7
                    elif value.dtype == "int8":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 8
                    elif value.dtype == "bool":
                        one_tensor.bool_data.extend(flat_value)
                        one_tensor.elem_type = 9
                    else:
                        _LOGGER.error(
                            "value type {} of tensor {} is not supported.".
                            format(value.dtype, key))
                else:
                    raise TypeError(
                        "only str and np.ndarray type is supported: {}".format(
                            type(value)))
B
barrierye 已提交
162 163
        return req

B
barrierye 已提交
164
    def _unpack_response_package(self, resp, fetch):
T
TeslaZhao 已提交
165
        return resp
B
barrierye 已提交
166

167 168 169 170
    def predict(self,
                feed_dict,
                fetch=None,
                asyn=False,
171
                pack_tensor_format=False,
172
                use_tensor_bytes=False,
173 174
                profile=False,
                log_id=0):
B
barrierye 已提交
175 176 177
        if not isinstance(feed_dict, dict):
            raise TypeError(
                "feed must be dict type with format: {name: value}.")
W
wangjiawei04 已提交
178
        if fetch is not None and not isinstance(fetch, list):
B
barrierye 已提交
179
            raise TypeError("fetch must be list type with format: [name].")
180
        print("PipelineClient::predict pack_data time:{}".format(time.time()))
181 182
        req = self._pack_request_package(feed_dict, pack_tensor_format,
                                         use_tensor_bytes, profile)
183
        req.logid = log_id
B
barrierye 已提交
184
        if not asyn:
185
            print("PipelineClient::predict before time:{}".format(time.time()))
B
barrierye 已提交
186
            resp = self._stub.inference(req)
W
wangjiawei04 已提交
187
            return self._unpack_response_package(resp, fetch)
B
barrierye 已提交
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
        else:
            call_future = self._stub.inference.future(req)
            return PipelinePredictFuture(
                call_future,
                functools.partial(
                    self._unpack_response_package, fetch=fetch))


class PipelinePredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
        resp = self.call_future_.result()
        return self.callback_func_(resp)