提交 7ad0d069 编写于 作者: M Mustafa Ispir 提交者: TensorFlower Gardener

Add type error to start_queue_runners if given session is not a `tf.Session`....

Add type error to start_queue_runners if given session is not a `tf.Session`. Due to semver, we suppress the error if a MonitoredSession is provided.

PiperOrigin-RevId: 157748375
上级 7106f9fa
......@@ -22,6 +22,7 @@ import threading
import weakref
from tensorflow.core.protobuf import queue_runner_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
......@@ -401,6 +402,10 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
collection: A `GraphKey` specifying the graph collection to
get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`.
Raises:
ValueError: if `sess` is None and there isn't any default session.
TypeError: if `sess` is not a `tf.Session` object.
Returns:
A list of threads.
"""
......@@ -410,6 +415,15 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
raise ValueError("Cannot start queue runners: No default session is "
"registered. Use `with sess.as_default()` or pass an "
"explicit session to tf.start_queue_runners(sess=sess)")
if not isinstance(sess, session.SessionInterface):
# Following check is due to backward compatibility. (b/62061352)
if sess.__class__.__name__ in [
"MonitoredSession", "SingularMonitoredSession"]:
return []
raise TypeError("sess must be a `tf.Session` object. "
"Given class: {}".format(sess.__class__))
with sess.graph.as_default():
threads = []
for qr in ops.get_collection(collection):
......
......@@ -30,6 +30,7 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
from tensorflow.python.training import queue_runner_impl
......@@ -247,6 +248,33 @@ class QueueRunnerTest(test.TestCase):
# The variable should be 3.
self.assertEqual(3, var.eval())
def testStartQueueRunnersRaisesIfNotASession(self):
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
with self.test_session():
init_op.run()
with self.assertRaisesRegexp(TypeError, "tf.Session"):
queue_runner_impl.start_queue_runners("NotASession")
def testStartQueueRunnersIgnoresMonitoredSession(self):
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
with self.test_session():
init_op.run()
threads = queue_runner_impl.start_queue_runners(
monitored_session.MonitoredSession())
self.assertFalse(threads)
def testStartQueueRunnersNonDefaultGraph(self):
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
graph = ops.Graph()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册