pipeline_client.py 7.6 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import grpc
B
barrierye 已提交
16
import sys
17
import time
B
barrierye 已提交
18
import numpy as np
B
barrierye 已提交
19 20
from numpy import *
import logging
B
barrierye 已提交
21
import functools
T
TeslaZhao 已提交
22
import json
23
import socket
T
TeslaZhao 已提交
24
from .channel import ChannelDataErrcode
B
barrierye 已提交
25 26
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
W
wangjiawei04 已提交
27
import six
28
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
29

B
barrierye 已提交
30 31

class PipelineClient(object):
32 33 34 35
    """
    PipelineClient provides the basic capabilities of the pipeline SDK
    """

B
barrierye 已提交
36 37
    def __init__(self):
        self._channel = None
B
barrierye 已提交
38 39
        self._profile_key = "pipeline.profile"
        self._profile_value = "1"
B
barrierye 已提交
40

B
barrierye 已提交
41 42 43 44 45 46
    def connect(self, endpoints):
        options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
                   ('grpc.max_send_message_length', 512 * 1024 * 1024),
                   ('grpc.lb_policy_name', 'round_robin')]
        g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
        self._channel = grpc.insecure_channel(g_endpoint, options=options)
B
barrierye 已提交
47 48 49
        self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
            self._channel)

50
    def _pack_request_package(self, feed_dict, pack_tensor_format, profile):
B
barrierye 已提交
51
        req = pipeline_service_pb2.Request()
52 53 54 55 56

        logid = feed_dict.get("logid")
        if logid is None:
            req.logid = 0
        else:
T
TeslaZhao 已提交
57
            if sys.version_info.major == 2:
W
wangjiawei04 已提交
58
                req.logid = long(logid)
T
TeslaZhao 已提交
59 60
            elif sys.version_info.major == 3:
                req.logid = int(logid)
61 62 63 64 65 66 67 68 69 70 71
            feed_dict.pop("logid")

        clientip = feed_dict.get("clientip")
        if clientip is None:
            hostname = socket.gethostname()
            ip = socket.gethostbyname(hostname)
            req.clientip = ip
        else:
            req.clientip = clientip
            feed_dict.pop("clientip")

B
barriery 已提交
72
        np.set_printoptions(threshold=sys.maxsize)
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
        if pack_tensor_format is False:
            # pack string key/val format
            for key, value in feed_dict.items():
                req.key.append(key)

                if (sys.version_info.major == 2 and
                        isinstance(value, (str, unicode)) or
                    ((sys.version_info.major == 3) and isinstance(value, str))):
                    req.value.append(value)
                    continue

                if isinstance(value, np.ndarray):
                    req.value.append(value.__repr__())
                elif isinstance(value, list):
                    req.value.append(np.array(value).__repr__())
                else:
                    raise TypeError(
                        "only str and np.ndarray type is supported: {}".format(
                            type(value)))

            if profile:
                req.key.append(self._profile_key)
                req.value.append(self._profile_value)
        else:
            # pack tensor format
            for key, value in feed_dict.items():
                one_tensor = req.tensors.add()
                one_tensor.name = key

                if (sys.version_info.major == 2 and
                        isinstance(value, (str, unicode)) or
                    ((sys.version_info.major == 3) and isinstance(value, str))):
                    one_tensor.string_data.add(value)
                    one_tensor.elem_type = 12  #12 => string
                    continue

                if isinstance(value, np.ndarray):
                    # copy shape
                    _LOGGER.info("value shape is {}".format(value.shape))
                    for one_dim in value.shape:
                        one_tensor.shape.append(one_dim)

                    flat_value = value.flatten().tolist()
                    # copy data
                    if value.dtype == "int64":
                        one_tensor.int64_data.extend(flat_value)
                        one_tensor.elem_type = 0
                    elif value.dtype == "float32":
                        one_tensor.float_data.extend(flat_value)
                        one_tensor.elem_type = 1
                    elif value.dtype == "int32":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 2
                    elif value.dtype == "float64":
                        one_tensor.float64_data.extend(flat_value)
                        one_tensor.elem_type = 3
                    elif value.dtype == "int16":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 4
                    elif value.dtype == "float16":
                        one_tensor.float_data.extend(flat_value)
                        one_tensor.elem_type = 5
                    elif value.dtype == "uint16":
                        one_tensor.uint32_data.extend(flat_value)
                        one_tensor.elem_type = 6
                    elif value.dtype == "uint8":
                        one_tensor.uint32_data.extend(flat_value)
                        one_tensor.elem_type = 7
                    elif value.dtype == "int8":
                        one_tensor.int_data.extend(flat_value)
                        one_tensor.elem_type = 8
                    elif value.dtype == "bool":
                        one_tensor.bool_data.extend(flat_value)
                        one_tensor.elem_type = 9
                    else:
                        _LOGGER.error(
                            "value type {} of tensor {} is not supported.".
                            format(value.dtype, key))
                else:
                    raise TypeError(
                        "only str and np.ndarray type is supported: {}".format(
                            type(value)))
B
barrierye 已提交
155 156
        return req

B
barrierye 已提交
157
    def _unpack_response_package(self, resp, fetch):
T
TeslaZhao 已提交
158
        return resp
B
barrierye 已提交
159

160 161 162 163
    def predict(self,
                feed_dict,
                fetch=None,
                asyn=False,
164
                pack_tensor_format=False,
165 166
                profile=False,
                log_id=0):
B
barrierye 已提交
167 168 169
        if not isinstance(feed_dict, dict):
            raise TypeError(
                "feed must be dict type with format: {name: value}.")
W
wangjiawei04 已提交
170
        if fetch is not None and not isinstance(fetch, list):
B
barrierye 已提交
171
            raise TypeError("fetch must be list type with format: [name].")
172
        print("PipelineClient::predict pack_data time:{}".format(time.time()))
173
        req = self._pack_request_package(feed_dict, pack_tensor_format, profile)
174
        req.logid = log_id
B
barrierye 已提交
175
        if not asyn:
176
            print("PipelineClient::predict before time:{}".format(time.time()))
B
barrierye 已提交
177
            resp = self._stub.inference(req)
W
wangjiawei04 已提交
178
            return self._unpack_response_package(resp, fetch)
B
barrierye 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
        else:
            call_future = self._stub.inference.future(req)
            return PipelinePredictFuture(
                call_future,
                functools.partial(
                    self._unpack_response_package, fetch=fetch))


class PipelinePredictFuture(object):
    def __init__(self, call_future, callback_func):
        self.call_future_ = call_future
        self.callback_func_ = callback_func

    def result(self):
        resp = self.call_future_.result()
        return self.callback_func_(resp)