Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
bed0e5c3
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
bed0e5c3
编写于
3月 16, 2017
作者:
M
Martin Wicke
提交者:
TensorFlower Gardener
3月 16, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Expose Estimator and associated utilities in the API.
Change: 150292011
上级
1ce24242
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
181 addition
and
50 deletion
+181
-50
tensorflow/contrib/cmake/tf_python.cmake
tensorflow/contrib/cmake/tf_python.cmake
+1
-0
tensorflow/python/__init__.py
tensorflow/python/__init__.py
+5
-4
tensorflow/python/estimator/BUILD
tensorflow/python/estimator/BUILD
+23
-7
tensorflow/python/estimator/__init__.py
tensorflow/python/estimator/__init__.py
+42
-0
tensorflow/python/estimator/estimator.py
tensorflow/python/estimator/estimator.py
+19
-12
tensorflow/python/estimator/estimator_test.py
tensorflow/python/estimator/estimator_test.py
+21
-6
tensorflow/python/estimator/export/__init__.py
tensorflow/python/estimator/export/__init__.py
+42
-0
tensorflow/python/estimator/export/export.py
tensorflow/python/estimator/export/export.py
+0
-0
tensorflow/python/estimator/export/export_output.py
tensorflow/python/estimator/export/export_output.py
+0
-0
tensorflow/python/estimator/export/export_output_test.py
tensorflow/python/estimator/export/export_output_test.py
+1
-1
tensorflow/python/estimator/export/export_test.py
tensorflow/python/estimator/export/export_test.py
+2
-2
tensorflow/python/estimator/inputs/__init__.py
tensorflow/python/estimator/inputs/__init__.py
+10
-1
tensorflow/python/estimator/model_fn.py
tensorflow/python/estimator/model_fn.py
+2
-4
tensorflow/python/estimator/model_fn_test.py
tensorflow/python/estimator/model_fn_test.py
+12
-12
tensorflow/python/estimator/run_config.py
tensorflow/python/estimator/run_config.py
+1
-1
未找到文件。
tensorflow/contrib/cmake/tf_python.cmake
浏览文件 @
bed0e5c3
...
...
@@ -184,6 +184,7 @@ add_python_module("tensorflow/python/debug/examples")
add_python_module
(
"tensorflow/python/debug/lib"
)
add_python_module
(
"tensorflow/python/debug/wrappers"
)
add_python_module
(
"tensorflow/python/estimator"
)
add_python_module
(
"tensorflow/python/estimator/export"
)
add_python_module
(
"tensorflow/python/estimator/inputs"
)
add_python_module
(
"tensorflow/python/estimator/inputs/queues"
)
add_python_module
(
"tensorflow/python/framework"
)
...
...
tensorflow/python/__init__.py
浏览文件 @
bed0e5c3
...
...
@@ -74,18 +74,18 @@ from tensorflow.python.ops.standard_ops import *
# pylint: enable=wildcard-import
# Bring in subpackages.
from
tensorflow.python
import
estimator
from
tensorflow.python.layers
import
layers
from
tensorflow.python.ops
import
image_ops
as
image
from
tensorflow.python.ops
import
metrics
from
tensorflow.python.ops
import
nn
from
tensorflow.python.ops
import
sdca_ops
as
sdca
from
tensorflow.python.ops
import
sets
from
tensorflow.python.ops
import
spectral_ops
as
spectral
from
tensorflow.python.ops
import
image_ops
as
image
from
tensorflow.python.ops.losses
import
losses
from
tensorflow.python.ops
import
sets
from
tensorflow.python.saved_model
import
saved_model
from
tensorflow.python.util
import
compat
from
tensorflow.python.user_ops
import
user_ops
from
tensorflow.python.util
import
compat
from
tensorflow.python.saved_model
import
saved_model
from
tensorflow.python.summary
import
summary
# Import the names from python/training.py as train.Name.
...
...
@@ -209,6 +209,7 @@ _allowed_symbols.extend([
'app'
,
'compat'
,
'errors'
,
'estimator'
,
'flags'
,
'gfile'
,
'graph_util'
,
...
...
tensorflow/python/estimator/BUILD
浏览文件 @
bed0e5c3
...
...
@@ -14,6 +14,7 @@ load("//tensorflow:tensorflow.bzl", "py_test")
py_library
(
name
=
"estimator_py"
,
srcs
=
[
"__init__.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":checkpoint_utils"
,
...
...
@@ -22,6 +23,7 @@ py_library(
":inputs"
,
":model_fn"
,
":run_config"
,
"//tensorflow/python:util"
,
],
)
...
...
@@ -69,7 +71,7 @@ py_library(
srcs
=
[
"model_fn.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":export
_output
"
,
":export"
,
"//tensorflow/python:array_ops"
,
"//tensorflow/python:framework_for_generated_wrappers"
,
"//tensorflow/python:training"
,
...
...
@@ -107,7 +109,6 @@ py_library(
deps
=
[
":checkpoint_utils"
,
":export"
,
":export_output"
,
":model_fn"
,
":run_config"
,
"//tensorflow/core:protos_all_py"
,
...
...
@@ -147,7 +148,7 @@ py_test(
py_library
(
name
=
"export_output"
,
srcs
=
[
"export_output.py"
],
srcs
=
[
"export
/export
_output.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
"//tensorflow/python/saved_model:signature_def_utils"
,
...
...
@@ -157,7 +158,7 @@ py_library(
py_test
(
name
=
"export_output_test"
,
size
=
"small"
,
srcs
=
[
"export_output_test.py"
],
srcs
=
[
"export
/export
_output_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":export_output"
,
...
...
@@ -168,7 +169,21 @@ py_test(
py_library
(
name
=
"export"
,
srcs
=
[
"export.py"
],
srcs
=
[
"export/__init__.py"
,
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":export_export"
,
":export_output"
,
],
)
py_library
(
name
=
"export_export"
,
srcs
=
[
"export/export.py"
,
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":export_output"
,
...
...
@@ -182,10 +197,10 @@ py_library(
py_test
(
name
=
"export_test"
,
size
=
"small"
,
srcs
=
[
"export_test.py"
],
srcs
=
[
"export
/export
_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":export"
,
":export
_export
"
,
":export_output"
,
"//tensorflow/python:client_testlib"
,
],
...
...
@@ -198,6 +213,7 @@ py_library(
deps
=
[
":numpy_io"
,
":pandas_io"
,
"//tensorflow/python:util"
,
],
)
...
...
tensorflow/python/estimator/__init__.py
0 → 100644
浏览文件 @
bed0e5c3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Estimator: High level tools for working with models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
tensorflow.python.estimator
import
export
from
tensorflow.python.estimator
import
inputs
from
tensorflow.python.estimator.estimator
import
Estimator
from
tensorflow.python.estimator.model_fn
import
EstimatorSpec
from
tensorflow.python.estimator.model_fn
import
ModeKeys
from
tensorflow.python.estimator.run_config
import
RunConfig
from
tensorflow.python.util.all_util
import
remove_undocumented
_allowed_symbols
=
[
'inputs'
,
'export'
,
'Estimator'
,
'EstimatorSpec'
,
'ModeKeys'
,
'RunConfig'
,
]
remove_undocumented
(
__name__
,
allowed_exception_list
=
_allowed_symbols
)
tensorflow/python/estimator/estimator.py
浏览文件 @
bed0e5c3
...
...
@@ -30,9 +30,10 @@ import six
from
tensorflow.core.framework
import
summary_pb2
from
tensorflow.core.protobuf
import
config_pb2
from
tensorflow.python.client
import
session
as
tf_session
from
tensorflow.python.estimator
import
export
from
tensorflow.python.estimator
import
model_fn
as
model_fn_lib
from
tensorflow.python.estimator
import
run_config
from
tensorflow.python.estimator.export.export
import
build_all_signature_defs
from
tensorflow.python.estimator.export.export
import
get_timestamped_export_dir
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
random_seed
from
tensorflow.python.ops
import
control_flow_ops
...
...
@@ -56,9 +57,9 @@ _VALID_MODEL_FN_ARGS = set(
class
Estimator
(
object
):
"""Estimator class to train and evaluate TensorFlow models.
The
Estimator object wraps a model which is specified by a `model_fn`, which
,
given inputs and a number of other parameters, returns the ops necessary to
perform training, evaluation, or predictions, respectively
.
The
`Estimator` object wraps a model which is specified by a `model_fn`
,
which, given inputs and a number of other parameters, returns the ops
necessary to perform training, evaluation, or predictions
.
All outputs (checkpoints, event files, etc.) are written to `model_dir`, or a
subdirectory thereof. If `model_dir` is not set, a temporary directory is
...
...
@@ -68,15 +69,20 @@ class Estimator(object):
about the execution environment. It is passed on to the `model_fn`, if the
`model_fn` has a parameter named "config" (and input functions in the same
manner). If the `config` parameter is not passed, it is instantiated by the
Estimator
. Not passing config means that defaults useful for local execution
are used.
Estimator
makes config available to the model (for instance, to
`Estimator`
. Not passing config means that defaults useful for local execution
are used.
`Estimator`
makes config available to the model (for instance, to
allow specialization based on the number of workers available), and also uses
some of its fields to control internals, especially regarding checkpointing.
The `params` argument contains hyperparameters. It is passed to the
`model_fn`, if the `model_fn` has a parameter named "params", and to the input
functions in the same manner. Estimator only passes params along, it does not
inspect it. The structure of params is therefore entirely up to the developer.
functions in the same manner. `Estimator` only passes params along, it does
not inspect it. The structure of `params` is therefore entirely up to the
developer.
None of `Estimator`'s methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use `model_fn` to configure
the base class, and may add methods implementing specialized functionality.
"""
def
__init__
(
self
,
model_fn
,
model_dir
=
None
,
config
=
None
,
params
=
None
):
...
...
@@ -116,7 +122,7 @@ class Estimator(object):
ValueError: if this is called via a subclass and if that class overrides
a member of `Estimator`.
"""
self
.
_assert_members_are_not_overridden
(
)
Estimator
.
_assert_members_are_not_overridden
(
self
)
# Model directory.
self
.
_model_dir
=
model_dir
if
self
.
_model_dir
is
None
:
...
...
@@ -395,7 +401,7 @@ class Estimator(object):
mode
=
model_fn_lib
.
ModeKeys
.
PREDICT
)
# Build the SignatureDefs from receivers and all outputs
signature_def_map
=
export
.
build_all_signature_defs
(
signature_def_map
=
build_all_signature_defs
(
serving_input_receiver
.
receiver_tensors
,
estimator_spec
.
export_outputs
)
...
...
@@ -405,7 +411,7 @@ class Estimator(object):
if
not
checkpoint_path
:
raise
ValueError
(
"Couldn't find trained model at %s."
%
self
.
_model_dir
)
export_dir
=
export
.
get_timestamped_export_dir
(
export_dir_base
)
export_dir
=
get_timestamped_export_dir
(
export_dir_base
)
# TODO(soergel): Consider whether MonitoredSession makes sense here
with
tf_session
.
Session
()
as
session
:
...
...
@@ -600,7 +606,8 @@ class Estimator(object):
if
model_fn_lib
.
MetricKeys
.
LOSS
in
estimator_spec
.
eval_metric_ops
:
raise
ValueError
(
'Metric with name `loss` is not allowed, because Estimator '
'Metric with name "%s" is not allowed, because Estimator '
%
(
model_fn_lib
.
MetricKeys
.
LOSS
)
+
'already defines a default metric with the same name.'
)
estimator_spec
.
eval_metric_ops
[
model_fn_lib
.
MetricKeys
.
LOSS
]
=
metrics_lib
.
mean
(
estimator_spec
.
loss
)
...
...
tensorflow/python/estimator/estimator_test.py
浏览文件 @
bed0e5c3
...
...
@@ -26,7 +26,6 @@ import numpy as np
from
tensorflow.python.client
import
session
from
tensorflow.python.estimator
import
estimator
from
tensorflow.python.estimator
import
export
from
tensorflow.python.estimator
import
export_output
from
tensorflow.python.estimator
import
model_fn
as
model_fn_lib
from
tensorflow.python.estimator
import
run_config
from
tensorflow.python.estimator.inputs
import
numpy_io
...
...
@@ -74,6 +73,22 @@ class EstimatorInheritanceConstraintTest(test.TestCase):
ValueError
,
'cannot override members of Estimator.*predict'
):
_Estimator
()
def
test_override_a_method_with_tricks
(
self
):
class
_Estimator
(
estimator
.
Estimator
):
def
__init__
(
self
):
super
(
_Estimator
,
self
).
__init__
(
model_fn
=
dummy_model_fn
)
def
_assert_members_are_not_overridden
(
self
):
pass
# HAHA! I tricked you!
def
predict
(
self
,
input_fn
,
predict_keys
=
None
,
hooks
=
None
):
pass
with
self
.
assertRaisesRegexp
(
ValueError
,
'cannot override members of Estimator.*predict'
):
_Estimator
()
def
test_extension_of_api_is_ok
(
self
):
class
_Estimator
(
estimator
.
Estimator
):
...
...
@@ -812,7 +827,7 @@ def _model_fn_for_export_tests(features, labels, mode):
loss
=
constant_op
.
constant
(
1.
),
train_op
=
constant_op
.
constant
(
2.
),
export_outputs
=
{
'test'
:
export
_output
.
ClassificationOutput
(
scores
,
classes
)})
'test'
:
export
.
ClassificationOutput
(
scores
,
classes
)})
def
_model_fn_with_saveables_for_export_tests
(
features
,
labels
,
mode
):
...
...
@@ -826,7 +841,7 @@ def _model_fn_with_saveables_for_export_tests(features, labels, mode):
loss
=
constant_op
.
constant
(
1.
),
train_op
=
train_op
,
export_outputs
=
{
'test'
:
export
_output
.
PredictOutput
({
'prediction'
:
prediction
})})
'test'
:
export
.
PredictOutput
({
'prediction'
:
prediction
})})
_VOCAB_FILE_CONTENT
=
'emerson
\n
lake
\n
palmer
\n
'
...
...
@@ -1038,7 +1053,7 @@ class EstimatorExportTest(test.TestCase):
loss
=
constant_op
.
constant
(
0.
),
train_op
=
constant_op
.
constant
(
0.
),
scaffold
=
training
.
Scaffold
(
saver
=
self
.
mock_saver
),
export_outputs
=
{
'test'
:
export
_output
.
ClassificationOutput
(
scores
)})
export_outputs
=
{
'test'
:
export
.
ClassificationOutput
(
scores
)})
est
=
estimator
.
Estimator
(
model_fn
=
_model_fn_scaffold
)
est
.
train
(
dummy_input_fn
,
steps
=
1
)
...
...
@@ -1075,7 +1090,7 @@ class EstimatorExportTest(test.TestCase):
loss
=
constant_op
.
constant
(
0.
),
train_op
=
constant_op
.
constant
(
0.
),
scaffold
=
training
.
Scaffold
(
local_init_op
=
custom_local_init_op
),
export_outputs
=
{
'test'
:
export
_output
.
ClassificationOutput
(
scores
)})
export_outputs
=
{
'test'
:
export
.
ClassificationOutput
(
scores
)})
est
=
estimator
.
Estimator
(
model_fn
=
_model_fn_scaffold
)
est
.
train
(
dummy_input_fn
,
steps
=
1
)
...
...
@@ -1107,7 +1122,7 @@ class EstimatorIntegrationTest(test.TestCase):
predictions
=
layers
.
dense
(
features
[
'x'
],
1
,
kernel_initializer
=
init_ops
.
zeros_initializer
())
export_outputs
=
{
'predictions'
:
export
_output
.
RegressionOutput
(
predictions
)
'predictions'
:
export
.
RegressionOutput
(
predictions
)
}
if
mode
==
model_fn_lib
.
ModeKeys
.
PREDICT
:
...
...
tensorflow/python/estimator/export/__init__.py
0 → 100644
浏览文件 @
bed0e5c3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility methods for exporting Estimator."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
tensorflow.python.estimator.export.export
import
build_parsing_serving_input_receiver_fn
from
tensorflow.python.estimator.export.export
import
build_raw_serving_input_receiver_fn
from
tensorflow.python.estimator.export.export
import
ServingInputReceiver
from
tensorflow.python.estimator.export.export_output
import
ClassificationOutput
from
tensorflow.python.estimator.export.export_output
import
ExportOutput
from
tensorflow.python.estimator.export.export_output
import
PredictOutput
from
tensorflow.python.estimator.export.export_output
import
RegressionOutput
from
tensorflow.python.util.all_util
import
remove_undocumented
_allowed_symbols
=
[
'build_parsing_serving_input_receiver_fn'
,
'build_raw_serving_input_receiver_fn'
,
'ServingInputReceiver'
,
'ClassificationOutput'
,
'ExportOutput'
,
'PredictOutput'
,
'RegressionOutput'
,
]
remove_undocumented
(
__name__
,
allowed_exception_list
=
_allowed_symbols
)
tensorflow/python/estimator/export.py
→
tensorflow/python/estimator/export
/export
.py
浏览文件 @
bed0e5c3
文件已移动
tensorflow/python/estimator/export_output.py
→
tensorflow/python/estimator/export
/export
_output.py
浏览文件 @
bed0e5c3
文件已移动
tensorflow/python/estimator/export_output_test.py
→
tensorflow/python/estimator/export
/export
_output_test.py
浏览文件 @
bed0e5c3
...
...
@@ -21,7 +21,7 @@ from __future__ import print_function
from
tensorflow.core.framework
import
tensor_shape_pb2
from
tensorflow.core.framework
import
types_pb2
from
tensorflow.core.protobuf
import
meta_graph_pb2
from
tensorflow.python.estimator
import
export_output
as
export_output_lib
from
tensorflow.python.estimator
.export
import
export_output
as
export_output_lib
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.platform
import
test
...
...
tensorflow/python/estimator/export_test.py
→
tensorflow/python/estimator/export
/export
_test.py
浏览文件 @
bed0e5c3
...
...
@@ -25,8 +25,8 @@ import time
from
google.protobuf
import
text_format
from
tensorflow.core.example
import
example_pb2
from
tensorflow.python.estimator
import
export
from
tensorflow.python.estimator
import
export_output
from
tensorflow.python.estimator
.export
import
export
from
tensorflow.python.estimator
.export
import
export_output
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
dtypes
...
...
tensorflow/python/estimator/inputs/__init__.py
浏览文件 @
bed0e5c3
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Methods to create input_fn
."""
"""
Utility methods to create simple input_fns
."""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -20,3 +20,12 @@ from __future__ import print_function
from
tensorflow.python.estimator.inputs.numpy_io
import
numpy_input_fn
from
tensorflow.python.estimator.inputs.pandas_io
import
pandas_input_fn
from
tensorflow.python.util.all_util
import
remove_undocumented
_allowed_symbols
=
[
'numpy_input_fn'
,
'pandas_input_fn'
]
remove_undocumented
(
__name__
,
allowed_exception_list
=
_allowed_symbols
)
tensorflow/python/estimator/model_fn.py
浏览文件 @
bed0e5c3
...
...
@@ -23,7 +23,7 @@ import collections
import
six
from
tensorflow.python.estimator
import
export_o
utput
from
tensorflow.python.estimator
.export.export_output
import
ExportO
utput
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
tensor_shape
from
tensorflow.python.ops
import
array_ops
...
...
@@ -50,8 +50,6 @@ class ModeKeys(object):
class
MetricKeys
(
object
):
"""Metric key strings."""
LOSS
=
'loss'
AUC
=
'auc'
ACCURACY
=
'accuracy'
class
EstimatorSpec
(
...
...
@@ -214,7 +212,7 @@ class EstimatorSpec(
raise
TypeError
(
'export_outputs must be dict, given: {}'
.
format
(
export_outputs
))
for
v
in
six
.
itervalues
(
export_outputs
):
if
not
isinstance
(
v
,
export_output
.
ExportOutput
):
if
not
isinstance
(
v
,
ExportOutput
):
raise
TypeError
(
'Values in export_outputs must be ExportOutput objects. '
'Given: {}'
.
format
(
export_outputs
))
...
...
tensorflow/python/estimator/model_fn_test.py
浏览文件 @
bed0e5c3
...
...
@@ -19,7 +19,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
tensorflow.python.estimator
import
export
_output
from
tensorflow.python.estimator
import
export
from
tensorflow.python.estimator
import
model_fn
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
ops
...
...
@@ -67,7 +67,7 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op
=
control_flow_ops
.
no_op
(),
eval_metric_ops
=
{
'loss'
:
(
control_flow_ops
.
no_op
(),
loss
)},
export_outputs
=
{
'head_name'
:
export
_output
.
ClassificationOutput
(
classes
=
classes
)
'head_name'
:
export
.
ClassificationOutput
(
classes
=
classes
)
},
training_chief_hooks
=
[
_FakeHook
()],
training_hooks
=
[
_FakeHook
()],
...
...
@@ -217,7 +217,7 @@ class EstimatorSpecEvalTest(test.TestCase):
train_op
=
control_flow_ops
.
no_op
(),
eval_metric_ops
=
{
'loss'
:
(
control_flow_ops
.
no_op
(),
loss
)},
export_outputs
=
{
'head_name'
:
export
_output
.
ClassificationOutput
(
classes
=
classes
)
'head_name'
:
export
.
ClassificationOutput
(
classes
=
classes
)
},
training_chief_hooks
=
[
_FakeHook
()],
training_hooks
=
[
_FakeHook
()],
...
...
@@ -401,7 +401,7 @@ class EstimatorSpecInferTest(test.TestCase):
train_op
=
control_flow_ops
.
no_op
(),
eval_metric_ops
=
{
'loss'
:
(
control_flow_ops
.
no_op
(),
loss
)},
export_outputs
=
{
'head_name'
:
export
_output
.
ClassificationOutput
(
classes
=
classes
)
'head_name'
:
export
.
ClassificationOutput
(
classes
=
classes
)
},
training_chief_hooks
=
[
_FakeHook
()],
training_hooks
=
[
_FakeHook
()],
...
...
@@ -446,7 +446,7 @@ class EstimatorSpecInferTest(test.TestCase):
model_fn
.
EstimatorSpec
(
mode
=
model_fn
.
ModeKeys
.
PREDICT
,
predictions
=
predictions
,
export_outputs
=
export
_output
.
ClassificationOutput
(
classes
=
classes
))
export_outputs
=
export
.
ClassificationOutput
(
classes
=
classes
))
def
testExportOutputsValueNotExportOutput
(
self
):
with
ops
.
Graph
().
as_default
(),
self
.
test_session
():
...
...
@@ -465,7 +465,7 @@ class EstimatorSpecInferTest(test.TestCase):
with
ops
.
Graph
().
as_default
(),
self
.
test_session
():
predictions
=
{
'loss'
:
constant_op
.
constant
(
1.
)}
output_1
=
constant_op
.
constant
([
1.
])
regression_output
=
export
_output
.
RegressionOutput
(
value
=
output_1
)
regression_output
=
export
.
RegressionOutput
(
value
=
output_1
)
export_outputs
=
{
'head-1'
:
regression_output
,
}
...
...
@@ -488,9 +488,9 @@ class EstimatorSpecInferTest(test.TestCase):
output_3
=
constant_op
.
constant
([
'3'
])
export_outputs
=
{
signature_constants
.
DEFAULT_SERVING_SIGNATURE_DEF_KEY
:
export
_output
.
RegressionOutput
(
value
=
output_1
),
'head-2'
:
export
_output
.
ClassificationOutput
(
classes
=
output_2
),
'head-3'
:
export
_output
.
PredictOutput
(
outputs
=
{
export
.
RegressionOutput
(
value
=
output_1
),
'head-2'
:
export
.
ClassificationOutput
(
classes
=
output_2
),
'head-3'
:
export
.
PredictOutput
(
outputs
=
{
'some_output_3'
:
output_3
})}
estimator_spec
=
model_fn
.
EstimatorSpec
(
...
...
@@ -506,9 +506,9 @@ class EstimatorSpecInferTest(test.TestCase):
output_2
=
constant_op
.
constant
([
'2'
])
output_3
=
constant_op
.
constant
([
'3'
])
export_outputs
=
{
'head-1'
:
export
_output
.
RegressionOutput
(
value
=
output_1
),
'head-2'
:
export
_output
.
ClassificationOutput
(
classes
=
output_2
),
'head-3'
:
export
_output
.
PredictOutput
(
outputs
=
{
'head-1'
:
export
.
RegressionOutput
(
value
=
output_1
),
'head-2'
:
export
.
ClassificationOutput
(
classes
=
output_2
),
'head-3'
:
export
.
PredictOutput
(
outputs
=
{
'some_output_3'
:
output_3
})}
with
self
.
assertRaisesRegexp
(
...
...
tensorflow/python/estimator/run_config.py
浏览文件 @
bed0e5c3
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Run Config
."""
"""
Environment configuration object for Estimators
."""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录