提交 c5073deb 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Initial API in contrib/metrics/ and adding one Op (auc_using_histogram).

Add AUC Op based on internal histograms:  auc_using_histogram().
* Works fine with large class imbalance (no special handling when entire batch is one class)
* Runs in O(batch_size + histogram_width) time, so works for huge batches
Change: 119324088
上级 3c3a833b
......@@ -20,6 +20,7 @@ py_library(
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/metrics:metrics_py",
"//tensorflow/contrib/skflow",
"//tensorflow/contrib/tensor_forest:tensor_forest_py",
"//tensorflow/contrib/testing:testing_py",
......
......@@ -26,5 +26,6 @@ from tensorflow.contrib import layers
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import lookup
from tensorflow.contrib import losses
from tensorflow.contrib import metrics
from tensorflow.contrib import testing
from tensorflow.contrib import util
# Description:
# Contains ops for evaluation metrics and summary statistics.
# APIs here are meant to evolve over time.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
py_library(
name = "metrics_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
srcs_version = "PY2AND3",
)
cuda_py_tests(
name = "histogram_ops_test",
size = "small",
srcs = ["python/kernel_tests/histogram_ops_test.py"],
additional_deps = [
":metrics_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
# TensorFlow evaluation metrics and summary statistics
## Evaluation metrics
Compare predictions and labels, producing an aggregate loss. Typically produce
a `value` and an `update_op`. The `update_op` is run with every batch to update
internal state (e.g. accumulated right/wrong predictions).
The `value` is extracted after all batches have been read (e.g. precision =
number correct / total).
```python
predictions = ...
labels = ...
value, update_op = some_metric(predictions, labels)
for step_num in range(max_steps):
update_op.run()
print "evaluation score: ", value.eval()
```
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for evaluation metrics and summary statistics.
## This package provides Ops for evaluation metrics and summary statistics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,g-importing-member
from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for histogram_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.metrics.python.ops import histogram_ops
class Strict1dCumsumTest(tf.test.TestCase):
"""Test this private function."""
def test_empty_tensor_returns_empty(self):
with self.test_session():
tensor = tf.constant([])
result = histogram_ops._strict_1d_cumsum(tensor, 0)
expected = tf.constant([])
np.testing.assert_array_equal(expected.eval(), result.eval())
def test_length_1_tensor_works(self):
with self.test_session():
tensor = tf.constant([3], dtype=tf.float32)
result = histogram_ops._strict_1d_cumsum(tensor, 1)
expected = tf.constant([3], dtype=tf.float32)
np.testing.assert_array_equal(expected.eval(), result.eval())
def test_length_3_tensor_works(self):
with self.test_session():
tensor = tf.constant([1, 2, 3], dtype=tf.float32)
result = histogram_ops._strict_1d_cumsum(tensor, 3)
expected = tf.constant([1, 3, 6], dtype=tf.float32)
np.testing.assert_array_equal(expected.eval(), result.eval())
class AUCUsingHistogramTest(tf.test.TestCase):
def setUp(self):
self.rng = np.random.RandomState(0)
def test_empty_labels_and_scores_gives_nan_auc(self):
with self.test_session():
labels = tf.constant([], shape=[0], dtype=tf.bool)
scores = tf.constant([], shape=[0], dtype=tf.float32)
score_range = [0, 1.]
auc, update_op = tf.contrib.metrics.auc_using_histogram(labels, scores,
score_range)
tf.initialize_local_variables().run()
update_op.run()
self.assertTrue(np.isnan(auc.eval()))
def test_perfect_scores_gives_auc_1(self):
self._check_auc(nbins=100,
desired_auc=1.0,
score_range=[0, 1.],
num_records=50,
frac_true=0.5,
atol=0.05,
num_updates=1)
def test_terrible_scores_gives_auc_0(self):
self._check_auc(nbins=100,
desired_auc=0.0,
score_range=[0, 1.],
num_records=50,
frac_true=0.5,
atol=0.05,
num_updates=1)
def test_many_common_conditions(self):
for nbins in [50]:
for desired_auc in [0.3, 0.5, 0.8]:
for score_range in [[-1, 1], [-10, 0]]:
for frac_true in [0.3, 0.8]:
# Tests pass with atol = 0.03. Moved up to 0.05 to avoid flakes.
self._check_auc(nbins=nbins,
desired_auc=desired_auc,
score_range=score_range,
num_records=100,
frac_true=frac_true,
atol=0.05,
num_updates=50)
def test_large_class_imbalance_still_ok(self):
# With probability frac_true ** num_records, each batch contains only True
# records. In this case, ~ 95%.
# Tests pass with atol = 0.02. Increased to 0.05 to avoid flakes.
self._check_auc(nbins=100,
desired_auc=0.8,
score_range=[-1, 1.],
num_records=10,
frac_true=0.995,
atol=0.05,
num_updates=1000)
def test_super_accuracy_with_many_bins_and_records(self):
# Test passes with atol = 0.0005. Increased atol to avoid flakes.
self._check_auc(nbins=1000,
desired_auc=0.75,
score_range=[0, 1.],
num_records=1000,
frac_true=0.5,
atol=0.005,
num_updates=100)
def _check_auc(self,
nbins=100,
desired_auc=0.75,
score_range=None,
num_records=50,
frac_true=0.5,
atol=0.05,
num_updates=10):
"""Check auc accuracy against synthetic data.
Args:
nbins: nbins arg from contrib.metrics.auc_using_histogram.
desired_auc: Number in [0, 1]. The desired auc for synthetic data.
score_range: 2-tuple, (low, high), giving the range of the resultant
scores. Defaults to [0, 1.].
num_records: Positive integer. The number of records to return.
frac_true: Number in (0, 1). Expected fraction of resultant labels that
will be True. This is just in expectation...more or less may actually
be True.
atol: Absolute tolerance for final AUC estimate.
num_updates: Update internal histograms this many times, each with a new
batch of synthetic data, before computing final AUC.
Raises:
AssertionError: If resultant AUC is not within atol of theoretical AUC
from synthetic data.
"""
score_range = [0, 1.] or score_range
with self.test_session():
labels = tf.placeholder(tf.bool, shape=[num_records])
scores = tf.placeholder(tf.float32, shape=[num_records])
auc, update_op = tf.contrib.metrics.auc_using_histogram(labels,
scores,
score_range,
nbins=nbins)
tf.initialize_local_variables().run()
# Updates, then extract auc.
for _ in range(num_updates):
labels_a, scores_a = synthetic_data(desired_auc, score_range,
num_records, self.rng, frac_true)
update_op.run(feed_dict={labels: labels_a, scores: scores_a})
labels_a, scores_a = synthetic_data(desired_auc, score_range, num_records,
self.rng, frac_true)
# Fetch current auc, and verify that fetching again doesn't change it.
auc_eval = auc.eval()
self.assertEqual(auc_eval, auc.eval())
msg = ('nbins: %s, desired_auc: %s, score_range: %s, '
'num_records: %s, frac_true: %s, num_updates: %s') % (nbins,
desired_auc,
score_range,
num_records,
frac_true,
num_updates)
np.testing.assert_allclose(desired_auc, auc_eval, atol=atol, err_msg=msg)
def synthetic_data(desired_auc, score_range, num_records, rng, frac_true):
"""Create synthetic boolean_labels and scores with adjustable auc.
Args:
desired_auc: Number in [0, 1], the theoretical AUC of resultant data.
score_range: 2-tuple, (low, high), giving the range of the resultant scores
num_records: Positive integer. The number of records to return.
rng: Initialized np.random.RandomState random number generator
frac_true: Number in (0, 1). Expected fraction of resultant labels that
will be True. This is just in expectation...more or less may actually be
True.
Returns:
boolean_labels: np.array, dtype=bool.
scores: np.array, dtype=np.float32
"""
# We prove here why the method (below) for computing AUC works. Of course we
# also checked this against sklearn.metrics.roc_auc_curve.
#
# First do this for score_range = [0, 1], then rescale.
# WLOG assume AUC >= 0.5, otherwise we will solve for AUC >= 0.5 then swap
# the labels.
# So for AUC in [0, 1] we create False and True labels
# and corresponding scores drawn from:
# F ~ U[0, 1], T ~ U[x, 1]
# We have,
# AUC
# = P[T > F]
# = P[T > F | F < x] P[F < x] + P[T > F | F > x] P[F > x]
# = (1 * x) + (0.5 * (1 - x)).
# Inverting, we have:
# x = 2 * AUC - 1, when AUC >= 0.5.
assert 0 <= desired_auc <= 1
assert 0 < frac_true < 1
if desired_auc < 0.5:
flip_labels = True
desired_auc = 1 - desired_auc
frac_true = 1 - frac_true
else:
flip_labels = False
x = 2 * desired_auc - 1
labels = rng.binomial(1, frac_true, size=num_records).astype(bool)
num_true = labels.sum()
num_false = num_records - labels.sum()
# Draw F ~ U[0, 1], and T ~ U[x, 1]
false_scores = rng.rand(num_false)
true_scores = x + rng.rand(num_true) * (1 - x)
# Reshape [0, 1] to score_range.
def reshape(scores):
return score_range[0] + scores * (score_range[1] - score_range[0])
false_scores = reshape(false_scores)
true_scores = reshape(true_scores)
# Place into one array corresponding with the labels.
scores = np.nan * np.ones(num_records, dtype=np.float32)
scores[labels] = true_scores
scores[~labels] = false_scores
if flip_labels:
labels = ~labels
return labels, scores
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""## Metrics that use histograms.
@@auc_using_histogram
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import histogram_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
def auc_using_histogram(boolean_labels,
scores,
score_range,
nbins=100,
collections=None,
check_shape=True,
name=None):
"""AUC computed by maintaining histograms.
Rather than computing AUC directly, this Op maintains Variables containing
histograms of the scores associated with `True` and `False` labels. By
comparing these the AUC is generated, with some discretization error.
See: "Efficient AUC Learning Curve Calculation" by Bouckaert.
This AUC Op updates in `O(batch_size + nbins)` time and works well even with
large class imbalance. The accuracy is limited by discretization error due
to finite number of bins. If scores are concentrated in a fewer bins,
accuracy is lower. If this is a concern, we recommend trying different
numbers of bins and comparing results.
Args:
boolean_labels: 1-D boolean `Tensor`. Entry is `True` if the corresponding
record is in class.
scores: 1-D numeric `Tensor`, same shape as boolean_labels.
score_range: `Tensor` of shape `[2]`, same dtype as `scores`. The min/max
values of score that we expect. Scores outside range will be clipped.
nbins: Integer number of bins to use. Accuracy strictly increases as the
number of bins increases.
collections: List of graph collections keys. Internal histogram Variables
are added to these collections. Defaults to `[GraphKeys.LOCAL_VARIABLES]`.
check_shape: Boolean. If `True`, do a runtime shape check on the scores
and labels.
name: A name for this Op. Defaults to "auc_using_histogram".
Returns:
auc: `float32` scalar `Tensor`. Fetching this converts internal histograms
to auc value.
update_op: `Op`, when run, updates internal histograms.
"""
if collections is None:
collections = [ops.GraphKeys.LOCAL_VARIABLES]
with variable_scope.variable_op_scope(
[boolean_labels, scores, score_range], name, 'auc_using_histogram'):
score_range = ops.convert_to_tensor(score_range, name='score_range')
boolean_labels, scores = _check_labels_and_scores(
boolean_labels, scores, check_shape)
hist_true, hist_false = _make_auc_histograms(boolean_labels, scores,
score_range, nbins)
hist_true_acc, hist_false_acc, update_op = _auc_hist_accumulate(hist_true,
hist_false,
nbins,
collections)
auc = _auc_convert_hist_to_auc(hist_true_acc, hist_false_acc, nbins)
return auc, update_op
def _check_labels_and_scores(boolean_labels, scores, check_shape):
"""Check the rank of labels/scores, return tensor versions."""
with ops.op_scope([boolean_labels, scores], '_check_labels_and_scores'):
boolean_labels = ops.convert_to_tensor(boolean_labels,
name='boolean_labels')
scores = ops.convert_to_tensor(scores, name='scores')
if boolean_labels.dtype != dtypes.bool:
raise ValueError(
'Argument boolean_labels should have dtype bool. Found: %s',
boolean_labels.dtype)
if check_shape:
labels_rank_1 = logging_ops.Assert(
math_ops.equal(1, array_ops.rank(boolean_labels)),
['Argument boolean_labels should have rank 1. Found: ',
boolean_labels.name, array_ops.shape(boolean_labels)])
scores_rank_1 = logging_ops.Assert(
math_ops.equal(1, array_ops.rank(scores)),
['Argument scores should have rank 1. Found: ', scores.name,
array_ops.shape(scores)])
with ops.control_dependencies([labels_rank_1, scores_rank_1]):
return boolean_labels, scores
else:
return boolean_labels, scores
def _make_auc_histograms(boolean_labels, scores, score_range, nbins):
"""Create histogram tensors from one batch of labels/scores."""
with variable_scope.variable_op_scope(
[boolean_labels, scores, nbins], None, 'make_auc_histograms'):
# Histogram of scores for records in this batch with True label.
hist_true = histogram_ops.histogram_fixed_width(
array_ops.boolean_mask(scores, boolean_labels),
score_range,
nbins=nbins,
dtype=dtypes.int64,
name='hist_true')
# Histogram of scores for records in this batch with False label.
hist_false = histogram_ops.histogram_fixed_width(
array_ops.boolean_mask(scores, math_ops.logical_not(boolean_labels)),
score_range,
nbins=nbins,
dtype=dtypes.int64,
name='hist_false')
return hist_true, hist_false
def _auc_hist_accumulate(hist_true, hist_false, nbins, collections):
"""Accumulate histograms in new variables."""
with variable_scope.variable_op_scope(
[hist_true, hist_false], None, 'hist_accumulate'):
# Holds running total histogram of scores for records labeled True.
hist_true_acc = variable_scope.get_variable(
'hist_true_acc',
initializer=array_ops.zeros_initializer(
[nbins],
dtype=hist_true.dtype),
collections=collections,
trainable=False)
# Holds running total histogram of scores for records labeled False.
hist_false_acc = variable_scope.get_variable(
'hist_false_acc',
initializer=array_ops.zeros_initializer(
[nbins],
dtype=hist_false.dtype),
collections=collections,
trainable=False)
update_op = control_flow_ops.group(
hist_true_acc.assign_add(hist_true),
hist_false_acc.assign_add(hist_false),
name='update_op')
return hist_true_acc, hist_false_acc, update_op
def _auc_convert_hist_to_auc(hist_true_acc, hist_false_acc, nbins):
"""Convert histograms to auc.
Args:
hist_true_acc: `Tensor` holding accumulated histogram of scores for records
that were `True`.
hist_false_acc: `Tensor` holding accumulated histogram of scores for
records that were `False`.
nbins: Integer number of bins in the histograms.
Returns:
Scalar `Tensor` estimating AUC.
"""
# Note that this follows the "Approximating AUC" section in:
# Efficient AUC learning curve calculation, R. R. Bouckaert,
# AI'06 Proceedings of the 19th Australian joint conference on Artificial
# Intelligence: advances in Artificial Intelligence
# Pages 181-191.
# Note that the above paper has an error, and we need to re-order our bins to
# go from high to low score.
# Normalize histogram so we get fraction in each bin.
normed_hist_true = math_ops.truediv(hist_true_acc,
math_ops.reduce_sum(hist_true_acc))
normed_hist_false = math_ops.truediv(hist_false_acc,
math_ops.reduce_sum(hist_false_acc))
# These become delta x, delta y from the paper.
delta_y_t = array_ops.reverse(normed_hist_true, [True], name='delta_y_t')
delta_x_t = array_ops.reverse(normed_hist_false, [True], name='delta_x_t')
# strict_1d_cumsum requires float32 args.
delta_y_t = math_ops.cast(delta_y_t, dtypes.float32)
delta_x_t = math_ops.cast(delta_x_t, dtypes.float32)
# Trapezoidal integration, \int_0^1 0.5 * (y_t + y_{t-1}) dx_t
y_t = _strict_1d_cumsum(delta_y_t, nbins)
first_trap = delta_x_t[0] * y_t[0] / 2.0
other_traps = delta_x_t[1:] * (y_t[1:] + y_t[:nbins - 1]) / 2.0
return math_ops.add(first_trap, math_ops.reduce_sum(other_traps), name='auc')
# TODO(langmore) Remove once a faster cumsum (accumulate_sum) Op is available.
# Also see if cast to float32 above can be removed with new cumsum.
# See: https://github.com/tensorflow/tensorflow/issues/813
def _strict_1d_cumsum(tensor, len_tensor):
"""Cumsum of a 1D tensor with defined shape by padding and convolving."""
# Assumes tensor shape is fully defined.
with ops.op_scope([tensor], 'strict_1d_cumsum'):
if len_tensor == 0:
return constant_op.constant([])
len_pad = len_tensor - 1
x = array_ops.pad(tensor, [[len_pad, 0]])
h = array_ops.ones_like(x)
return _strict_conv1d(x, h)[:len_tensor]
# TODO(langmore) Remove once a faster cumsum (accumulate_sum) Op is available.
# See: https://github.com/tensorflow/tensorflow/issues/813
def _strict_conv1d(x, h):
"""Return x * h for rank 1 tensors x and h."""
with ops.op_scope([x, h], 'strict_conv1d'):
x = array_ops.reshape(x, (1, -1, 1, 1))
h = array_ops.reshape(h, (-1, 1, 1, 1))
result = nn_ops.conv2d(x, h, [1, 1, 1, 1], 'SAME')
return array_ops.reshape(result, [-1])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册