pipeline_client.py 7.4 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import grpc
B
barrierye 已提交
16
import sys
B
barrierye 已提交
17
import numpy as np
B
barrierye 已提交
18 19
from numpy import *
import logging
B
barrierye 已提交
20
import functools
T
TeslaZhao 已提交
21
import json
22
import socket
T
TeslaZhao 已提交
23
from .channel import ChannelDataErrcode
B
barrierye 已提交
24 25
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
W
wangjiawei04 已提交
26
import six
27
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
28

B
barrierye 已提交
29 30

class PipelineClient(object):
31 32 33 34
    """
    PipelineClient provides the basic capabilities of the pipeline SDK
    """

B
barrierye 已提交
35 36
    def __init__(self):
        self._channel = None
B
barrierye 已提交
37 38
        self._profile_key = "pipeline.profile"
        self._profile_value = "1"
B
barrierye 已提交
39

B
barrierye 已提交
40 41 42 43 44 45
    def connect(self, endpoints):
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
                   ('grpc.lb_policy_name', 'round_robin')]
        g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
        self._channel = grpc.insecure_channel(g_endpoint, options=options)
B
barrierye 已提交
46 47 48
        self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
            self._channel)

49
    def _pack_request_package(self, feed_dict, pack_tensor_format, profile):
B
barrierye 已提交
50
        req = pipeline_service_pb2.Request()
51 52 53 54 55

        logid = feed_dict.get("logid")
        if logid is None:
            req.logid = 0
        else:
T
TeslaZhao 已提交
56
            if sys.version_info.major == 2:
W
wangjiawei04 已提交
57
                req.logid = long(logid)
T
TeslaZhao 已提交
58 59
            elif sys.version_info.major == 3:
                req.logid = int(logid)
60 61 62 63 64 65 66 67 68 69 70
            feed_dict.pop("logid")

        clientip = feed_dict.get("clientip")
        if clientip is None:
            hostname = socket.gethostname()
            ip = socket.gethostbyname(hostname)
            req.clientip = ip
        else:
            req.clientip = clientip
            feed_dict.pop("clientip")

B
barriery 已提交
71
        np.set_printoptions(threshold=sys.maxsize)
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
        if pack_tensor_format is False:
            # pack string key/val format
            for key, value in feed_dict.items():
                req.key.append(key)

                if (sys.version_info.major == 2 and
                        isinstance(value, (str, unicode)) or
                    ((sys.version_info.major == 3) and isinstance(value, str))):
                    req.value.append(value)
                    continue

                if isinstance(value, np.ndarray):
                    req.value.append(value.__repr__())
                elif isinstance(value, list):
                    req.value.append(np.array(value).__repr__())
                else:
                    raise TypeError(
                        "only str and np.ndarray type is supported: {}".format(
                            type(value)))

            if profile:
                req.key.append(self._profile_key)
                req.value.append(self._profile_value)
        else:
            # pack tensor format
            for key, value in feed_dict.items():
                one_tensor = req.tensors.add()
                one_tensor.name = key

                if (sys.version_info.major == 2 and
                        isinstance(value, (str, unicode)) or
                    ((sys.version_info.major == 3) and isinstance(value, str))):
                    one_tensor.string_data.add(value)
                    one_tensor.elem_type = 12  #12 => string
                    continue

                if isinstance(value, np.ndarray):
                    # copy shape
                    _LOGGER.info("value shape is {}".format(value.shape))
                    for one_dim in value.shape:
                        one_tensor.shape.append(one_dim)

                    flat_value = value.flatten().tolist()
                    # copy data
                    if value.dtype == "int64":
                        one_tensor.int64_data.extend(flat_value)
                        one_tensor.elem_type = 0
                    elif value.dtype == "float32":
                        one_tensor.float_data.extend(flat_value)
                        one_tensor.elem_type = 1
                    elif value.dtype == "int32":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 2
                    elif value.dtype == "float64":
                        one_tensor.float64_data.extend(flat_value)
                        one_tensor.elem_type = 3
                    elif value.dtype == "int16":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 4
                    elif value.dtype == "float16":
                        one_tensor.float_data.extend(flat_value)
                        one_tensor.elem_type = 5
                    elif value.dtype == "uint16":
                        one_tensor.uint32_data.extend(flat_value)
                        one_tensor.elem_type = 6
                    elif value.dtype == "uint8":
                        one_tensor.uint32_data.extend(flat_value)
                        one_tensor.elem_type = 7
                    elif value.dtype == "int8":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 8
                    elif value.dtype == "bool":
                        one_tensor.bool_data.extend(flat_value)
                        one_tensor.elem_type = 9
                    else:
                        _LOGGER.error(
                            "value type {} of tensor {} is not supported.".
                            format(value.dtype, key))
                else:
                    raise TypeError(
                        "only str and np.ndarray type is supported: {}".format(
                            type(value)))
B
barrierye 已提交
154 155
        return req

B
barrierye 已提交
156
    def _unpack_response_package(self, resp, fetch):
T
TeslaZhao 已提交
157
        return resp
B
barrierye 已提交
158

159 160 161 162
    def predict(self,
                feed_dict,
                fetch=None,
                asyn=False,
163
                pack_tensor_format=False,
164 165
                profile=False,
                log_id=0):
B
barrierye 已提交
166 167 168
        if not isinstance(feed_dict, dict):
            raise TypeError(
                "feed must be dict type with format: {name: value}.")
W
wangjiawei04 已提交
169
        if fetch is not None and not isinstance(fetch, list):
B
barrierye 已提交
170
            raise TypeError("fetch must be list type with format: [name].")
171 172

        req = self._pack_request_package(feed_dict, pack_tensor_format, profile)
173
        req.logid = log_id
B
barrierye 已提交
174 175
        if not asyn:
            resp = self._stub.inference(req)
W
wangjiawei04 已提交
176
            return self._unpack_response_package(resp, fetch)
B
barrierye 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
        else:
            call_future = self._stub.inference.future(req)
            return PipelinePredictFuture(
                call_future,
                functools.partial(
                    self._unpack_response_package, fetch=fetch))


class PipelinePredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
        resp = self.call_future_.result()
        return self.callback_func_(resp)