Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
5ad6738c
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5ad6738c
编写于
10月 31, 2016
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
10月 31, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow a QueueRunner to create_threads on multiple sessions.
Change: 137701036
上级
57f42975
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
48 addition
and
30 deletion
+48
-30
tensorflow/contrib/learn/python/learn/dataframe/queues/feeding_queue_runner.py
...arn/python/learn/dataframe/queues/feeding_queue_runner.py
+14
-14
tensorflow/python/training/queue_runner.py
tensorflow/python/training/queue_runner.py
+19
-16
tensorflow/python/training/queue_runner_test.py
tensorflow/python/training/queue_runner_test.py
+15
-0
未找到文件。
tensorflow/contrib/learn/python/learn/dataframe/queues/feeding_queue_runner.py
浏览文件 @
5ad6738c
...
...
@@ -95,9 +95,9 @@ class FeedingQueueRunner(qr.QueueRunner):
except
(
errors
.
OutOfRangeError
,
errors
.
CancelledError
):
# This exception indicates that a queue was closed.
with
self
.
_lock
:
self
.
_runs
-=
1
self
.
_runs
_per_session
[
sess
]
-=
1
decremented
=
True
if
self
.
_runs
==
0
:
if
self
.
_runs
_per_session
[
sess
]
==
0
:
try
:
sess
.
run
(
self
.
_close_op
)
except
Exception
as
e
:
...
...
@@ -117,10 +117,10 @@ class FeedingQueueRunner(qr.QueueRunner):
# Make sure we account for all terminations: normal or errors.
if
not
decremented
:
with
self
.
_lock
:
self
.
_runs
-=
1
self
.
_runs
_per_session
[
sess
]
-=
1
def
create_threads
(
self
,
sess
,
coord
=
None
,
daemon
=
False
,
start
=
False
):
"""Create threads to run the enqueue ops.
"""Create threads to run the enqueue ops
for the given session
.
This method requires a session in which the graph was launched. It creates
a list of threads, optionally starting them. There is one thread for each
...
...
@@ -131,8 +131,8 @@ class FeedingQueueRunner(qr.QueueRunner):
this method starts an additional thread to close the queue when the
coordinator requests a stop.
This method may be called again as long as all threads from a previous call
have stopp
ed.
If previously created threads for the given session are still running, no
new threads will be creat
ed.
Args:
sess: A `Session`.
...
...
@@ -144,16 +144,16 @@ class FeedingQueueRunner(qr.QueueRunner):
Returns:
A list of threads.
Raises:
RuntimeError: If threads from a previous call to `create_threads()` are
still running.
"""
with
self
.
_lock
:
if
self
.
_runs
>
0
:
# Already started: no new threads to return.
return
[]
self
.
_runs
=
len
(
self
.
_enqueue_ops
)
try
:
if
self
.
_runs_per_session
[
sess
]
>
0
:
# Already started: no new threads to return.
return
[]
except
KeyError
:
# We haven't seen this session yet.
pass
self
.
_runs_per_session
[
sess
]
=
len
(
self
.
_enqueue_ops
)
self
.
_exceptions_raised
=
[]
ret_threads
=
[
threading
.
Thread
(
target
=
self
.
_run
,
...
...
tensorflow/python/training/queue_runner.py
浏览文件 @
5ad6738c
...
...
@@ -19,6 +19,7 @@ from __future__ import division
from
__future__
import
print_function
import
threading
import
weakref
from
tensorflow.core.protobuf
import
queue_runner_pb2
from
tensorflow.python.framework
import
errors
...
...
@@ -90,7 +91,9 @@ class QueueRunner(object):
queue_closed_exception_types
=
queue_closed_exception_types
)
# Protect the count of runs to wait for.
self
.
_lock
=
threading
.
Lock
()
self
.
_runs
=
0
# A map from a session object to the number of outstanding queue runner
# threads for that session.
self
.
_runs_per_session
=
weakref
.
WeakKeyDictionary
()
# List of exceptions raised by the running threads.
self
.
_exceptions_raised
=
[]
...
...
@@ -234,9 +237,9 @@ class QueueRunner(object):
except
self
.
_queue_closed_exception_types
:
# pylint: disable=catching-non-exception
# This exception indicates that a queue was closed.
with
self
.
_lock
:
self
.
_runs
-=
1
self
.
_runs
_per_session
[
sess
]
-=
1
decremented
=
True
if
self
.
_runs
==
0
:
if
self
.
_runs
_per_session
[
sess
]
==
0
:
try
:
sess
.
run
(
self
.
_close_op
)
except
Exception
as
e
:
...
...
@@ -256,7 +259,7 @@ class QueueRunner(object):
# Make sure we account for all terminations: normal or errors.
if
not
decremented
:
with
self
.
_lock
:
self
.
_runs
-=
1
self
.
_runs
_per_session
[
sess
]
-=
1
def
_close_on_stop
(
self
,
sess
,
cancel_op
,
coord
):
"""Close the queue when the Coordinator requests stop.
...
...
@@ -276,19 +279,19 @@ class QueueRunner(object):
# pylint: enable=broad-except
def
create_threads
(
self
,
sess
,
coord
=
None
,
daemon
=
False
,
start
=
False
):
"""Create threads to run the enqueue ops.
"""Create threads to run the enqueue ops
for the given session
.
This method requires a session in which the graph was launched. It creates
a list of threads, optionally starting them. There is one thread for each
op passed in `enqueue_ops`.
The `coord` argument is an optional coordinator
,
that the threads will use
The `coord` argument is an optional coordinator that the threads will use
to terminate together and report exceptions. If a coordinator is given,
this method starts an additional thread to close the queue when the
coordinator requests a stop.
This method may be called again as long as all threads from a previous call
have stopp
ed.
If previously created threads for the given session are still running, no
new threads will be creat
ed.
Args:
sess: A `Session`.
...
...
@@ -300,16 +303,16 @@ class QueueRunner(object):
Returns:
A list of threads.
Raises:
RuntimeError: If threads from a previous call to `create_threads()` are
still running.
"""
with
self
.
_lock
:
if
self
.
_runs
>
0
:
# Already started: no new threads to return.
return
[]
self
.
_runs
=
len
(
self
.
_enqueue_ops
)
try
:
if
self
.
_runs_per_session
[
sess
]
>
0
:
# Already started: no new threads to return.
return
[]
except
KeyError
:
# We haven't seen this session yet.
pass
self
.
_runs_per_session
[
sess
]
=
len
(
self
.
_enqueue_ops
)
self
.
_exceptions_raised
=
[]
ret_threads
=
[
threading
.
Thread
(
target
=
self
.
_run
,
args
=
(
sess
,
op
,
coord
))
...
...
tensorflow/python/training/queue_runner_test.py
浏览文件 @
5ad6738c
...
...
@@ -173,6 +173,21 @@ class QueueRunnerTest(tf.test.TestCase):
# the queue to be closed and the enqueue to terminate.
coord
.
join
(
stop_grace_period_secs
=
0.05
)
def
testMultipleSessions
(
self
):
with
self
.
test_session
()
as
sess
:
with
tf
.
Session
()
as
other_sess
:
zero64
=
tf
.
constant
(
0
,
dtype
=
tf
.
int64
)
var
=
tf
.
Variable
(
zero64
)
count_up_to
=
var
.
count_up_to
(
3
)
queue
=
tf
.
FIFOQueue
(
10
,
tf
.
float32
)
tf
.
initialize_all_variables
().
run
()
coord
=
tf
.
train
.
Coordinator
()
qr
=
tf
.
train
.
QueueRunner
(
queue
,
[
count_up_to
])
# NOTE that this test does not actually start the threads.
threads
=
qr
.
create_threads
(
sess
,
coord
=
coord
)
other_threads
=
qr
.
create_threads
(
other_sess
,
coord
=
coord
)
self
.
assertEqual
(
len
(
threads
),
len
(
other_threads
))
def
testIgnoreMultiStarts
(
self
):
with
self
.
test_session
()
as
sess
:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录