Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
8596df45
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
8596df45
编写于
12月 13, 2018
作者:
K
Katherine Wu
提交者:
TensorFlower Gardener
12月 13, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add `serving_only` option to save_keras_model, allowing subclassed models to be saved.
PiperOrigin-RevId: 225402096
上级
99313dd8
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
184 addition
and
71 deletion
+184
-71
tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
...ntrib/saved_model/python/saved_model/keras_saved_model.py
+112
-62
tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
.../saved_model/python/saved_model/keras_saved_model_test.py
+72
-9
未找到文件。
tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
浏览文件 @
8596df45
...
...
@@ -22,53 +22,57 @@ import os
import
six
from
tensorflow.python.client
import
session
from
tensorflow.python.estimator
import
keras
as
estimator_keras_util
from
tensorflow.python.estimator
import
model_fn
as
model_fn_lib
from
tensorflow.python.estimator.export
import
export
as
export_helpers
from
tensorflow.python.framework
import
ops
from
tensorflow.python.keras
import
backend
as
K
from
tensorflow.python.keras
import
models
as
models_lib
from
tensorflow.python.keras
import
optimizers
from
tensorflow.python.keras.engine
import
sequential
from
tensorflow.python.keras.engine
import
training_utils
from
tensorflow.python.keras.metrics
import
Metric
from
tensorflow.python.keras.models
import
model_from_json
from
tensorflow.python.lib.io
import
file_io
from
tensorflow.python.ops
import
variables
from
tensorflow.python.platform
import
gfile
from
tensorflow.python.platform
import
tf_logging
as
logging
from
tensorflow.python.saved_model
import
builder
as
saved_model_builder
from
tensorflow.python.saved_model
import
constants
from
tensorflow.python.saved_model
import
save
as
save_lib
from
tensorflow.python.saved_model
import
utils_impl
as
saved_model_utils
from
tensorflow.python.training
import
saver
as
saver_lib
from
tensorflow.python.training.checkpointable
import
util
as
checkpointable_utils
from
tensorflow.python.util
import
compat
from
tensorflow.python.util
import
nest
from
tensorflow_estimator.python.estimator
import
keras
as
estimator_keras_util
from
tensorflow_estimator.python.estimator
import
model_fn
as
model_fn_lib
from
tensorflow_estimator.python.estimator.export
import
export
as
export_helpers
def
save_keras_model
(
model
,
saved_model_path
,
custom_objects
=
None
,
as_text
=
None
):
"""Save a `tf.keras.Model` into Tensorflow SavedModel format.
model
,
saved_model_path
,
custom_objects
=
None
,
as_text
=
None
,
input_signature
=
None
,
serving_only
=
False
):
"""Saves a `tf.keras.Model` into Tensorflow SavedModel format.
`save_model` generates new files/folders under the `saved_model_path` folder:
1) an asset folder containing the json string of the model's
configuration (topology).
2) a checkpoint containing the model weights.
3) a saved_model.pb file containing the model's MetaGraphs. The prediction
1) a checkpoint containing the model weights.
2) a saved_model.pb file containing the model's MetaGraphs. The prediction
graph is always exported. The evaluaton and training graphs are exported
if the following conditions are met:
- Evaluation: model loss is defined.
- Training: model is compiled with an optimizer defined under `tf.train`.
This is because `tf.keras.optimizers.Optimizer` instances cannot be
saved to checkpoints.
Model Requirements:
- Model must be a sequential model or functional model. Subclassed models can
not be saved via this function, unless you provide an implementation for
get_config() and from_config().
- All variables must be saveable by the model. In general, this condition is
met through the use of layers defined in the keras library. However,
there is currently a bug with variables created in Lambda layer functions
not being saved correctly (see
https://github.com/keras-team/keras/issues/9740).
3) Model's json configuration, if model.get_config() has been implemented.
This file can be used to reload the model using
tf.keras.models.model_from_json(). Note that if any custom objects were
used, they should be passed to the `custom_object` argument when loading
the model.
Model limitations:
- Sequential and functional models can always be saved.
- Subclassed models can only be saved when `serving_only=True`. This is due to
the current implementation copying the model in order to export the training
and evaluation graphs. Because the topology of subclassed models cannot be
determined, the subclassed models cannot be cloned. Subclassed models will
be entirely exportable in the future.
Note that each mode is exported in separate graphs, so different modes do not
share variables. To use the train graph with evaluation or prediction graphs,
...
...
@@ -94,38 +98,88 @@ def save_keras_model(
```
Args:
model: A `tf.keras.Model` to be saved.
model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag
`serving_only` must be set to True.
saved_model_path: a string specifying the path to the SavedModel directory.
The SavedModel will be saved to a timestamped folder created within this
directory.
custom_objects: Optional dictionary mapping string names to custom classes
or functions (e.g. custom loss functions).
as_text: whether to write the `SavedModel` proto in text format.
as_text: whether to write the `SavedModel` proto in text format. Currently
unavailable in serving-only mode.
input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used
to specify the expected model inputs. `input_signature`'s nested structure
should match the expected nested structure of the inputs to the model. If
this is not set, this function will attempt to infer the input shapes and
dtypes from the model. Note that if the model is subclassed, the tensor
inputs to the call function should be nested in the first argument (this
is a general requirement for using subclassed models with Keras functions
.fit(), .predict(), etc.).
serving_only: Export only the outputs produced from calling the model in
predict mode. The losses, optimizer, and other training configurations are
not saved. If the SavedModel will only be used for serving (rather than
retraining), or if the model is subclassed, this can be set to True.
Returns:
String path to the SavedModel folder, a subdirectory of `saved_model_path`.
Raises:
NotImplementedError: If the model is a subclassed model
.
ValueError: If a Sequential model does not have input shapes defined by the
user, and is not built
.
NotImplementedError: If the model is a subclassed model
, and serving_only is
False.
ValueError: If the input signature cannot be inferred from the model
.
"""
export_dir
=
export_helpers
.
get_timestamped_export_dir
(
saved_model_path
)
if
serving_only
:
save_lib
.
save
(
model
,
export_dir
,
signatures
=
training_utils
.
trace_model_call
(
model
,
input_signature
))
else
:
_save_v1_format
(
model
,
export_dir
,
custom_objects
,
as_text
,
input_signature
)
try
:
_export_model_json
(
model
,
export_dir
)
except
NotImplementedError
:
logging
.
warning
(
'Skipped saving model JSON, subclassed model does not have '
'get_config() defined.'
)
return
export_dir
def
_export_model_json
(
model
,
saved_model_path
):
"""Saves model configuration as a json string under assets folder."""
model_json
=
model
.
to_json
()
model_json_filepath
=
os
.
path
.
join
(
saved_model_utils
.
get_or_create_assets_dir
(
saved_model_path
),
compat
.
as_text
(
constants
.
SAVED_MODEL_FILENAME_JSON
))
file_io
.
write_string_to_file
(
model_json_filepath
,
model_json
)
def
_export_model_variables
(
model
,
saved_model_path
):
"""Saves model weights in checkpoint format under variables folder."""
saved_model_utils
.
get_or_create_variables_dir
(
saved_model_path
)
checkpoint_prefix
=
saved_model_utils
.
get_variables_path
(
saved_model_path
)
model
.
save_weights
(
checkpoint_prefix
,
save_format
=
'tf'
,
overwrite
=
True
)
return
checkpoint_prefix
def
_save_v1_format
(
model
,
path
,
custom_objects
,
as_text
,
input_signature
):
"""Exports model to v1 SavedModel format."""
if
not
model
.
_is_graph_network
:
if
isinstance
(
model
,
sequential
.
Sequential
):
# If input shape is not directly set in the model, the exported model
# will assume that the inputs have the same shape as the shape the model
# was built model with.
if
not
model
.
built
:
# will infer the expected shapes of the input from the model.
if
not
model
.
built
and
input_signature
is
None
:
raise
ValueError
(
'Sequential model must be built before it can be exported.'
)
'Sequential model
\'
s input shape is unknown. Please build the '
'model, or use the input_signature argument to specify the '
'model inputs.'
)
else
:
raise
NotImplementedError
(
'Exporting subclassed models is not yet supported.'
)
'Subclassed models can only be exported for serving. Please set '
'argument serving_only=True.'
)
export_dir
=
export_helpers
.
get_timestamped_export_dir
(
saved_model_path
)
temp_export_dir
=
export_helpers
.
get_temp_export_dir
(
export_dir
)
builder
=
saved_model_builder
.
_SavedModelBuilder
(
temp_export_dir
)
builder
=
saved_model_builder
.
_SavedModelBuilder
(
path
)
# Manually save variables to export them in an object-based checkpoint. This
# skips the `builder.add_meta_graph_and_variables()` step, which saves a
...
...
@@ -133,7 +187,7 @@ def save_keras_model(
# TODO(b/113134168): Add fn to Builder to save with object-based saver.
# TODO(b/113178242): This should only export the model json structure. Only
# one save is needed once the weights can be copied from the model to clone.
checkpoint_path
=
_export_model_
json_and_variables
(
model
,
temp_export_dir
)
checkpoint_path
=
_export_model_
variables
(
model
,
path
)
# Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
# Keras models and `Estimator`s are exported with the same format.
...
...
@@ -143,10 +197,12 @@ def save_keras_model(
export_args
=
{
'builder'
:
builder
,
'model'
:
model
,
'custom_objects'
:
custom_objects
,
'checkpoint_path'
:
checkpoint_path
}
'checkpoint_path'
:
checkpoint_path
,
'input_signature'
:
input_signature
}
has_saved_vars
=
False
if
model
.
optimizer
:
# TODO(kathywu): Verify this works with v2 optimizer.
if
isinstance
(
model
.
optimizer
,
optimizers
.
TFOptimizer
):
_export_mode
(
model_fn_lib
.
ModeKeys
.
TRAIN
,
has_saved_vars
,
**
export_args
)
has_saved_vars
=
True
...
...
@@ -161,34 +217,20 @@ def save_keras_model(
builder
.
save
(
as_text
)
gfile
.
Rename
(
temp_export_dir
,
export_dir
)
return
export_dir
def
_export_model_json_and_variables
(
model
,
saved_model_path
):
"""Save model variables and json structure into SavedModel subdirectories."""
# Save model configuration as a json string under assets folder.
model_json
=
model
.
to_json
()
model_json_filepath
=
os
.
path
.
join
(
saved_model_utils
.
get_or_create_assets_dir
(
saved_model_path
),
compat
.
as_text
(
constants
.
SAVED_MODEL_FILENAME_JSON
))
file_io
.
write_string_to_file
(
model_json_filepath
,
model_json
)
# Save model weights in checkpoint format under variables folder.
saved_model_utils
.
get_or_create_variables_dir
(
saved_model_path
)
checkpoint_prefix
=
saved_model_utils
.
get_variables_path
(
saved_model_path
)
model
.
save_weights
(
checkpoint_prefix
,
save_format
=
'tf'
,
overwrite
=
True
)
return
checkpoint_prefix
def
_get_var_list
(
model
):
"""Return list of all checkpointed saveable objects in the model."""
"""Return
s
list of all checkpointed saveable objects in the model."""
return
checkpointable_utils
.
named_saveables
(
model
)
def
create_placeholder
(
spec
):
return
K
.
placeholder
(
shape
=
spec
.
shape
,
dtype
=
spec
.
dtype
,
name
=
spec
.
name
)
def
_export_mode
(
mode
,
has_saved_vars
,
builder
,
model
,
custom_objects
,
checkpoint_path
):
"""Export a model, and optionally save new vars from the clone model.
mode
,
has_saved_vars
,
builder
,
model
,
custom_objects
,
checkpoint_path
,
input_signature
):
"""Exports a model, and optionally saves new vars from the clone model.
Args:
mode: A `tf.estimator.ModeKeys` string.
...
...
@@ -199,6 +241,8 @@ def _export_mode(
custom_objects: A dictionary mapping string names to custom classes
or functions.
checkpoint_path: String path to checkpoint.
input_signature: Nested TensorSpec containing the expected inputs. Can be
`None`, in which case the signature will be inferred from the model.
Raises:
ValueError: If the train/eval mode is being exported, but the model does
...
...
@@ -214,10 +258,16 @@ def _export_mode(
K
.
set_learning_phase
(
mode
==
model_fn_lib
.
ModeKeys
.
TRAIN
)
if
input_signature
is
None
:
input_tensors
=
None
else
:
input_tensors
=
nest
.
map_structure
(
create_placeholder
,
input_signature
)
# Clone the model into blank graph. This will create placeholders for inputs
# and targets.
clone
=
models_lib
.
clone_and_build_model
(
model
,
custom_objects
=
custom_objects
,
compile_clone
=
compile_clone
)
model
,
input_tensors
=
input_tensors
,
custom_objects
=
custom_objects
,
compile_clone
=
compile_clone
)
# Make sure that iterations variable is added to the global step collection,
# to ensure that, when the SavedModel graph is loaded, the iterations
...
...
@@ -271,7 +321,7 @@ def _export_mode(
def
_create_signature_def_map
(
model
,
mode
):
"""Create a SignatureDef map from a Keras model."""
"""Create
s
a SignatureDef map from a Keras model."""
inputs_dict
=
{
name
:
x
for
name
,
x
in
zip
(
model
.
input_names
,
model
.
inputs
)}
if
model
.
optimizer
:
targets_dict
=
{
x
.
name
.
split
(
':'
)[
0
]:
x
...
...
@@ -309,14 +359,14 @@ def _create_signature_def_map(model, mode):
def
_assert_same_non_optimizer_objects
(
model
,
model_graph
,
clone
,
clone_graph
):
# pylint: disable=unused-argument
"""Assert model and clone contain the same checkpointable objects."""
"""Assert
s
model and clone contain the same checkpointable objects."""
# TODO(fchollet, kathywu): make sure this works in eager mode.
return
True
def
load_keras_model
(
saved_model_path
):
"""Load a keras.Model from SavedModel.
"""Load
s
a keras.Model from SavedModel.
load_model reinstantiates model state by:
1) loading model topology from json (this will eventually come
...
...
tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
浏览文件 @
8596df45
...
...
@@ -29,7 +29,9 @@ from tensorflow.python import keras
from
tensorflow.python.client
import
session
from
tensorflow.python.eager
import
context
from
tensorflow.python.estimator
import
model_fn
as
model_fn_lib
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
tensor_spec
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.keras.engine
import
training
from
tensorflow.python.keras.utils
import
tf_utils
...
...
@@ -215,7 +217,7 @@ class LayerWithLearningPhase(keras.engine.base_layer.Layer):
return
input_shape
def
functional_model
(
uses_learning_phase
):
def
functional_model
(
uses_learning_phase
=
True
):
inputs
=
keras
.
layers
.
Input
(
shape
=
(
3
,))
x
=
keras
.
layers
.
Dense
(
2
)(
inputs
)
x
=
keras
.
layers
.
Dense
(
3
)(
x
)
...
...
@@ -224,7 +226,7 @@ def functional_model(uses_learning_phase):
return
keras
.
models
.
Model
(
inputs
,
x
)
def
sequential_model
(
uses_learning_phase
):
def
sequential_model
(
uses_learning_phase
=
True
):
model
=
keras
.
models
.
Sequential
()
model
.
add
(
keras
.
layers
.
Dense
(
2
,
input_shape
=
(
3
,)))
model
.
add
(
keras
.
layers
.
Dense
(
3
))
...
...
@@ -233,7 +235,7 @@ def sequential_model(uses_learning_phase):
return
model
def
sequential_model_without_input_shape
(
uses_learning_phase
):
def
sequential_model_without_input_shape
(
uses_learning_phase
=
True
):
model
=
keras
.
models
.
Sequential
()
model
.
add
(
keras
.
layers
.
Dense
(
2
))
model
.
add
(
keras
.
layers
.
Dense
(
3
))
...
...
@@ -242,10 +244,30 @@ def sequential_model_without_input_shape(uses_learning_phase):
return
model
class
Subclassed
(
keras
.
models
.
Model
):
def
__init__
(
self
):
super
(
Subclassed
,
self
).
__init__
()
self
.
dense1
=
keras
.
layers
.
Dense
(
2
)
self
.
dense2
=
keras
.
layers
.
Dense
(
3
)
def
call
(
self
,
inputs
):
x
=
self
.
dense1
(
inputs
)
x
=
self
.
dense2
(
x
)
return
x
def
subclassed_model
():
return
Subclassed
()
def
load_model
(
sess
,
path
,
mode
):
tags
=
model_fn_lib
.
EXPORT_TAG_MAP
[
mode
]
sig_def_key
=
(
signature_constants
.
DEFAULT_SERVING_SIGNATURE_DEF_KEY
if
mode
==
model_fn_lib
.
ModeKeys
.
PREDICT
else
mode
)
if
mode
==
model_fn_lib
.
ModeKeys
.
PREDICT
:
sig_def_key
=
signature_constants
.
DEFAULT_SERVING_SIGNATURE_DEF_KEY
else
:
sig_def_key
=
mode
meta_graph_def
=
loader_impl
.
load
(
sess
,
tags
,
path
)
inputs
=
{
k
:
sess
.
graph
.
get_tensor_by_name
(
v
.
name
)
...
...
@@ -463,13 +485,54 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
clone
.
compile
(
loss
=
'mse'
,
optimizer
=
keras
.
optimizers
.
RMSprop
(
lr
=
0.0001
))
clone
.
train_on_batch
(
input_arr
,
target_arr
)
def
testSaveSeqModelWithoutInputShapesRaisesError
(
self
):
"""A Sequential model that hasn't been built should raise an error."""
def
testSaveSequentialModelWithoutInputShapes
(
self
):
model
=
sequential_model_without_input_shape
(
True
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'must be built
'
):
# A Sequential model that hasn't been built should raise an error.
with
self
.
assertRaisesRegexp
(
ValueError
,
'Please build the model
'
):
keras_saved_model
.
save_keras_model
(
model
,
''
)
saved_model_path
=
self
.
_save_model_dir
()
output_path
=
keras_saved_model
.
save_keras_model
(
model
,
saved_model_path
,
input_signature
=
tensor_spec
.
TensorSpec
(
shape
=
(
10
,
11
,
12
,
13
,
14
),
dtype
=
dtypes
.
float32
,
name
=
'spec_input'
))
with
session
.
Session
(
graph
=
ops
.
Graph
())
as
sess
:
inputs
,
outputs
,
_
=
load_model
(
sess
,
output_path
,
model_fn_lib
.
ModeKeys
.
PREDICT
)
self
.
assertEqual
(
5
,
inputs
[
next
(
iter
(
inputs
.
keys
()))].
shape
.
ndims
)
self
.
assertEqual
(
5
,
outputs
[
next
(
iter
(
outputs
.
keys
()))].
shape
.
ndims
)
self
.
assertEqual
(
3
,
outputs
[
next
(
iter
(
outputs
.
keys
()))].
shape
[
-
1
])
@
test_util
.
run_v2_only
@
parameterized
.
parameters
(
{
'model_builder'
:
sequential_model_without_input_shape
,
'input_signature'
:
[
tensor_spec
.
TensorSpec
(
shape
=
[
None
,
3
],
dtype
=
dtypes
.
float32
)]},
{
'model_builder'
:
subclassed_model
,
'input_signature'
:
[
tensor_spec
.
TensorSpec
(
shape
=
[
None
,
3
],
dtype
=
dtypes
.
float32
)]})
def
testServingOnly
(
self
,
model_builder
,
input_signature
):
saved_model_path
=
self
.
_save_model_dir
()
input_arr
=
np
.
random
.
random
((
5
,
3
)).
astype
(
np
.
float32
)
model
=
model_builder
()
ref_predict
=
model
.
predict
(
input_arr
)
output_path
=
keras_saved_model
.
save_keras_model
(
model
,
saved_model_path
,
serving_only
=
True
,
input_signature
=
input_signature
)
# Load predict graph, and test predictions
with
session
.
Session
(
graph
=
ops
.
Graph
())
as
sess
:
inputs
,
outputs
,
_
=
load_model
(
sess
,
output_path
,
model_fn_lib
.
ModeKeys
.
PREDICT
)
predictions
=
sess
.
run
(
outputs
[
next
(
iter
(
outputs
.
keys
()))],
{
inputs
[
next
(
iter
(
inputs
.
keys
()))]:
input_arr
})
self
.
assertAllClose
(
ref_predict
,
predictions
,
atol
=
1e-05
)
if
__name__
==
'__main__'
:
test
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录