提交 5ed1de77 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Automated rollback of change 136502135

Change: 136641403
上级 ad3bcda5
......@@ -216,12 +216,8 @@ def _make_metrics_ops(metrics, features, targets, predictions):
predictions.
Returns:
`dict` whose keys are summary names, and values are the result of the
metric, either:
- `Tensor` values (in which case only the result of the last eval batch
will be summarized).
- `tuple` of 2 `Tensor` objects, update op and value. The update op will
be run once each eval step, and the value written to summary.
A dict mapping the friendly given in `metrics` to the result of calling the
given metric function.
Raises:
ValueError: If metrics specifications do not work with the type of
......@@ -271,13 +267,6 @@ def _make_metrics_ops(metrics, features, targets, predictions):
return result
def _maybe_add_streaming_mean(result, key, value):
if key in result:
logging.warning('Metrics already contains %s, skipping.', key)
return
result[key] = metrics_lib.streaming_mean(value)
class BaseEstimator(
sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable):
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
......@@ -585,7 +574,7 @@ class BaseEstimator(
Args:
features: `Tensor` or `dict` of `Tensor` objects.
targets: `Tensor` or `dict` of `Tensor` objects.
metrics: Dict of metrics to run. If `None`, the default metric functions
metrics: Dict of metrics to run. If None, the default metric functions
are used; if {}, no metrics are used. Otherwise, `metrics` should map
friendly names for the metric to a `MetricSpec` object defining which
model outputs to evaluate against which targets with which metric
......@@ -1055,28 +1044,14 @@ class Estimator(BaseEstimator):
`../metric_spec.py`.
Returns:
`dict` whose keys are summary names, and values are either:
- `Tensor` values (in which case only the result of the last eval batch
will be summarized).
- `tuple` of 2 `Tensor` objects, update op and value. The update op will
be run once each eval step, and the value written to summary.
metrics: `dict` of `Tensor` objects.
Raises:
ValueError: if `metrics` don't match `targets`.
"""
predictions, loss, _ = self._call_model_fn(features, targets, ModeKeys.EVAL)
result = _make_metrics_ops(metrics, features, targets, predictions)
_maybe_add_streaming_mean(result, 'loss', loss)
# TODO(ptucker): Work-around until we have an easier way to specify metrics
# from model_fn.
if predictions is not None:
if isinstance(predictions, dict):
for k, v in six.iteritems(predictions):
_maybe_add_streaming_mean(result, k, v)
else:
_maybe_add_streaming_mean(result, 'predictions', predictions)
result = {'loss': metrics_lib.streaming_mean(loss)}
result.update(_make_metrics_ops(metrics, features, targets, predictions))
return result
def _get_predict_ops(self, features):
......
......@@ -25,7 +25,8 @@ import threading
import time
import numpy as np
import six
from six import reraise
from tensorflow.contrib.framework import load_variable
from tensorflow.contrib.framework.python.ops import ops as contrib_ops
......@@ -578,7 +579,7 @@ def _train_internal(graph,
logging.error('Got exception during tf.learn final checkpoint %s.', e)
finally:
if excinfo:
six.reraise(*excinfo)
reraise(*excinfo)
return loss_value
......@@ -628,14 +629,14 @@ def _write_summary_results(output_dir, eval_results, current_global_step):
_eval_results_to_str(eval_results))
summary_writer = get_summary_writer(output_dir)
summary = summary_pb2.Summary()
for key, eval_result in six.iteritems(eval_results):
for key in eval_results:
if eval_results[key] is None:
continue
value = summary.value.add()
value.tag = key
if (isinstance(eval_result, np.float32) or
isinstance(eval_result, float)):
value.simple_value = float(eval_result)
if (isinstance(eval_results[key], np.float32) or
isinstance(eval_results[key], float)):
value.simple_value = float(eval_results[key])
else:
logging.warn('Skipping summary for %s, must be a float or np.float32.',
key)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册