From 5ed1de777fe1733f1609cff00c679c3a59df9a4c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Oct 2016 13:19:44 -0800 Subject: [PATCH] Automated rollback of change 136502135 Change: 136641403 --- .../python/learn/estimators/estimator.py | 37 +++---------------- .../learn/python/learn/graph_actions.py | 13 ++++--- 2 files changed, 13 insertions(+), 37 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 9b8206f8145..d1a2d6be5c8 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -216,12 +216,8 @@ def _make_metrics_ops(metrics, features, targets, predictions): predictions. Returns: - `dict` whose keys are summary names, and values are the result of the - metric, either: - - `Tensor` values (in which case only the result of the last eval batch - will be summarized). - - `tuple` of 2 `Tensor` objects, update op and value. The update op will - be run once each eval step, and the value written to summary. + A dict mapping the friendly given in `metrics` to the result of calling the + given metric function. Raises: ValueError: If metrics specifications do not work with the type of @@ -271,13 +267,6 @@ def _make_metrics_ops(metrics, features, targets, predictions): return result -def _maybe_add_streaming_mean(result, key, value): - if key in result: - logging.warning('Metrics already contains %s, skipping.', key) - return - result[key] = metrics_lib.streaming_mean(value) - - class BaseEstimator( sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable): """Abstract BaseEstimator class to train and evaluate TensorFlow models. @@ -585,7 +574,7 @@ class BaseEstimator( Args: features: `Tensor` or `dict` of `Tensor` objects. targets: `Tensor` or `dict` of `Tensor` objects. - metrics: Dict of metrics to run. If `None`, the default metric functions + metrics: Dict of metrics to run. If None, the default metric functions are used; if {}, no metrics are used. Otherwise, `metrics` should map friendly names for the metric to a `MetricSpec` object defining which model outputs to evaluate against which targets with which metric @@ -1055,28 +1044,14 @@ class Estimator(BaseEstimator): `../metric_spec.py`. Returns: - `dict` whose keys are summary names, and values are either: - - `Tensor` values (in which case only the result of the last eval batch - will be summarized). - - `tuple` of 2 `Tensor` objects, update op and value. The update op will - be run once each eval step, and the value written to summary. + metrics: `dict` of `Tensor` objects. Raises: ValueError: if `metrics` don't match `targets`. """ predictions, loss, _ = self._call_model_fn(features, targets, ModeKeys.EVAL) - result = _make_metrics_ops(metrics, features, targets, predictions) - _maybe_add_streaming_mean(result, 'loss', loss) - - # TODO(ptucker): Work-around until we have an easier way to specify metrics - # from model_fn. - if predictions is not None: - if isinstance(predictions, dict): - for k, v in six.iteritems(predictions): - _maybe_add_streaming_mean(result, k, v) - else: - _maybe_add_streaming_mean(result, 'predictions', predictions) - + result = {'loss': metrics_lib.streaming_mean(loss)} + result.update(_make_metrics_ops(metrics, features, targets, predictions)) return result def _get_predict_ops(self, features): diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index aed130875f6..c67e41131dc 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -25,7 +25,8 @@ import threading import time import numpy as np -import six + +from six import reraise from tensorflow.contrib.framework import load_variable from tensorflow.contrib.framework.python.ops import ops as contrib_ops @@ -578,7 +579,7 @@ def _train_internal(graph, logging.error('Got exception during tf.learn final checkpoint %s.', e) finally: if excinfo: - six.reraise(*excinfo) + reraise(*excinfo) return loss_value @@ -628,14 +629,14 @@ def _write_summary_results(output_dir, eval_results, current_global_step): _eval_results_to_str(eval_results)) summary_writer = get_summary_writer(output_dir) summary = summary_pb2.Summary() - for key, eval_result in six.iteritems(eval_results): + for key in eval_results: if eval_results[key] is None: continue value = summary.value.add() value.tag = key - if (isinstance(eval_result, np.float32) or - isinstance(eval_result, float)): - value.simple_value = float(eval_result) + if (isinstance(eval_results[key], np.float32) or + isinstance(eval_results[key], float)): + value.simple_value = float(eval_results[key]) else: logging.warn('Skipping summary for %s, must be a float or np.float32.', key) -- GitLab