提交 3172da66 编写于 作者: B barriery

update channel to support timeout

上级 8ec22efb
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from time import time as _time
import threading import threading
import multiprocessing import multiprocessing
import multiprocessing.queues import multiprocessing.queues
...@@ -175,7 +176,7 @@ class ProcessChannel(object): ...@@ -175,7 +176,7 @@ class ProcessChannel(object):
Only when all types of Ops get the data of the same ID, Only when all types of Ops get the data of the same ID,
the data will be poped; The Op of the same type will not the data will be poped; The Op of the same type will not
get the data of the same ID. get the data of the same ID.
3. (TODO) Timeout and BatchSize are not fully supported. 3. Function front support timeout param to make auto-batching.
Note: Note:
1. The ID of the data in the channel must be different. 1. The ID of the data in the channel must be different.
...@@ -194,7 +195,7 @@ class ProcessChannel(object): ...@@ -194,7 +195,7 @@ class ProcessChannel(object):
maintains the data obtained from queue. maintains the data obtained from queue.
""" """
def __init__(self, manager, name=None, maxsize=0, timeout=None): def __init__(self, manager, name=None, maxsize=0):
# For queue multiprocess: after putting an object on # For queue multiprocess: after putting an object on
# an empty queue there may be an infinitessimal delay # an empty queue there may be an infinitessimal delay
# before the queue's :meth:`~Queue.empty` # before the queue's :meth:`~Queue.empty`
...@@ -203,7 +204,6 @@ class ProcessChannel(object): ...@@ -203,7 +204,6 @@ class ProcessChannel(object):
# - https://hg.python.org/cpython/rev/860fc6a2bd21 # - https://hg.python.org/cpython/rev/860fc6a2bd21
self._que = manager.Queue(maxsize=maxsize) self._que = manager.Queue(maxsize=maxsize)
self._maxsize = maxsize self._maxsize = maxsize
self._timeout = timeout
self.name = name self.name = name
self._stop = manager.Value('i', 0) self._stop = manager.Value('i', 0)
...@@ -327,7 +327,13 @@ class ProcessChannel(object): ...@@ -327,7 +327,13 @@ class ProcessChannel(object):
self._cv.notify_all() self._cv.notify_all()
return True return True
def front(self, op_name=None): def front(self, op_name=None, timeout=None):
endtime = None
if timeout is not None and timeout <= 0:
timeout = None
else:
endtime = _time() + timeout
_LOGGER.debug(self._log("{} try to get data...".format(op_name))) _LOGGER.debug(self._log("{} try to get data...".format(op_name)))
if len(self._consumer_cursors) == 0: if len(self._consumer_cursors) == 0:
raise Exception( raise Exception(
...@@ -345,16 +351,15 @@ class ProcessChannel(object): ...@@ -345,16 +351,15 @@ class ProcessChannel(object):
resp = self._que.get(timeout=0) resp = self._que.get(timeout=0)
break break
except Queue.Empty: except Queue.Empty:
_LOGGER.debug( if timeout is not None:
self._log( remaining = endtime - _time()
"{} wait for empty queue(with channel empty: {})". if remaining <= 0.0:
format(op_name, self._que.empty()))) raise ChannelTimeoutError()
self._cv.wait() self._cv.wait(remaining)
else:
self._cv.wait()
if self._stop.value == 1: if self._stop.value == 1:
raise ChannelStopError() raise ChannelStopError()
_LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__(
))))
return resp return resp
elif op_name is None: elif op_name is None:
raise Exception( raise Exception(
...@@ -389,11 +394,13 @@ class ProcessChannel(object): ...@@ -389,11 +394,13 @@ class ProcessChannel(object):
self._output_buf.append(channeldata) self._output_buf.append(channeldata)
break break
except Queue.Empty: except Queue.Empty:
_LOGGER.debug( if timeout is not None:
self._log( remaining = endtime - _time()
"{} wait for empty queue(with channel size: {})". if remaining <= 0.0:
format(op_name, self._que.qsize()))) raise ChannelTimeoutError()
self._cv.wait() self._cv.wait(remaining)
else:
self._cv.wait()
if self._stop.value == 1: if self._stop.value == 1:
raise ChannelStopError() raise ChannelStopError()
...@@ -458,7 +465,7 @@ class ThreadChannel(Queue.Queue): ...@@ -458,7 +465,7 @@ class ThreadChannel(Queue.Queue):
Only when all types of Ops get the data of the same ID, Only when all types of Ops get the data of the same ID,
the data will be poped; The Op of the same type will not the data will be poped; The Op of the same type will not
get the data of the same ID. get the data of the same ID.
3. (TODO) Timeout and BatchSize are not fully supported. 3. Function front support timeout param to make auto-batching.
Note: Note:
1. The ID of the data in the channel must be different. 1. The ID of the data in the channel must be different.
...@@ -477,10 +484,9 @@ class ThreadChannel(Queue.Queue): ...@@ -477,10 +484,9 @@ class ThreadChannel(Queue.Queue):
maintains the data obtained from queue. maintains the data obtained from queue.
""" """
def __init__(self, name=None, maxsize=-1, timeout=None): def __init__(self, name=None, maxsize=-1):
Queue.Queue.__init__(self, maxsize=maxsize) Queue.Queue.__init__(self, maxsize=maxsize)
self._maxsize = maxsize self._maxsize = maxsize
self._timeout = timeout
self.name = name self.name = name
self._stop = False self._stop = False
...@@ -592,7 +598,13 @@ class ThreadChannel(Queue.Queue): ...@@ -592,7 +598,13 @@ class ThreadChannel(Queue.Queue):
self._cv.notify_all() self._cv.notify_all()
return True return True
def front(self, op_name=None): def front(self, op_name=None, timeout=None):
endtime = None
if timeout is not None and timeout <= 0:
timeout = None
else:
endtime = _time() + timeout
_LOGGER.debug(self._log("{} try to get data".format(op_name))) _LOGGER.debug(self._log("{} try to get data".format(op_name)))
if len(self._consumer_cursors) == 0: if len(self._consumer_cursors) == 0:
raise Exception( raise Exception(
...@@ -607,7 +619,13 @@ class ThreadChannel(Queue.Queue): ...@@ -607,7 +619,13 @@ class ThreadChannel(Queue.Queue):
resp = self.get(timeout=0) resp = self.get(timeout=0)
break break
except Queue.Empty: except Queue.Empty:
self._cv.wait() if timeout is not None:
remaining = endtime - _time()
if remaining <= 0.0:
raise ChannelTimeoutError()
self._cv.wait(remaining)
else:
self._cv.wait()
if self._stop: if self._stop:
raise ChannelStopError() raise ChannelStopError()
_LOGGER.debug( _LOGGER.debug(
...@@ -639,7 +657,13 @@ class ThreadChannel(Queue.Queue): ...@@ -639,7 +657,13 @@ class ThreadChannel(Queue.Queue):
self._output_buf.append(channeldata) self._output_buf.append(channeldata)
break break
except Queue.Empty: except Queue.Empty:
self._cv.wait() if timeout is not None:
remaining = endtime - _time()
if remaining <= 0.0:
raise ChannelTimeoutError()
self._cv.wait(remaining)
else:
self._cv.wait()
if self._stop: if self._stop:
raise ChannelStopError() raise ChannelStopError()
...@@ -687,6 +711,9 @@ class ThreadChannel(Queue.Queue): ...@@ -687,6 +711,9 @@ class ThreadChannel(Queue.Queue):
with self._cv: with self._cv:
self._cv.notify_all() self._cv.notify_all()
class ChannelTimeoutError(RuntimeError):
def __init__(self):
pass
class ChannelStopError(RuntimeError): class ChannelStopError(RuntimeError):
def __init__(self): def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册