Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
4ecd2a70
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4ecd2a70
编写于
2月 24, 2016
作者:
S
Sherry Moore
提交者:
TensorFlower Gardener
2月 25, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added unit test for max_to_keep being None.
Change: 115516426
上级
77da168d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
47 addition
and
41 deletion
+47
-41
tensorflow/python/training/saver_test.py
tensorflow/python/training/saver_test.py
+47
-41
未找到文件。
tensorflow/python/training/saver_test.py
浏览文件 @
4ecd2a70
...
...
@@ -37,6 +37,14 @@ from tensorflow.python.framework import function
from
tensorflow.python.platform
import
gfile
def
_TestDir
(
test_name
):
test_dir
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
test_name
)
if
os
.
path
.
exists
(
test_dir
):
shutil
.
rmtree
(
test_dir
)
gfile
.
MakeDirs
(
test_dir
)
return
test_dir
class
SaverTest
(
tf
.
test
.
TestCase
):
def
testBasics
(
self
):
...
...
@@ -349,12 +357,7 @@ class SaveRestoreShardedTest(tf.test.TestCase):
class
MaxToKeepTest
(
tf
.
test
.
TestCase
):
def
testNonSharded
(
self
):
save_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"max_to_keep_non_sharded"
)
try
:
gfile
.
DeleteRecursively
(
save_dir
)
except
OSError
:
pass
# Ignore
gfile
.
MakeDirs
(
save_dir
)
save_dir
=
_TestDir
(
"max_to_keep_non_sharded"
)
with
self
.
test_session
()
as
sess
:
v
=
tf
.
Variable
(
10.0
,
name
=
"v"
)
...
...
@@ -456,12 +459,7 @@ class MaxToKeepTest(tf.test.TestCase):
self
.
assertTrue
(
gfile
.
Exists
(
save
.
_MetaGraphFilename
(
s1
)))
def
testSharded
(
self
):
save_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"max_to_keep_sharded"
)
try
:
gfile
.
DeleteRecursively
(
save_dir
)
except
OSError
:
pass
# Ignore
gfile
.
MakeDirs
(
save_dir
)
save_dir
=
_TestDir
(
"max_to_keep_sharded"
)
with
tf
.
Session
(
target
=
""
,
...
...
@@ -495,17 +493,39 @@ class MaxToKeepTest(tf.test.TestCase):
self
.
assertEqual
(
2
,
len
(
gfile
.
Glob
(
s3
)))
self
.
assertTrue
(
gfile
.
Exists
(
save
.
_MetaGraphFilename
(
s3
)))
def
testNoMaxToKeep
(
self
):
save_dir
=
_TestDir
(
"no_max_to_keep"
)
save_dir2
=
_TestDir
(
"max_to_keep_0"
)
with
self
.
test_session
()
as
sess
:
v
=
tf
.
Variable
(
10.0
,
name
=
"v"
)
tf
.
initialize_all_variables
().
run
()
# Test max_to_keep being None.
save
=
tf
.
train
.
Saver
({
"v"
:
v
},
max_to_keep
=
None
)
self
.
assertEqual
([],
save
.
last_checkpoints
)
s1
=
save
.
save
(
sess
,
os
.
path
.
join
(
save_dir
,
"s1"
))
self
.
assertEqual
([],
save
.
last_checkpoints
)
self
.
assertTrue
(
gfile
.
Exists
(
s1
))
s2
=
save
.
save
(
sess
,
os
.
path
.
join
(
save_dir
,
"s2"
))
self
.
assertEqual
([],
save
.
last_checkpoints
)
self
.
assertTrue
(
gfile
.
Exists
(
s2
))
# Test max_to_keep being 0.
save2
=
tf
.
train
.
Saver
({
"v"
:
v
},
max_to_keep
=
0
)
self
.
assertEqual
([],
save2
.
last_checkpoints
)
s1
=
save2
.
save
(
sess
,
os
.
path
.
join
(
save_dir2
,
"s1"
))
self
.
assertEqual
([],
save2
.
last_checkpoints
)
self
.
assertTrue
(
gfile
.
Exists
(
s1
))
s2
=
save2
.
save
(
sess
,
os
.
path
.
join
(
save_dir2
,
"s2"
))
self
.
assertEqual
([],
save2
.
last_checkpoints
)
self
.
assertTrue
(
gfile
.
Exists
(
s2
))
class
KeepCheckpointEveryNHoursTest
(
tf
.
test
.
TestCase
):
def
testNonSharded
(
self
):
save_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"keep_checkpoint_every_n_hours"
)
try
:
gfile
.
DeleteRecursively
(
save_dir
)
except
OSError
:
pass
# Ignore
gfile
.
MakeDirs
(
save_dir
)
save_dir
=
_TestDir
(
"keep_checkpoint_every_n_hours"
)
with
self
.
test_session
()
as
sess
:
v
=
tf
.
Variable
([
10.0
],
name
=
"v"
)
...
...
@@ -685,15 +705,8 @@ class LatestCheckpointWithRelativePaths(tf.test.TestCase):
class
CheckpointStateTest
(
tf
.
test
.
TestCase
):
def
_TestDir
(
self
,
test_name
):
test_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
test_name
)
if
os
.
path
.
exists
(
test_dir
):
shutil
.
rmtree
(
test_dir
)
gfile
.
MakeDirs
(
test_dir
)
return
test_dir
def
testAbsPath
(
self
):
save_dir
=
self
.
_TestDir
(
"abs_paths"
)
save_dir
=
_TestDir
(
"abs_paths"
)
abs_path
=
os
.
path
.
join
(
save_dir
,
"model-0"
)
ckpt
=
tf
.
train
.
generate_checkpoint_state_proto
(
save_dir
,
abs_path
)
self
.
assertEqual
(
ckpt
.
model_checkpoint_path
,
abs_path
)
...
...
@@ -712,7 +725,7 @@ class CheckpointStateTest(tf.test.TestCase):
self
.
assertEqual
(
ckpt
.
all_model_checkpoint_paths
[
-
1
],
new_rel_path
)
def
testAllModelCheckpointPaths
(
self
):
save_dir
=
self
.
_TestDir
(
"all_models_test"
)
save_dir
=
_TestDir
(
"all_models_test"
)
abs_path
=
os
.
path
.
join
(
save_dir
,
"model-0"
)
for
paths
in
[
None
,
[],
[
"model-2"
]]:
ckpt
=
tf
.
train
.
generate_checkpoint_state_proto
(
...
...
@@ -726,7 +739,7 @@ class CheckpointStateTest(tf.test.TestCase):
self
.
assertEqual
(
ckpt
.
all_model_checkpoint_paths
[
-
1
],
abs_path
)
def
testUpdateCheckpointState
(
self
):
save_dir
=
self
.
_TestDir
(
"update_checkpoint_state"
)
save_dir
=
_TestDir
(
"update_checkpoint_state"
)
os
.
chdir
(
save_dir
)
# Make a temporary train directory.
train_dir
=
"train"
...
...
@@ -746,15 +759,8 @@ class CheckpointStateTest(tf.test.TestCase):
class
MetaGraphTest
(
tf
.
test
.
TestCase
):
def
_TestDir
(
self
,
test_name
):
test_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
test_name
)
if
os
.
path
.
exists
(
test_dir
):
shutil
.
rmtree
(
test_dir
)
gfile
.
MakeDirs
(
test_dir
)
return
test_dir
def
testAddCollectionDef
(
self
):
test_dir
=
self
.
_TestDir
(
"good_collection"
)
test_dir
=
_TestDir
(
"good_collection"
)
filename
=
os
.
path
.
join
(
test_dir
,
"metafile"
)
with
self
.
test_session
():
# Creates a graph.
...
...
@@ -819,7 +825,7 @@ class MetaGraphTest(tf.test.TestCase):
self
.
assertEqual
(
len
(
meta_graph_def
.
collection_def
),
0
)
def
_testMultiSaverCollectionSave
(
self
):
test_dir
=
self
.
_TestDir
(
"saver_collection"
)
test_dir
=
_TestDir
(
"saver_collection"
)
filename
=
os
.
path
.
join
(
test_dir
,
"metafile"
)
saver0_ckpt
=
os
.
path
.
join
(
test_dir
,
"saver0.ckpt"
)
saver1_ckpt
=
os
.
path
.
join
(
test_dir
,
"saver1.ckpt"
)
...
...
@@ -894,7 +900,7 @@ class MetaGraphTest(tf.test.TestCase):
self
.
_testMultiSaverCollectionRestore
()
def
testBinaryAndTextFormat
(
self
):
test_dir
=
self
.
_TestDir
(
"binary_and_text"
)
test_dir
=
_TestDir
(
"binary_and_text"
)
filename
=
os
.
path
.
join
(
test_dir
,
"metafile"
)
with
self
.
test_session
(
graph
=
tf
.
Graph
()):
# Creates a graph.
...
...
@@ -924,7 +930,7 @@ class MetaGraphTest(tf.test.TestCase):
tf
.
train
.
import_meta_graph
(
filename
)
def
testSliceVariable
(
self
):
test_dir
=
self
.
_TestDir
(
"slice_saver"
)
test_dir
=
_TestDir
(
"slice_saver"
)
filename
=
os
.
path
.
join
(
test_dir
,
"metafile"
)
with
self
.
test_session
():
v1
=
tf
.
Variable
([
20.0
],
name
=
"v1"
)
...
...
@@ -946,7 +952,7 @@ class MetaGraphTest(tf.test.TestCase):
self
.
assertProtoEquals
(
meta_graph_def
,
new_meta_graph_def
)
def
_testGraphExtensionSave
(
self
):
test_dir
=
self
.
_TestDir
(
"graph_extension"
)
test_dir
=
_TestDir
(
"graph_extension"
)
filename
=
os
.
path
.
join
(
test_dir
,
"metafile"
)
saver0_ckpt
=
os
.
path
.
join
(
test_dir
,
"saver0.ckpt"
)
with
self
.
test_session
(
graph
=
tf
.
Graph
())
as
sess
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录