提交 580e8478 编写于 作者: B barrierye

use future in channel && remove type def

上级 93e3d126
...@@ -39,6 +39,9 @@ py_grpc_proto_compile(multi_lang_general_model_service_py_proto SRCS proto/multi ...@@ -39,6 +39,9 @@ py_grpc_proto_compile(multi_lang_general_model_service_py_proto SRCS proto/multi
add_custom_target(multi_lang_general_model_service_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(multi_lang_general_model_service_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(multi_lang_general_model_service_py_proto multi_lang_general_model_service_py_proto_init) add_dependencies(multi_lang_general_model_service_py_proto multi_lang_general_model_service_py_proto_init)
py_grpc_proto_compile(general_python_service_py_proto SRCS proto/general_python_service.proto)
add_custom_target(general_python_service_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(general_python_service_py_proto general_python_service_py_proto_init)
if (CLIENT) if (CLIENT)
py_proto_compile(sdk_configure_py_proto SRCS proto/sdk_configure.proto) py_proto_compile(sdk_configure_py_proto SRCS proto/sdk_configure.proto)
add_custom_target(sdk_configure_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(sdk_configure_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
...@@ -60,6 +63,12 @@ add_custom_command(TARGET multi_lang_general_model_service_py_proto POST_BUILD ...@@ -60,6 +63,12 @@ add_custom_command(TARGET multi_lang_general_model_service_py_proto POST_BUILD
COMMAND cp *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto COMMAND cp *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMENT "Copy generated multi_lang_general_model_service proto file into directory paddle_serving_client/proto." COMMENT "Copy generated multi_lang_general_model_service proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET general_python_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMAND cp *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMENT "Copy generated general_python_service proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif() endif()
if (APP) if (APP)
...@@ -79,9 +88,6 @@ py_proto_compile(pyserving_channel_py_proto SRCS proto/pyserving_channel.proto) ...@@ -79,9 +88,6 @@ py_proto_compile(pyserving_channel_py_proto SRCS proto/pyserving_channel.proto)
add_custom_target(pyserving_channel_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(pyserving_channel_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(pyserving_channel_py_proto pyserving_channel_py_proto_init) add_dependencies(pyserving_channel_py_proto pyserving_channel_py_proto_init)
py_grpc_proto_compile(general_python_service_py_proto SRCS proto/general_python_service.proto)
add_custom_target(general_python_service_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(general_python_service_py_proto general_python_service_py_proto_init)
if (NOT WITH_GPU) if (NOT WITH_GPU)
add_custom_command(TARGET server_config_py_proto POST_BUILD add_custom_command(TARGET server_config_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
syntax = "proto2"; syntax = "proto2";
package baidu.paddle_serving.pyserving;
service GeneralPythonService { service GeneralPythonService {
rpc inference(Request) returns (Response) {} rpc inference(Request) returns (Response) {}
...@@ -21,11 +22,15 @@ service GeneralPythonService { ...@@ -21,11 +22,15 @@ service GeneralPythonService {
message Request { message Request {
repeated bytes feed_insts = 1; repeated bytes feed_insts = 1;
repeated string feed_var_names = 2; repeated string feed_var_names = 2;
repeated bytes shape = 3;
repeated string type = 4;
} }
message Response { message Response {
repeated bytes fetch_insts = 1; repeated bytes fetch_insts = 1;
repeated string fetch_var_names = 2; repeated string fetch_var_names = 2;
required int32 is_error = 3; required int32 ecode = 3;
optional string error_info = 4; optional string error_info = 4;
repeated bytes shape = 5;
repeated string type = 6;
} }
...@@ -13,17 +13,19 @@ ...@@ -13,17 +13,19 @@
// limitations under the License. // limitations under the License.
syntax = "proto2"; syntax = "proto2";
package baidu.paddle_serving.pyserving;
message ChannelData { message ChannelData {
repeated Inst insts = 1; repeated Inst insts = 1;
required int32 id = 2; required int32 id = 2;
optional string type = 3 required int32 type = 3 [ default = 0 ];
[ default = "CD" ]; // CD(channel data), CF(channel futures) required int32 ecode = 4;
required int32 is_error = 4;
optional string error_info = 5; optional string error_info = 5;
} }
message Inst { message Inst {
required bytes data = 1; required bytes data = 1;
required string name = 2; required string name = 2;
required bytes shape = 3;
required string type = 4;
} }
...@@ -30,8 +30,7 @@ lp = LineProfiler() ...@@ -30,8 +30,7 @@ lp = LineProfiler()
lp_wrapper = lp(client.predict) lp_wrapper = lp(client.predict)
for i in range(1): for i in range(1):
fetch_map = lp_wrapper( fetch_map = lp_wrapper(feed={"x": x}, fetch=["combine_op_output"])
feed={"x": x}, fetch_with_type={"combine_op_output": "float"})
# fetch_map = client.predict( # fetch_map = client.predict(
# feed={"x": x}, fetch_with_type={"combine_op_output": "float"}) # feed={"x": x}, fetch_with_type={"combine_op_output": "float"})
print(fetch_map) print(fetch_map)
......
...@@ -16,29 +16,26 @@ ...@@ -16,29 +16,26 @@
from paddle_serving_server.pyserver import Op from paddle_serving_server.pyserver import Op
from paddle_serving_server.pyserver import Channel from paddle_serving_server.pyserver import Channel
from paddle_serving_server.pyserver import PyServer from paddle_serving_server.pyserver import PyServer
from paddle_serving_server import python_service_channel_pb2
import numpy as np import numpy as np
import logging import logging
logging.basicConfig( logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M', datefmt='%Y-%m-%d %H:%M',
#level=logging.DEBUG)
level=logging.INFO) level=logging.INFO)
# channel data: {name(str): data(bytes)} # channel data: {name(str): data(narray)}
class CombineOp(Op): class CombineOp(Op):
def preprocess(self, input_data): def preprocess(self, input_data):
cnt = 0 cnt = 0
for op_name, data in input_data.items(): for op_name, channeldata in input_data.items():
logging.debug("CombineOp preprocess: {}".format(op_name)) logging.debug("CombineOp preprocess: {}".format(op_name))
cnt += np.frombuffer(data.insts[0].data, dtype='float') data = channeldata.parse()
data = python_service_channel_pb2.ChannelData() cnt += data["prediction"]
inst = python_service_channel_pb2.Inst() data = {"combine_op_output": cnt}
inst.data = np.ndarray.tobytes(cnt)
inst.name = "combine_op_output"
data.insts.append(inst)
return data return data
def postprocess(self, output_data): def postprocess(self, output_data):
...@@ -47,12 +44,8 @@ class CombineOp(Op): ...@@ -47,12 +44,8 @@ class CombineOp(Op):
class UciOp(Op): class UciOp(Op):
def postprocess(self, output_data): def postprocess(self, output_data):
data = python_service_channel_pb2.ChannelData() pred = np.array(output_data["price"][0][0], dtype='float32')
inst = python_service_channel_pb2.Inst() data = {"prediction": pred}
pred = np.array(output_data["price"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction"
data.insts.append(inst)
return data return data
...@@ -60,12 +53,10 @@ read_channel = Channel(name="read_channel") ...@@ -60,12 +53,10 @@ read_channel = Channel(name="read_channel")
combine_channel = Channel(name="combine_channel") combine_channel = Channel(name="combine_channel")
out_channel = Channel(name="out_channel") out_channel = Channel(name="out_channel")
cnn_op = UciOp( uci1_op = UciOp(
name="cnn", name="uci1",
input=read_channel, input=read_channel,
in_dtype='float',
outputs=[combine_channel], outputs=[combine_channel],
out_dtype='float',
server_model="./uci_housing_model", server_model="./uci_housing_model",
server_port="9393", server_port="9393",
device="cpu", device="cpu",
...@@ -73,15 +64,13 @@ cnn_op = UciOp( ...@@ -73,15 +64,13 @@ cnn_op = UciOp(
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["price"], fetch_names=["price"],
concurrency=1, concurrency=1,
timeout=0.01, timeout=0.1,
retry=2) retry=2)
bow_op = UciOp( uci2_op = UciOp(
name="bow", name="uci2",
input=read_channel, input=read_channel,
in_dtype='float',
outputs=[combine_channel], outputs=[combine_channel],
out_dtype='float',
server_model="./uci_housing_model", server_model="./uci_housing_model",
server_port="9292", server_port="9292",
device="cpu", device="cpu",
...@@ -95,9 +84,7 @@ bow_op = UciOp( ...@@ -95,9 +84,7 @@ bow_op = UciOp(
combine_op = CombineOp( combine_op = CombineOp(
name="combine", name="combine",
input=combine_channel, input=combine_channel,
in_dtype='float',
outputs=[out_channel], outputs=[out_channel],
out_dtype='float',
concurrency=1, concurrency=1,
timeout=-1, timeout=-1,
retry=1) retry=1)
...@@ -109,8 +96,8 @@ pyserver = PyServer(profile=False, retry=1) ...@@ -109,8 +96,8 @@ pyserver = PyServer(profile=False, retry=1)
pyserver.add_channel(read_channel) pyserver.add_channel(read_channel)
pyserver.add_channel(combine_channel) pyserver.add_channel(combine_channel)
pyserver.add_channel(out_channel) pyserver.add_channel(out_channel)
pyserver.add_op(cnn_op) pyserver.add_op(uci1_op)
pyserver.add_op(bow_op) pyserver.add_op(uci2_op)
pyserver.add_op(combine_op) pyserver.add_op(combine_op)
pyserver.prepare_server(port=8080, worker_num=2) pyserver.prepare_server(port=8080, worker_num=2)
pyserver.run_server() pyserver.run_server()
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
service GeneralPythonService {
rpc inference(Request) returns (Response) {}
}
message Request {
repeated bytes feed_insts = 1;
repeated string feed_var_names = 2;
}
message Response {
repeated bytes fetch_insts = 1;
repeated string fetch_var_names = 2;
}
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
import grpc import grpc
import general_python_service_pb2 from .proto import general_python_service_pb2
import general_python_service_pb2_grpc from .proto import general_python_service_pb2_grpc
import numpy as np import numpy as np
...@@ -33,24 +33,27 @@ class PyClient(object): ...@@ -33,24 +33,27 @@ class PyClient(object):
if not isinstance(data, np.ndarray): if not isinstance(data, np.ndarray):
raise TypeError( raise TypeError(
"only numpy array type is supported temporarily.") "only numpy array type is supported temporarily.")
data2bytes = np.ndarray.tobytes(data)
req.feed_var_names.append(name) req.feed_var_names.append(name)
req.feed_insts.append(data2bytes) req.feed_insts.append(data.tobytes())
req.shape.append(np.array(data.shape, dtype="int32").tobytes())
req.type.append(str(data.dtype))
return req return req
def predict(self, feed, fetch_with_type): def predict(self, feed, fetch):
if not isinstance(feed, dict): if not isinstance(feed, dict):
raise TypeError( raise TypeError(
"feed must be dict type with format: {name: value}.") "feed must be dict type with format: {name: value}.")
if not isinstance(fetch_with_type, dict): if not isinstance(fetch, list):
raise TypeError( raise TypeError(
"fetch_with_type must be dict type with format: {name : type}.") "fetch_with_type must be list type with format: [name].")
req = self._pack_data_for_infer(feed) req = self._pack_data_for_infer(feed)
resp = self._stub.inference(req) resp = self._stub.inference(req)
fetch_map = {} fetch_map = {}
for idx, name in enumerate(resp.fetch_var_names): for idx, name in enumerate(resp.fetch_var_names):
if name not in fetch_with_type: if name not in fetch:
continue continue
fetch_map[name] = np.frombuffer( fetch_map[name] = np.frombuffer(
resp.fetch_insts[idx], dtype=fetch_with_type[name]) resp.fetch_insts[idx], dtype=resp.type[idx])
fetch_map[name].shape = np.frombuffer(
resp.shape[idx], dtype="int32")
return fetch_map return fetch_map
...@@ -18,17 +18,19 @@ import Queue ...@@ -18,17 +18,19 @@ import Queue
import os import os
import sys import sys
import paddle_serving_server import paddle_serving_server
from paddle_serving_client import Client from paddle_serving_client import MultiLangClient as Client
from concurrent import futures from concurrent import futures
import numpy as np import numpy as np
import grpc import grpc
import general_python_service_pb2 from .proto import general_model_config_pb2 as m_config
import general_python_service_pb2_grpc from .proto import general_python_service_pb2 as pyservice_pb2
import python_service_channel_pb2 from .proto import pyserving_channel_pb2 as channel_pb2
from .proto import general_python_service_pb2_grpc
import logging import logging
import random import random
import time import time
import func_timeout import func_timeout
import enum
class _TimeProfiler(object): class _TimeProfiler(object):
...@@ -71,6 +73,51 @@ class _TimeProfiler(object): ...@@ -71,6 +73,51 @@ class _TimeProfiler(object):
_profiler = _TimeProfiler() _profiler = _TimeProfiler()
class ChannelDataEcode(enum.Enum):
OK = 0
TIMEOUT = 1
class ChannelDataType(enum.Enum):
CHANNEL_PBDATA = 0
CHANNEL_FUTURE = 1
class ChannelData(object):
def __init__(self,
future=None,
pbdata=None,
data_id=None,
callback_func=None):
self.future = future
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.type = ChannelDataType.CHANNEL_FUTURE.value
pbdata.ecode = ChannelDataEcode.OK.value
pbdata.id = data_id
self.pbdata = pbdata
self.callback_func = callback_func
def parse(self):
# return narray
feed = {}
if self.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
for inst in self.pbdata.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
elif self.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
feed = self.future.result()
if self.callback_func is not None:
feed = self.callback_func(feed)
else:
raise TypeError(
self._log("Error type({}) in pbdata.type.".format(
self.pbdata.type)))
return feed
class Channel(Queue.Queue): class Channel(Queue.Queue):
""" """
The channel used for communication between Ops. The channel used for communication between Ops.
...@@ -94,6 +141,7 @@ class Channel(Queue.Queue): ...@@ -94,6 +141,7 @@ class Channel(Queue.Queue):
self._maxsize = maxsize self._maxsize = maxsize
self._timeout = timeout self._timeout = timeout
self._name = name self._name = name
self._stop = False
self._cv = threading.Condition() self._cv = threading.Condition()
...@@ -101,7 +149,6 @@ class Channel(Queue.Queue): ...@@ -101,7 +149,6 @@ class Channel(Queue.Queue):
self._producer_res_count = {} # {data_id: count} self._producer_res_count = {} # {data_id: count}
self._push_res = {} # {data_id: {op_name: data}} self._push_res = {} # {data_id: {op_name: data}}
self._front_wait_interval = 0.1 # second
self._consumers = {} # {op_name: idx} self._consumers = {} # {op_name: idx}
self._idx_consumer_num = {} # {idx: num} self._idx_consumer_num = {} # {idx: num}
self._consumer_base_idx = 0 self._consumer_base_idx = 0
...@@ -138,9 +185,10 @@ class Channel(Queue.Queue): ...@@ -138,9 +185,10 @@ class Channel(Queue.Queue):
self._idx_consumer_num[0] = 0 self._idx_consumer_num[0] = 0
self._idx_consumer_num[0] += 1 self._idx_consumer_num[0] += 1
def push(self, data, op_name=None): def push(self, channeldata, op_name=None):
logging.debug( logging.debug(
self._log("{} try to push data: {}".format(op_name, data))) self._log("{} try to push data: {}".format(op_name,
channeldata.pbdata)))
if len(self._producers) == 0: if len(self._producers) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -148,9 +196,9 @@ class Channel(Queue.Queue): ...@@ -148,9 +196,9 @@ class Channel(Queue.Queue):
)) ))
elif len(self._producers) == 1: elif len(self._producers) == 1:
with self._cv: with self._cv:
while True: while self._stop is False:
try: try:
self.put(data, timeout=0) self.put(channeldata, timeout=0)
break break
except Queue.Empty: except Queue.Empty:
self._cv.wait() self._cv.wait()
...@@ -163,17 +211,17 @@ class Channel(Queue.Queue): ...@@ -163,17 +211,17 @@ class Channel(Queue.Queue):
"There are multiple producers, so op_name cannot be None.")) "There are multiple producers, so op_name cannot be None."))
producer_num = len(self._producers) producer_num = len(self._producers)
data_id = data.id data_id = channeldata.pbdata.id
put_data = None put_data = None
with self._cv: with self._cv:
logging.debug(self._log("{} get lock ~".format(op_name))) logging.debug(self._log("{} get lock".format(op_name)))
if data_id not in self._push_res: if data_id not in self._push_res:
self._push_res[data_id] = { self._push_res[data_id] = {
name: None name: None
for name in self._producers for name in self._producers
} }
self._producer_res_count[data_id] = 0 self._producer_res_count[data_id] = 0
self._push_res[data_id][op_name] = data self._push_res[data_id][op_name] = channeldata
if self._producer_res_count[data_id] + 1 == producer_num: if self._producer_res_count[data_id] + 1 == producer_num:
put_data = self._push_res[data_id] put_data = self._push_res[data_id]
self._push_res.pop(data_id) self._push_res.pop(data_id)
...@@ -183,10 +231,10 @@ class Channel(Queue.Queue): ...@@ -183,10 +231,10 @@ class Channel(Queue.Queue):
if put_data is None: if put_data is None:
logging.debug( logging.debug(
self._log("{} push data succ, not not push to queue.". self._log("{} push data succ, but not push to queue.".
format(op_name))) format(op_name)))
else: else:
while True: while self._stop is False:
try: try:
self.put(put_data, timeout=0) self.put(put_data, timeout=0)
break break
...@@ -208,7 +256,7 @@ class Channel(Queue.Queue): ...@@ -208,7 +256,7 @@ class Channel(Queue.Queue):
elif len(self._consumers) == 1: elif len(self._consumers) == 1:
resp = None resp = None
with self._cv: with self._cv:
while resp is None: while self._stop is False and resp is None:
try: try:
resp = self.get(timeout=0) resp = self.get(timeout=0)
break break
...@@ -223,11 +271,11 @@ class Channel(Queue.Queue): ...@@ -223,11 +271,11 @@ class Channel(Queue.Queue):
with self._cv: with self._cv:
# data_idx = consumer_idx - base_idx # data_idx = consumer_idx - base_idx
while self._consumers[op_name] - self._consumer_base_idx >= len( while self._stop is False and self._consumers[
self._front_res): op_name] - self._consumer_base_idx >= len(self._front_res):
try: try:
data = self.get(timeout=0) channeldata = self.get(timeout=0)
self._front_res.append(data) self._front_res.append(channeldata)
break break
except Queue.Empty: except Queue.Empty:
self._cv.wait() self._cv.wait()
...@@ -256,14 +304,17 @@ class Channel(Queue.Queue): ...@@ -256,14 +304,17 @@ class Channel(Queue.Queue):
logging.debug(self._log("multi | {} get data succ!".format(op_name))) logging.debug(self._log("multi | {} get data succ!".format(op_name)))
return resp # reference, read only return resp # reference, read only
def stop(self):
#TODO
self.close()
self._stop = True
class Op(object): class Op(object):
def __init__(self, def __init__(self,
name, name,
input, input,
in_dtype,
outputs, outputs,
out_dtype,
server_model=None, server_model=None,
server_port=None, server_port=None,
device=None, device=None,
...@@ -278,9 +329,7 @@ class Op(object): ...@@ -278,9 +329,7 @@ class Op(object):
self._name = name # to identify the type of OP, it must be globally unique self._name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency self._concurrency = concurrency # amount of concurrency
self.set_input(input) self.set_input(input)
self._in_dtype = in_dtype
self.set_outputs(outputs) self.set_outputs(outputs)
self._out_dtype = out_dtype
self._client = None self._client = None
if client_config is not None and \ if client_config is not None and \
server_name is not None and \ server_name is not None and \
...@@ -324,15 +373,13 @@ class Op(object): ...@@ -324,15 +373,13 @@ class Op(object):
channel.add_producer(self._name) channel.add_producer(self._name)
self._outputs = channels self._outputs = channels
def preprocess(self, data): def preprocess(self, channeldata):
if isinstance(data, dict): if isinstance(channeldata, dict):
raise Exception( raise Exception(
self._log( self._log(
'this Op has multiple previous inputs. Please override this method' 'this Op has multiple previous inputs. Please override this method'
)) ))
feed = {} feed = channeldata.parse()
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
return feed return feed
def midprocess(self, data): def midprocess(self, data):
...@@ -343,47 +390,58 @@ class Op(object): ...@@ -343,47 +390,58 @@ class Op(object):
format(type(data)))) format(type(data))))
logging.debug(self._log('data: {}'.format(data))) logging.debug(self._log('data: {}'.format(data)))
logging.debug(self._log('fetch: {}'.format(self._fetch_names))) logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
fetch_map = self._client.predict(feed=data, fetch=self._fetch_names) call_future = self._client.predict(
logging.debug(self._log("finish predict")) feed=data, fetch=self._fetch_names, asyn=True)
return fetch_map logging.debug(self._log("get call_future"))
return call_future
def postprocess(self, output_data): def postprocess(self, output_data):
raise Exception( raise Exception(
self._log( self._log(
'Please override this method to convert data to the format in channel.' 'Please override this method to convert data to the format in channel.' \
' The return value format should be in {name(str): var(narray)}'
)) ))
def errorprocess(self, error_info): def errorprocess(self, error_info, data_id):
data = python_service_channel_pb2.ChannelData() data = channel_pb2.ChannelData()
data.is_error = 1 data.ecode = 1
data.id = data_id
data.error_info = error_info data.error_info = error_info
return data return data
def stop(self): def stop(self):
self._input.stop()
for channel in self._outputs:
channel.stop()
self._run = False self._run = False
def _parse_channeldata(self, channeldata):
data_id, error_data = None, None
if isinstance(channeldata, dict):
parsed_data = {}
key = channeldata.keys()[0]
data_id = channeldata[key].pbdata.id
for _, data in channeldata.items():
if data.pbdata.ecode != 0:
error_data = data
break
else:
data_id = channeldata.pbdata.id
if channeldata.pbdata.ecode != 0:
error_data = channeldata.pbdata
return data_id, error_data
def start(self, concurrency_idx): def start(self, concurrency_idx):
self._run = True self._run = True
while self._run: while self._run:
_profiler.record("{}{}-get_0".format(self._name, concurrency_idx)) _profiler.record("{}{}-get_0".format(self._name, concurrency_idx))
input_data = self._input.front(self._name) input_data = self._input.front(self._name)
_profiler.record("{}{}-get_1".format(self._name, concurrency_idx)) _profiler.record("{}{}-get_1".format(self._name, concurrency_idx))
data_id = None
output_data = None
error_data = None
logging.debug(self._log("input_data: {}".format(input_data))) logging.debug(self._log("input_data: {}".format(input_data)))
if isinstance(input_data, dict):
key = input_data.keys()[0]
data_id = input_data[key].id
for _, data in input_data.items():
if data.is_error != 0:
error_data = data
break
else:
data_id = input_data.id
if input_data.is_error != 0:
error_data = input_data
data_id, error_data = self._parse_channeldata(input_data)
output_data = None
if error_data is None: if error_data is None:
_profiler.record("{}{}-prep_0".format(self._name, _profiler.record("{}{}-prep_0".format(self._name,
concurrency_idx)) concurrency_idx))
...@@ -391,6 +449,7 @@ class Op(object): ...@@ -391,6 +449,7 @@ class Op(object):
_profiler.record("{}{}-prep_1".format(self._name, _profiler.record("{}{}-prep_1".format(self._name,
concurrency_idx)) concurrency_idx))
call_future = None
error_info = None error_info = None
if self.with_serving(): if self.with_serving():
for i in range(self._retry): for i in range(self._retry):
...@@ -398,7 +457,7 @@ class Op(object): ...@@ -398,7 +457,7 @@ class Op(object):
concurrency_idx)) concurrency_idx))
if self._timeout > 0: if self._timeout > 0:
try: try:
middata = func_timeout.func_timeout( call_future = func_timeout.func_timeout(
self._timeout, self._timeout,
self.midprocess, self.midprocess,
args=(data, )) args=(data, ))
...@@ -411,38 +470,48 @@ class Op(object): ...@@ -411,38 +470,48 @@ class Op(object):
error_info = "{}({}): {}".format( error_info = "{}({}): {}".format(
self._name, concurrency_idx, e) self._name, concurrency_idx, e)
else: else:
middata = self.midprocess(data) call_future = self.midprocess(data)
_profiler.record("{}{}-midp_1".format(self._name, _profiler.record("{}{}-midp_1".format(self._name,
concurrency_idx)) concurrency_idx))
if error_info is None:
data = middata
break
if i + 1 < self._retry: if i + 1 < self._retry:
error_info = None error_info = None
logging.warn( logging.warn(
self._log("warn: timeout, retry({})".format(i + self._log("warn: timeout, retry({})".format(i +
1))) 1)))
_profiler.record("{}{}-postp_0".format(self._name, _profiler.record("{}{}-postp_0".format(self._name,
concurrency_idx)) concurrency_idx))
if error_info is not None: if error_info is not None:
output_data = self.errorprocess(error_info) error_data = self.errorprocess(error_info, data_id)
output_data = ChannelData(pbdata=error_data)
else: else:
output_data = self.postprocess(data) if self.with_serving(): # use call_future
output_data = ChannelData(
if not isinstance(output_data, future=call_future,
python_service_channel_pb2.ChannelData): data_id=data_id,
raise TypeError( callback_func=self.postprocess)
self._log( else:
'output_data must be ChannelData type, but get {}'. post_data = self.postprocess(data)
format(type(output_data)))) if not isinstance(post_data, dict):
output_data.is_error = 0 raise TypeError(
self._log(
'output_data must be dict type, but get {}'.
format(type(output_data))))
pbdata = channel_pb2.ChannelData()
for name, value in post_data.items():
inst = channel_pb2.Inst()
inst.data = value.tobytes()
inst.name = name
inst.shape = np.array(
value.shape, dtype="int32").tobytes()
inst.type = str(value.dtype)
pbdata.insts.append(inst)
pbdata.ecode = 0
pbdata.id = data_id
output_data = ChannelData(pbdata=pbdata)
_profiler.record("{}{}-postp_1".format(self._name, _profiler.record("{}{}-postp_1".format(self._name,
concurrency_idx)) concurrency_idx))
output_data.id = data_id
else: else:
output_data = error_data output_data = ChannelData(pbdata=error_data)
_profiler.record("{}{}-push_0".format(self._name, concurrency_idx)) _profiler.record("{}{}-push_0".format(self._name, concurrency_idx))
for channel in self._outputs: for channel in self._outputs:
...@@ -498,13 +567,14 @@ class GeneralPythonService( ...@@ -498,13 +567,14 @@ class GeneralPythonService(
def _recive_out_channel_func(self): def _recive_out_channel_func(self):
while True: while True:
data = self._out_channel.front(self._name) channeldata = self._out_channel.front(self._name)
if not isinstance(data, python_service_channel_pb2.ChannelData): if not isinstance(channeldata, ChannelData):
raise TypeError( raise TypeError(
self._log('data must be ChannelData type, but get {}'. self._log('data must be ChannelData type, but get {}'.
format(type(data)))) format(type(channeldata))))
with self._cv: with self._cv:
self._globel_resp_dict[data.id] = data data_id = channeldata.pbdata.id
self._globel_resp_dict[data_id] = channeldata
self._cv.notify_all() self._cv.notify_all()
def _get_next_id(self): def _get_next_id(self):
...@@ -523,34 +593,58 @@ class GeneralPythonService( ...@@ -523,34 +593,58 @@ class GeneralPythonService(
def _pack_data_for_infer(self, request): def _pack_data_for_infer(self, request):
logging.debug(self._log('start inferce')) logging.debug(self._log('start inferce'))
data = python_service_channel_pb2.ChannelData() pbdata = channel_pb2.ChannelData()
data_id = self._get_next_id() data_id = self._get_next_id()
data.id = data_id pbdata.id = data_id
data.is_error = 0
for idx, name in enumerate(request.feed_var_names): for idx, name in enumerate(request.feed_var_names):
logging.debug( logging.debug(
self._log('name: {}'.format(request.feed_var_names[idx]))) self._log('name: {}'.format(request.feed_var_names[idx])))
logging.debug(self._log('data: {}'.format(request.feed_insts[idx]))) logging.debug(self._log('data: {}'.format(request.feed_insts[idx])))
inst = python_service_channel_pb2.Inst() inst = channel_pb2.Inst()
inst.data = request.feed_insts[idx] inst.data = request.feed_insts[idx]
inst.shape = request.shape[idx]
inst.name = name inst.name = name
data.insts.append(inst) inst.type = request.type[idx]
return data, data_id pbdata.insts.append(inst)
pbdata.ecode = 0 #TODO: parse request error
return ChannelData(pbdata=pbdata), data_id
def _pack_data_for_resp(self, data): def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get data')) logging.debug(self._log('get channeldata'))
resp = general_python_service_pb2.Response()
logging.debug(self._log('gen resp')) logging.debug(self._log('gen resp'))
logging.debug(data) resp = pyservice_pb2.Response()
resp.is_error = data.is_error resp.ecode = channeldata.pbdata.ecode
if resp.is_error == 0: if resp.ecode == 0:
for inst in data.insts: if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value:
logging.debug(self._log('append data')) for inst in channeldata.pbdata.insts:
resp.fetch_insts.append(inst.data) logging.debug(self._log('append data'))
logging.debug(self._log('append name')) resp.fetch_insts.append(inst.data)
resp.fetch_var_names.append(inst.name) logging.debug(self._log('append name'))
resp.fetch_var_names.append(inst.name)
logging.debug(self._log('append shape'))
resp.shape.append(inst.shape)
logging.debug(self._log('append type'))
resp.type.append(inst.type)
elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value:
feed = channeldata.futures.result()
if channeldata.callback_func is not None:
feed = channeldata.callback_func(feed)
for name, var in feed:
logging.debug(self._log('append data'))
resp.fetch_insts.append(var.tobytes())
logging.debug(self._log('append name'))
resp.fetch_var_names.append(name)
logging.debug(self._log('append shape'))
resp.shape.append(
np.array(
var.shape, dtype="int32").tobytes())
resp.type.append(str(var.dtype))
else:
raise TypeError(
self._log("Error type({}) in pbdata.type.".format(
self.pbdata.type)))
else: else:
resp.error_info = data.error_info resp.error_info = channeldata.pbdata.error_info
return resp return resp
def inference(self, request, context): def inference(self, request, context):
...@@ -558,6 +652,7 @@ class GeneralPythonService( ...@@ -558,6 +652,7 @@ class GeneralPythonService(
data, data_id = self._pack_data_for_infer(request) data, data_id = self._pack_data_for_infer(request)
_profiler.record("{}-prepack_1".format(self._name)) _profiler.record("{}-prepack_1".format(self._name))
resp_channeldata = None
for i in range(self._retry): for i in range(self._retry):
logging.debug(self._log('push data')) logging.debug(self._log('push data'))
_profiler.record("{}-push_0".format(self._name)) _profiler.record("{}-push_0".format(self._name))
...@@ -565,17 +660,17 @@ class GeneralPythonService( ...@@ -565,17 +660,17 @@ class GeneralPythonService(
_profiler.record("{}-push_1".format(self._name)) _profiler.record("{}-push_1".format(self._name))
logging.debug(self._log('wait for infer')) logging.debug(self._log('wait for infer'))
resp_data = None
_profiler.record("{}-fetch_0".format(self._name)) _profiler.record("{}-fetch_0".format(self._name))
resp_data = self._get_data_in_globel_resp_dict(data_id) resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self._name)) _profiler.record("{}-fetch_1".format(self._name))
if resp_data.is_error == 0: if resp_channeldata.pbdata.ecode == 0:
break break
logging.warn("retry({}): {}".format(i + 1, resp_data.error_info)) logging.warn("retry({}): {}".format(
i + 1, resp_channeldata.pbdata.error_info))
_profiler.record("{}-postpack_0".format(self._name)) _profiler.record("{}-postpack_0".format(self._name))
resp = self._pack_data_for_resp(resp_data) resp = self._pack_data_for_resp(resp_channeldata)
_profiler.record("{}-postpack_1".format(self._name)) _profiler.record("{}-postpack_1".format(self._name))
_profiler.print_profile() _profiler.print_profile()
return resp return resp
...@@ -600,7 +695,7 @@ class PyServer(object): ...@@ -600,7 +695,7 @@ class PyServer(object):
self._ops.append(op) self._ops.append(op)
def gen_desc(self): def gen_desc(self):
logging.info('here will generate desc for paas') logging.info('here will generate desc for PAAS')
pass pass
def prepare_server(self, port, worker_num): def prepare_server(self, port, worker_num):
...@@ -638,6 +733,10 @@ class PyServer(object): ...@@ -638,6 +733,10 @@ class PyServer(object):
th.start() th.start()
self._op_threads.append(th) self._op_threads.append(th)
def _stop_ops(self):
for op in self._ops:
op.stop()
def run_server(self): def run_server(self):
self._run_ops() self._run_ops()
server = grpc.server( server = grpc.server(
...@@ -647,12 +746,10 @@ class PyServer(object): ...@@ -647,12 +746,10 @@ class PyServer(object):
self._retry), server) self._retry), server)
server.add_insecure_port('[::]:{}'.format(self._port)) server.add_insecure_port('[::]:{}'.format(self._port))
server.start() server.start()
try: server.wait_for_termination()
for th in self._op_threads: self._stop_ops() # TODO
th.join() for th in self._op_threads:
server.join() th.join()
except KeyboardInterrupt:
server.stop(0)
def prepare_serving(self, op): def prepare_serving(self, op):
model_path = op._server_model model_path = op._server_model
...@@ -660,11 +757,11 @@ class PyServer(object): ...@@ -660,11 +757,11 @@ class PyServer(object):
device = op._device device = op._device
if device == "cpu": if device == "cpu":
cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format( cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
model_path, port) " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
else: else:
cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format( cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
model_path, port) " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
# run a server (not in PyServing) # run a server (not in PyServing)
logging.info("run a server (not in PyServing): {}".format(cmd)) logging.info("run a server (not in PyServing): {}".format(cmd))
return return
......
...@@ -49,6 +49,10 @@ def parse_args(): # pylint: disable=doc-string-missing ...@@ -49,6 +49,10 @@ def parse_args(): # pylint: disable=doc-string-missing
type=int, type=int,
default=512 * 1024 * 1024, default=512 * 1024 * 1024,
help="Limit sizes of messages") help="Limit sizes of messages")
parser.add_argument(
"--use_multilang",
action='store_true',
help="Use Multi-language-service")
return parser.parse_args() return parser.parse_args()
...@@ -63,6 +67,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -63,6 +67,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
ir_optim = args.ir_optim ir_optim = args.ir_optim
max_body_size = args.max_body_size max_body_size = args.max_body_size
use_mkl = args.use_mkl use_mkl = args.use_mkl
use_multilang = args.use_multilang
if model == "": if model == "":
print("You must specify your serving model") print("You must specify your serving model")
...@@ -79,14 +84,19 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -79,14 +84,19 @@ def start_standard_model(): # pylint: disable=doc-string-missing
op_seq_maker.add_op(general_infer_op) op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(general_response_op) op_seq_maker.add_op(general_response_op)
server = serving.Server() server = None
server.set_op_sequence(op_seq_maker.get_op_sequence()) if use_multilang:
server.set_num_threads(thread_num) server = serving.MultiLangServer()
server.set_memory_optimize(mem_optim) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_ir_optimize(ir_optim) else:
server.use_mkl(use_mkl) server = serving.Server()
server.set_max_body_size(max_body_size) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_port(port) server.set_num_threads(thread_num)
server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.use_mkl(use_mkl)
server.set_max_body_size(max_body_size)
server.set_port(port)
server.load_model_config(model) server.load_model_config(model)
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(workdir=workdir, port=port, device=device)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册