提交 d1c969ac 编写于 作者: B barrierye

add condition

上级 9c0e751b
...@@ -184,7 +184,11 @@ class GeneralPythonService( ...@@ -184,7 +184,11 @@ class GeneralPythonService(
super(GeneralPythonService, self).__init__() super(GeneralPythonService, self).__init__()
self._in_channel = in_channel self._in_channel = in_channel
self._out_channel = out_channel self._out_channel = out_channel
self._lock = threading.Lock() #TODO:
# multi-lock for different clients
# diffenert lock for server and client
self._id_lock = threading.Lock()
self._cv = threading.Condition()
self._globel_resp_dict = {} self._globel_resp_dict = {}
self._id_counter = 0 self._id_counter = 0
self._recive_func = threading.Thread( self._recive_func = threading.Thread(
...@@ -202,20 +206,24 @@ class GeneralPythonService( ...@@ -202,20 +206,24 @@ class GeneralPythonService(
if data_id != d.id: if data_id != d.id:
raise Exception("id not match: {} vs {}".format(data_id, raise Exception("id not match: {} vs {}".format(data_id,
d.id)) d.id))
with self._lock: self._cv.acquire()
self._globel_resp_dict[data_id] = data self._globel_resp_dict[data_id] = data
#TODO wake up inference self._cv.notify_all()
self._cv.release()
def _get_next_id(self): def _get_next_id(self):
with self._lock: with self._id_lock:
self._id_counter += 1 self._id_counter += 1
return self._id_counter - 1 return self._id_counter - 1
def _get_data_in_globel_resp_dict(self, data_id): def _get_data_in_globel_resp_dict(self, data_id):
if data_id in self._globel_resp_dict: self._cv.acquire()
with self._lock: while data_id not in self._globel_resp_dict:
return self._globel_resp_dict.pop(data_id) self._cv.wait()
return None resp = self._globel_resp_dict.pop(data_id)
self._cv.notify_all()
self._cv.release()
return resp
def _pack_data_for_infer(self, request): def _pack_data_for_infer(self, request):
logging.debug('start inferce') logging.debug('start inferce')
...@@ -250,11 +258,7 @@ class GeneralPythonService( ...@@ -250,11 +258,7 @@ class GeneralPythonService(
self._in_channel.push(data) self._in_channel.push(data)
logging.debug('wait for infer') logging.debug('wait for infer')
resp_data = None resp_data = None
while True: resp_data = self._get_data_in_globel_resp_dict(data_id)
resp_data = self._get_data_in_globel_resp_dict(data_id)
if resp_data is not None:
break
time.sleep(0.05) #TODO: wake up by _recive_out_channel_func
resp = self._pack_data_for_resp(resp_data) resp = self._pack_data_for_resp(resp_data)
return resp return resp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册