提交 3172da66 编写于 作者: B barriery

update channel to support timeout

上级 8ec22efb
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from time import time as _time
import threading
import multiprocessing
import multiprocessing.queues
......@@ -175,7 +176,7 @@ class ProcessChannel(object):
Only when all types of Ops get the data of the same ID,
the data will be poped; The Op of the same type will not
get the data of the same ID.
3. (TODO) Timeout and BatchSize are not fully supported.
3. Function front support timeout param to make auto-batching.
Note:
1. The ID of the data in the channel must be different.
......@@ -194,7 +195,7 @@ class ProcessChannel(object):
maintains the data obtained from queue.
"""
def __init__(self, manager, name=None, maxsize=0, timeout=None):
def __init__(self, manager, name=None, maxsize=0):
# For queue multiprocess: after putting an object on
# an empty queue there may be an infinitessimal delay
# before the queue's :meth:`~Queue.empty`
......@@ -203,7 +204,6 @@ class ProcessChannel(object):
# - https://hg.python.org/cpython/rev/860fc6a2bd21
self._que = manager.Queue(maxsize=maxsize)
self._maxsize = maxsize
self._timeout = timeout
self.name = name
self._stop = manager.Value('i', 0)
......@@ -327,7 +327,13 @@ class ProcessChannel(object):
self._cv.notify_all()
return True
def front(self, op_name=None):
def front(self, op_name=None, timeout=None):
endtime = None
if timeout is not None and timeout <= 0:
timeout = None
else:
endtime = _time() + timeout
_LOGGER.debug(self._log("{} try to get data...".format(op_name)))
if len(self._consumer_cursors) == 0:
raise Exception(
......@@ -345,16 +351,15 @@ class ProcessChannel(object):
resp = self._que.get(timeout=0)
break
except Queue.Empty:
_LOGGER.debug(
self._log(
"{} wait for empty queue(with channel empty: {})".
format(op_name, self._que.empty())))
self._cv.wait()
if timeout is not None:
remaining = endtime - _time()
if remaining <= 0.0:
raise ChannelTimeoutError()
self._cv.wait(remaining)
else:
self._cv.wait()
if self._stop.value == 1:
raise ChannelStopError()
_LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__(
))))
return resp
elif op_name is None:
raise Exception(
......@@ -389,11 +394,13 @@ class ProcessChannel(object):
self._output_buf.append(channeldata)
break
except Queue.Empty:
_LOGGER.debug(
self._log(
"{} wait for empty queue(with channel size: {})".
format(op_name, self._que.qsize())))
self._cv.wait()
if timeout is not None:
remaining = endtime - _time()
if remaining <= 0.0:
raise ChannelTimeoutError()
self._cv.wait(remaining)
else:
self._cv.wait()
if self._stop.value == 1:
raise ChannelStopError()
......@@ -458,7 +465,7 @@ class ThreadChannel(Queue.Queue):
Only when all types of Ops get the data of the same ID,
the data will be poped; The Op of the same type will not
get the data of the same ID.
3. (TODO) Timeout and BatchSize are not fully supported.
3. Function front support timeout param to make auto-batching.
Note:
1. The ID of the data in the channel must be different.
......@@ -477,10 +484,9 @@ class ThreadChannel(Queue.Queue):
maintains the data obtained from queue.
"""
def __init__(self, name=None, maxsize=-1, timeout=None):
def __init__(self, name=None, maxsize=-1):
Queue.Queue.__init__(self, maxsize=maxsize)
self._maxsize = maxsize
self._timeout = timeout
self.name = name
self._stop = False
......@@ -592,7 +598,13 @@ class ThreadChannel(Queue.Queue):
self._cv.notify_all()
return True
def front(self, op_name=None):
def front(self, op_name=None, timeout=None):
endtime = None
if timeout is not None and timeout <= 0:
timeout = None
else:
endtime = _time() + timeout
_LOGGER.debug(self._log("{} try to get data".format(op_name)))
if len(self._consumer_cursors) == 0:
raise Exception(
......@@ -607,7 +619,13 @@ class ThreadChannel(Queue.Queue):
resp = self.get(timeout=0)
break
except Queue.Empty:
self._cv.wait()
if timeout is not None:
remaining = endtime - _time()
if remaining <= 0.0:
raise ChannelTimeoutError()
self._cv.wait(remaining)
else:
self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug(
......@@ -639,7 +657,13 @@ class ThreadChannel(Queue.Queue):
self._output_buf.append(channeldata)
break
except Queue.Empty:
self._cv.wait()
if timeout is not None:
remaining = endtime - _time()
if remaining <= 0.0:
raise ChannelTimeoutError()
self._cv.wait(remaining)
else:
self._cv.wait()
if self._stop:
raise ChannelStopError()
......@@ -687,6 +711,9 @@ class ThreadChannel(Queue.Queue):
with self._cv:
self._cv.notify_all()
class ChannelTimeoutError(RuntimeError):
def __init__(self):
pass
class ChannelStopError(RuntimeError):
def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册