提交 5ad6738c 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Allow a QueueRunner to create_threads on multiple sessions.

Change: 137701036
上级 57f42975
......@@ -95,9 +95,9 @@ class FeedingQueueRunner(qr.QueueRunner):
except (errors.OutOfRangeError, errors.CancelledError):
# This exception indicates that a queue was closed.
with self._lock:
self._runs -= 1
self._runs_per_session[sess] -= 1
decremented = True
if self._runs == 0:
if self._runs_per_session[sess] == 0:
try:
sess.run(self._close_op)
except Exception as e:
......@@ -117,10 +117,10 @@ class FeedingQueueRunner(qr.QueueRunner):
# Make sure we account for all terminations: normal or errors.
if not decremented:
with self._lock:
self._runs -= 1
self._runs_per_session[sess] -= 1
def create_threads(self, sess, coord=None, daemon=False, start=False):
"""Create threads to run the enqueue ops.
"""Create threads to run the enqueue ops for the given session.
This method requires a session in which the graph was launched. It creates
a list of threads, optionally starting them. There is one thread for each
......@@ -131,8 +131,8 @@ class FeedingQueueRunner(qr.QueueRunner):
this method starts an additional thread to close the queue when the
coordinator requests a stop.
This method may be called again as long as all threads from a previous call
have stopped.
If previously created threads for the given session are still running, no
new threads will be created.
Args:
sess: A `Session`.
......@@ -144,16 +144,16 @@ class FeedingQueueRunner(qr.QueueRunner):
Returns:
A list of threads.
Raises:
RuntimeError: If threads from a previous call to `create_threads()` are
still running.
"""
with self._lock:
if self._runs > 0:
# Already started: no new threads to return.
return []
self._runs = len(self._enqueue_ops)
try:
if self._runs_per_session[sess] > 0:
# Already started: no new threads to return.
return []
except KeyError:
# We haven't seen this session yet.
pass
self._runs_per_session[sess] = len(self._enqueue_ops)
self._exceptions_raised = []
ret_threads = [threading.Thread(target=self._run,
......
......@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import threading
import weakref
from tensorflow.core.protobuf import queue_runner_pb2
from tensorflow.python.framework import errors
......@@ -90,7 +91,9 @@ class QueueRunner(object):
queue_closed_exception_types=queue_closed_exception_types)
# Protect the count of runs to wait for.
self._lock = threading.Lock()
self._runs = 0
# A map from a session object to the number of outstanding queue runner
# threads for that session.
self._runs_per_session = weakref.WeakKeyDictionary()
# List of exceptions raised by the running threads.
self._exceptions_raised = []
......@@ -234,9 +237,9 @@ class QueueRunner(object):
except self._queue_closed_exception_types: # pylint: disable=catching-non-exception
# This exception indicates that a queue was closed.
with self._lock:
self._runs -= 1
self._runs_per_session[sess] -= 1
decremented = True
if self._runs == 0:
if self._runs_per_session[sess] == 0:
try:
sess.run(self._close_op)
except Exception as e:
......@@ -256,7 +259,7 @@ class QueueRunner(object):
# Make sure we account for all terminations: normal or errors.
if not decremented:
with self._lock:
self._runs -= 1
self._runs_per_session[sess] -= 1
def _close_on_stop(self, sess, cancel_op, coord):
"""Close the queue when the Coordinator requests stop.
......@@ -276,19 +279,19 @@ class QueueRunner(object):
# pylint: enable=broad-except
def create_threads(self, sess, coord=None, daemon=False, start=False):
"""Create threads to run the enqueue ops.
"""Create threads to run the enqueue ops for the given session.
This method requires a session in which the graph was launched. It creates
a list of threads, optionally starting them. There is one thread for each
op passed in `enqueue_ops`.
The `coord` argument is an optional coordinator, that the threads will use
The `coord` argument is an optional coordinator that the threads will use
to terminate together and report exceptions. If a coordinator is given,
this method starts an additional thread to close the queue when the
coordinator requests a stop.
This method may be called again as long as all threads from a previous call
have stopped.
If previously created threads for the given session are still running, no
new threads will be created.
Args:
sess: A `Session`.
......@@ -300,16 +303,16 @@ class QueueRunner(object):
Returns:
A list of threads.
Raises:
RuntimeError: If threads from a previous call to `create_threads()` are
still running.
"""
with self._lock:
if self._runs > 0:
# Already started: no new threads to return.
return []
self._runs = len(self._enqueue_ops)
try:
if self._runs_per_session[sess] > 0:
# Already started: no new threads to return.
return []
except KeyError:
# We haven't seen this session yet.
pass
self._runs_per_session[sess] = len(self._enqueue_ops)
self._exceptions_raised = []
ret_threads = [threading.Thread(target=self._run, args=(sess, op, coord))
......
......@@ -173,6 +173,21 @@ class QueueRunnerTest(tf.test.TestCase):
# the queue to be closed and the enqueue to terminate.
coord.join(stop_grace_period_secs=0.05)
def testMultipleSessions(self):
with self.test_session() as sess:
with tf.Session() as other_sess:
zero64 = tf.constant(0, dtype=tf.int64)
var = tf.Variable(zero64)
count_up_to = var.count_up_to(3)
queue = tf.FIFOQueue(10, tf.float32)
tf.initialize_all_variables().run()
coord = tf.train.Coordinator()
qr = tf.train.QueueRunner(queue, [count_up_to])
# NOTE that this test does not actually start the threads.
threads = qr.create_threads(sess, coord=coord)
other_threads = qr.create_threads(other_sess, coord=coord)
self.assertEqual(len(threads), len(other_threads))
def testIgnoreMultiStarts(self):
with self.test_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册