提交 b4907919 编写于 作者: B barrierye

[WIP] update thread to process

上级 d947eff9
......@@ -14,6 +14,7 @@
# pylint: disable=doc-string-missing
import threading
import multiprocessing
import multiprocessing.queues
import Queue
import os
import sys
......@@ -109,34 +110,27 @@ class ChannelData(object):
4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
6. ChannelData(ecode, error_info, data_id)
Protobufs are not pickle-able:
https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
'''
if ecode is not None:
if data_id is None or error_info is None:
raise ValueError("data_id and error_info cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.ecode = ecode
pbdata.id = data_id
pbdata.error_info = error_info
datatype = ChannelDataType.ERROR.value
else:
if datatype == ChannelDataType.CHANNEL_FUTURE.value:
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.ecode = ChannelDataEcode.OK.value
pbdata.id = data_id
if data_id is None:
raise ValueError("data_id cannot be None")
ecode = ChannelDataEcode.OK.value
elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
if pbdata is None:
if data_id is None:
raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.id = data_id
ecode, error_info = self._check_npdata(npdata)
pbdata.ecode = ecode
if pbdata.ecode != ChannelDataEcode.OK.value:
pbdata.error_info = error_info
logging.error(pbdata.error_info)
if ecode != ChannelDataEcode.OK.value:
logging.error(error_info)
else:
for name, value in npdata.items():
inst = channel_pb2.Inst()
......@@ -148,23 +142,18 @@ class ChannelData(object):
pbdata.insts.append(inst)
elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
ecode, error_info = self._check_npdata(npdata)
pbdata = channel_pb2.ChannelData()
pbdata.id = data_id
pbdata.ecode = ecode
if pbdata.ecode != ChannelDataEcode.OK.value:
pbdata.error_info = error_info
logging.error(pbdata.error_info)
if ecode != ChannelDataEcode.OK.value:
logging.error(error_info)
else:
raise ValueError("datatype not match")
if not isinstance(pbdata, channel_pb2.ChannelData):
raise TypeError(
"pbdata must be pyserving_channel_pb2.ChannelData type({})".
format(type(pbdata)))
self.future = future
self.pbdata = pbdata
self.npdata = npdata
self.datatype = datatype
self.callback_func = callback_func
self.id = data_id
self.ecode = ecode
self.error_info = error_info
def _check_npdata(self, npdata):
ecode = ChannelDataEcode.OK.value
......@@ -176,8 +165,8 @@ class ChannelData(object):
"be str, but get {}".format(type(name)))
break
if not isinstance(value, np.ndarray):
pbdata.ecode = ChannelDataEcode.TYPE_ERROR.value
pbdata.error_info = log("the value of postped_data must " \
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("the value of postped_data must " \
"be np.ndarray, but get {}".format(type(value)))
break
return ecode, error_info
......@@ -197,16 +186,15 @@ class ChannelData(object):
elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = self.npdata
else:
raise TypeError("Error type({}) in pbdata.type.".format(
self.pbdata.type))
raise TypeError("Error type({}) in datatype.".format(self.datatype))
return feed
def __str__(self):
return "type[{}], ecode[{}]".format(
ChannelDataType(self.datatype).name, self.pbdata.ecode)
ChannelDataType(self.datatype).name, self.ecode)
class Channel(Queue.Queue):
class Channel(multiprocessing.queues.Queue):
"""
The channel used for communication between Ops.
......@@ -225,7 +213,8 @@ class Channel(Queue.Queue):
"""
def __init__(self, name=None, maxsize=-1, timeout=None):
Queue.Queue.__init__(self, maxsize=maxsize)
# https://stackoverflow.com/questions/39496554/cannot-subclass-multiprocessing-queue-in-python-3-5
multiprocessing.queues.Queue.__init__(self, maxsize=maxsize)
self._maxsize = maxsize
self._timeout = timeout
self.name = name
......@@ -299,7 +288,7 @@ class Channel(Queue.Queue):
"There are multiple producers, so op_name cannot be None."))
producer_num = len(self._producers)
data_id = channeldata.pbdata.id
data_id = channeldata.id
put_data = None
with self._cv:
logging.debug(self._log("{} get lock".format(op_name)))
......@@ -418,29 +407,41 @@ class Op(object):
self.name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency
self.set_input_ops(inputs)
self.set_client(client_config, server_name, fetch_names)
self._server_model = server_model
self._server_port = server_port
self._device = device
self._timeout = timeout
self._retry = max(1, retry)
self._input = None
self._outputs = []
def set_client(self, client_config, server_name, fetch_names):
self.with_serving = False
self._client_config = client_config
self._server_name = server_name
self._fetch_names = fetch_names
self._server_model = server_model
self._server_port = server_port
self._device = device
if self._client_config is not None and \
self._server_name is not None and \
self._fetch_names is not None and \
self._server_model is not None and \
self._server_port is not None and \
self._device is not None:
self.with_serving = True
def init_client(self, client_config, server_name, fetch_names):
self._client = None
if client_config is None or \
server_name is None or \
fetch_names is None:
logging.debug("no client")
return
logging.debug("client_config: {}".format(client_config))
logging.debug("server_name: {}".format(server_name))
logging.debug("fetch_names: {}".format(fetch_names))
self._client = Client()
self._client.load_client_config(client_config)
self._client.connect([server_name])
self._fetch_names = fetch_names
def with_serving(self):
return self._client is not None
def get_input_channel(self):
return self._input
......@@ -508,45 +509,56 @@ class Op(object):
self._run = False
def _parse_channeldata(self, channeldata):
data_id, error_pbdata = None, None
data_id, error_channeldata = None, None
if isinstance(channeldata, dict):
parsed_data = {}
key = channeldata.keys()[0]
data_id = channeldata[key].pbdata.id
data_id = channeldata[key].id
for _, data in channeldata.items():
if data.pbdata.ecode != ChannelDataEcode.OK.value:
error_pbdata = data.pbdata
if data.ecode != ChannelDataEcode.OK.value:
error_channeldata = data
break
else:
data_id = channeldata.pbdata.id
if channeldata.pbdata.ecode != ChannelDataEcode.OK.value:
error_pbdata = channeldata.pbdata
return data_id, error_pbdata
data_id = channeldata.id
if channeldata.ecode != ChannelDataEcode.OK.value:
error_channeldata = channeldata
return data_id, error_channeldata
def _push_to_output_channels(self, data, name=None):
def _push_to_output_channels(self, data, channels, name=None):
if name is None:
name = self.name
for channel in self._outputs:
for channel in channels:
channel.push(data, name)
def start(self, concurrency_idx):
def start(self):
proces = []
for concurrency_idx in range(self._concurrency):
p = multiprocessing.Process(
target=self._run,
args=(concurrency_idx, self.get_input_channel(),
self.get_output_channels()))
p.start()
proces.append(p)
return proces
def _run(self, input_channel, output_channels):
self.init_client(self._client_config, self._server_name,
self._fetch_names)
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._run = True
while self._run:
_profiler.record("{}-get_0".format(op_info_prefix))
channeldata = self._input.front(self.name)
channeldata = input_channel.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix))
logging.debug(log("input_data: {}".format(channeldata)))
data_id, error_pbdata = self._parse_channeldata(channeldata)
data_id, error_channeldata = self._parse_channeldata(channeldata)
# error data in predecessor Op
if error_pbdata is not None:
self._push_to_output_channels(
ChannelData(
datatype=ChannelDataType.CHANNEL_PBDATA.value,
pbdata=error_pbdata))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# preprecess
......@@ -562,17 +574,19 @@ class Op(object):
ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info,
data_id=data_id))
data_id=data_id),
output_channels)
continue
except TypeError as e:
# Error type in channeldata.pbdata.type
# Error type in channeldata.datatype
error_info = log(e)
logging.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info,
data_id=data_id))
data_id=data_id),
output_channels)
continue
except Exception as e:
error_info = log(e)
......@@ -581,12 +595,13 @@ class Op(object):
ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info,
data_id=data_id))
data_id=data_id),
output_channels)
continue
# midprocess
call_future = None
if self.with_serving():
if self.with_serving:
ecode = ChannelDataEcode.OK.value
_profiler.record("{}-midp_0".format(op_info_prefix))
if self._timeout <= 0:
......@@ -622,14 +637,15 @@ class Op(object):
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
data_id=data_id),
output_channels)
continue
_profiler.record("{}-midp_1".format(op_info_prefix))
# postprocess
output_data = None
_profiler.record("{}-postp_0".format(op_info_prefix))
if self.with_serving():
if self.with_serving:
# use call_future
output_data = ChannelData(
datatype=ChannelDataType.CHANNEL_FUTURE.value,
......@@ -646,7 +662,8 @@ class Op(object):
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
data_id=data_id),
output_channels)
continue
if not isinstance(postped_data, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
......@@ -656,7 +673,8 @@ class Op(object):
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id))
data_id=data_id),
output_channels)
continue
output_data = ChannelData(
......@@ -667,7 +685,7 @@ class Op(object):
# push data to channel (if run succ)
_profiler.record("{}-push_0".format(op_info_prefix))
self._push_to_output_channels(output_data)
self._push_to_output_channels(output_data, output_channels)
_profiler.record("{}-push_1".format(op_info_prefix))
def _log(self, info):
......@@ -703,22 +721,25 @@ class VirtualOp(Op):
channel.add_producer(op.name)
self._outputs.append(channel)
def start(self, concurrency_idx):
def _run(self, input_channel, output_channels):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix)
self._run = True
while self._run:
_profiler.record("{}-get_0".format(op_info_prefix))
channeldata = self._input.front(self.name)
channeldata = input_channel.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix))
_profiler.record("{}-push_0".format(op_info_prefix))
if isinstance(channeldata, dict):
for name, data in channeldata.items():
self._push_to_output_channels(data, name=name)
self._push_to_output_channels(
data, channels=output_channels, name=name)
else:
self._push_to_output_channels(channeldata,
self._virtual_pred_ops[0].name)
self._push_to_output_channels(
channeldata,
channels=output_channels,
name=self._virtual_pred_ops[0].name)
_profiler.record("{}-push_1".format(op_info_prefix))
......@@ -770,7 +791,7 @@ class GeneralPythonService(
self._log('data must be ChannelData type, but get {}'.
format(type(channeldata))))
with self._cv:
data_id = channeldata.pbdata.id
data_id = channeldata.id
self._globel_resp_dict[data_id] = channeldata
self._cv.notify_all()
......@@ -790,33 +811,33 @@ class GeneralPythonService(
def _pack_data_for_infer(self, request):
logging.debug(self._log('start inferce'))
pbdata = channel_pb2.ChannelData()
data_id = self._get_next_id()
pbdata.id = data_id
pbdata.ecode = ChannelDataEcode.OK.value
npdata = {}
try:
for idx, name in enumerate(request.feed_var_names):
logging.debug(
self._log('name: {}'.format(request.feed_var_names[idx])))
logging.debug(
self._log('data: {}'.format(request.feed_insts[idx])))
inst = channel_pb2.Inst()
inst.data = request.feed_insts[idx]
inst.shape = request.shape[idx]
inst.name = name
inst.type = request.type[idx]
pbdata.insts.append(inst)
npdata[name] = np.frombuffer(
request.feed_insts[idx], dtype=request.type[idx])
npdata[name].shape = np.frombuffer(
request.shape[idx], dtype="int32")
except Exception as e:
pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
pbdata.error_info = "rpc package error"
return ChannelData(
datatype=ChannelDataType.CHANNEL_PBDATA.value,
pbdata=pbdata), data_id
return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
error_info="rpc package error",
data_id=data_id), data_id
else:
return ChannelData(
datatype=ChannelDataType.CHANNEL_NPDATA.value,
npdata=npdata,
data_id=data_id), data_id
def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get channeldata'))
resp = pyservice_pb2.Response()
resp.ecode = channeldata.pbdata.ecode
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
for inst in channeldata.pbdata.insts:
......@@ -836,10 +857,10 @@ class GeneralPythonService(
resp.type.append(str(var.dtype))
else:
raise TypeError(
self._log("Error type({}) in pbdata.type.".format(
self._log("Error type({}) in datatype.".format(
channeldata.datatype)))
else:
resp.error_info = channeldata.pbdata.error_info
resp.error_info = channeldata.error_info
return resp
def inference(self, request, context):
......@@ -859,11 +880,11 @@ class GeneralPythonService(
resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self.name))
if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value:
if resp_channeldata.ecode == ChannelDataEcode.OK.value:
break
if i + 1 < self._retry:
logging.warn("retry({}): {}".format(
i + 1, resp_channeldata.pbdata.error_info))
i + 1, resp_channeldata.error_info))
_profiler.record("{}-postpack_0".format(self.name))
resp = self._pack_data_for_resp(resp_channeldata)
......@@ -877,7 +898,6 @@ class PyServer(object):
self._channels = []
self._user_ops = []
self._actual_ops = []
self._op_threads = []
self._port = None
self._worker_num = None
self._in_channel = None
......@@ -1046,30 +1066,22 @@ class PyServer(object):
self._in_channel = input_channel
self._out_channel = output_channel
for op in self._actual_ops:
if op.with_serving():
if op.with_serving:
self.prepare_serving(op)
self.gen_desc()
def _op_start_wrapper(self, op, concurrency_idx):
return op.start(concurrency_idx)
def _run_ops(self):
proces = []
for op in self._actual_ops:
op_concurrency = op.get_concurrency()
logging.debug("run op: {}, op_concurrency: {}".format(
op.name, op_concurrency))
for c in range(op_concurrency):
th = threading.Thread(
target=self._op_start_wrapper, args=(op, c))
th.start()
self._op_threads.append(th)
proces.extend(op.start())
return proces
def _stop_ops(self):
for op in self._actual_ops:
op.stop()
def run_server(self):
self._run_ops()
op_proces = self._run_ops()
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
......@@ -1079,8 +1091,8 @@ class PyServer(object):
server.start()
server.wait_for_termination()
self._stop_ops() # TODO
for th in self._op_threads:
th.join()
for p in op_proces:
p.join()
def prepare_serving(self, op):
model_path = op._server_model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册