提交 8090d687 编写于 作者: B barrierye

update thread to process

上级 a0c3d077
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import sys import sys
import paddle_serving_server import paddle_serving_server
from paddle_serving_client import MultiLangClient as Client from paddle_serving_client import MultiLangClient as Client
from paddle_serving_client import MultiLangPredictFuture
from concurrent import futures from concurrent import futures
import numpy as np import numpy as np
import grpc import grpc
...@@ -190,8 +191,8 @@ class ChannelData(object): ...@@ -190,8 +191,8 @@ class ChannelData(object):
return feed return feed
def __str__(self): def __str__(self):
return "type[{}], ecode[{}]".format( return "type[{}], ecode[{}], id[{}]".format(
ChannelDataType(self.datatype).name, self.ecode) ChannelDataType(self.datatype).name, self.ecode, self.id)
class Channel(multiprocessing.queues.Queue): class Channel(multiprocessing.queues.Queue):
...@@ -212,24 +213,30 @@ class Channel(multiprocessing.queues.Queue): ...@@ -212,24 +213,30 @@ class Channel(multiprocessing.queues.Queue):
and can only be called during initialization. and can only be called during initialization.
""" """
def __init__(self, name=None, maxsize=-1, timeout=None): def __init__(self, manager, name=None, maxsize=0, timeout=None):
# https://stackoverflow.com/questions/39496554/cannot-subclass-multiprocessing-queue-in-python-3-5 # https://stackoverflow.com/questions/39496554/cannot-subclass-multiprocessing-queue-in-python-3-5/
multiprocessing.queues.Queue.__init__(self, maxsize=maxsize) multiprocessing.queues.Queue.__init__(self, maxsize=maxsize)
self._maxsize = maxsize self._maxsize = maxsize
self._timeout = timeout self._timeout = timeout
self.name = name self.name = name
self._stop = False self._stop = False
self._cv = threading.Condition() self._cv = multiprocessing.Condition()
self._producers = [] self._producers = []
self._producer_res_count = {} # {data_id: count} self._producer_res_count = manager.dict() # {data_id: count}
self._push_res = {} # {data_id: {op_name: data}} # self._producer_res_count = {} # {data_id: count}
self._push_res = manager.dict() # {data_id: {op_name: data}}
self._consumers = {} # {op_name: idx} # self._push_res = {} # {data_id: {op_name: data}}
self._idx_consumer_num = {} # {idx: num}
self._consumer_base_idx = 0 self._consumers = manager.dict() # {op_name: idx}
self._front_res = [] # self._consumers = {} # {op_name: idx}
self._idx_consumer_num = manager.dict() # {idx: num}
# self._idx_consumer_num = {} # {idx: num}
self._consumer_base_idx = manager.Value('i', 0)
# self._consumer_base_idx = 0
self._front_res = manager.list()
# self._front_res = []
def get_producers(self): def get_producers(self):
return self._producers return self._producers
...@@ -277,9 +284,13 @@ class Channel(multiprocessing.queues.Queue): ...@@ -277,9 +284,13 @@ class Channel(multiprocessing.queues.Queue):
try: try:
self.put(channeldata, timeout=0) self.put(channeldata, timeout=0)
break break
except Queue.Empty: except Queue.Full:
self._cv.wait() self._cv.wait()
logging.debug(
self._log("{} channel size: {}".format(op_name,
self.qsize())))
self._cv.notify_all() self._cv.notify_all()
logging.debug(self._log("{} notify all".format(op_name)))
logging.debug(self._log("{} push data succ!".format(op_name))) logging.debug(self._log("{} push data succ!".format(op_name)))
return True return True
elif op_name is None: elif op_name is None:
...@@ -298,7 +309,12 @@ class Channel(multiprocessing.queues.Queue): ...@@ -298,7 +309,12 @@ class Channel(multiprocessing.queues.Queue):
for name in self._producers for name in self._producers
} }
self._producer_res_count[data_id] = 0 self._producer_res_count[data_id] = 0
self._push_res[data_id][op_name] = channeldata # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects
# self._push_res[data_id][op_name] = channeldata
tmp_push_res = self._push_res[data_id]
tmp_push_res[op_name] = channeldata
self._push_res[data_id] = tmp_push_res
if self._producer_res_count[data_id] + 1 == producer_num: if self._producer_res_count[data_id] + 1 == producer_num:
put_data = self._push_res[data_id] put_data = self._push_res[data_id]
self._push_res.pop(data_id) self._push_res.pop(data_id)
...@@ -313,6 +329,9 @@ class Channel(multiprocessing.queues.Queue): ...@@ -313,6 +329,9 @@ class Channel(multiprocessing.queues.Queue):
else: else:
while self._stop is False: while self._stop is False:
try: try:
logging.debug(
self._log("{} push data succ: {}".format(
op_name, put_data.__str__())))
self.put(put_data, timeout=0) self.put(put_data, timeout=0)
break break
except Queue.Empty: except Queue.Empty:
...@@ -324,7 +343,7 @@ class Channel(multiprocessing.queues.Queue): ...@@ -324,7 +343,7 @@ class Channel(multiprocessing.queues.Queue):
return True return True
def front(self, op_name=None): def front(self, op_name=None):
logging.debug(self._log("{} try to get data".format(op_name))) logging.debug(self._log("{} try to get data...".format(op_name)))
if len(self._consumers) == 0: if len(self._consumers) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -335,9 +354,18 @@ class Channel(multiprocessing.queues.Queue): ...@@ -335,9 +354,18 @@ class Channel(multiprocessing.queues.Queue):
with self._cv: with self._cv:
while self._stop is False and resp is None: while self._stop is False and resp is None:
try: try:
resp = self.get(timeout=0) logging.debug(
self._log("{} try to get(with channel size: {})".
format(op_name, self.qsize())))
#TODO: bug to fix
# (multiple processes) the queue is not empty, but it raise Queue.Empty
resp = self.get(timeout=1e-3)
break break
except Queue.Empty: except Queue.Empty:
logging.debug(
self._log(
"{} wait for empty queue(with channel size: {})".
format(op_name, self.qsize())))
self._cv.wait() self._cv.wait()
logging.debug( logging.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__( self._log("{} get data succ: {}".format(op_name, resp.__str__(
...@@ -351,16 +379,31 @@ class Channel(multiprocessing.queues.Queue): ...@@ -351,16 +379,31 @@ class Channel(multiprocessing.queues.Queue):
with self._cv: with self._cv:
# data_idx = consumer_idx - base_idx # data_idx = consumer_idx - base_idx
while self._stop is False and self._consumers[ while self._stop is False and self._consumers[
op_name] - self._consumer_base_idx >= len(self._front_res): op_name] - self._consumer_base_idx.value >= len(
self._front_res):
logging.debug(
self._log(
"({}) B self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
format(op_name, self._consumers, self.
_consumer_base_idx.value, len(self._front_res))))
try: try:
channeldata = self.get(timeout=0) logging.debug(
self._log("{} try to get(with channel size: {})".format(
op_name, self.qsize())))
#TODO: bug to fix
# (multiple processes) the queue is not empty, but it raise Queue.Empty
channeldata = self.get(timeout=1e-3)
self._front_res.append(channeldata) self._front_res.append(channeldata)
break break
except Queue.Empty: except Queue.Empty:
logging.debug(
self._log(
"{} wait for empty queue(with channel size: {})".
format(op_name, self.qsize())))
self._cv.wait() self._cv.wait()
consumer_idx = self._consumers[op_name] consumer_idx = self._consumers[op_name]
base_idx = self._consumer_base_idx base_idx = self._consumer_base_idx.value
data_idx = consumer_idx - base_idx data_idx = consumer_idx - base_idx
resp = self._front_res[data_idx] resp = self._front_res[data_idx]
logging.debug(self._log("{} get data: {}".format(op_name, resp))) logging.debug(self._log("{} get data: {}".format(op_name, resp)))
...@@ -370,14 +413,19 @@ class Channel(multiprocessing.queues.Queue): ...@@ -370,14 +413,19 @@ class Channel(multiprocessing.queues.Queue):
consumer_idx] == 0: consumer_idx] == 0:
self._idx_consumer_num.pop(consumer_idx) self._idx_consumer_num.pop(consumer_idx)
self._front_res.pop(0) self._front_res.pop(0)
self._consumer_base_idx += 1 self._consumer_base_idx.value += 1
self._consumers[op_name] += 1 self._consumers[op_name] += 1
new_consumer_idx = self._consumers[op_name] new_consumer_idx = self._consumers[op_name]
if self._idx_consumer_num.get(new_consumer_idx) is None: if self._idx_consumer_num.get(new_consumer_idx) is None:
self._idx_consumer_num[new_consumer_idx] = 0 self._idx_consumer_num[new_consumer_idx] = 0
self._idx_consumer_num[new_consumer_idx] += 1 self._idx_consumer_num[new_consumer_idx] += 1
logging.debug(
self._log(
"({}) A self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
format(op_name, self._consumers, self._consumer_base_idx.
value, len(self._front_res))))
logging.debug(self._log("{} notify all".format(op_name)))
self._cv.notify_all() self._cv.notify_all()
logging.debug(self._log("multi | {} get data succ!".format(op_name))) logging.debug(self._log("multi | {} get data succ!".format(op_name)))
...@@ -403,7 +451,7 @@ class Op(object): ...@@ -403,7 +451,7 @@ class Op(object):
concurrency=1, concurrency=1,
timeout=-1, timeout=-1,
retry=2): retry=2):
self._run = False self._is_run = False
self.name = name # to identify the type of OP, it must be globally unique self.name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency self._concurrency = concurrency # amount of concurrency
self.set_input_ops(inputs) self.set_input_ops(inputs)
...@@ -428,15 +476,12 @@ class Op(object): ...@@ -428,15 +476,12 @@ class Op(object):
self.with_serving = True self.with_serving = True
def init_client(self, client_config, server_name, fetch_names): def init_client(self, client_config, server_name, fetch_names):
self._client = None if self.with_serving == False:
if client_config is None or \ logging.debug("{} no client".format(self.name))
server_name is None or \
fetch_names is None:
logging.debug("no client")
return return
logging.debug("client_config: {}".format(client_config)) logging.debug("{} client_config: {}".format(self.name, client_config))
logging.debug("server_name: {}".format(server_name)) logging.debug("{} server_name: {}".format(self.name, server_name))
logging.debug("fetch_names: {}".format(fetch_names)) logging.debug("{} fetch_names: {}".format(self.name, fetch_names))
self._client = Client() self._client = Client()
self._client.load_client_config(client_config) self._client.load_client_config(client_config)
self._client.connect([server_name]) self._client.connect([server_name])
...@@ -506,7 +551,7 @@ class Op(object): ...@@ -506,7 +551,7 @@ class Op(object):
self._input.stop() self._input.stop()
for channel in self._outputs: for channel in self._outputs:
channel.stop() channel.stop()
self._run = False self._is_run = False
def _parse_channeldata(self, channeldata): def _parse_channeldata(self, channeldata):
data_id, error_channeldata = None, None data_id, error_channeldata = None, None
...@@ -541,13 +586,13 @@ class Op(object): ...@@ -541,13 +586,13 @@ class Op(object):
proces.append(p) proces.append(p)
return proces return proces
def _run(self, input_channel, output_channels): def _run(self, concurrency_idx, input_channel, output_channels):
self.init_client(self._client_config, self._server_name, self.init_client(self._client_config, self._server_name,
self._fetch_names) self._fetch_names)
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix) log = self._get_log_func(op_info_prefix)
self._run = True self._is_run = True
while self._run: while self._is_run:
_profiler.record("{}-get_0".format(op_info_prefix)) _profiler.record("{}-get_0".format(op_info_prefix))
channeldata = input_channel.front(self.name) channeldata = input_channel.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix)) _profiler.record("{}-get_1".format(op_info_prefix))
...@@ -593,7 +638,7 @@ class Op(object): ...@@ -593,7 +638,7 @@ class Op(object):
logging.error(error_info) logging.error(error_info)
self._push_to_output_channels( self._push_to_output_channels(
ChannelData( ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value, ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info, error_info=error_info,
data_id=data_id), data_id=data_id),
output_channels) output_channels)
...@@ -652,6 +697,15 @@ class Op(object): ...@@ -652,6 +697,15 @@ class Op(object):
future=call_future, future=call_future,
data_id=data_id, data_id=data_id,
callback_func=self.postprocess) callback_func=self.postprocess)
#TODO: for future are not picklable
npdata = self.postprocess(call_future.result())
self._push_to_output_channels(
ChannelData(
ChannelDataType.CHANNEL_NPDATA.value,
npdata=npdata,
data_id=data_id),
output_channels)
continue
else: else:
try: try:
postped_data = self.postprocess(preped_data) postped_data = self.postprocess(preped_data)
...@@ -724,8 +778,8 @@ class VirtualOp(Op): ...@@ -724,8 +778,8 @@ class VirtualOp(Op):
def _run(self, input_channel, output_channels): def _run(self, input_channel, output_channels):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix) log = self._get_log_func(op_info_prefix)
self._run = True self._is_run = True
while self._run: while self._is_run:
_profiler.record("{}-get_0".format(op_info_prefix)) _profiler.record("{}-get_0".format(op_info_prefix))
channeldata = input_channel.front(self.name) channeldata = input_channel.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix)) _profiler.record("{}-get_1".format(op_info_prefix))
...@@ -903,6 +957,7 @@ class PyServer(object): ...@@ -903,6 +957,7 @@ class PyServer(object):
self._in_channel = None self._in_channel = None
self._out_channel = None self._out_channel = None
self._retry = retry self._retry = retry
self._manager = multiprocessing.Manager()
_profiler.enable(profile) _profiler.enable(profile)
def add_channel(self, channel): def add_channel(self, channel):
...@@ -1011,7 +1066,7 @@ class PyServer(object): ...@@ -1011,7 +1066,7 @@ class PyServer(object):
for o_idx, op in enumerate(actual_next_view): for o_idx, op in enumerate(actual_next_view):
if op.name in processed_op: if op.name in processed_op:
continue continue
channel = Channel(name=channel_name_gen.next()) channel = Channel(self._manager, name=channel_name_gen.next())
channels.append(channel) channels.append(channel)
logging.debug("{} => {}".format(channel.name, op.name)) logging.debug("{} => {}".format(channel.name, op.name))
op.add_input_channel(channel) op.add_input_channel(channel)
...@@ -1042,7 +1097,7 @@ class PyServer(object): ...@@ -1042,7 +1097,7 @@ class PyServer(object):
other_op.name)) other_op.name))
other_op.add_input_channel(channel) other_op.add_input_channel(channel)
processed_op.add(other_op.name) processed_op.add(other_op.name)
output_channel = Channel(name=channel_name_gen.next()) output_channel = Channel(self._manager, name=channel_name_gen.next())
channels.append(output_channel) channels.append(output_channel)
last_op = dag_views[-1][0] last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel) last_op.add_output_channel(output_channel)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册