提交 cfa1ccc3 编写于 作者: B barrierye

clean up code

......@@ -23,7 +23,8 @@ else:
raise Exception("Error Python version")
import os
import paddle_serving_server
from paddle_serving_client import MultiLangClient as Client
#from paddle_serving_client import MultiLangClient as Client
from paddle_serving_client import Client
from concurrent import futures
import numpy as np
import grpc
......@@ -118,30 +119,20 @@ class ChannelData(object):
if ecode is not None:
if data_id is None or error_info is None:
raise ValueError("data_id and error_info cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.ecode = ecode
pbdata.id = data_id
pbdata.error_info = error_info
datatype = ChannelDataType.ERROR.value
else:
if datatype == ChannelDataType.CHANNEL_FUTURE.value:
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.ecode = ChannelDataEcode.OK.value
pbdata.id = data_id
ecode = ChannelDataEcode.OK.value
elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.id = data_id
ecode, error_info = self._check_npdata(npdata)
pbdata.ecode = ecode
if pbdata.ecode != ChannelDataEcode.OK.value:
pbdata.error_info = error_info
logging.error(pbdata.error_info)
if ecode != ChannelDataEcode.OK.value:
logging.error(error_info)
else:
for name, value in npdata.items():
inst = channel_pb2.Inst()
......@@ -153,33 +144,23 @@ class ChannelData(object):
pbdata.insts.append(inst)
elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
ecode, error_info = self._check_npdata(npdata)
pbdata = channel_pb2.ChannelData()
pbdata.id = data_id
pbdata.ecode = ecode
if pbdata.ecode != ChannelDataEcode.OK.value:
pbdata.error_info = error_info
logging.error(pbdata.error_info)
if ecode != ChannelDataEcode.OK.value:
logging.error(error_info)
else:
raise ValueError("datatype not match")
if not isinstance(pbdata, channel_pb2.ChannelData):
raise TypeError(
"pbdata must be pyserving_channel_pb2.ChannelData type({})".
format(type(pbdata)))
self.future = future
self.pbdata = pbdata
self.npdata = npdata
self.datatype = datatype
self.id = data_id
self.ecode = ecode
self.error_info = error_info
self.callback_func = callback_func
def _check_npdata(self, npdata):
ecode = ChannelDataEcode.OK.value
error_info = None
for name, value in npdata.items():
if not isinstance(name, (str, unicode)):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the key of postped_data must " \
"be str, but get {}".format(type(name)))
break
for _, value in npdata.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the value of postped_data must " \
......@@ -202,12 +183,12 @@ class ChannelData(object):
elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = self.npdata
else:
raise TypeError("Error type({}) in datatype.".format(datatype))
raise TypeError("Error type({}) in datatype.".format(self.datatype))
return feed
def __str__(self):
return "type[{}], ecode[{}]".format(
ChannelDataType(self.datatype).name, self.pbdata.ecode)
return "type[{}], ecode[{}], id[{}]".format(
ChannelDataType(self.datatype).name, self.ecode, self.id)
class Channel(Queue.Queue):
......@@ -303,7 +284,7 @@ class Channel(Queue.Queue):
"There are multiple producers, so op_name cannot be None."))
producer_num = len(self._producers)
data_id = channeldata.pbdata.id
data_id = channeldata.id
put_data = None
with self._cv:
logging.debug(self._log("{} get lock".format(op_name)))
......@@ -418,7 +399,7 @@ class Op(object):
concurrency=1,
timeout=-1,
retry=2):
self._run = False
self._is_run = False
self.name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency
self.set_input_ops(inputs)
......@@ -432,19 +413,17 @@ class Op(object):
self._outputs = []
def set_client(self, client_config, server_name, fetch_names):
self._client = None
self.with_serving = True
if client_config is None or \
server_name is None or \
fetch_names is None:
self.with_serving = False
return
self._client = Client()
self._client.load_client_config(client_config)
self._client.connect([server_name])
self._fetch_names = fetch_names
def with_serving(self):
return self._client is not None
def get_input_channel(self):
return self._input
......@@ -489,7 +468,7 @@ class Op(object):
feed = channeldata.parse()
return feed
def midprocess(self, data):
def midprocess(self, data, asyn=True):
if not isinstance(data, dict):
raise Exception(
self._log(
......@@ -497,10 +476,16 @@ class Op(object):
format(type(data))))
logging.debug(self._log('data: {}'.format(data)))
logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
call_future = self._client.predict(
feed=data, fetch=self._fetch_names, asyn=True)
logging.debug(self._log("get call_future"))
return call_future
if Client.__name__ == "MultiLangClient":
call_result = self._client.predict(
feed=data, fetch=self._fetch_names, asyn=asyn)
elif Client.__name__ == "Client":
call_result = self._client.predict(
feed=data, fetch=self._fetch_names)
else:
raise Exception("unknow client type: {}".format(Client.__name__))
logging.debug(self._log("get call_result"))
return call_result
def postprocess(self, output_data):
return output_data
......@@ -509,23 +494,23 @@ class Op(object):
self._input.stop()
for channel in self._outputs:
channel.stop()
self._run = False
self._is_run = False
def _parse_channeldata(self, channeldata):
data_id, error_pbdata = None, None
data_id, error_channeldata = None, None
if isinstance(channeldata, dict):
parsed_data = {}
key = channeldata.keys()[0]
data_id = channeldata[key].pbdata.id
key = list(channeldata.keys())[0]
data_id = channeldata[key].id
for _, data in channeldata.items():
if data.pbdata.ecode != ChannelDataEcode.OK.value:
error_pbdata = data.pbdata
if data.ecode != ChannelDataEcode.OK.value:
error_channeldata = data
break
else:
data_id = channeldata.pbdata.id
if channeldata.pbdata.ecode != ChannelDataEcode.OK.value:
error_pbdata = channeldata.pbdata
return data_id, error_pbdata
data_id = channeldata.id
if channeldata.ecode != ChannelDataEcode.OK.value:
error_channeldata = channeldata
return data_id, error_channeldata
def _push_to_output_channels(self, data, name=None):
if name is None:
......@@ -536,21 +521,18 @@ class Op(object):
def start(self, concurrency_idx):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._run = True
while self._run:
self._is_run = True
while self._is_run:
_profiler.record("{}-get_0".format(op_info_prefix))
channeldata = self._input.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix))
logging.debug(log("input_data: {}".format(channeldata)))
data_id, error_pbdata = self._parse_channeldata(channeldata)
data_id, error_channeldata = self._parse_channeldata(channeldata)
# error data in predecessor Op
if error_pbdata is not None:
self._push_to_output_channels(
ChannelData(
datatype=ChannelDataType.CHANNEL_PBDATA.value,
pbdata=error_pbdata))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata)
continue
# preprecess
......@@ -589,13 +571,14 @@ class Op(object):
continue
# midprocess
call_future = None
if self.with_serving():
midped_data = None
asyn = False
if self.with_serving:
ecode = ChannelDataEcode.OK.value
_profiler.record("{}-midp_0".format(op_info_prefix))
if self._timeout <= 0:
try:
call_future = self.midprocess(preped_data)
midped_data = self.midprocess(preped_data, asyn)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
......@@ -603,10 +586,10 @@ class Op(object):
else:
for i in range(self._retry):
try:
call_future = func_timeout.func_timeout(
midped_data = func_timeout.func_timeout(
self._timeout,
self.midprocess,
args=(preped_data, ))
args=(preped_data, asyn))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
......@@ -629,20 +612,22 @@ class Op(object):
data_id=data_id))
continue
_profiler.record("{}-midp_1".format(op_info_prefix))
else:
midped_data = preped_data
# postprocess
output_data = None
_profiler.record("{}-postp_0".format(op_info_prefix))
if self.with_serving():
if self.with_serving and asyn:
# use call_future
output_data = ChannelData(
datatype=ChannelDataType.CHANNEL_FUTURE.value,
future=call_future,
future=midped_data,
data_id=data_id,
callback_func=self.postprocess)
else:
try:
postped_data = self.postprocess(preped_data)
postped_data = self.postprocess(midped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
......@@ -710,8 +695,8 @@ class VirtualOp(Op):
def start(self, concurrency_idx):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._run = True
while self._run:
self._is_run = True
while self._is_run:
_profiler.record("{}-get_0".format(op_info_prefix))
channeldata = self._input.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix))
......@@ -727,7 +712,7 @@ class VirtualOp(Op):
class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
general_python_service_pb2_grpc.GeneralPythonServiceServicer):
def __init__(self, in_channel, out_channel, retry=2):
super(GeneralPythonService, self).__init__()
self.name = "#G"
......@@ -774,7 +759,7 @@ class GeneralPythonService(
self._log('data must be ChannelData type, but get {}'.
format(type(channeldata))))
with self._cv:
data_id = channeldata.pbdata.id
data_id = channeldata.id
self._globel_resp_dict[data_id] = channeldata
self._cv.notify_all()
......@@ -794,33 +779,32 @@ class GeneralPythonService(
def _pack_data_for_infer(self, request):
logging.debug(self._log('start inferce'))
pbdata = channel_pb2.ChannelData()
data_id = self._get_next_id()
pbdata.id = data_id
pbdata.ecode = ChannelDataEcode.OK.value
npdata = {}
try:
for idx, name in enumerate(request.feed_var_names):
logging.debug(
self._log('name: {}'.format(request.feed_var_names[idx])))
logging.debug(
self._log('data: {}'.format(request.feed_insts[idx])))
inst = channel_pb2.Inst()
inst.data = request.feed_insts[idx]
inst.shape = request.shape[idx]
inst.name = name
inst.type = request.type[idx]
pbdata.insts.append(inst)
npdata[name] = np.frombuffer(
request.feed_insts[idx], dtype=request.type[idx])
npdata[name].shape = np.frombuffer(
request.shape[idx], dtype="int32")
except Exception as e:
pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
pbdata.error_info = "rpc package error"
return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
error_info="rpc package error",
data_id=data_id), data_id
return ChannelData(
datatype=ChannelDataType.CHANNEL_PBDATA.value,
pbdata=pbdata), data_id
npdata=npdata,
data_id=data_id), data_id
def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get channeldata'))
resp = pyservice_pb2.Response()
resp.ecode = channeldata.pbdata.ecode
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
for inst in channeldata.pbdata.insts:
......@@ -843,7 +827,7 @@ class GeneralPythonService(
self._log("Error type({}) in datatype.".format(
channeldata.datatype)))
else:
resp.error_info = channeldata.pbdata.error_info
resp.error_info = channeldata.error_info
return resp
def inference(self, request, context):
......@@ -863,11 +847,11 @@ class GeneralPythonService(
resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self.name))
if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value:
if resp_channeldata.ecode == ChannelDataEcode.OK.value:
break
if i + 1 < self._retry:
logging.warn("retry({}): {}".format(
i + 1, resp_channeldata.pbdata.error_info))
i + 1, resp_channeldata.error_info))
_profiler.record("{}-postpack_0".format(self.name))
resp = self._pack_data_for_resp(resp_channeldata)
......@@ -911,6 +895,7 @@ class PyServer(object):
op.name = "#G" # update read_op.name
break
outdegs = {op.name: [] for op in self._user_ops}
zero_indeg_num, zero_outdeg_num = 0, 0
for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique
if op.name in indeg_num:
......@@ -918,8 +903,16 @@ class PyServer(object):
indeg_num[op.name] = len(op.get_input_ops())
if indeg_num[op.name] == 0:
ques[que_idx].put(op)
zero_indeg_num += 1
for pred_op in op.get_input_ops():
outdegs[pred_op.name].append(op)
if zero_indeg_num != 1:
raise Exception("DAG contains multiple input Ops")
for _, succ_list in outdegs.items():
if len(succ_list) == 0:
zero_outdeg_num += 1
if zero_outdeg_num != 1:
raise Exception("DAG contains multiple output Ops")
# topo sort to get dag_views
dag_views = []
......@@ -942,10 +935,6 @@ class PyServer(object):
que_idx = (que_idx + 1) % 2
if sorted_op_num < len(self._user_ops):
raise Exception("not legal DAG")
if len(dag_views[0]) != 1:
raise Exception("DAG contains multiple input Ops")
if len(dag_views[-1]) != 1:
raise Exception("DAG contains multiple output Ops")
# create channels and virtual ops
def name_generator(prefix):
......@@ -983,7 +972,14 @@ class PyServer(object):
else:
# create virtual op
virtual_op = None
virtual_op = VirtualOp(name=virtual_op_name_gen.next())
if sys.version_info.major == 2:
virtual_op = VirtualOp(
name=virtual_op_name_gen.next())
elif sys.version_info.major == 3:
virtual_op = VirtualOp(
name=virtual_op_name_gen.__next__())
else:
raise Exception("Error Python version")
virtual_ops.append(virtual_op)
outdegs[virtual_op.name] = [succ_op]
actual_next_view.append(virtual_op)
......@@ -995,7 +991,13 @@ class PyServer(object):
for o_idx, op in enumerate(actual_next_view):
if op.name in processed_op:
continue
if sys.version_info.major == 2:
channel = Channel(name=channel_name_gen.next())
elif sys.version_info.major == 3:
channel = Channel(
self._manager, name=channel_name_gen.__next__())
else:
raise Exception("Error Python version")
channels.append(channel)
logging.debug("{} => {}".format(channel.name, op.name))
op.add_input_channel(channel)
......@@ -1026,7 +1028,13 @@ class PyServer(object):
other_op.name))
other_op.add_input_channel(channel)
processed_op.add(other_op.name)
if sys.version_info.major == 2:
output_channel = Channel(name=channel_name_gen.next())
elif sys.version_info.major == 3:
output_channel = Channel(
self._manager, name=channel_name_gen.__next__())
else:
raise Exception("Error Python version")
channels.append(output_channel)
last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel)
......@@ -1050,7 +1058,7 @@ class PyServer(object):
self._in_channel = input_channel
self._out_channel = output_channel
for op in self._actual_ops:
if op.with_serving():
if op.with_serving:
self.prepare_serving(op)
self.gen_desc()
......@@ -1091,11 +1099,21 @@ class PyServer(object):
port = op._server_port
device = op._device
if Client.__name__ == "MultiLangClient":
if device == "cpu":
cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
cmd = "(Use grpc impl) python -m paddle_serving_server.serve" \
" --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
else:
cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
cmd = "(Use grpc impl) python -m paddle_serving_server_gpu.serve" \
" --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
elif Client.__name__ == "Client":
if device == "cpu":
cmd = "(Use brpc impl) python -m paddle_serving_server.serve" \
" --model {} --thread 4 --port {} &>/dev/null &".format(model_path, port)
else:
cmd = "(Use brpc impl) python -m paddle_serving_server_gpu.serve" \
" --model {} --thread 4 --port {} &>/dev/null &".format(model_path, port)
else:
raise Exception("unknow client type: {}".format(Client.__name__))
# run a server (not in PyServing)
logging.info("run a server (not in PyServing): {}".format(cmd))
......@@ -22,8 +22,9 @@ elif sys.version_info.major == 3:
else:
raise Exception("Error Python version")
import os
import paddle_serving_server
from paddle_serving_client import MultiLangClient as Client
import paddle_serving_server_gpu
#from paddle_serving_client import MultiLangClient as Client
from paddle_serving_client import Client
from concurrent import futures
import numpy as np
import grpc
......@@ -118,30 +119,20 @@ class ChannelData(object):
if ecode is not None:
if data_id is None or error_info is None:
raise ValueError("data_id and error_info cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.ecode = ecode
pbdata.id = data_id
pbdata.error_info = error_info
datatype = ChannelDataType.ERROR.value
else:
if datatype == ChannelDataType.CHANNEL_FUTURE.value:
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.ecode = ChannelDataEcode.OK.value
pbdata.id = data_id
ecode = ChannelDataEcode.OK.value
elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.id = data_id
ecode, error_info = self._check_npdata(npdata)
pbdata.ecode = ecode
if pbdata.ecode != ChannelDataEcode.OK.value:
pbdata.error_info = error_info
logging.error(pbdata.error_info)
if ecode != ChannelDataEcode.OK.value:
logging.error(error_info)
else:
for name, value in npdata.items():
inst = channel_pb2.Inst()
......@@ -153,33 +144,23 @@ class ChannelData(object):
pbdata.insts.append(inst)
elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
ecode, error_info = self._check_npdata(npdata)
pbdata = channel_pb2.ChannelData()
pbdata.id = data_id
pbdata.ecode = ecode
if pbdata.ecode != ChannelDataEcode.OK.value:
pbdata.error_info = error_info
logging.error(pbdata.error_info)
if ecode != ChannelDataEcode.OK.value:
logging.error(error_info)
else:
raise ValueError("datatype not match")
if not isinstance(pbdata, channel_pb2.ChannelData):
raise TypeError(
"pbdata must be pyserving_channel_pb2.ChannelData type({})".
format(type(pbdata)))
self.future = future
self.pbdata = pbdata
self.npdata = npdata
self.datatype = datatype
self.id = data_id
self.ecode = ecode
self.error_info = error_info
self.callback_func = callback_func
def _check_npdata(self, npdata):
ecode = ChannelDataEcode.OK.value
error_info = None
for name, value in npdata.items():
if not isinstance(name, (str, unicode)):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the key of postped_data must " \
"be str, but get {}".format(type(name)))
break
for _, value in npdata.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the value of postped_data must " \
......@@ -202,12 +183,12 @@ class ChannelData(object):
elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = self.npdata
else:
raise TypeError("Error type({}) in datatype.".format(datatype))
raise TypeError("Error type({}) in datatype.".format(self.datatype))
return feed
def __str__(self):
return "type[{}], ecode[{}]".format(
ChannelDataType(self.datatype).name, self.pbdata.ecode)
return "type[{}], ecode[{}], id[{}]".format(
ChannelDataType(self.datatype).name, self.ecode, self.id)
class Channel(Queue.Queue):
......@@ -303,7 +284,7 @@ class Channel(Queue.Queue):
"There are multiple producers, so op_name cannot be None."))
producer_num = len(self._producers)
data_id = channeldata.pbdata.id
data_id = channeldata.id
put_data = None
with self._cv:
logging.debug(self._log("{} get lock".format(op_name)))
......@@ -418,7 +399,7 @@ class Op(object):
concurrency=1,
timeout=-1,
retry=2):
self._run = False
self._is_run = False
self.name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency
self.set_input_ops(inputs)
......@@ -432,19 +413,17 @@ class Op(object):
self._outputs = []
def set_client(self, client_config, server_name, fetch_names):
self._client = None
self.with_serving = True
if client_config is None or \
server_name is None or \
fetch_names is None:
self.with_serving = False
return
self._client = Client()
self._client.load_client_config(client_config)
self._client.connect([server_name])
self._fetch_names = fetch_names
def with_serving(self):
return self._client is not None
def get_input_channel(self):
return self._input
......@@ -489,7 +468,7 @@ class Op(object):
feed = channeldata.parse()
return feed
def midprocess(self, data):
def midprocess(self, data, asyn=True):
if not isinstance(data, dict):
raise Exception(
self._log(
......@@ -497,10 +476,16 @@ class Op(object):
format(type(data))))
logging.debug(self._log('data: {}'.format(data)))
logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
call_future = self._client.predict(
feed=data, fetch=self._fetch_names, asyn=True)
logging.debug(self._log("get call_future"))
return call_future
if Client.__name__ == "MultiLangClient":
call_result = self._client.predict(
feed=data, fetch=self._fetch_names, asyn=asyn)
elif Client.__name__ == "Client":
call_result = self._client.predict(
feed=data, fetch=self._fetch_names)
else:
raise Exception("unknow client type: {}".format(Client.__name__))
logging.debug(self._log("get call_result"))
return call_result
def postprocess(self, output_data):
return output_data
......@@ -509,23 +494,23 @@ class Op(object):
self._input.stop()
for channel in self._outputs:
channel.stop()
self._run = False
self._is_run = False
def _parse_channeldata(self, channeldata):
data_id, error_pbdata = None, None
data_id, error_channeldata = None, None
if isinstance(channeldata, dict):
parsed_data = {}
key = channeldata.keys()[0]
data_id = channeldata[key].pbdata.id
key = list(channeldata.keys())[0]
data_id = channeldata[key].id
for _, data in channeldata.items():
if data.pbdata.ecode != ChannelDataEcode.OK.value:
error_pbdata = data.pbdata
if data.ecode != ChannelDataEcode.OK.value:
error_channeldata = data
break
else:
data_id = channeldata.pbdata.id
if channeldata.pbdata.ecode != ChannelDataEcode.OK.value:
error_pbdata = channeldata.pbdata
return data_id, error_pbdata
data_id = channeldata.id
if channeldata.ecode != ChannelDataEcode.OK.value:
error_channeldata = channeldata
return data_id, error_channeldata
def _push_to_output_channels(self, data, name=None):
if name is None:
......@@ -536,21 +521,18 @@ class Op(object):
def start(self, concurrency_idx):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._run = True
while self._run:
self._is_run = True
while self._is_run:
_profiler.record("{}-get_0".format(op_info_prefix))
channeldata = self._input.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix))
logging.debug(log("input_data: {}".format(channeldata)))
data_id, error_pbdata = self._parse_channeldata(channeldata)
data_id, error_channeldata = self._parse_channeldata(channeldata)
# error data in predecessor Op
if error_pbdata is not None:
self._push_to_output_channels(
ChannelData(
datatype=ChannelDataType.CHANNEL_PBDATA.value,
pbdata=error_pbdata))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata)
continue
# preprecess
......@@ -589,13 +571,14 @@ class Op(object):
continue
# midprocess
call_future = None
if self.with_serving():
midped_data = None
asyn = False
if self.with_serving:
ecode = ChannelDataEcode.OK.value
_profiler.record("{}-midp_0".format(op_info_prefix))
if self._timeout <= 0:
try:
call_future = self.midprocess(preped_data)
midped_data = self.midprocess(preped_data, asyn)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
......@@ -603,10 +586,10 @@ class Op(object):
else:
for i in range(self._retry):
try:
call_future = func_timeout.func_timeout(
midped_data = func_timeout.func_timeout(
self._timeout,
self.midprocess,
args=(preped_data, ))
args=(preped_data, asyn))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
......@@ -629,20 +612,22 @@ class Op(object):
data_id=data_id))
continue
_profiler.record("{}-midp_1".format(op_info_prefix))
else:
midped_data = preped_data
# postprocess
output_data = None
_profiler.record("{}-postp_0".format(op_info_prefix))
if self.with_serving():
if self.with_serving and asyn:
# use call_future
output_data = ChannelData(
datatype=ChannelDataType.CHANNEL_FUTURE.value,
future=call_future,
future=midped_data,
data_id=data_id,
callback_func=self.postprocess)
else:
try:
postped_data = self.postprocess(preped_data)
postped_data = self.postprocess(midped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
......@@ -710,8 +695,8 @@ class VirtualOp(Op):
def start(self, concurrency_idx):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._run = True
while self._run:
self._is_run = True
while self._is_run:
_profiler.record("{}-get_0".format(op_info_prefix))
channeldata = self._input.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix))
......@@ -727,7 +712,7 @@ class VirtualOp(Op):
class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
general_python_service_pb2_grpc.GeneralPythonServiceServicer):
def __init__(self, in_channel, out_channel, retry=2):
super(GeneralPythonService, self).__init__()
self.name = "#G"
......@@ -774,7 +759,7 @@ class GeneralPythonService(
self._log('data must be ChannelData type, but get {}'.
format(type(channeldata))))
with self._cv:
data_id = channeldata.pbdata.id
data_id = channeldata.id
self._globel_resp_dict[data_id] = channeldata
self._cv.notify_all()
......@@ -794,33 +779,32 @@ class GeneralPythonService(
def _pack_data_for_infer(self, request):
logging.debug(self._log('start inferce'))
pbdata = channel_pb2.ChannelData()
data_id = self._get_next_id()
pbdata.id = data_id
pbdata.ecode = ChannelDataEcode.OK.value
npdata = {}
try:
for idx, name in enumerate(request.feed_var_names):
logging.debug(
self._log('name: {}'.format(request.feed_var_names[idx])))
logging.debug(
self._log('data: {}'.format(request.feed_insts[idx])))
inst = channel_pb2.Inst()
inst.data = request.feed_insts[idx]
inst.shape = request.shape[idx]
inst.name = name
inst.type = request.type[idx]
pbdata.insts.append(inst)
npdata[name] = np.frombuffer(
request.feed_insts[idx], dtype=request.type[idx])
npdata[name].shape = np.frombuffer(
request.shape[idx], dtype="int32")
except Exception as e:
pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
pbdata.error_info = "rpc package error"
return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
error_info="rpc package error",
data_id=data_id), data_id
return ChannelData(
datatype=ChannelDataType.CHANNEL_PBDATA.value,
pbdata=pbdata), data_id
npdata=npdata,
data_id=data_id), data_id
def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get channeldata'))
resp = pyservice_pb2.Response()
resp.ecode = channeldata.pbdata.ecode
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
for inst in channeldata.pbdata.insts:
......@@ -843,7 +827,7 @@ class GeneralPythonService(
self._log("Error type({}) in datatype.".format(
channeldata.datatype)))
else:
resp.error_info = channeldata.pbdata.error_info
resp.error_info = channeldata.error_info
return resp
def inference(self, request, context):
......@@ -863,11 +847,11 @@ class GeneralPythonService(
resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self.name))
if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value:
if resp_channeldata.ecode == ChannelDataEcode.OK.value:
break
if i + 1 < self._retry:
logging.warn("retry({}): {}".format(
i + 1, resp_channeldata.pbdata.error_info))
i + 1, resp_channeldata.error_info))
_profiler.record("{}-postpack_0".format(self.name))
resp = self._pack_data_for_resp(resp_channeldata)
......@@ -911,6 +895,7 @@ class PyServer(object):
op.name = "#G" # update read_op.name
break
outdegs = {op.name: [] for op in self._user_ops}
zero_indeg_num, zero_outdeg_num = 0, 0
for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique
if op.name in indeg_num:
......@@ -918,8 +903,16 @@ class PyServer(object):
indeg_num[op.name] = len(op.get_input_ops())
if indeg_num[op.name] == 0:
ques[que_idx].put(op)
zero_indeg_num += 1
for pred_op in op.get_input_ops():
outdegs[pred_op.name].append(op)
if zero_indeg_num != 1:
raise Exception("DAG contains multiple input Ops")
for _, succ_list in outdegs.items():
if len(succ_list) == 0:
zero_outdeg_num += 1
if zero_outdeg_num != 1:
raise Exception("DAG contains multiple output Ops")
# topo sort to get dag_views
dag_views = []
......@@ -942,10 +935,6 @@ class PyServer(object):
que_idx = (que_idx + 1) % 2
if sorted_op_num < len(self._user_ops):
raise Exception("not legal DAG")
if len(dag_views[0]) != 1:
raise Exception("DAG contains multiple input Ops")
if len(dag_views[-1]) != 1:
raise Exception("DAG contains multiple output Ops")
# create channels and virtual ops
def name_generator(prefix):
......@@ -983,7 +972,14 @@ class PyServer(object):
else:
# create virtual op
virtual_op = None
virtual_op = VirtualOp(name=virtual_op_name_gen.next())
if sys.version_info.major == 2:
virtual_op = VirtualOp(
name=virtual_op_name_gen.next())
elif sys.version_info.major == 3:
virtual_op = VirtualOp(
name=virtual_op_name_gen.__next__())
else:
raise Exception("Error Python version")
virtual_ops.append(virtual_op)
outdegs[virtual_op.name] = [succ_op]
actual_next_view.append(virtual_op)
......@@ -995,7 +991,13 @@ class PyServer(object):
for o_idx, op in enumerate(actual_next_view):
if op.name in processed_op:
continue
if sys.version_info.major == 2:
channel = Channel(name=channel_name_gen.next())
elif sys.version_info.major == 3:
channel = Channel(
self._manager, name=channel_name_gen.__next__())
else:
raise Exception("Error Python version")
channels.append(channel)
logging.debug("{} => {}".format(channel.name, op.name))
op.add_input_channel(channel)
......@@ -1026,7 +1028,13 @@ class PyServer(object):
other_op.name))
other_op.add_input_channel(channel)
processed_op.add(other_op.name)
if sys.version_info.major == 2:
output_channel = Channel(name=channel_name_gen.next())
elif sys.version_info.major == 3:
output_channel = Channel(
self._manager, name=channel_name_gen.__next__())
else:
raise Exception("Error Python version")
channels.append(output_channel)
last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel)
......@@ -1050,7 +1058,7 @@ class PyServer(object):
self._in_channel = input_channel
self._out_channel = output_channel
for op in self._actual_ops:
if op.with_serving():
if op.with_serving:
self.prepare_serving(op)
self.gen_desc()
......@@ -1091,11 +1099,21 @@ class PyServer(object):
port = op._server_port
device = op._device
if Client.__name__ == "MultiLangClient":
if device == "cpu":
cmd = "(Use MultiLangServer) python -m paddle_serving_server.serve" \
cmd = "(Use grpc impl) python -m paddle_serving_server.serve" \
" --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
else:
cmd = "(Use MultiLangServer) python -m paddle_serving_server_gpu.serve" \
cmd = "(Use grpc impl) python -m paddle_serving_server_gpu.serve" \
" --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port)
elif Client.__name__ == "Client":
if device == "cpu":
cmd = "(Use brpc impl) python -m paddle_serving_server.serve" \
" --model {} --thread 4 --port {} &>/dev/null &".format(model_path, port)
else:
cmd = "(Use brpc impl) python -m paddle_serving_server_gpu.serve" \
" --model {} --thread 4 --port {} &>/dev/null &".format(model_path, port)
else:
raise Exception("unknow client type: {}".format(Client.__name__))
# run a server (not in PyServing)
logging.info("run a server (not in PyServing): {}".format(cmd))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册