Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
19ef8215
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
19ef8215
编写于
3月 15, 2017
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
3月 15, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make SavedModel exports include all the SAVEABLE objects and not just global variables.
Change: 150243023
上级
b05e0840
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
303 addition
and
77 deletion
+303
-77
tensorflow/contrib/learn/python/learn/estimators/estimator.py
...orflow/contrib/learn/python/learn/estimators/estimator.py
+6
-1
tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
...w/contrib/learn/python/learn/estimators/estimator_test.py
+87
-0
tensorflow/python/BUILD
tensorflow/python/BUILD
+12
-1
tensorflow/python/estimator/BUILD
tensorflow/python/estimator/BUILD
+1
-0
tensorflow/python/estimator/estimator.py
tensorflow/python/estimator/estimator.py
+1
-1
tensorflow/python/estimator/estimator_test.py
tensorflow/python/estimator/estimator_test.py
+59
-0
tensorflow/python/saved_model/BUILD
tensorflow/python/saved_model/BUILD
+1
-0
tensorflow/python/saved_model/builder_impl.py
tensorflow/python/saved_model/builder_impl.py
+4
-4
tensorflow/python/saved_model/saved_model_test.py
tensorflow/python/saved_model/saved_model_test.py
+30
-0
tensorflow/python/training/saver_test.py
tensorflow/python/training/saver_test.py
+16
-70
tensorflow/python/training/saver_test_utils.py
tensorflow/python/training/saver_test_utils.py
+86
-0
未找到文件。
tensorflow/contrib/learn/python/learn/estimators/estimator.py
浏览文件 @
19ef8215
...
...
@@ -58,6 +58,7 @@ from tensorflow.python.framework import random_seed
from
tensorflow.python.framework
import
sparse_tensor
from
tensorflow.python.ops
import
control_flow_ops
from
tensorflow.python.ops
import
data_flow_ops
from
tensorflow.python.ops
import
resources
from
tensorflow.python.ops
import
variables
from
tensorflow.python.platform
import
gfile
from
tensorflow.python.platform
import
tf_logging
as
logging
...
...
@@ -1254,13 +1255,17 @@ class Estimator(BaseEstimator):
with
tf_session
.
Session
(
''
)
as
session
:
variables
.
initialize_local_variables
()
data_flow_ops
.
tables_initializer
()
resources
.
initialize_resources
(
resources
.
shared_resources
())
saver_for_restore
=
saver
.
Saver
(
variables
.
global_variables
(),
# pylint: disable=protected-access
variables
.
_all_saveable_objects
(),
# pylint: enable=protected-access
sharded
=
True
)
saver_for_restore
.
restore
(
session
,
checkpoint_path
)
init_op
=
control_flow_ops
.
group
(
variables
.
local_variables_initializer
(),
resources
.
initialize_resources
(
resources
.
shared_resources
()),
data_flow_ops
.
tables_initializer
())
# Perform the export
...
...
tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
浏览文件 @
19ef8215
...
...
@@ -50,6 +50,7 @@ from tensorflow.python.framework import constant_op
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
ops
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
control_flow_ops
from
tensorflow.python.ops
import
data_flow_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
parsing_ops
...
...
@@ -225,6 +226,49 @@ def _build_estimator_for_export_tests(tmpdir):
return
est
,
serving_input_fn_with_asset
def
_build_estimator_for_resource_export_test
():
def
_input_fn
():
iris
=
base
.
load_iris
()
return
{
'feature'
:
constant_op
.
constant
(
iris
.
data
,
dtype
=
dtypes
.
float32
)
},
constant_op
.
constant
(
iris
.
target
,
shape
=
[
150
],
dtype
=
dtypes
.
int32
)
feature_columns
=
[
feature_column_lib
.
real_valued_column
(
'feature'
,
dimension
=
4
)
]
def
resource_constant_model_fn
(
unused_features
,
unused_labels
,
mode
):
"""A model_fn that loads a constant from a resource and serves it."""
assert
mode
in
(
model_fn
.
ModeKeys
.
TRAIN
,
model_fn
.
ModeKeys
.
EVAL
,
model_fn
.
ModeKeys
.
INFER
)
const
=
constant_op
.
constant
(
-
1
,
dtype
=
dtypes
.
int64
)
table
=
lookup
.
MutableHashTable
(
dtypes
.
string
,
dtypes
.
int64
,
const
,
name
=
'LookupTableModel'
)
if
mode
in
(
model_fn
.
ModeKeys
.
TRAIN
,
model_fn
.
ModeKeys
.
EVAL
):
key
=
constant_op
.
constant
([
'key'
])
value
=
constant_op
.
constant
([
42
],
dtype
=
dtypes
.
int64
)
train_op_1
=
table
.
insert
(
key
,
value
)
training_state
=
lookup
.
MutableHashTable
(
dtypes
.
string
,
dtypes
.
int64
,
const
,
name
=
'LookupTableTrainingState'
)
training_op_2
=
training_state
.
insert
(
key
,
value
)
return
const
,
const
,
control_flow_ops
.
group
(
train_op_1
,
training_op_2
)
if
mode
==
model_fn
.
ModeKeys
.
INFER
:
key
=
constant_op
.
constant
([
'key'
])
prediction
=
table
.
lookup
(
key
)
return
prediction
,
const
,
control_flow_ops
.
no_op
()
est
=
estimator
.
Estimator
(
model_fn
=
resource_constant_model_fn
)
est
.
fit
(
input_fn
=
_input_fn
,
steps
=
1
)
feature_spec
=
feature_column_lib
.
create_feature_spec_for_parsing
(
feature_columns
)
serving_input_fn
=
input_fn_utils
.
build_parsing_serving_input_fn
(
feature_spec
)
return
est
,
serving_input_fn
class
CheckCallsMonitor
(
monitors_lib
.
BaseMonitor
):
def
__init__
(
self
,
expect_calls
):
...
...
@@ -753,6 +797,49 @@ class EstimatorTest(test.TestCase):
# cleanup
gfile
.
DeleteRecursively
(
tmpdir
)
def
test_export_savedmodel_with_resource
(
self
):
tmpdir
=
tempfile
.
mkdtemp
()
est
,
serving_input_fn
=
_build_estimator_for_resource_export_test
()
export_dir_base
=
os
.
path
.
join
(
compat
.
as_bytes
(
tmpdir
),
compat
.
as_bytes
(
'export'
))
export_dir
=
est
.
export_savedmodel
(
export_dir_base
,
serving_input_fn
)
self
.
assertTrue
(
gfile
.
Exists
(
export_dir_base
))
self
.
assertTrue
(
gfile
.
Exists
(
export_dir
))
self
.
assertTrue
(
gfile
.
Exists
(
os
.
path
.
join
(
compat
.
as_bytes
(
export_dir
),
compat
.
as_bytes
(
'saved_model.pb'
))))
self
.
assertTrue
(
gfile
.
Exists
(
os
.
path
.
join
(
compat
.
as_bytes
(
export_dir
),
compat
.
as_bytes
(
'variables'
))))
self
.
assertTrue
(
gfile
.
Exists
(
os
.
path
.
join
(
compat
.
as_bytes
(
export_dir
),
compat
.
as_bytes
(
'variables/variables.index'
))))
self
.
assertTrue
(
gfile
.
Exists
(
os
.
path
.
join
(
compat
.
as_bytes
(
export_dir
),
compat
.
as_bytes
(
'variables/variables.data-00000-of-00001'
))))
# Restore, to validate that the export was well-formed.
with
ops
.
Graph
().
as_default
()
as
graph
:
with
session_lib
.
Session
(
graph
=
graph
)
as
sess
:
loader
.
load
(
sess
,
[
tag_constants
.
SERVING
],
export_dir
)
graph_ops
=
[
x
.
name
for
x
in
graph
.
get_operations
()]
self
.
assertTrue
(
'input_example_tensor'
in
graph_ops
)
self
.
assertTrue
(
'ParseExample/ParseExample'
in
graph_ops
)
self
.
assertTrue
(
'LookupTableModel'
in
graph_ops
)
self
.
assertFalse
(
'LookupTableTrainingState'
in
graph_ops
)
# cleanup
gfile
.
DeleteRecursively
(
tmpdir
)
class
InferRealValuedColumnsTest
(
test
.
TestCase
):
...
...
tensorflow/python/BUILD
浏览文件 @
19ef8215
...
...
@@ -74,6 +74,7 @@ py_library(
":tensor_array_ops"
,
":training"
,
":ops"
,
":saver_test_utils"
,
":test_ops"
,
# TODO: Break testing code out into separate rule.
":util"
,
":weights_broadcast_ops"
,
...
...
@@ -2935,6 +2936,16 @@ cuda_py_tests(
],
)
py_library
(
name
=
"saver_test_utils"
,
srcs
=
[
"training/saver_test_utils.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":data_flow_ops_gen"
,
":training"
,
],
)
cuda_py_test
(
name
=
"saver_test"
,
size
=
"medium"
,
...
...
@@ -2946,12 +2957,12 @@ cuda_py_test(
":client_testlib"
,
":control_flow_ops"
,
":data_flow_ops"
,
":data_flow_ops_gen"
,
":errors"
,
":gradients"
,
":math_ops"
,
":nn_grad"
,
":nn_ops"
,
":saver_test_utils"
,
":partitioned_variables"
,
":platform"
,
":platform_test"
,
...
...
tensorflow/python/estimator/BUILD
浏览文件 @
19ef8215
...
...
@@ -136,6 +136,7 @@ py_test(
"//tensorflow/python:control_flow_ops"
,
"//tensorflow/python:framework_for_generated_wrappers"
,
"//tensorflow/python:layers"
,
"//tensorflow/python:saver_test_utils"
,
"//tensorflow/python:session"
,
"//tensorflow/python:state_ops"
,
"//tensorflow/python:training"
,
...
...
tensorflow/python/estimator/estimator.py
浏览文件 @
19ef8215
...
...
@@ -411,7 +411,7 @@ class Estimator(object):
with
tf_session
.
Session
()
as
session
:
saver_for_restore
=
estimator_spec
.
scaffold
.
saver
or
saver
.
Saver
(
variables
.
global_variables
(),
variables
.
_all_saveable_objects
(),
# pylint: disable=protected-access
sharded
=
True
)
saver_for_restore
.
restore
(
session
,
checkpoint_path
)
...
...
tensorflow/python/estimator/estimator_test.py
浏览文件 @
19ef8215
...
...
@@ -48,6 +48,7 @@ from tensorflow.python.platform import tf_logging as logging
from
tensorflow.python.saved_model
import
loader
from
tensorflow.python.saved_model
import
tag_constants
from
tensorflow.python.training
import
saver
from
tensorflow.python.training
import
saver_test_utils
from
tensorflow.python.training
import
session_run_hook
from
tensorflow.python.training
import
training
from
tensorflow.python.util
import
compat
...
...
@@ -814,6 +815,20 @@ def _model_fn_for_export_tests(features, labels, mode):
'test'
:
export_output
.
ClassificationOutput
(
scores
,
classes
)})
def
_model_fn_with_saveables_for_export_tests
(
features
,
labels
,
mode
):
_
,
_
=
features
,
labels
table
=
saver_test_utils
.
CheckpointedOp
(
name
=
'v2'
)
train_op
=
table
.
insert
(
'k1'
,
30.0
)
prediction
=
table
.
lookup
(
'k1'
,
0.0
)
return
model_fn_lib
.
EstimatorSpec
(
mode
,
predictions
=
prediction
,
loss
=
constant_op
.
constant
(
1.
),
train_op
=
train_op
,
export_outputs
=
{
'test'
:
export_output
.
PredictOutput
({
'prediction'
:
prediction
})})
_VOCAB_FILE_CONTENT
=
'emerson
\n
lake
\n
palmer
\n
'
_EXTRA_FILE_CONTENT
=
'kermit
\n
piggy
\n
ralph
\n
'
...
...
@@ -863,6 +878,50 @@ class EstimatorExportTest(test.TestCase):
# Clean up.
gfile
.
DeleteRecursively
(
tmpdir
)
def
test_export_savedmodel_with_saveables_proto_roundtrip
(
self
):
tmpdir
=
tempfile
.
mkdtemp
()
est
=
estimator
.
Estimator
(
model_fn
=
_model_fn_with_saveables_for_export_tests
)
est
.
train
(
input_fn
=
dummy_input_fn
,
steps
=
1
)
feature_spec
=
{
'x'
:
parsing_ops
.
VarLenFeature
(
dtype
=
dtypes
.
int64
),
'y'
:
parsing_ops
.
VarLenFeature
(
dtype
=
dtypes
.
int64
)}
serving_input_receiver_fn
=
export
.
build_parsing_serving_input_receiver_fn
(
feature_spec
)
# Perform the export.
export_dir_base
=
os
.
path
.
join
(
compat
.
as_bytes
(
tmpdir
),
compat
.
as_bytes
(
'export'
))
export_dir
=
est
.
export_savedmodel
(
export_dir_base
,
serving_input_receiver_fn
)
# Check that all the files are in the right places.
self
.
assertTrue
(
gfile
.
Exists
(
export_dir_base
))
self
.
assertTrue
(
gfile
.
Exists
(
export_dir
))
self
.
assertTrue
(
gfile
.
Exists
(
os
.
path
.
join
(
compat
.
as_bytes
(
export_dir
),
compat
.
as_bytes
(
'saved_model.pb'
))))
self
.
assertTrue
(
gfile
.
Exists
(
os
.
path
.
join
(
compat
.
as_bytes
(
export_dir
),
compat
.
as_bytes
(
'variables'
))))
self
.
assertTrue
(
gfile
.
Exists
(
os
.
path
.
join
(
compat
.
as_bytes
(
export_dir
),
compat
.
as_bytes
(
'variables/variables.index'
))))
self
.
assertTrue
(
gfile
.
Exists
(
os
.
path
.
join
(
compat
.
as_bytes
(
export_dir
),
compat
.
as_bytes
(
'variables/variables.data-00000-of-00001'
))))
# Restore, to validate that the export was well-formed.
with
ops
.
Graph
().
as_default
()
as
graph
:
with
session
.
Session
(
graph
=
graph
)
as
sess
:
loader
.
load
(
sess
,
[
tag_constants
.
SERVING
],
export_dir
)
graph_ops
=
[
x
.
name
for
x
in
graph
.
get_operations
()]
self
.
assertTrue
(
'input_example_tensor'
in
graph_ops
)
self
.
assertTrue
(
'ParseExample/ParseExample'
in
graph_ops
)
self
.
assertTrue
(
'save/LookupTableImport'
in
graph_ops
)
# Clean up.
gfile
.
DeleteRecursively
(
tmpdir
)
def
test_export_savedmodel_assets
(
self
):
tmpdir
=
tempfile
.
mkdtemp
()
est
=
estimator
.
Estimator
(
model_fn
=
_model_fn_for_export_tests
)
...
...
tensorflow/python/saved_model/BUILD
浏览文件 @
19ef8215
...
...
@@ -122,6 +122,7 @@ py_test(
"//tensorflow/python:framework_for_generated_wrappers"
,
"//tensorflow/python:lib"
,
"//tensorflow/python:math_ops"
,
"//tensorflow/python:saver_test_utils"
,
"//tensorflow/python:state_ops"
,
"//tensorflow/python:util"
,
"//tensorflow/python:variables"
,
...
...
tensorflow/python/saved_model/builder_impl.py
浏览文件 @
19ef8215
...
...
@@ -352,10 +352,10 @@ class SavedModelBuilder(object):
else
:
self
.
_add_main_op
(
main_op
)
# Initialize a saver to generate a sharded output for all
vari
ables in the
# Initialize a saver to generate a sharded output for all
save
ables in the
# current scope.
saver
=
tf_saver
.
Saver
(
variables
.
global_variables
(),
variables
.
_all_saveable_objects
(),
# pylint: disable=protected-access
sharded
=
True
,
write_version
=
saver_pb2
.
SaverDef
.
V2
,
allow_empty
=
True
)
...
...
@@ -423,10 +423,10 @@ class SavedModelBuilder(object):
else
:
self
.
_add_main_op
(
main_op
)
# Initialize a saver to generate a sharded output for all
vari
ables in the
# Initialize a saver to generate a sharded output for all
save
ables in the
# current scope.
saver
=
tf_saver
.
Saver
(
variables
.
global_variables
(),
variables
.
_all_saveable_objects
(),
# pylint: disable=protected-access
sharded
=
True
,
write_version
=
saver_pb2
.
SaverDef
.
V2
,
allow_empty
=
True
)
...
...
tensorflow/python/saved_model/saved_model_test.py
浏览文件 @
19ef8215
...
...
@@ -39,6 +39,7 @@ from tensorflow.python.saved_model import loader
from
tensorflow.python.saved_model
import
main_op
from
tensorflow.python.saved_model
import
signature_def_utils
from
tensorflow.python.saved_model
import
tag_constants
from
tensorflow.python.training
import
saver_test_utils
from
tensorflow.python.util
import
compat
SAVED_MODEL_PATH
=
(
"cc/saved_model/testdata/half_plus_two/00000123"
)
...
...
@@ -734,6 +735,35 @@ class SavedModelTest(test.TestCase):
ops
.
get_collection
(
"init_op"
)[
0
].
run
()
self
.
assertEqual
(
3
,
ops
.
get_collection
(
"v"
)[
2
].
eval
())
def
testCustomSaveable
(
self
):
export_dir
=
os
.
path
.
join
(
test
.
get_temp_dir
(),
"custom_saveable"
)
builder
=
saved_model_builder
.
SavedModelBuilder
(
export_dir
)
with
session
.
Session
(
graph
=
ops
.
Graph
(),
config
=
config_pb2
.
ConfigProto
(
device_count
=
{
"CPU"
:
2
}))
as
sess
:
# CheckpointedOp is a key-value table that can be saved across sessions.
# The table register itself in SAVEABLE_OBJECTS collection.
v1
=
saver_test_utils
.
CheckpointedOp
(
name
=
"v1"
)
variables
.
global_variables_initializer
().
run
()
v1
.
insert
(
"k1"
,
3.0
).
run
()
# Once the table is restored, we can access it through this reference.
ops
.
add_to_collection
(
"table_ref"
,
v1
.
table_ref
)
builder
.
add_meta_graph_and_variables
(
sess
,
[
"foo"
])
# Save the SavedModel to disk.
builder
.
save
()
with
session
.
Session
(
graph
=
ops
.
Graph
(),
config
=
config_pb2
.
ConfigProto
(
device_count
=
{
"CPU"
:
2
}))
as
sess
:
loader
.
load
(
sess
,
[
"foo"
],
export_dir
)
# Instantiate a wrapper object from the checkpointed reference.
v1
=
saver_test_utils
.
CheckpointedOp
(
name
=
"v1"
,
table_ref
=
ops
.
get_collection
(
"table_ref"
)[
0
])
self
.
assertEqual
(
b
"k1"
,
v1
.
keys
().
eval
())
self
.
assertEqual
(
3.0
,
v1
.
values
().
eval
())
def
testClearDevices
(
self
):
export_dir
=
os
.
path
.
join
(
test
.
get_temp_dir
(),
"test_clear_devices"
)
builder
=
saved_model_builder
.
SavedModelBuilder
(
export_dir
)
...
...
tensorflow/python/training/saver_test.py
浏览文件 @
19ef8215
...
...
@@ -48,7 +48,6 @@ from tensorflow.python.framework import ops as ops_lib
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
control_flow_ops
from
tensorflow.python.ops
import
data_flow_ops
from
tensorflow.python.ops
import
gen_data_flow_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
nn_ops
from
tensorflow.python.ops
import
partitioned_variables
...
...
@@ -65,63 +64,10 @@ from tensorflow.python.training import adam
from
tensorflow.python.training
import
gradient_descent
from
tensorflow.python.training
import
queue_runner_impl
from
tensorflow.python.training
import
saver
as
saver_module
from
tensorflow.python.training
import
saver_test_utils
from
tensorflow.python.util
import
compat
class
CheckpointedOp
(
object
):
"""Op with a custom checkpointing implementation.
Defined as part of the test because the MutableHashTable Python code is
currently in contrib.
"""
def
__init__
(
self
,
name
):
self
.
_table_ref
=
gen_data_flow_ops
.
_mutable_hash_table
(
key_dtype
=
dtypes
.
string
,
value_dtype
=
dtypes
.
float32
,
name
=
name
)
self
.
_name
=
name
self
.
_saveable
=
CheckpointedOp
.
CustomSaveable
(
self
,
name
)
ops_lib
.
add_to_collection
(
ops_lib
.
GraphKeys
.
SAVEABLE_OBJECTS
,
self
.
_saveable
)
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
saveable
(
self
):
return
self
.
_saveable
def
insert
(
self
,
keys
,
values
):
return
gen_data_flow_ops
.
_lookup_table_insert
(
self
.
_table_ref
,
keys
,
values
)
def
keys
(
self
):
return
self
.
_export
()[
0
]
def
values
(
self
):
return
self
.
_export
()[
1
]
def
_export
(
self
):
return
gen_data_flow_ops
.
_lookup_table_export
(
self
.
_table_ref
,
dtypes
.
string
,
dtypes
.
float32
)
class
CustomSaveable
(
saver_module
.
BaseSaverBuilder
.
SaveableObject
):
def
__init__
(
self
,
table
,
name
):
tensors
=
table
.
_export
()
specs
=
[
saver_module
.
BaseSaverBuilder
.
SaveSpec
(
tensors
[
0
],
""
,
name
+
"-keys"
),
saver_module
.
BaseSaverBuilder
.
SaveSpec
(
tensors
[
1
],
""
,
name
+
"-values"
)
]
super
(
CheckpointedOp
.
CustomSaveable
,
self
).
__init__
(
table
,
specs
,
name
)
def
restore
(
self
,
restore_tensors
,
shapes
):
return
gen_data_flow_ops
.
_lookup_table_import
(
self
.
op
.
_table_ref
,
restore_tensors
[
0
],
restore_tensors
[
1
])
class
SaverTest
(
test
.
TestCase
):
def
basicSaveRestore
(
self
,
variable_op
):
...
...
@@ -131,7 +77,7 @@ class SaverTest(test.TestCase):
# Restore nodes for them.
v0
=
variable_op
(
10.0
,
name
=
"v0"
)
v1
=
variable_op
(
20.0
,
name
=
"v1"
)
v2
=
CheckpointedOp
(
name
=
"v2"
)
v2
=
saver_test_utils
.
CheckpointedOp
(
name
=
"v2"
)
v2_init
=
v2
.
insert
(
"k1"
,
30.0
)
save
=
saver_module
.
Saver
(
{
...
...
@@ -161,7 +107,7 @@ class SaverTest(test.TestCase):
with
self
.
test_session
()
as
sess
:
v0
=
variable_op
(
-
1.0
,
name
=
"v0"
)
v1
=
variable_op
(
-
1.0
,
name
=
"v1"
)
v2
=
CheckpointedOp
(
name
=
"v2"
)
v2
=
saver_test_utils
.
CheckpointedOp
(
name
=
"v2"
)
save
=
saver_module
.
Saver
({
"v0"
:
v0
,
"v1"
:
v1
,
"v2"
:
v2
.
saveable
})
# Assert that the variables are not initialized.
...
...
@@ -183,7 +129,7 @@ class SaverTest(test.TestCase):
with
self
.
test_session
()
as
sess
:
v0_2
=
variable_op
(
1000.0
,
name
=
"v0"
)
v1_2
=
variable_op
(
2000.0
,
name
=
"v1"
)
v2_2
=
CheckpointedOp
(
name
=
"v2"
)
v2_2
=
saver_test_utils
.
CheckpointedOp
(
name
=
"v2"
)
save2
=
saver_module
.
Saver
({
"v0"
:
v0_2
,
"v1"
:
v1_2
,
"v2"
:
v2_2
.
saveable
})
v2_2
.
insert
(
"k1000"
,
3000.0
).
run
()
variables
.
global_variables_initializer
().
run
()
...
...
@@ -276,7 +222,7 @@ class SaverTest(test.TestCase):
def
testSameName
(
self
):
with
ops_lib
.
Graph
().
as_default
():
v0
=
variables
.
Variable
([
10.0
],
name
=
"v0"
)
v2
=
CheckpointedOp
(
name
=
"v2"
)
v2
=
saver_test_utils
.
CheckpointedOp
(
name
=
"v2"
)
# Saving one variable under two names raises an error.
with
self
.
assertRaisesRegexp
(
...
...
@@ -299,7 +245,7 @@ class SaverTest(test.TestCase):
# Restore nodes for them.
v0
=
variables
.
Variable
(
10.0
,
name
=
"v0"
)
v1
=
variables
.
Variable
(
20.0
,
name
=
"v1"
)
v2
=
CheckpointedOp
(
name
=
"v2"
)
v2
=
saver_test_utils
.
CheckpointedOp
(
name
=
"v2"
)
v2_init
=
v2
.
insert
(
"k1"
,
30.0
)
save
=
saver_module
.
Saver
([
v0
,
v1
,
v2
.
saveable
])
variables
.
global_variables_initializer
().
run
()
...
...
@@ -321,7 +267,7 @@ class SaverTest(test.TestCase):
with
self
.
test_session
(
graph
=
ops_lib
.
Graph
())
as
sess
:
v0
=
variables
.
Variable
(
-
1.0
,
name
=
"v0"
)
v1
=
variables
.
Variable
(
-
1.0
,
name
=
"v1"
)
v2
=
CheckpointedOp
(
name
=
"v2"
)
v2
=
saver_test_utils
.
CheckpointedOp
(
name
=
"v2"
)
save
=
saver_module
.
Saver
([
v0
,
v1
,
v2
.
saveable
])
with
self
.
assertRaisesWithPredicateMatch
(
...
...
@@ -346,7 +292,7 @@ class SaverTest(test.TestCase):
with
self
.
test_session
(
graph
=
ops_lib
.
Graph
())
as
sess
:
v0_2
=
variables
.
Variable
(
1000.0
,
name
=
"v0"
)
v1_2
=
variables
.
Variable
(
2000.0
,
name
=
"v1"
)
v2_2
=
CheckpointedOp
(
name
=
"v2"
)
v2_2
=
saver_test_utils
.
CheckpointedOp
(
name
=
"v2"
)
save2
=
saver_module
.
Saver
([
v0_2
,
v1_2
,
v2_2
.
saveable
])
v2_2
.
insert
(
"k1000"
,
3000.0
).
run
()
variables
.
global_variables_initializer
().
run
()
...
...
@@ -418,7 +364,7 @@ class SaverTest(test.TestCase):
with
session
.
Session
(
""
,
graph
=
ops_lib
.
Graph
())
as
sess
:
one
=
variables
.
Variable
(
1.0
)
twos
=
variables
.
Variable
([
2.0
,
2.0
,
2.0
])
v2
=
CheckpointedOp
(
name
=
"v2"
)
v2
=
saver_test_utils
.
CheckpointedOp
(
name
=
"v2"
)
init
=
variables
.
global_variables_initializer
()
save
=
saver_module
.
Saver
()
init
.
run
()
...
...
@@ -428,7 +374,7 @@ class SaverTest(test.TestCase):
with
session
.
Session
(
""
,
graph
=
ops_lib
.
Graph
())
as
sess
:
one
=
variables
.
Variable
(
0.0
)
twos
=
variables
.
Variable
([
0.0
,
0.0
,
0.0
])
v2
=
CheckpointedOp
(
name
=
"v2"
)
v2
=
saver_test_utils
.
CheckpointedOp
(
name
=
"v2"
)
# Saver with no arg, defaults to 'all variables'.
save
=
saver_module
.
Saver
()
save
.
restore
(
sess
,
save_path
)
...
...
@@ -593,10 +539,10 @@ class SaveRestoreShardedTest(test.TestCase):
config
=
config_pb2
.
ConfigProto
(
device_count
=
{
"CPU"
:
2
}))
as
sess
:
with
sess
.
graph
.
device
(
"/cpu:0"
):
v0
=
variables
.
Variable
(
10
,
name
=
"v0"
)
t0
=
CheckpointedOp
(
name
=
"t0"
)
t0
=
saver_test_utils
.
CheckpointedOp
(
name
=
"t0"
)
with
sess
.
graph
.
device
(
"/cpu:1"
):
v1
=
variables
.
Variable
(
20
,
name
=
"v1"
)
t1
=
CheckpointedOp
(
name
=
"t1"
)
t1
=
saver_test_utils
.
CheckpointedOp
(
name
=
"t1"
)
save
=
saver_module
.
Saver
(
{
"v0"
:
v0
,
...
...
@@ -623,7 +569,7 @@ class SaveRestoreShardedTest(test.TestCase):
config
=
config_pb2
.
ConfigProto
(
device_count
=
{
"CPU"
:
2
}))
as
sess
:
with
sess
.
graph
.
device
(
"/cpu:0"
):
v0
=
variables
.
Variable
(
111
,
name
=
"v0"
)
t0
=
CheckpointedOp
(
name
=
"t0"
)
t0
=
saver_test_utils
.
CheckpointedOp
(
name
=
"t0"
)
save
=
saver_module
.
Saver
({
"v0"
:
v0
,
"t0"
:
t0
.
saveable
},
sharded
=
True
)
variables
.
global_variables_initializer
().
run
()
t0
.
insert
(
"k11"
,
33.0
).
run
()
...
...
@@ -641,7 +587,7 @@ class SaveRestoreShardedTest(test.TestCase):
config
=
config_pb2
.
ConfigProto
(
device_count
=
{
"CPU"
:
2
}))
as
sess
:
with
sess
.
graph
.
device
(
"/cpu:0"
):
v1
=
variables
.
Variable
(
222
)
t1
=
CheckpointedOp
(
name
=
"t1"
)
t1
=
saver_test_utils
.
CheckpointedOp
(
name
=
"t1"
)
save
=
saver_module
.
Saver
({
"v1"
:
v1
,
"t1"
:
t1
.
saveable
},
sharded
=
True
)
variables
.
global_variables_initializer
().
run
()
t1
.
insert
(
"k22"
,
44.0
).
run
()
...
...
@@ -659,10 +605,10 @@ class SaveRestoreShardedTest(test.TestCase):
config
=
config_pb2
.
ConfigProto
(
device_count
=
{
"CPU"
:
2
}))
as
sess
:
with
sess
.
graph
.
device
(
"/cpu:0"
):
v0
=
variables
.
Variable
(
111
,
name
=
"v0"
)
t0
=
CheckpointedOp
(
name
=
"t0"
)
t0
=
saver_test_utils
.
CheckpointedOp
(
name
=
"t0"
)
with
sess
.
graph
.
device
(
"/cpu:1"
):
v1
=
variables
.
Variable
(
222
,
name
=
"v1"
)
t1
=
CheckpointedOp
(
name
=
"t1"
)
t1
=
saver_test_utils
.
CheckpointedOp
(
name
=
"t1"
)
save
=
saver_module
.
Saver
(
{
"v0"
:
v0
,
...
...
tensorflow/python/training/saver_test_utils.py
0 → 100644
浏览文件 @
19ef8215
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Utility classes for testing checkpointing."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
ops
as
ops_lib
from
tensorflow.python.ops
import
gen_data_flow_ops
from
tensorflow.python.training
import
saver
as
saver_module
class
CheckpointedOp
(
object
):
"""Op with a custom checkpointing implementation.
Defined as part of the test because the MutableHashTable Python code is
currently in contrib.
"""
# pylint: disable=protected-access
def
__init__
(
self
,
name
,
table_ref
=
None
):
if
table_ref
is
None
:
self
.
table_ref
=
gen_data_flow_ops
.
_mutable_hash_table
(
key_dtype
=
dtypes
.
string
,
value_dtype
=
dtypes
.
float32
,
name
=
name
)
else
:
self
.
table_ref
=
table_ref
self
.
_name
=
name
self
.
_saveable
=
CheckpointedOp
.
CustomSaveable
(
self
,
name
)
ops_lib
.
add_to_collection
(
ops_lib
.
GraphKeys
.
SAVEABLE_OBJECTS
,
self
.
_saveable
)
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
saveable
(
self
):
return
self
.
_saveable
def
insert
(
self
,
keys
,
values
):
return
gen_data_flow_ops
.
_lookup_table_insert
(
self
.
table_ref
,
keys
,
values
)
def
lookup
(
self
,
keys
,
default
):
return
gen_data_flow_ops
.
_lookup_table_find
(
self
.
table_ref
,
keys
,
default
)
def
keys
(
self
):
return
self
.
_export
()[
0
]
def
values
(
self
):
return
self
.
_export
()[
1
]
def
_export
(
self
):
return
gen_data_flow_ops
.
_lookup_table_export
(
self
.
table_ref
,
dtypes
.
string
,
dtypes
.
float32
)
class
CustomSaveable
(
saver_module
.
BaseSaverBuilder
.
SaveableObject
):
"""A custom saveable for CheckpointedOp."""
def
__init__
(
self
,
table
,
name
):
tensors
=
table
.
_export
()
specs
=
[
saver_module
.
BaseSaverBuilder
.
SaveSpec
(
tensors
[
0
],
""
,
name
+
"-keys"
),
saver_module
.
BaseSaverBuilder
.
SaveSpec
(
tensors
[
1
],
""
,
name
+
"-values"
)
]
super
(
CheckpointedOp
.
CustomSaveable
,
self
).
__init__
(
table
,
specs
,
name
)
def
restore
(
self
,
restore_tensors
,
shapes
):
return
gen_data_flow_ops
.
_lookup_table_import
(
self
.
op
.
table_ref
,
restore_tensors
[
0
],
restore_tensors
[
1
])
# pylint: enable=protected-access
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录