提交 0af45cc0 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #667 from barrierye/pyserving

[WIP] Pipeline Serving
......@@ -143,7 +143,6 @@ function(grpc_protobuf_generate_python SRCS)
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
endfunction()
# Print and set the protobuf library information,
# finish this cmake process and exit from this file.
macro(PROMPT_PROTOBUF_LIB)
......
......@@ -86,6 +86,7 @@ add_custom_command(TARGET general_model_config_py_proto POST_BUILD
COMMAND cp *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMENT "Copy generated general_model_config proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET multi_lang_general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMAND cp *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
......
if (CLIENT)
file(INSTALL pipeline DESTINATION paddle_serving_client)
execute_process(COMMAND ${PYTHON_EXECUTABLE} run_codegen.py
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/paddle_serving_client/pipeline/proto)
file(GLOB_RECURSE SERVING_CLIENT_PY_FILES paddle_serving_client/*.py)
set(PY_FILES ${SERVING_CLIENT_PY_FILES})
SET(PACKAGE_NAME "serving_client")
......@@ -7,8 +10,14 @@ endif()
if (SERVER)
if (NOT WITH_GPU)
file(INSTALL pipeline DESTINATION paddle_serving_server)
execute_process(COMMAND ${PYTHON_EXECUTABLE} run_codegen.py
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/paddle_serving_server/pipeline/proto)
file(GLOB_RECURSE SERVING_SERVER_PY_FILES paddle_serving_server/*.py)
else()
file(INSTALL pipeline DESTINATION paddle_serving_server_gpu)
execute_process(COMMAND ${PYTHON_EXECUTABLE} run_codegen.py
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/paddle_serving_server_gpu/pipeline/proto)
file(GLOB_RECURSE SERVING_SERVER_PY_FILES paddle_serving_server_gpu/*.py)
endif()
set(PY_FILES ${SERVING_SERVER_PY_FILES})
......
......@@ -29,6 +29,6 @@ imdb_dataset.load_resource(sys.argv[2])
for line in sys.stdin:
word_ids, label = imdb_dataset.get_words_and_label(line)
feed = {"words": word_ids}
fetch = ["acc", "cost", "prediction"]
fetch = ["prediction"]
fetch_map = client.predict(feed=feed, fetch=fetch)
print("{} {}".format(fetch_map["prediction"][0], label[0]))
use_multithread: true
client_type: brpc
retry: 1
profile: false
prot: 8080
worker_num: 2
wget --no-check-certificate https://fleet.bj.bcebos.com/text_classification_data.tar.gz
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/imdb-demo/imdb_model.tar.gz
tar -zxvf text_classification_data.tar.gz
tar -zxvf imdb_model.tar.gz
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client.pipeline import PipelineClient
import numpy as np
from line_profiler import LineProfiler
client = PipelineClient()
client.connect('localhost:8080')
lp = LineProfiler()
lp_wrapper = lp(client.predict)
words = 'i am very sad | 0'
for i in range(1):
fetch_map = lp_wrapper(feed_dict={"words": words}, fetch=["prediction"])
print(fetch_map)
#lp.print_stats()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server.pipeline import PipelineServer
from paddle_serving_server.pipeline.proto import pipeline_service_pb2
from paddle_serving_server.pipeline.channel import ChannelDataEcode
import numpy as np
import logging
from paddle_serving_app.reader import IMDBDataset
_LOGGER = logging.getLogger(__name__)
logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M',
level=logging.DEBUG)
class ImdbRequestOp(RequestOp):
def load_user_resources(self):
self.imdb_dataset = IMDBDataset()
self.imdb_dataset.load_resource('imdb.vocab')
def unpack_request_package(self, request):
dictdata = {}
for idx, key in enumerate(request.key):
if key != "words":
continue
words = request.value[idx]
word_ids, _ = self.imdb_dataset.get_words_and_label(words)
dictdata[key] = np.array(word_ids)
return dictdata
class CombineOp(Op):
def preprocess(self, input_data):
combined_prediction = 0
for op_name, data in input_data.items():
_LOGGER.info("{}: {}".format(op_name, data["prediction"]))
combined_prediction += data["prediction"]
data = {"prediction": combined_prediction / 2}
return data
class ImdbResponseOp(ResponseOp):
# Here ImdbResponseOp is consistent with the default ResponseOp implementation
def pack_response_package(self, channeldata):
resp = pipeline_service_pb2.Response()
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
feed = channeldata.parse()
# ndarray to string
for name, var in feed.items():
resp.value.append(var.__repr__())
resp.key.append(name)
else:
resp.error_info = channeldata.error_info
return resp
read_op = ImdbRequestOp()
bow_op = Op(name="bow",
input_ops=[read_op],
server_endpoints=["127.0.0.1:9393"],
fetch_list=["prediction"],
client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
concurrency=1,
timeout=-1,
retry=1)
cnn_op = Op(name="cnn",
input_ops=[read_op],
server_endpoints=["127.0.0.1:9292"],
fetch_list=["prediction"],
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
concurrency=1,
timeout=-1,
retry=1)
combine_op = CombineOp(
name="combine",
input_ops=[bow_op, cnn_op],
concurrency=1,
timeout=-1,
retry=1)
# fetch output of bow_op
# response_op = ImdbResponseOp(input_ops=[bow_op])
# fetch output of combine_op
response_op = ImdbResponseOp(input_ops=[combine_op])
# use default ResponseOp implementation
# response_op = ResponseOp(input_ops=[combine_op])
server = PipelineServer()
server.set_response_op(response_op)
server.prepare_server('config.yml')
server.run_server()
......@@ -16,10 +16,16 @@ def prase(pid_str, time_str, counter):
if len(name_list) == 2:
name = name_list[0]
else:
name = name_list[0] + "_" + name_list[1]
name = "_".join(name_list[:-1])
name_list = name.split("#")
if len(name_list) > 1:
tid = name_list[-1]
name = "#".join(name_list[:-1])
else:
tid = 0
event_dict = {}
event_dict["name"] = name
event_dict["tid"] = 0
event_dict["tid"] = tid
event_dict["pid"] = pid
event_dict["ts"] = ts
event_dict["ph"] = ph
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from operator import Op, RequestOp, ResponseOp
from pipeline_server import PipelineServer
from pipeline_client import PipelineClient
此差异已折叠。
......@@ -12,3 +12,473 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
from numpy import *
from .proto import pipeline_service_pb2
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_op_name_gen = NameGenerator("Op")
class Op(object):
def __init__(self,
name=None,
input_ops=[],
server_endpoints=[],
fetch_list=[],
client_config=None,
concurrency=1,
timeout=-1,
retry=1):
if name is None:
name = _op_name_gen.next()
self._is_run = False
self.name = name # to identify the type of OP, it must be globally unique
self.concurrency = concurrency # amount of concurrency
self.set_input_ops(input_ops)
self._server_endpoints = server_endpoints
self.with_serving = False
if len(self._server_endpoints) != 0:
self.with_serving = True
self._client_config = client_config
self._fetch_names = fetch_list
self._timeout = timeout
self._retry = max(1, retry)
self._input = None
self._outputs = []
self._profiler = None
def init_profiler(self, profiler):
self._profiler = profiler
def _profiler_record(self, string):
if self._profiler is None:
return
self._profiler.record(string)
def init_client(self, client_type, client_config, server_endpoints,
fetch_names):
if self.with_serving == False:
_LOGGER.debug("{} no client".format(self.name))
return
_LOGGER.debug("{} client_config: {}".format(self.name, client_config))
_LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc':
self._client = Client()
self._client.load_client_config(client_config)
elif client_type == 'grpc':
self._client = MultiLangClient()
else:
raise ValueError("unknow client type: {}".format(client_type))
self._client.connect(server_endpoints)
self._fetch_names = fetch_names
def _get_input_channel(self):
return self._input
def get_input_ops(self):
return self._input_ops
def set_input_ops(self, ops):
if not isinstance(ops, list):
ops = [] if ops is None else [ops]
self._input_ops = []
for op in ops:
if not isinstance(op, Op):
raise TypeError(
self._log('input op must be Op type, not {}'.format(
type(op))))
self._input_ops.append(op)
def add_input_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('input channel must be Channel type, not {}'.format(
type(channel))))
channel.add_consumer(self.name)
self._input = channel
def _get_output_channels(self):
return self._outputs
def add_output_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('output channel must be Channel type, not {}'.format(
type(channel))))
channel.add_producer(self.name)
self._outputs.append(channel)
def preprocess(self, input_dicts):
# multiple previous Op
if len(input_dicts) != 1:
raise NotImplementedError(
'this Op has multiple previous inputs. Please override this func.'
)
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, feed_dict):
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0:
raise NotImplementedError(
"{} Please override preprocess func.".format(err_info))
_LOGGER.debug(self._log('feed_dict: {}'.format(feed_dict)))
_LOGGER.debug(self._log('fetch: {}'.format(self._fetch_names)))
call_result = self._client.predict(
feed=feed_dict, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result"))
return call_result
def postprocess(self, fetch_dict):
return fetch_dict
def stop(self):
self._is_run = False
def _parse_channeldata(self, channeldata_dict):
data_id, error_channeldata = None, None
parsed_data = {}
key = list(channeldata_dict.keys())[0]
data_id = channeldata_dict[key].id
for name, data in channeldata_dict.items():
if data.ecode != ChannelDataEcode.OK.value:
error_channeldata = data
break
parsed_data[name] = data.parse()
return data_id, error_channeldata, parsed_data
def _push_to_output_channels(self, data, channels, name=None):
if name is None:
name = self.name
for channel in channels:
channel.push(data, name)
def start_with_process(self, client_type):
proces = []
for concurrency_idx in range(self.concurrency):
p = multiprocessing.Process(
target=self._run,
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type))
p.start()
proces.append(p)
return proces
def start_with_thread(self, client_type):
threads = []
for concurrency_idx in range(self.concurrency):
t = threading.Thread(
target=self._run,
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type))
t.start()
threads.append(t)
return threads
def load_user_resources(self):
pass
def _run_preprocess(self, parsed_data, data_id, log_func):
preped_data, error_channeldata = None, None
try:
preped_data = self.preprocess(parsed_data)
except NotImplementedError as e:
# preprocess function not implemented
error_info = log_func(e)
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info,
data_id=data_id)
except TypeError as e:
# Error type in channeldata.datatype
error_info = log_func(e)
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info,
data_id=data_id)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
return preped_data, error_channeldata
def _run_process(self, preped_data, data_id, log_func):
midped_data, error_channeldata = None, None
if self.with_serving:
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
try:
midped_data = self.process(preped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
_LOGGER.error(error_info)
else:
for i in range(self._retry):
try:
midped_data = func_timeout.func_timeout(
self._timeout, self.process, args=(preped_data, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
error_info = log_func(e)
_LOGGER.error(error_info)
else:
_LOGGER.warn(
log_func("timeout, retry({})".format(i + 1)))
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
_LOGGER.error(error_info)
break
else:
break
if ecode != ChannelDataEcode.OK.value:
error_channeldata = ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id)
elif midped_data is None:
# op client return None
error_channeldata = ChannelData(
ecode=ChannelDataEcode.CLIENT_ERROR.value,
error_info=log_func(
"predict failed. pls check the server side."),
data_id=data_id)
else:
midped_data = preped_data
return midped_data, error_channeldata
def _run_postprocess(self, midped_data, data_id, log_func):
output_data, error_channeldata = None, None
try:
postped_data = self.postprocess(midped_data)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
return output_data, error_channeldata
if not isinstance(postped_data, dict):
error_info = log_func("output of postprocess funticon must be " \
"dict type, but get {}".format(type(postped_data)))
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
return output_data, error_channeldata
err, _ = ChannelData.check_npdata(postped_data)
if err == 0:
output_data = ChannelData(
ChannelDataType.CHANNEL_NPDATA.value,
npdata=postped_data,
data_id=data_id)
else:
output_data = ChannelData(
ChannelDataType.DICT.value,
dictdata=postped_data,
data_id=data_id)
return output_data, error_channeldata
def _run(self, concurrency_idx, input_channel, output_channels,
client_type):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
return log_func
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
# create client based on client_type
self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
# load user resources
self.load_user_resources()
self._is_run = True
while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
channeldata_dict = input_channel.front(self.name)
self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
_LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
data_id, error_channeldata, parsed_data = self._parse_channeldata(
channeldata_dict)
# error data in predecessor Op
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# preprecess
self._profiler_record("{}-prep#{}_0".format(op_info_prefix, tid))
preped_data, error_channeldata = self._run_preprocess(parsed_data,
data_id, log)
self._profiler_record("{}-prep#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# process
self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid))
midped_data, error_channeldata = self._run_process(preped_data,
data_id, log)
self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# postprocess
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
output_data, error_channeldata = self._run_postprocess(midped_data,
data_id, log)
self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# push data to channel (if run succ)
self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid))
self._push_to_output_channels(output_data, output_channels)
self._profiler_record("{}-push#{}_1".format(op_info_prefix, tid))
def _log(self, info):
return "{} {}".format(self.name, info)
class RequestOp(Op):
""" RequestOp do not run preprocess, process, postprocess. """
def __init__(self, concurrency=1):
# PipelineService.name = "#G"
super(RequestOp, self).__init__(
name="#G", input_ops=[], concurrency=concurrency)
# load user resources
self.load_user_resources()
def unpack_request_package(self, request):
dictdata = {}
for idx, key in enumerate(request.key):
data = request.value[idx]
try:
data = eval(data)
except Exception as e:
pass
dictdata[key] = data
return dictdata
class ResponseOp(Op):
""" ResponseOp do not run preprocess, process, postprocess. """
def __init__(self, input_ops, concurrency=1):
super(ResponseOp, self).__init__(
name="#R", input_ops=input_ops, concurrency=concurrency)
# load user resources
self.load_user_resources()
def pack_response_package(self, channeldata):
resp = pipeline_service_pb2.Response()
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = channeldata.parse()
# ndarray to string:
# https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
for name, var in feed.items():
resp.value.append(var.__repr__())
resp.key.append(name)
elif channeldata.datatype == ChannelDataType.DICT.value:
feed = channeldata.parse()
for name, var in feed.items():
if not isinstance(var, str):
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"fetch var type must be str({}).".format(
type(var)))
break
resp.value.append(var)
resp.key.append(name)
else:
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"Error type({}) in datatype.".format(channeldata.datatype))
_LOGGER.error(resp.error_info)
else:
resp.error_info = channeldata.error_info
return resp
class VirtualOp(Op):
''' For connecting two channels. '''
def __init__(self, name, concurrency=1):
super(VirtualOp, self).__init__(
name=name, input_ops=None, concurrency=concurrency)
self._virtual_pred_ops = []
def add_virtual_pred_op(self, op):
self._virtual_pred_ops.append(op)
def add_output_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('output channel must be Channel type, not {}'.format(
type(channel))))
for op in self._virtual_pred_ops:
channel.add_producer(op.name)
self._outputs.append(channel)
def _run(self, concurrency_idx, input_channel, output_channels,
client_type):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
return log_func
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
self._is_run = True
while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
channeldata_dict = input_channel.front(self.name)
self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid))
for name, data in channeldata_dict.items():
self._push_to_output_channels(
data, channels=output_channels, name=name)
self._profiler_record("{}-push#{}_1".format(op_info_prefix, tid))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import grpc
import numpy as np
from numpy import *
import logging
import functools
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
_LOGGER = logging.getLogger(__name__)
class PipelineClient(object):
def __init__(self):
self._channel = None
def connect(self, endpoint):
self._channel = grpc.insecure_channel(endpoint)
self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
self._channel)
def _pack_request_package(self, feed_dict):
req = pipeline_service_pb2.Request()
for key, value in feed_dict.items():
req.key.append(key)
if isinstance(value, np.ndarray):
req.value.append(value.__repr__())
elif isinstance(value, str):
req.value.append(value)
elif isinstance(value, list):
req.value.append(np.array(value).__repr__())
else:
raise TypeError("only str and np.ndarray type is supported: {}".
format(type(value)))
return req
def _unpack_response_package(self, resp, fetch):
if resp.ecode != 0:
return {"ecode": resp.ecode, "error_info": resp.error_info}
fetch_map = {"ecode": resp.ecode}
for idx, key in enumerate(resp.key):
if key not in fetch:
continue
data = resp.value[idx]
try:
data = eval(data)
except Exception as e:
pass
fetch_map[key] = data
return fetch_map
def predict(self, feed_dict, fetch, asyn=False):
if not isinstance(feed_dict, dict):
raise TypeError(
"feed must be dict type with format: {name: value}.")
if not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict)
if not asyn:
resp = self._stub.inference(req)
return self._unpack_response_package(resp)
else:
call_future = self._stub.inference.future(req)
return PipelinePredictFuture(
call_future,
functools.partial(
self._unpack_response_package, fetch=fetch))
class PipelinePredictFuture(object):
def __init__(self, call_future, callback_func):
self.call_future_ = call_future
self.callback_func_ = callback_func
def result(self):
resp = self.call_future_.result()
return self.callback_func_(resp)
......@@ -12,3 +12,440 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
import multiprocessing.queues
import sys
if sys.version_info.major == 2:
import Queue
elif sys.version_info.major == 3:
import queue as Queue
else:
raise Exception("Error Python version")
import os
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import numpy as np
import grpc
import logging
import random
import time
import func_timeout
import enum
import collections
import copy
import socket
from contextlib import closing
import yaml
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
from .operator import Op, RequestOp, ResponseOp, VirtualOp
from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcode, ChannelDataType
from .profiler import TimeProfiler
from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_profiler = TimeProfiler()
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, in_channel, out_channel, unpack_func, pack_func,
retry=2):
super(PipelineService, self).__init__()
self.name = "#G"
self.set_in_channel(in_channel)
self.set_out_channel(out_channel)
_LOGGER.debug(self._log(in_channel.debug()))
_LOGGER.debug(self._log(out_channel.debug()))
#TODO:
# multi-lock for different clients
# diffenert lock for server and client
self._id_lock = threading.Lock()
self._cv = threading.Condition()
self._globel_resp_dict = {}
self._id_counter = 0
self._reset_max_id = 1000000000000000000
self._retry = retry
self._is_run = True
self._pack_func = pack_func
self._unpack_func = unpack_func
self._recive_func = threading.Thread(
target=PipelineService._recive_out_channel_func, args=(self, ))
self._recive_func.start()
def _log(self, info_str):
return "[{}] {}".format(self.name, info_str)
def set_in_channel(self, in_channel):
if not isinstance(in_channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('in_channel must be Channel type, but get {}'.format(
type(in_channel))))
in_channel.add_producer(self.name)
self._in_channel = in_channel
def set_out_channel(self, out_channel):
if not isinstance(out_channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('out_channel must be Channel type, but get {}'.format(
type(out_channel))))
out_channel.add_consumer(self.name)
self._out_channel = out_channel
def stop(self):
self._is_run = False
def _recive_out_channel_func(self):
while self._is_run:
channeldata_dict = self._out_channel.front(self.name)
if len(channeldata_dict) != 1:
raise Exception("out_channel cannot have multiple input ops")
(_, channeldata), = channeldata_dict.items()
if not isinstance(channeldata, ChannelData):
raise TypeError(
self._log('data must be ChannelData type, but get {}'.
format(type(channeldata))))
with self._cv:
data_id = channeldata.id
self._globel_resp_dict[data_id] = channeldata
self._cv.notify_all()
def _get_next_id(self):
with self._id_lock:
if self._id_counter >= self._reset_max_id:
self._id_counter -= self._reset_max_id
self._id_counter += 1
return self._id_counter - 1
def _get_data_in_globel_resp_dict(self, data_id):
resp = None
with self._cv:
while data_id not in self._globel_resp_dict:
self._cv.wait()
resp = self._globel_resp_dict.pop(data_id)
self._cv.notify_all()
return resp
def _pack_data_for_infer(self, request):
_LOGGER.debug(self._log('start inferce'))
data_id = self._get_next_id()
dictdata = None
try:
dictdata = self._unpack_func(request)
except Exception as e:
return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
error_info="rpc package error: {}".format(e),
data_id=data_id), data_id
else:
return ChannelData(
datatype=ChannelDataType.DICT.value,
dictdata=dictdata,
data_id=data_id), data_id
def _pack_data_for_resp(self, channeldata):
_LOGGER.debug(self._log('get channeldata'))
return self._pack_func(channeldata)
def inference(self, request, context):
_profiler.record("{}-prepack_0".format(self.name))
data, data_id = self._pack_data_for_infer(request)
_profiler.record("{}-prepack_1".format(self.name))
resp_channeldata = None
for i in range(self._retry):
_LOGGER.debug(self._log('push data'))
_profiler.record("{}-push_0".format(self.name))
self._in_channel.push(data, self.name)
_profiler.record("{}-push_1".format(self.name))
_LOGGER.debug(self._log('wait for infer'))
_profiler.record("{}-fetch_0".format(self.name))
resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self.name))
if resp_channeldata.ecode == ChannelDataEcode.OK.value:
break
if i + 1 < self._retry:
_LOGGER.warn("retry({}): {}".format(
i + 1, resp_channeldata.error_info))
_profiler.record("{}-postpack_0".format(self.name))
resp = self._pack_data_for_resp(resp_channeldata)
_profiler.record("{}-postpack_1".format(self.name))
_profiler.print_profile()
return resp
class PipelineServer(object):
def __init__(self):
self._channels = []
self._actual_ops = []
self._port = None
self._worker_num = None
self._in_channel = None
self._out_channel = None
self._response_op = None
self._pack_func = None
self._unpack_func = None
def add_channel(self, channel):
self._channels.append(channel)
def gen_desc(self):
_LOGGER.info('here will generate desc for PAAS')
pass
def set_response_op(self, response_op):
if not isinstance(response_op, Op):
raise Exception("response_op must be Op type.")
if len(response_op.get_input_ops()) != 1:
raise Exception("response_op can only have one previous op.")
self._response_op = response_op
def _topo_sort(self, response_op):
if response_op is None:
raise Exception("response_op has not been set.")
def get_use_ops(root):
# root: response_op
unique_names = set()
use_ops = set()
succ_ops_of_use_op = {} # {op_name: succ_ops}
que = Queue.Queue()
que.put(root)
#use_ops.add(root)
#unique_names.add(root.name)
while que.qsize() != 0:
op = que.get()
for pred_op in op.get_input_ops():
if pred_op.name not in succ_ops_of_use_op:
succ_ops_of_use_op[pred_op.name] = []
if op != root:
succ_ops_of_use_op[pred_op.name].append(op)
if pred_op not in use_ops:
que.put(pred_op)
use_ops.add(pred_op)
# check the name of op is globally unique
if pred_op.name in unique_names:
raise Exception("the name of Op must be unique: {}".
format(pred_op.name))
unique_names.add(pred_op.name)
return use_ops, succ_ops_of_use_op
use_ops, out_degree_ops = get_use_ops(response_op)
if len(use_ops) <= 1:
raise Exception(
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
)
name2op = {op.name: op for op in use_ops}
out_degree_num = {
name: len(ops)
for name, ops in out_degree_ops.items()
}
que_idx = 0 # scroll queue
ques = [Queue.Queue() for _ in range(2)]
zero_indegree_num = 0
for op in use_ops:
if len(op.get_input_ops()) == 0:
zero_indegree_num += 1
if zero_indegree_num != 1:
raise Exception("DAG contains multiple input Ops")
last_op = response_op.get_input_ops()[0]
ques[que_idx].put(last_op)
# topo sort to get dag_views
dag_views = []
sorted_op_num = 0
while True:
que = ques[que_idx]
next_que = ques[(que_idx + 1) % 2]
dag_view = []
while que.qsize() != 0:
op = que.get()
dag_view.append(op)
sorted_op_num += 1
for pred_op in op.get_input_ops():
out_degree_num[pred_op.name] -= 1
if out_degree_num[pred_op.name] == 0:
next_que.put(pred_op)
dag_views.append(dag_view)
if next_que.qsize() == 0:
break
que_idx = (que_idx + 1) % 2
if sorted_op_num < len(use_ops):
raise Exception("not legal DAG")
# create channels and virtual ops
def gen_channel(name_gen):
channel = None
if self._use_multithread:
channel = ThreadChannel(name=name_gen.next())
else:
channel = ProcessChannel(self._manager, name=name_gen.next())
return channel
def gen_virtual_op(name_gen):
return VirtualOp(name=name_gen.next())
virtual_op_name_gen = NameGenerator("vir")
channel_name_gen = NameGenerator("chl")
virtual_ops = []
channels = []
input_channel = None
actual_view = None
dag_views = list(reversed(dag_views))
for v_idx, view in enumerate(dag_views):
if v_idx + 1 >= len(dag_views):
break
next_view = dag_views[v_idx + 1]
if actual_view is None:
actual_view = view
actual_next_view = []
pred_op_of_next_view_op = {}
for op in actual_view:
# find actual succ op in next view and create virtual op
for succ_op in out_degree_ops[op.name]:
if succ_op in next_view:
if succ_op not in actual_next_view:
actual_next_view.append(succ_op)
if succ_op.name not in pred_op_of_next_view_op:
pred_op_of_next_view_op[succ_op.name] = []
pred_op_of_next_view_op[succ_op.name].append(op)
else:
# create virtual op
virtual_op = gen_virtual_op(virtual_op_name_gen)
virtual_ops.append(virtual_op)
out_degree_ops[virtual_op.name] = [succ_op]
actual_next_view.append(virtual_op)
pred_op_of_next_view_op[virtual_op.name] = [op]
virtual_op.add_virtual_pred_op(op)
actual_view = actual_next_view
# create channel
processed_op = set()
for o_idx, op in enumerate(actual_next_view):
if op.name in processed_op:
continue
channel = gen_channel(channel_name_gen)
channels.append(channel)
_LOGGER.debug("{} => {}".format(channel.name, op.name))
op.add_input_channel(channel)
pred_ops = pred_op_of_next_view_op[op.name]
if v_idx == 0:
input_channel = channel
else:
# if pred_op is virtual op, it will use ancestors as producers to channel
for pred_op in pred_ops:
_LOGGER.debug("{} => {}".format(pred_op.name,
channel.name))
pred_op.add_output_channel(channel)
processed_op.add(op.name)
# find same input op to combine channel
for other_op in actual_next_view[o_idx + 1:]:
if other_op.name in processed_op:
continue
other_pred_ops = pred_op_of_next_view_op[other_op.name]
if len(other_pred_ops) != len(pred_ops):
continue
same_flag = True
for pred_op in pred_ops:
if pred_op not in other_pred_ops:
same_flag = False
break
if same_flag:
_LOGGER.debug("{} => {}".format(channel.name,
other_op.name))
other_op.add_input_channel(channel)
processed_op.add(other_op.name)
output_channel = gen_channel(channel_name_gen)
channels.append(output_channel)
last_op.add_output_channel(output_channel)
pack_func, unpack_func = None, None
pack_func = self._response_op.pack_response_package
self._actual_ops = virtual_ops
for op in use_ops:
if len(op.get_input_ops()) == 0:
unpack_func = op.unpack_request_package
continue
self._actual_ops.append(op)
self._channels = channels
for c in channels:
_LOGGER.debug(c.debug())
return input_channel, output_channel, pack_func, unpack_func
def _port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
return result != 0
def prepare_server(self, yml_file):
with open(yml_file) as f:
yml_config = yaml.load(f.read())
self._port = yml_config.get('port', 8080)
if not self._port_is_available(self._port):
raise SystemExit("Prot {} is already used".format(self._port))
self._worker_num = yml_config.get('worker_num', 2)
self._retry = yml_config.get('retry', 1)
self._client_type = yml_config.get('client_type', 'brpc')
self._use_multithread = yml_config.get('use_multithread', True)
profile = yml_config.get('profile', False)
if not self._use_multithread:
self._manager = multiprocessing.Manager()
if profile:
raise Exception(
"profile cannot be used in multiprocess version temporarily")
_profiler.enable(profile)
input_channel, output_channel, self._pack_func, self._unpack_func = self._topo_sort(
self._response_op)
self._in_channel = input_channel
self._out_channel = output_channel
for op in self._actual_ops:
if op.with_serving:
self.prepare_serving(op)
self.gen_desc()
def _run_ops(self):
threads_or_proces = []
for op in self._actual_ops:
op.init_profiler(_profiler)
if self._use_multithread:
threads_or_proces.extend(
op.start_with_thread(self._client_type))
else:
threads_or_proces.extend(
op.start_with_process(self._client_type))
return threads_or_proces
def _stop_all(self, service):
service.stop()
for op in self._actual_ops:
op.stop()
for chl in self._channels:
chl.stop()
def run_server(self):
op_threads_or_proces = self._run_ops()
service = PipelineService(self._in_channel, self._out_channel,
self._unpack_func, self._pack_func,
self._retry)
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(service,
server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.start()
server.wait_for_termination()
self._stop_all() # TODO
for x in op_threads_or_proces:
x.join()
def prepare_serving(self, op):
# run a server (not in PyServing)
_LOGGER.info("run a server (not in PyServing)")
......@@ -12,3 +12,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import os
import sys
import logging
if sys.version_info.major == 2:
import Queue
elif sys.version_info.major == 3:
import queue as Queue
else:
raise Exception("Error Python version")
import time
_LOGGER = logging.getLogger(__name__)
class TimeProfiler(object):
def __init__(self):
self._pid = os.getpid()
self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
self._time_record = Queue.Queue()
self._enable = False
def enable(self, enable):
self._enable = enable
def record(self, name_with_tag):
if self._enable is False:
return
name_with_tag = name_with_tag.split("_")
tag = name_with_tag[-1]
name = '_'.join(name_with_tag[:-1])
self._time_record.put((name, tag, int(round(time.time() * 1000000))))
def print_profile(self):
if self._enable is False:
return
print_str = self._print_head
tmp = {}
while not self._time_record.empty():
name, tag, timestamp = self._time_record.get()
if name in tmp:
ptag, ptimestamp = tmp.pop(name)
print_str += "{}_{}:{} ".format(name, ptag, ptimestamp)
print_str += "{}_{}:{} ".format(name, tag, timestamp)
else:
tmp[name] = (tag, timestamp)
print_str += "\n"
sys.stderr.write(print_str)
for name, item in tmp.items():
tag, timestamp = item
self._time_record.put((name, tag, timestamp))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.pipeline_serving;
message Request {
repeated string key = 1;
repeated string value = 2;
};
message Response {
repeated string key = 1;
repeated string value = 2;
required int32 ecode = 3;
optional string error_info = 4;
};
service PipelineService {
rpc inference(Request) returns (Response) {}
};
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs protoc with the gRPC plugin to generate messages and gRPC stubs."""
from grpc_tools import protoc
protoc.main((
'',
'-I.',
'--python_out=.',
'--grpc_python_out=.',
'pipeline_service.proto', ))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
class NameGenerator(object):
def __init__(self, prefix):
self._idx = -1
self._prefix = prefix
def next(self):
self._idx += 1
return "{}{}".format(self._prefix, self._idx)
numpy>=1.12, <=1.16.4 ; python_version<"3.5"
grpcio-tools>=1.28.1
grpcio>=1.28.1
func-timeout>=4.3.5
......@@ -65,11 +65,14 @@ REQUIRED_PACKAGES = [
if not find_package("paddlepaddle") and not find_package("paddlepaddle-gpu"):
REQUIRED_PACKAGES.append("paddlepaddle")
packages=['paddle_serving_client',
'paddle_serving_client.proto',
'paddle_serving_client.io',
'paddle_serving_client.metric',
'paddle_serving_client.utils',]
'paddle_serving_client.metric',
'paddle_serving_client.utils',
'paddle_serving_client.pipeline',
'paddle_serving_client.pipeline.proto']
package_data={'paddle_serving_client': ['serving_client.so','lib/*'],}
package_dir={'paddle_serving_client':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client',
......@@ -77,10 +80,14 @@ package_dir={'paddle_serving_client':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto',
'paddle_serving_client.io':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/io',
'paddle_serving_client.metric':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/metric',
'paddle_serving_client.utils':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/utils',}
'paddle_serving_client.metric':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/metric',
'paddle_serving_client.utils':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/utils',
'paddle_serving_client.pipeline':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/pipeline',
'paddle_serving_client.pipeline.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/pipeline/proto'}
setup(
name='paddle-serving-client',
......
......@@ -42,12 +42,18 @@ REQUIRED_PACKAGES = [
]
packages=['paddle_serving_server',
'paddle_serving_server.proto']
'paddle_serving_server.proto',
'paddle_serving_server.pipeline',
'paddle_serving_server.pipeline.proto']
package_dir={'paddle_serving_server':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server',
'paddle_serving_server.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto'}
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto',
'paddle_serving_server.pipeline':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/pipeline',
'paddle_serving_server.pipeline.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/pipeline/proto'}
setup(
name='paddle-serving-server',
......
......@@ -43,12 +43,18 @@ REQUIRED_PACKAGES = [
packages=['paddle_serving_server_gpu',
'paddle_serving_server_gpu.proto']
'paddle_serving_server_gpu.proto',
'paddle_serving_server_gpu.pipeline',
'paddle_serving_server_gpu.pipeline.proto']
package_dir={'paddle_serving_server_gpu':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu',
'paddle_serving_server_gpu.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto'}
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto',
'paddle_serving_server_gpu.pipeline':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/pipeline',
'paddle_serving_server_gpu.pipeline.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/pipeline/proto'}
setup(
name='paddle-serving-server-gpu',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册