提交 bed0e5c3 编写于 作者: M Martin Wicke 提交者: TensorFlower Gardener

Expose Estimator and associated utilities in the API.

Change: 150292011
上级 1ce24242
......@@ -184,6 +184,7 @@ add_python_module("tensorflow/python/debug/examples")
add_python_module("tensorflow/python/debug/lib")
add_python_module("tensorflow/python/debug/wrappers")
add_python_module("tensorflow/python/estimator")
add_python_module("tensorflow/python/estimator/export")
add_python_module("tensorflow/python/estimator/inputs")
add_python_module("tensorflow/python/estimator/inputs/queues")
add_python_module("tensorflow/python/framework")
......
......@@ -74,18 +74,18 @@ from tensorflow.python.ops.standard_ops import *
# pylint: enable=wildcard-import
# Bring in subpackages.
from tensorflow.python import estimator
from tensorflow.python.layers import layers
from tensorflow.python.ops import image_ops as image
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn
from tensorflow.python.ops import sdca_ops as sdca
from tensorflow.python.ops import sets
from tensorflow.python.ops import spectral_ops as spectral
from tensorflow.python.ops import image_ops as image
from tensorflow.python.ops.losses import losses
from tensorflow.python.ops import sets
from tensorflow.python.saved_model import saved_model
from tensorflow.python.util import compat
from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat
from tensorflow.python.saved_model import saved_model
from tensorflow.python.summary import summary
# Import the names from python/training.py as train.Name.
......@@ -209,6 +209,7 @@ _allowed_symbols.extend([
'app',
'compat',
'errors',
'estimator',
'flags',
'gfile',
'graph_util',
......
......@@ -14,6 +14,7 @@ load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "estimator_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
":checkpoint_utils",
......@@ -22,6 +23,7 @@ py_library(
":inputs",
":model_fn",
":run_config",
"//tensorflow/python:util",
],
)
......@@ -69,7 +71,7 @@ py_library(
srcs = ["model_fn.py"],
srcs_version = "PY2AND3",
deps = [
":export_output",
":export",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:training",
......@@ -107,7 +109,6 @@ py_library(
deps = [
":checkpoint_utils",
":export",
":export_output",
":model_fn",
":run_config",
"//tensorflow/core:protos_all_py",
......@@ -147,7 +148,7 @@ py_test(
py_library(
name = "export_output",
srcs = ["export_output.py"],
srcs = ["export/export_output.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/saved_model:signature_def_utils",
......@@ -157,7 +158,7 @@ py_library(
py_test(
name = "export_output_test",
size = "small",
srcs = ["export_output_test.py"],
srcs = ["export/export_output_test.py"],
srcs_version = "PY2AND3",
deps = [
":export_output",
......@@ -168,7 +169,21 @@ py_test(
py_library(
name = "export",
srcs = ["export.py"],
srcs = [
"export/__init__.py",
],
srcs_version = "PY2AND3",
deps = [
":export_export",
":export_output",
],
)
py_library(
name = "export_export",
srcs = [
"export/export.py",
],
srcs_version = "PY2AND3",
deps = [
":export_output",
......@@ -182,10 +197,10 @@ py_library(
py_test(
name = "export_test",
size = "small",
srcs = ["export_test.py"],
srcs = ["export/export_test.py"],
srcs_version = "PY2AND3",
deps = [
":export",
":export_export",
":export_output",
"//tensorflow/python:client_testlib",
],
......@@ -198,6 +213,7 @@ py_library(
deps = [
":numpy_io",
":pandas_io",
"//tensorflow/python:util",
],
)
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Estimator: High level tools for working with models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator import export
from tensorflow.python.estimator import inputs
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.model_fn import EstimatorSpec
from tensorflow.python.estimator.model_fn import ModeKeys
from tensorflow.python.estimator.run_config import RunConfig
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'inputs',
'export',
'Estimator',
'EstimatorSpec',
'ModeKeys',
'RunConfig',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
......@@ -30,9 +30,10 @@ import six
from tensorflow.core.framework import summary_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.estimator import export
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator.export.export import build_all_signature_defs
from tensorflow.python.estimator.export.export import get_timestamped_export_dir
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import control_flow_ops
......@@ -56,9 +57,9 @@ _VALID_MODEL_FN_ARGS = set(
class Estimator(object):
"""Estimator class to train and evaluate TensorFlow models.
The Estimator object wraps a model which is specified by a `model_fn`, which,
given inputs and a number of other parameters, returns the ops necessary to
perform training, evaluation, or predictions, respectively.
The `Estimator` object wraps a model which is specified by a `model_fn`,
which, given inputs and a number of other parameters, returns the ops
necessary to perform training, evaluation, or predictions.
All outputs (checkpoints, event files, etc.) are written to `model_dir`, or a
subdirectory thereof. If `model_dir` is not set, a temporary directory is
......@@ -68,15 +69,20 @@ class Estimator(object):
about the execution environment. It is passed on to the `model_fn`, if the
`model_fn` has a parameter named "config" (and input functions in the same
manner). If the `config` parameter is not passed, it is instantiated by the
Estimator. Not passing config means that defaults useful for local execution
are used. Estimator makes config available to the model (for instance, to
`Estimator`. Not passing config means that defaults useful for local execution
are used. `Estimator` makes config available to the model (for instance, to
allow specialization based on the number of workers available), and also uses
some of its fields to control internals, especially regarding checkpointing.
The `params` argument contains hyperparameters. It is passed to the
`model_fn`, if the `model_fn` has a parameter named "params", and to the input
functions in the same manner. Estimator only passes params along, it does not
inspect it. The structure of params is therefore entirely up to the developer.
functions in the same manner. `Estimator` only passes params along, it does
not inspect it. The structure of `params` is therefore entirely up to the
developer.
None of `Estimator`'s methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use `model_fn` to configure
the base class, and may add methods implementing specialized functionality.
"""
def __init__(self, model_fn, model_dir=None, config=None, params=None):
......@@ -116,7 +122,7 @@ class Estimator(object):
ValueError: if this is called via a subclass and if that class overrides
a member of `Estimator`.
"""
self._assert_members_are_not_overridden()
Estimator._assert_members_are_not_overridden(self)
# Model directory.
self._model_dir = model_dir
if self._model_dir is None:
......@@ -395,7 +401,7 @@ class Estimator(object):
mode=model_fn_lib.ModeKeys.PREDICT)
# Build the SignatureDefs from receivers and all outputs
signature_def_map = export.build_all_signature_defs(
signature_def_map = build_all_signature_defs(
serving_input_receiver.receiver_tensors,
estimator_spec.export_outputs)
......@@ -405,7 +411,7 @@ class Estimator(object):
if not checkpoint_path:
raise ValueError("Couldn't find trained model at %s." % self._model_dir)
export_dir = export.get_timestamped_export_dir(export_dir_base)
export_dir = get_timestamped_export_dir(export_dir_base)
# TODO(soergel): Consider whether MonitoredSession makes sense here
with tf_session.Session() as session:
......@@ -600,7 +606,8 @@ class Estimator(object):
if model_fn_lib.MetricKeys.LOSS in estimator_spec.eval_metric_ops:
raise ValueError(
'Metric with name `loss` is not allowed, because Estimator '
'Metric with name "%s" is not allowed, because Estimator ' % (
model_fn_lib.MetricKeys.LOSS) +
'already defines a default metric with the same name.')
estimator_spec.eval_metric_ops[
model_fn_lib.MetricKeys.LOSS] = metrics_lib.mean(estimator_spec.loss)
......
......@@ -26,7 +26,6 @@ import numpy as np
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import export
from tensorflow.python.estimator import export_output
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator.inputs import numpy_io
......@@ -74,6 +73,22 @@ class EstimatorInheritanceConstraintTest(test.TestCase):
ValueError, 'cannot override members of Estimator.*predict'):
_Estimator()
def test_override_a_method_with_tricks(self):
class _Estimator(estimator.Estimator):
def __init__(self):
super(_Estimator, self).__init__(model_fn=dummy_model_fn)
def _assert_members_are_not_overridden(self):
pass # HAHA! I tricked you!
def predict(self, input_fn, predict_keys=None, hooks=None):
pass
with self.assertRaisesRegexp(
ValueError, 'cannot override members of Estimator.*predict'):
_Estimator()
def test_extension_of_api_is_ok(self):
class _Estimator(estimator.Estimator):
......@@ -812,7 +827,7 @@ def _model_fn_for_export_tests(features, labels, mode):
loss=constant_op.constant(1.),
train_op=constant_op.constant(2.),
export_outputs={
'test': export_output.ClassificationOutput(scores, classes)})
'test': export.ClassificationOutput(scores, classes)})
def _model_fn_with_saveables_for_export_tests(features, labels, mode):
......@@ -826,7 +841,7 @@ def _model_fn_with_saveables_for_export_tests(features, labels, mode):
loss=constant_op.constant(1.),
train_op=train_op,
export_outputs={
'test': export_output.PredictOutput({'prediction': prediction})})
'test': export.PredictOutput({'prediction': prediction})})
_VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n'
......@@ -1038,7 +1053,7 @@ class EstimatorExportTest(test.TestCase):
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.),
scaffold=training.Scaffold(saver=self.mock_saver),
export_outputs={'test': export_output.ClassificationOutput(scores)})
export_outputs={'test': export.ClassificationOutput(scores)})
est = estimator.Estimator(model_fn=_model_fn_scaffold)
est.train(dummy_input_fn, steps=1)
......@@ -1075,7 +1090,7 @@ class EstimatorExportTest(test.TestCase):
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.),
scaffold=training.Scaffold(local_init_op=custom_local_init_op),
export_outputs={'test': export_output.ClassificationOutput(scores)})
export_outputs={'test': export.ClassificationOutput(scores)})
est = estimator.Estimator(model_fn=_model_fn_scaffold)
est.train(dummy_input_fn, steps=1)
......@@ -1107,7 +1122,7 @@ class EstimatorIntegrationTest(test.TestCase):
predictions = layers.dense(
features['x'], 1, kernel_initializer=init_ops.zeros_initializer())
export_outputs = {
'predictions': export_output.RegressionOutput(predictions)
'predictions': export.RegressionOutput(predictions)
}
if mode == model_fn_lib.ModeKeys.PREDICT:
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility methods for exporting Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator.export.export import build_parsing_serving_input_receiver_fn
from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn
from tensorflow.python.estimator.export.export import ServingInputReceiver
from tensorflow.python.estimator.export.export_output import ClassificationOutput
from tensorflow.python.estimator.export.export_output import ExportOutput
from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.estimator.export.export_output import RegressionOutput
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'build_parsing_serving_input_receiver_fn',
'build_raw_serving_input_receiver_fn',
'ServingInputReceiver',
'ClassificationOutput',
'ExportOutput',
'PredictOutput',
'RegressionOutput',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
......@@ -21,7 +21,7 @@ from __future__ import print_function
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.estimator import export_output as export_output_lib
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
......
......@@ -25,8 +25,8 @@ import time
from google.protobuf import text_format
from tensorflow.core.example import example_pb2
from tensorflow.python.estimator import export
from tensorflow.python.estimator import export_output
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Methods to create input_fn."""
"""Utility methods to create simple input_fns."""
from __future__ import absolute_import
from __future__ import division
......@@ -20,3 +20,12 @@ from __future__ import print_function
from tensorflow.python.estimator.inputs.numpy_io import numpy_input_fn
from tensorflow.python.estimator.inputs.pandas_io import pandas_input_fn
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'numpy_input_fn',
'pandas_input_fn'
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
......@@ -23,7 +23,7 @@ import collections
import six
from tensorflow.python.estimator import export_output
from tensorflow.python.estimator.export.export_output import ExportOutput
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
......@@ -50,8 +50,6 @@ class ModeKeys(object):
class MetricKeys(object):
"""Metric key strings."""
LOSS = 'loss'
AUC = 'auc'
ACCURACY = 'accuracy'
class EstimatorSpec(
......@@ -214,7 +212,7 @@ class EstimatorSpec(
raise TypeError('export_outputs must be dict, given: {}'.format(
export_outputs))
for v in six.itervalues(export_outputs):
if not isinstance(v, export_output.ExportOutput):
if not isinstance(v, ExportOutput):
raise TypeError(
'Values in export_outputs must be ExportOutput objects. '
'Given: {}'.format(export_outputs))
......
......@@ -19,7 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator import export_output
from tensorflow.python.estimator import export
from tensorflow.python.estimator import model_fn
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
......@@ -67,7 +67,7 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=control_flow_ops.no_op(),
eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
'head_name': export.ClassificationOutput(classes=classes)
},
training_chief_hooks=[_FakeHook()],
training_hooks=[_FakeHook()],
......@@ -217,7 +217,7 @@ class EstimatorSpecEvalTest(test.TestCase):
train_op=control_flow_ops.no_op(),
eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
'head_name': export.ClassificationOutput(classes=classes)
},
training_chief_hooks=[_FakeHook()],
training_hooks=[_FakeHook()],
......@@ -401,7 +401,7 @@ class EstimatorSpecInferTest(test.TestCase):
train_op=control_flow_ops.no_op(),
eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
'head_name': export.ClassificationOutput(classes=classes)
},
training_chief_hooks=[_FakeHook()],
training_hooks=[_FakeHook()],
......@@ -446,7 +446,7 @@ class EstimatorSpecInferTest(test.TestCase):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs=export_output.ClassificationOutput(classes=classes))
export_outputs=export.ClassificationOutput(classes=classes))
def testExportOutputsValueNotExportOutput(self):
with ops.Graph().as_default(), self.test_session():
......@@ -465,7 +465,7 @@ class EstimatorSpecInferTest(test.TestCase):
with ops.Graph().as_default(), self.test_session():
predictions = {'loss': constant_op.constant(1.)}
output_1 = constant_op.constant([1.])
regression_output = export_output.RegressionOutput(value=output_1)
regression_output = export.RegressionOutput(value=output_1)
export_outputs = {
'head-1': regression_output,
}
......@@ -488,9 +488,9 @@ class EstimatorSpecInferTest(test.TestCase):
output_3 = constant_op.constant(['3'])
export_outputs = {
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
export_output.RegressionOutput(value=output_1),
'head-2': export_output.ClassificationOutput(classes=output_2),
'head-3': export_output.PredictOutput(outputs={
export.RegressionOutput(value=output_1),
'head-2': export.ClassificationOutput(classes=output_2),
'head-3': export.PredictOutput(outputs={
'some_output_3': output_3
})}
estimator_spec = model_fn.EstimatorSpec(
......@@ -506,9 +506,9 @@ class EstimatorSpecInferTest(test.TestCase):
output_2 = constant_op.constant(['2'])
output_3 = constant_op.constant(['3'])
export_outputs = {
'head-1': export_output.RegressionOutput(value=output_1),
'head-2': export_output.ClassificationOutput(classes=output_2),
'head-3': export_output.PredictOutput(outputs={
'head-1': export.RegressionOutput(value=output_1),
'head-2': export.ClassificationOutput(classes=output_2),
'head-3': export.PredictOutput(outputs={
'some_output_3': output_3
})}
with self.assertRaisesRegexp(
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run Config."""
"""Environment configuration object for Estimators."""
from __future__ import absolute_import
from __future__ import division
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册