提交 dcd959b7 编写于 作者: B barrierye

refactor pipeline server

上级 8aafd5ed
......@@ -39,9 +39,6 @@ py_grpc_proto_compile(multi_lang_general_model_service_py_proto SRCS proto/multi
add_custom_target(multi_lang_general_model_service_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(multi_lang_general_model_service_py_proto multi_lang_general_model_service_py_proto_init)
py_grpc_proto_compile(general_python_service_py_proto SRCS proto/general_python_service.proto)
add_custom_target(general_python_service_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(general_python_service_py_proto general_python_service_py_proto_init)
if (CLIENT)
py_proto_compile(sdk_configure_py_proto SRCS proto/sdk_configure.proto)
add_custom_target(sdk_configure_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
......@@ -65,11 +62,6 @@ add_custom_command(TARGET multi_lang_general_model_service_py_proto POST_BUILD
COMMENT "Copy generated multi_lang_general_model_service proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET general_python_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMAND cp general_python_service*.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMENT "Copy generated general_python_service proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
if (APP)
......@@ -85,10 +77,6 @@ py_proto_compile(server_config_py_proto SRCS proto/server_configure.proto)
add_custom_target(server_config_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(server_config_py_proto server_config_py_proto_init)
py_proto_compile(pyserving_channel_py_proto SRCS proto/pyserving_channel.proto)
add_custom_target(pyserving_channel_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(pyserving_channel_py_proto pyserving_channel_py_proto_init)
if (NOT WITH_GPU)
add_custom_command(TARGET server_config_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
......@@ -103,18 +91,6 @@ add_custom_command(TARGET general_model_config_py_proto POST_BUILD
COMMENT "Copy generated general_model_config proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET general_python_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMAND cp general_python_service*.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMENT "Copy generated general_python_service proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET pyserving_channel_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMAND cp pyserving_channel*.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMENT "Copy generated pyserving_channel proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET multi_lang_general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMAND cp multi_lang_general_model_service*.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
......@@ -141,18 +117,6 @@ add_custom_command(TARGET general_model_config_py_proto POST_BUILD
paddle_serving_server_gpu/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET general_python_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto
COMMAND cp general_python_service*.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto
COMMENT "Copy generated general_python_service proto file into directory paddle_serving_server_gpu/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET pyserving_channel_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto
COMMAND cp pyserving_channel*.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto
COMMENT "Copy generated pyserving_channel proto file into directory paddle_serving_server_gpu/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET multi_lang_general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto
COMMAND cp multi_lang_general_model_service*.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.pyserving;
service GeneralPythonService {
rpc inference(Request) returns (Response) {}
}
message Request {
repeated bytes feed_insts = 1;
repeated string feed_var_names = 2;
repeated bytes shape = 3;
repeated string type = 4;
}
message Response {
repeated bytes fetch_insts = 1;
repeated string fetch_var_names = 2;
required int32 ecode = 3;
optional string error_info = 4;
repeated bytes shape = 5;
repeated string type = 6;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.pyserving;
message ChannelData {
repeated Inst insts = 1;
required int32 id = 2;
required int32 ecode = 4;
optional string error_info = 5;
}
message Inst {
required bytes data = 1;
required string name = 2;
required bytes shape = 3;
required string type = 4;
}
if (CLIENT)
file(INSTALL pipeline DESTINATION paddle_serving_client)
execute_process(COMMAND ${PYTHON_EXECUTABLE} run_codegen.py
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/paddle_serving_client/pipeline/proto)
file(GLOB_RECURSE SERVING_CLIENT_PY_FILES paddle_serving_client/*.py)
set(PY_FILES ${SERVING_CLIENT_PY_FILES})
SET(PACKAGE_NAME "serving_client")
......@@ -8,9 +11,13 @@ endif()
if (SERVER)
if (NOT WITH_GPU)
file(INSTALL pipeline DESTINATION paddle_serving_server)
execute_process(COMMAND ${PYTHON_EXECUTABLE} run_codegen.py
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/paddle_serving_server/pipeline/proto)
file(GLOB_RECURSE SERVING_SERVER_PY_FILES paddle_serving_server/*.py)
else()
file(INSTALL pipeline DESTINATION paddle_serving_server_gpu)
execute_process(COMMAND ${PYTHON_EXECUTABLE} run_codegen.py
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/paddle_serving_server_gpu/pipeline/proto)
file(GLOB_RECURSE SERVING_SERVER_PY_FILES paddle_serving_server_gpu/*.py)
endif()
set(PY_FILES ${SERVING_SERVER_PY_FILES})
......
......@@ -18,7 +18,7 @@ import sys
client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9292"])
client.connect(["127.0.0.1:9393"])
# you can define any english sentence or dataset here
# This example reuses imdb reader in training, you
......@@ -29,6 +29,6 @@ imdb_dataset.load_resource(sys.argv[2])
for line in sys.stdin:
word_ids, label = imdb_dataset.get_words_and_label(line)
feed = {"words": word_ids}
fetch = ["acc", "cost", "prediction"]
fetch = ["prediction"]
fetch_map = client.predict(feed=feed, fetch=fetch)
print("{} {}".format(fetch_map["prediction"][0], label[0]))
use_multithread: true
client_type: brpc
retry: 2
profile: false
prot: 8080
worker_num: 2
......@@ -13,7 +13,6 @@
# limitations under the License.
from paddle_serving_client.pipeline import PipelineClient
import numpy as np
from paddle_serving_app.reader import IMDBDataset
from line_profiler import LineProfiler
client = PipelineClient()
......@@ -23,12 +22,9 @@ lp = LineProfiler()
lp_wrapper = lp(client.predict)
words = 'i am very sad | 0'
imdb_dataset = IMDBDataset()
imdb_dataset.load_resource('imdb.vocab')
for i in range(1):
word_ids, label = imdb_dataset.get_words_and_label(words)
fetch_map = lp_wrapper(feed={"words": word_ids}, fetch=["prediction"])
for i in range(10):
fetch_map = lp_wrapper(feed_dict={"words": words}, fetch=["prediction"])
print(fetch_map)
#lp.print_stats()
......@@ -13,18 +13,31 @@
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server.pipeline import Op
from paddle_serving_server.pipeline import Op, ReadOp
from paddle_serving_server.pipeline import PipelineServer
import numpy as np
import logging
from paddle_serving_app.reader import IMDBDataset
logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M',
#level=logging.DEBUG)
# level=logging.DEBUG)
level=logging.INFO)
class ImdbOp(Op):
def preprocess(self, input_data):
data = input_data.parse()
imdb_dataset = IMDBDataset()
imdb_dataset.load_resource('imdb.vocab')
word_ids, _ = imdb_dataset.get_words_and_label(data['words'])
return {"words": word_ids}
# def postprocess(self, fetch_data):
# return {key: str(value) for key, value in fetch_data.items()}
class CombineOp(Op):
def preprocess(self, input_data):
combined_prediction = 0
......@@ -32,42 +45,39 @@ class CombineOp(Op):
data = channeldata.parse()
logging.info("{}: {}".format(op_name, data["prediction"]))
combined_prediction += data["prediction"]
data = {"prediction": combined_prediction / 2}
data = {"prediction": str(combined_prediction / 2)}
return data
read_op = Op(name="read", inputs=None)
bow_op = Op(name="bow",
inputs=[read_op],
server_model="imdb_bow_model",
server_port="9393",
device="cpu",
client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["prediction"],
concurrency=1,
timeout=0.1,
retry=2)
cnn_op = Op(name="cnn",
inputs=[read_op],
server_model="imdb_cnn_model",
server_port="9292",
device="cpu",
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9292",
fetch_names=["prediction"],
concurrency=1,
timeout=-1,
retry=1)
read_op = ReadOp()
bow_op = ImdbOp(
name="bow",
input_ops=[read_op],
server_endpoints=["127.0.0.1:9393"],
fetch_list=["prediction"],
client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
concurrency=1,
timeout=-1,
retry=1)
cnn_op = ImdbOp(
name="cnn",
input_ops=[read_op],
server_endpoints=["127.0.0.1:9292"],
fetch_list=["prediction"],
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
concurrency=1,
timeout=-1,
retry=1)
combine_op = CombineOp(
name="combine", inputs=[bow_op, cnn_op], concurrency=1, timeout=-1, retry=1)
pyserver = PipelineServer(
use_multithread=True,
client_type='grpc',
use_future=False,
profile=False,
name="combine",
input_ops=[bow_op, cnn_op],
concurrency=1,
timeout=-1,
retry=1)
pyserver.add_ops([read_op, bow_op, cnn_op, combine_op])
pyserver.prepare_server(port=8080, worker_num=2)
pyserver.run_server()
server = PipelineServer()
server.add_ops([read_op, bow_op, cnn_op, combine_op])
# server.set_response_op(bow_op)
server.set_response_op(combine_op)
server.prepare_server('config.yml')
server.run_server()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import grpc
from .proto import general_python_service_pb2
from .proto import general_python_service_pb2_grpc
import numpy as np
class PipelineClient(object):
def __init__(self):
self._channel = None
def connect(self, endpoint):
self._channel = grpc.insecure_channel(endpoint)
self._stub = general_python_service_pb2_grpc.GeneralPythonServiceStub(
self._channel)
def _pack_data_for_infer(self, feed_data):
req = general_python_service_pb2.Request()
for name, data in feed_data.items():
if isinstance(data, list):
data = np.array(data)
elif not isinstance(data, np.ndarray):
raise TypeError("only list and numpy array type is supported.")
req.feed_var_names.append(name)
req.feed_insts.append(data.tobytes())
req.shape.append(np.array(data.shape, dtype="int32").tobytes())
req.type.append(str(data.dtype))
return req
def predict(self, feed, fetch):
if not isinstance(feed, dict):
raise TypeError(
"feed must be dict type with format: {name: value}.")
if not isinstance(fetch, list):
raise TypeError(
"fetch_with_type must be list type with format: [name].")
req = self._pack_data_for_infer(feed)
resp = self._stub.inference(req)
if resp.ecode != 0:
return {"ecode": resp.ecode, "error_info": resp.error_info}
fetch_map = {"ecode": resp.ecode}
for idx, name in enumerate(resp.fetch_var_names):
if name not in fetch:
continue
fetch_map[name] = np.frombuffer(
resp.fetch_insts[idx], dtype=resp.type[idx])
fetch_map[name].shape = np.frombuffer(
resp.shape[idx], dtype="int32")
return fetch_map
......@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from operator import Op
from operator import Op, ReadOp
from pipeline_server import PipelineServer
from pipeline_client import PipelineClient
......@@ -22,15 +22,10 @@ elif sys.version_info.major == 3:
import queue as Queue
else:
raise Exception("Error Python version")
<<<<<<< HEAD
=======
from concurrent import futures
import numpy as np
from ..proto import pyserving_channel_pb2 as channel_pb2
import logging
import enum
import copy
>>>>>>> d84910a1180061b57c51824e35e3ca5c857eb3b5
class ChannelDataEcode(enum.Enum):
......@@ -43,48 +38,25 @@ class ChannelDataEcode(enum.Enum):
class ChannelDataType(enum.Enum):
<<<<<<< HEAD
DICT = 0
CHANNEL_NPDATA = 1
ERROR = 2
class ChannelData(object):
pass
class ThreadChannel(Queue.Queue):
pass
class ProcessChannel(multiprocessing.queues.Queue):
pass
=======
CHANNEL_PBDATA = 0
CHANNEL_FUTURE = 1
CHANNEL_NPDATA = 2
ERROR = 3
class ChannelData(object):
def __init__(self,
datatype=None,
future=None,
pbdata=None,
npdata=None,
dictdata=None,
data_id=None,
callback_func=None,
ecode=None,
error_info=None):
'''
There are several ways to use it:
1. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, pbdata[, callback_func])
2. ChannelData(ChannelDataType.CHANNEL_FUTURE.value, future, data_id[, callback_func])
3. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, pbdata)
4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
6. ChannelData(ecode, error_info, data_id)
1. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
2. ChannelData(ChannelDataType.DICT.value, dictdata, data_id)
3. ChannelData(ecode, error_info, data_id)
Protobufs are not pickle-able:
https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
......@@ -94,67 +66,55 @@ class ChannelData(object):
raise ValueError("data_id and error_info cannot be None")
datatype = ChannelDataType.ERROR.value
else:
if datatype == ChannelDataType.CHANNEL_FUTURE.value:
if data_id is None:
raise ValueError("data_id cannot be None")
ecode = ChannelDataEcode.OK.value
elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
ecode, error_info = self._check_npdata(npdata)
if ecode != ChannelDataEcode.OK.value:
logging.error(error_info)
else:
for name, value in npdata.items():
inst = channel_pb2.Inst()
inst.data = value.tobytes()
inst.name = name
inst.shape = np.array(
value.shape, dtype="int32").tobytes()
inst.type = str(value.dtype)
pbdata.insts.append(inst)
elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
ecode, error_info = self._check_npdata(npdata)
if datatype == ChannelDataType.CHANNEL_NPDATA.value:
ecode, error_info = ChannelData.check_npdata(npdata)
if ecode != ChannelDataEcode.OK.value:
datatype = ChannelDataType.ERROR.value
logging.error(error_info)
elif datatype == ChannelDataType.DICT.value:
ecode, error_info = ChannelData.check_dictdata(dictdata)
if ecode != ChannelDataEcode.OK.value:
datatype = ChannelDataType.ERROR.value
logging.error(error_info)
else:
raise ValueError("datatype not match")
self.future = future
self.pbdata = pbdata
self.npdata = npdata
self.datatype = datatype
self.npdata = npdata
self.dictdata = dictdata
self.id = data_id
self.ecode = ecode
self.error_info = error_info
self.callback_func = callback_func
def _check_npdata(self, npdata):
@staticmethod
def check_dictdata(dictdata):
ecode = ChannelDataEcode.OK.value
error_info = None
if not isinstance(dictdata, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of postped_data must " \
"be dict, but get {}".format(type(dictdata))
return ecode, error_info
@staticmethod
def check_npdata(npdata):
ecode = ChannelDataEcode.OK.value
error_info = None
for _, value in npdata.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the value of postped_data must " \
"be np.ndarray, but get {}".format(type(value)))
error_info = "the value of postped_data must " \
"be np.ndarray, but get {}".format(type(value))
break
return ecode, error_info
def parse(self):
# return narray
feed = None
if self.datatype == ChannelDataType.CHANNEL_PBDATA.value:
feed = {}
for inst in self.pbdata.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype=inst.type)
feed[inst.name].shape = np.frombuffer(inst.shape, dtype="int32")
elif self.datatype == ChannelDataType.CHANNEL_FUTURE.value:
feed = self.future.result()
if self.callback_func is not None:
feed = self.callback_func(feed)
elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
if self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
# return narray
feed = self.npdata
elif self.datatype == ChannelDataType.DICT.value:
# return dict
feed = self.dictdata
else:
raise TypeError("Error type({}) in datatype.".format(self.datatype))
return feed
......@@ -609,4 +569,3 @@ class ThreadChannel(Queue.Queue):
self.close()
self._stop = True
self._cv.notify_all()
>>>>>>> d84910a1180061b57c51824e35e3ca5c857eb3b5
......@@ -23,35 +23,33 @@ import func_timeout
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
from .util import NameGenerator
_name_gen = NameGenerator("Op")
_op_name_gen = NameGenerator("Op")
class Op(object):
def __init__(self,
name=None,
input_ops=[],
server_endpoints=[],
fetch_list=[],
client_config=None,
concurrency=1,
timeout=-1,
retry=1):
if name is None:
name = _name_gen.next()
name = _op_name_gen.next()
self._is_run = False
self.name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency
self.concurrency = concurrency # amount of concurrency
self.set_input_ops(input_ops)
self._server_endpoints = server_endpoints
self.with_serving = False
'''
self._client_config = client_config
self._server_name = server_name
self._fetch_names = fetch_names
self._server_model = server_model
self._server_port = server_port
self._device = device
if self._client_config is not None and \
self._server_name is not None and \
self._fetch_names is not None:
if len(self._server_endpoints) != 0:
self.with_serving = True
'''
self._client_config = client_config
self._fetch_names = fetch_list
self._timeout = timeout
self._retry = max(1, retry)
self._input = None
......@@ -66,12 +64,12 @@ class Op(object):
return
self._profiler.record(string)
def init_client(self, client_type, client_config, server_name, fetch_names):
def init_client(self, client_type, client_config, server_endpoints,
fetch_names):
if self.with_serving == False:
logging.debug("{} no client".format(self.name))
return
logging.debug("{} client_config: {}".format(self.name, client_config))
logging.debug("{} server_name: {}".format(self.name, server_name))
logging.debug("{} fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc':
self._client = Client()
......@@ -80,10 +78,10 @@ class Op(object):
else:
raise ValueError("unknow client type: {}".format(client_type))
self._client.load_client_config(client_config)
self._client.connect([server_name])
self._client.connect(server_endpoints)
self._fetch_names = fetch_names
def get_input_channel(self):
def _get_input_channel(self):
return self._input
def get_input_ops(self):
......@@ -108,7 +106,7 @@ class Op(object):
channel.add_consumer(self.name)
self._input = channel
def get_output_channels(self):
def _get_output_channels(self):
return self._outputs
def add_output_channel(self, channel):
......@@ -120,32 +118,41 @@ class Op(object):
self._outputs.append(channel)
def preprocess(self, channeldata):
# multiple previous Op
if isinstance(channeldata, dict):
raise NotImplementedError(
'this Op has multiple previous inputs. Please override this method'
)
feed = channeldata.parse()
return feed
def midprocess(self, data, use_future=True):
if not isinstance(data, dict):
if channeldata.datatype is not ChannelDataType.CHANNEL_NPDATA.value:
raise NotImplementedError(
'datatype in channeldata is not CHANNEL_NPDATA({}). '
'Please override this method'.format(channeldata.datatype))
# get numpy dict
feed_data = channeldata.parse()
return feed_data
def process(self, feed_dict):
if not isinstance(feed_dict, dict):
raise Exception(
self._log(
'data must be dict type(the output of preprocess()), but get {}'.
format(type(data))))
logging.debug(self._log('data: {}'.format(data)))
'feed_dict must be dict type(the output of preprocess()), but get {}'.
format(type(feed_dict))))
logging.debug(self._log('feed_dict: {}'.format(feed_dict)))
logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
if isinstance(self._client, MultiLangClient):
call_result = self._client.predict(
feed=data, fetch=self._fetch_names, asyn=use_future)
feed=feed_dict, fetch=self._fetch_names)
logging.debug(self._log("get call_result"))
else:
call_result = self._client.predict(
feed=data, fetch=self._fetch_names)
logging.debug(self._log("get call_result"))
feed=feed_dict, fetch=self._fetch_names)
logging.debug(self._log("get fetch_dict"))
return call_result
def postprocess(self, output_data):
return output_data
def postprocess(self, fetch_dict):
return fetch_dict
def stop(self):
self._input.stop()
......@@ -175,37 +182,45 @@ class Op(object):
for channel in channels:
channel.push(data, name)
def start_with_process(self, client_type, use_future):
def start_with_process(self, client_type):
proces = []
for concurrency_idx in range(self._concurrency):
for concurrency_idx in range(self.concurrency):
p = multiprocessing.Process(
target=self._run,
args=(concurrency_idx, self.get_input_channel(),
self.get_output_channels(), client_type, use_future))
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type))
p.start()
proces.append(p)
return proces
def start_with_thread(self, client_type, use_future):
def start_with_thread(self, client_type):
threads = []
for concurrency_idx in range(self._concurrency):
for concurrency_idx in range(self.concurrency):
t = threading.Thread(
target=self._run,
args=(concurrency_idx, self.get_input_channel(),
self.get_output_channels(), client_type, use_future))
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type))
t.start()
threads.append(t)
return threads
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
use_future):
# create client based on client_type
self.init_client(client_type, self._client_config, self._server_name,
self._fetch_names)
def _run(self, concurrency_idx, input_channel, output_channels,
client_type):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
return log_func
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._is_run = True
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
# create client based on client_type
self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
self._is_run = True
while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
channeldata = input_channel.front(self.name)
......@@ -260,7 +275,7 @@ class Op(object):
output_channels)
continue
# midprocess
# process
midped_data = None
if self.with_serving:
ecode = ChannelDataEcode.OK.value
......@@ -268,7 +283,7 @@ class Op(object):
tid))
if self._timeout <= 0:
try:
midped_data = self.midprocess(preped_data, use_future)
midped_data = self.process(preped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
......@@ -278,8 +293,8 @@ class Op(object):
try:
midped_data = func_timeout.func_timeout(
self._timeout,
self.midprocess,
args=(preped_data, use_future))
self.process,
args=(preped_data, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
......@@ -310,42 +325,39 @@ class Op(object):
# postprocess
output_data = None
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
if self.with_serving and client_type == 'grpc' and use_future:
# use call_future
output_data = ChannelData(
datatype=ChannelDataType.CHANNEL_FUTURE.value,
future=midped_data,
data_id=data_id,
callback_func=self.postprocess)
else:
try:
postped_data = self.postprocess(midped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id),
output_channels)
continue
if not isinstance(postped_data, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("output of postprocess funticon must be " \
"dict type, but get {}".format(type(postped_data)))
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id),
output_channels)
continue
try:
postped_data = self.postprocess(midped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id),
output_channels)
continue
if not isinstance(postped_data, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("output of postprocess funticon must be " \
"dict type, but get {}".format(type(postped_data)))
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id),
output_channels)
continue
err, _ = ChannelData.check_npdata(postped_data)
if err == 0:
output_data = ChannelData(
ChannelDataType.CHANNEL_NPDATA.value,
npdata=postped_data,
data_id=data_id)
else:
output_data = ChannelData(
ChannelDataType.DICT.value,
dictdata=postped_data,
data_id=data_id)
self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
# push data to channel (if run succ)
......@@ -356,14 +368,12 @@ class Op(object):
def _log(self, info):
return "{} {}".format(self.name, info)
def _get_log_func(self, op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
return log_func
def get_concurrency(self):
return self._concurrency
class ReadOp(Op):
def __init__(self, concurrency=1):
# PipelineService.name = "#G"
super(ReadOp, self).__init__(
name="#G", input_ops=[], concurrency=concurrency)
class VirtualOp(Op):
......@@ -371,7 +381,7 @@ class VirtualOp(Op):
def __init__(self, name, concurrency=1):
super(VirtualOp, self).__init__(
name=name, inputs=None, concurrency=concurrency)
name=name, input_ops=None, concurrency=concurrency)
self._virtual_pred_ops = []
def add_virtual_pred_op(self, op):
......@@ -386,10 +396,18 @@ class VirtualOp(Op):
channel.add_producer(op.name)
self._outputs.append(channel)
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
use_future):
def _run(self, concurrency_idx, input_channel, output_channels,
client_type):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
return log_func
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
self._is_run = True
while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
......@@ -407,4 +425,3 @@ class VirtualOp(Op):
channels=output_channels,
name=self._virtual_pred_ops[0].name)
self._profiler_record("{}-push#{}_1".format(op_info_prefix, tid))
>>>>>>> d84910a1180061b57c51824e35e3ca5c857eb3b5
......@@ -12,28 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
<<<<<<< HEAD
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, in_channel, out_channel, retry=2):
super(PipelineService, self).__init__()
pass
class PipelineServer(object):
def __init__(self):
pass
def set_response_op(self, response_op):
pass
def prepare_server(self, yml_file):
pass
def run_server(self):
pass
=======
import threading
import multiprocessing
import multiprocessing.queues
......@@ -49,9 +27,6 @@ from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import numpy as np
import grpc
from ..proto import general_python_service_pb2 as pyservice_pb2
from ..proto import pyserving_channel_pb2 as channel_pb2
from ..proto import general_python_service_pb2_grpc
import logging
import random
import time
......@@ -59,18 +34,23 @@ import func_timeout
import enum
import collections
import copy
import socket
from contextlib import closing
import yaml
from .operator import Op, VirtualOp
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
from .operator import Op, ReadOp, VirtualOp
from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcode, ChannelDataType
from .profiler import TimeProfiler
from .util import NameGenerator
_profiler = TimeProfiler()
class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonServiceServicer):
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, in_channel, out_channel, retry=2):
super(GeneralPythonService, self).__init__()
super(PipelineService, self).__init__()
self.name = "#G"
self.set_in_channel(in_channel)
self.set_out_channel(out_channel)
......@@ -85,7 +65,7 @@ class GeneralPythonService(
self._id_counter = 0
self._retry = retry
self._recive_func = threading.Thread(
target=GeneralPythonService._recive_out_channel_func, args=(self, ))
target=PipelineService._recive_out_channel_func, args=(self, ))
self._recive_func.start()
def _log(self, info_str):
......@@ -136,17 +116,12 @@ class GeneralPythonService(
def _pack_data_for_infer(self, request):
logging.debug(self._log('start inferce'))
data_id = self._get_next_id()
npdata = {}
dictdata = {}
try:
for idx, name in enumerate(request.feed_var_names):
logging.debug(
self._log('name: {}'.format(request.feed_var_names[idx])))
logging.debug(
self._log('data: {}'.format(request.feed_insts[idx])))
npdata[name] = np.frombuffer(
request.feed_insts[idx], dtype=request.type[idx])
npdata[name].shape = np.frombuffer(
request.shape[idx], dtype="int32")
for idx, key in enumerate(request.key):
logging.debug(self._log('key: {}'.format(key)))
logging.debug(self._log('value: {}'.format(request.value[idx])))
dictdata[key] = request.value[idx]
except Exception as e:
return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
......@@ -154,31 +129,25 @@ class GeneralPythonService(
data_id=data_id), data_id
else:
return ChannelData(
datatype=ChannelDataType.CHANNEL_NPDATA.value,
npdata=npdata,
datatype=ChannelDataType.DICT.value,
dictdata=dictdata,
data_id=data_id), data_id
def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get channeldata'))
resp = pyservice_pb2.Response()
resp = pipeline_service_pb2.Response()
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
for inst in channeldata.pbdata.insts:
resp.fetch_insts.append(inst.data)
resp.fetch_var_names.append(inst.name)
resp.shape.append(inst.shape)
resp.type.append(inst.type)
elif channeldata.datatype in (ChannelDataType.CHANNEL_FUTURE.value,
ChannelDataType.CHANNEL_NPDATA.value):
if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = channeldata.parse()
for name, var in feed.items():
resp.value.append(var.tostring()) #TODO: no shape and type
resp.key.append(name)
elif channeldata.datatype == ChannelDataType.DICT.value:
feed = channeldata.parse()
for name, var in feed.items():
resp.fetch_insts.append(var.tobytes())
resp.fetch_var_names.append(name)
resp.shape.append(
np.array(
var.shape, dtype="int32").tobytes())
resp.type.append(str(var.dtype))
resp.value.append(str(var))
resp.key.append(name)
else:
raise TypeError(
self._log("Error type({}) in datatype.".format(
......@@ -218,12 +187,7 @@ class GeneralPythonService(
class PipelineServer(object):
def __init__(self,
use_multithread=True,
client_type='brpc',
use_future=False,
retry=2,
profile=False):
def __init__(self):
self._channels = []
self._user_ops = []
self._actual_ops = []
......@@ -231,20 +195,7 @@ class PipelineServer(object):
self._worker_num = None
self._in_channel = None
self._out_channel = None
self._retry = retry
self._use_multithread = use_multithread
self._client_type = client_type
self._use_future = use_future
if not self._use_multithread:
self._manager = multiprocessing.Manager()
if profile:
raise Exception(
"profile cannot be used in multiprocess version temporarily")
if self._use_future:
raise Exception("cannot use future in multiprocess")
if self._client_type == 'brpc' and self._use_future:
logging.warn("brpc impl cannot use future")
_profiler.enable(profile)
self._response_op = None
def add_channel(self, channel):
self._channels.append(channel)
......@@ -259,33 +210,55 @@ class PipelineServer(object):
logging.info('here will generate desc for PAAS')
pass
def _topo_sort(self):
indeg_num = {}
def set_response_op(self, response_op):
if not isinstance(response_op, Op):
raise Exception("response_op must be Op type.")
self._response_op = response_op
def _topo_sort(self, response_op):
if response_op is None:
raise Exception("response_op has not been set.")
def get_use_ops(root):
unique_names = set()
use_ops = set()
succ_ops_of_use_op = {} # {op_name: succ_ops}
que = Queue.Queue()
que.put(root)
use_ops.add(root)
unique_names.add(root.name)
while que.qsize() != 0:
op = que.get()
for pred_op in op.get_input_ops():
if pred_op.name not in succ_ops_of_use_op:
succ_ops_of_use_op[pred_op.name] = []
succ_ops_of_use_op[pred_op.name].append(op)
if pred_op not in use_ops:
que.put(pred_op)
use_ops.add(pred_op)
# check the name of op is globally unique
if pred_op.name in unique_names:
raise Exception("the name of Op must be unique: {}".
format(pred_op.name))
unique_names.add(pred_op.name)
return use_ops, succ_ops_of_use_op
use_ops, out_degree_ops = get_use_ops(response_op)
name2op = {op.name: op for op in use_ops}
out_degree_num = {
name: len(ops)
for name, ops in out_degree_ops.items()
}
que_idx = 0 # scroll queue
ques = [Queue.Queue() for _ in range(2)]
for op in self._user_ops:
zero_indegree_num = 0
for op in use_ops:
if len(op.get_input_ops()) == 0:
op.name = "#G" # update read_op.name
break
outdegs = {op.name: [] for op in self._user_ops}
zero_indeg_num, zero_outdeg_num = 0, 0
for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique
if op.name in indeg_num:
raise Exception("the name of Op must be unique")
indeg_num[op.name] = len(op.get_input_ops())
if indeg_num[op.name] == 0:
ques[que_idx].put(op)
zero_indeg_num += 1
for pred_op in op.get_input_ops():
outdegs[pred_op.name].append(op)
if zero_indeg_num != 1:
zero_indegree_num += 1
if zero_indegree_num != 1:
raise Exception("DAG contains multiple input Ops")
for _, succ_list in outdegs.items():
if len(succ_list) == 0:
zero_outdeg_num += 1
if zero_outdeg_num != 1:
raise Exception("DAG contains multiple output Ops")
ques[que_idx].put(response_op)
# topo sort to get dag_views
dag_views = []
......@@ -298,63 +271,36 @@ class PipelineServer(object):
op = que.get()
dag_view.append(op)
sorted_op_num += 1
for succ_op in outdegs[op.name]:
indeg_num[succ_op.name] -= 1
if indeg_num[succ_op.name] == 0:
next_que.put(succ_op)
for pred_op in op.get_input_ops():
out_degree_num[pred_op.name] -= 1
if out_degree_num[pred_op.name] == 0:
next_que.put(pred_op)
dag_views.append(dag_view)
if next_que.qsize() == 0:
break
que_idx = (que_idx + 1) % 2
if sorted_op_num < len(self._user_ops):
if sorted_op_num < len(use_ops):
raise Exception("not legal DAG")
# create channels and virtual ops
def name_generator(prefix):
def number_generator():
idx = 0
while True:
yield "{}{}".format(prefix, idx)
idx += 1
return number_generator()
def gen_channel(name_gen):
channel = None
if self._use_multithread:
if sys.version_info.major == 2:
channel = ThreadChannel(name=name_gen.next())
elif sys.version_info.major == 3:
channel = ThreadChannel(name=name_gen.__next__())
else:
raise Exception("Error Python version")
channel = ThreadChannel(name=name_gen.next())
else:
if sys.version_info.major == 2:
channel = ProcessChannel(
self._manager, name=name_gen.next())
elif sys.version_info.major == 3:
channel = ProcessChannel(
self._manager, name=name_gen.__next__())
else:
raise Exception("Error Python version")
channel = ProcessChannel(self._manager, name=name_gen.next())
return channel
def gen_virtual_op(name_gen):
virtual_op = None
if sys.version_info.major == 2:
virtual_op = VirtualOp(name=name_gen.next())
elif sys.version_info.major == 3:
virtual_op = VirtualOp(name=op_name_gen.__next__())
else:
raise Exception("Error Python version")
return virtual_op
return VirtualOp(name=name_gen.next())
virtual_op_name_gen = name_generator("vir")
channel_name_gen = name_generator("chl")
virtual_op_name_gen = NameGenerator("vir")
channel_name_gen = NameGenerator("chl")
virtual_ops = []
channels = []
input_channel = None
actual_view = None
dag_views = list(reversed(dag_views))
for v_idx, view in enumerate(dag_views):
if v_idx + 1 >= len(dag_views):
break
......@@ -365,7 +311,7 @@ class PipelineServer(object):
pred_op_of_next_view_op = {}
for op in actual_view:
# find actual succ op in next view and create virtual op
for succ_op in outdegs[op.name]:
for succ_op in out_degree_ops[op.name]:
if succ_op in next_view:
if succ_op not in actual_next_view:
actual_next_view.append(succ_op)
......@@ -376,7 +322,7 @@ class PipelineServer(object):
# create virtual op
virtual_op = gen_virtual_op(virtual_op_name_gen)
virtual_ops.append(virtual_op)
outdegs[virtual_op.name] = [succ_op]
out_degree_ops[virtual_op.name] = [succ_op]
actual_next_view.append(virtual_op)
pred_op_of_next_view_op[virtual_op.name] = [op]
virtual_op.add_virtual_pred_op(op)
......@@ -419,11 +365,10 @@ class PipelineServer(object):
processed_op.add(other_op.name)
output_channel = gen_channel(channel_name_gen)
channels.append(output_channel)
last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel)
response_op.add_output_channel(output_channel)
self._actual_ops = virtual_ops
for op in self._user_ops:
for op in use_ops:
if len(op.get_input_ops()) == 0:
# pass read op
continue
......@@ -433,11 +378,33 @@ class PipelineServer(object):
logging.debug(c.debug())
return input_channel, output_channel
def prepare_server(self, port, worker_num):
self._port = port
self._worker_num = worker_num
def _port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
return result != 0
def prepare_server(self, yml_file):
with open(yml_file) as f:
yml_config = yaml.load(f.read())
self._port = yml_config.get('port', 8080)
if not self._port_is_available(self._port):
raise SystemExit("Prot {} is already used".format(self._port))
self._worker_num = yml_config.get('worker_num', 2)
self._retry = yml_config.get('retry', 1)
self._client_type = yml_config.get('client_type', 'brpc')
self._use_multithread = yml_config.get('use_multithread', True)
profile = yml_config.get('profile', False)
input_channel, output_channel = self._topo_sort()
if not self._use_multithread:
self._manager = multiprocessing.Manager()
if profile:
raise Exception(
"profile cannot be used in multiprocess version temporarily")
_profiler.enable(profile)
input_channel, output_channel = self._topo_sort(self._response_op)
self._in_channel = input_channel
self._out_channel = output_channel
for op in self._actual_ops:
......@@ -451,10 +418,10 @@ class PipelineServer(object):
op.init_profiler(_profiler)
if self._use_multithread:
threads_or_proces.extend(
op.start_with_thread(self._client_type, self._use_future))
op.start_with_thread(self._client_type))
else:
threads_or_proces.extend(
op.start_with_process(self._client_type, self._use_future))
op.start_with_process(self._client_type))
return threads_or_proces
def _stop_ops(self):
......@@ -465,9 +432,9 @@ class PipelineServer(object):
op_threads_or_proces = self._run_ops()
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
GeneralPythonService(self._in_channel, self._out_channel,
self._retry), server)
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(self._in_channel, self._out_channel, self._retry),
server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.start()
server.wait_for_termination()
......@@ -476,26 +443,5 @@ class PipelineServer(object):
x.join()
def prepare_serving(self, op):
model_path = op._server_model
port = op._server_port
device = op._device
if self._client_type == "grpc":
if device == "cpu":
cmd = "(Use grpc impl) python -m paddle_serving_server.serve" \
" --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
else:
cmd = "(Use grpc impl) python -m paddle_serving_server_gpu.serve" \
" --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
elif self._client_type == "brpc":
if device == "cpu":
cmd = "(Use brpc impl) python -m paddle_serving_server.serve" \
" --model {} --thread 4 --port {} &>/dev/null &".format(model_path, port)
else:
cmd = "(Use brpc impl) python -m paddle_serving_server_gpu.serve" \
" --model {} --thread 4 --port {} &>/dev/null &".format(model_path, port)
else:
raise Exception("unknow client type: {}".format(self._client_type))
# run a server (not in PyServing)
logging.info("run a server (not in PyServing): {}".format(cmd))
>>>>>>> d84910a1180061b57c51824e35e3ca5c857eb3b5
logging.info("run a server (not in PyServing)")
......@@ -13,8 +13,6 @@
# limitations under the License.
# pylint: disable=doc-string-missing
<<<<<<< HEAD
=======
import os
import sys
if sys.version_info.major == 2:
......@@ -25,7 +23,6 @@ else:
raise Exception("Error Python version")
import time
>>>>>>> d84910a1180061b57c51824e35e3ca5c857eb3b5
class TimeProfiler(object):
def __init__(self):
......@@ -48,30 +45,18 @@ class TimeProfiler(object):
def print_profile(self):
if self._enable is False:
return
<<<<<<< HEAD
print_str = self._print_head
=======
sys.stderr.write(self._print_head)
>>>>>>> d84910a1180061b57c51824e35e3ca5c857eb3b5
tmp = {}
while not self._time_record.empty():
name, tag, timestamp = self._time_record.get()
if name in tmp:
ptag, ptimestamp = tmp.pop(name)
<<<<<<< HEAD
print_str += "{}_{}:{} ".format(name, ptag, ptimestamp)
print_str += "{}_{}:{} ".format(name, tag, timestamp)
else:
tmp[name] = (tag, timestamp)
print_str += "\n"
sys.stderr.write(print_str)
=======
sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
else:
tmp[name] = (tag, timestamp)
sys.stderr.write('\n')
>>>>>>> d84910a1180061b57c51824e35e3ca5c857eb3b5
for name, item in tmp.items():
tag, timestamp = item
self._time_record.put((name, tag, timestamp))
......@@ -13,6 +13,7 @@
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.pipeline_serving;
message Request {
repeated string key = 1;
......@@ -22,6 +23,8 @@ message Request {
message Response {
repeated string key = 1;
repeated string value = 2;
required int32 ecode = 3;
optional string error_info = 4;
};
service PipelineService {
......
......@@ -65,11 +65,14 @@ REQUIRED_PACKAGES = [
if not find_package("paddlepaddle") and not find_package("paddlepaddle-gpu"):
REQUIRED_PACKAGES.append("paddlepaddle")
packages=['paddle_serving_client',
'paddle_serving_client.proto',
'paddle_serving_client.io',
'paddle_serving_client.metric',
'paddle_serving_client.utils',]
'paddle_serving_client.metric',
'paddle_serving_client.utils',
'paddle_serving_client.pipeline',
'paddle_serving_client.pipeline.proto']
package_data={'paddle_serving_client': ['serving_client.so','lib/*'],}
package_dir={'paddle_serving_client':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client',
......@@ -77,10 +80,14 @@ package_dir={'paddle_serving_client':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto',
'paddle_serving_client.io':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/io',
'paddle_serving_client.metric':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/metric',
'paddle_serving_client.utils':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/utils',}
'paddle_serving_client.metric':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/metric',
'paddle_serving_client.utils':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/utils',
'paddle_serving_client.pipeline':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/pipeline',
'paddle_serving_client.pipeline.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/pipeline/proto'}
setup(
name='paddle-serving-client',
......
......@@ -43,12 +43,17 @@ REQUIRED_PACKAGES = [
packages=['paddle_serving_server',
'paddle_serving_server.proto',
'paddle_serving_server.pipeline']
'paddle_serving_server.pipeline',
'paddle_serving_server.pipeline.proto']
package_dir={'paddle_serving_server':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server',
'paddle_serving_server.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto'}
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto',
'paddle_serving_server.pipeline':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/pipeline',
'paddle_serving_server.pipeline.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/pipeline/proto'}
setup(
name='paddle-serving-server',
......
......@@ -44,12 +44,17 @@ REQUIRED_PACKAGES = [
packages=['paddle_serving_server_gpu',
'paddle_serving_server_gpu.proto',
'paddle_serving_server.pipeline']
'paddle_serving_server_gpu.pipeline',
'paddle_serving_server_gpu.pipeline.proto']
package_dir={'paddle_serving_server_gpu':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu',
'paddle_serving_server_gpu.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto'}
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto',
'paddle_serving_server_gpu.pipeline':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/pipeline',
'paddle_serving_server_gpu.pipeline.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/pipeline/proto'}
setup(
name='paddle-serving-server-gpu',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册