提交 524cc3ea 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Delete AUC metrics from MultiClassHead. This removed AUC metrics from...

Delete AUC metrics from MultiClassHead. This removed AUC metrics from DNNClassifier, LinearClassifier and DNNLinearCombinedClassifier when n_classes > 2.
Change: 149656850
上级 168e8dac
......@@ -1033,9 +1033,6 @@ class _MultiClassHead(_SingleHead):
# "accuracy/threshold_0.500000_mean" metric for binary classification.
metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
metrics_lib.streaming_accuracy(classes, labels, weights))
metrics[_summary_key(self.head_name, mkey.AUC)] = (
_streaming_auc_with_class_id_label(
probabilities, labels, weights, self.logits_dimension))
for class_id in self._metric_class_ids:
# TODO(ptucker): Add per-class accuracy, precision, recall.
......@@ -1051,9 +1048,6 @@ class _MultiClassHead(_SingleHead):
metrics[_summary_key(
self.head_name, mkey.CLASS_LOGITS_MEAN % class_id)] = (
_predictions_streaming_mean(logits, weights, class_id))
metrics[_summary_key(self.head_name, mkey.CLASS_AUC % class_id)] = (
_class_streaming_auc(probabilities, labels, weights, class_id,
self.logits_dimension))
return metrics
......@@ -1773,20 +1767,6 @@ def _class_labels_streaming_mean(labels, weights, class_id):
weights=weights)
def _class_streaming_auc(predictions, labels, weights, class_id,
num_classes):
indicator_labels = _class_id_labels_to_indicator(
labels, num_classes=num_classes)
return _streaming_auc(predictions, indicator_labels, weights, class_id)
def _streaming_auc_with_class_id_label(predictions, labels, weights,
num_classes):
indicator_labels = _class_id_labels_to_indicator(
labels, num_classes=num_classes)
return _streaming_auc(predictions, indicator_labels, weights)
def _streaming_auc(predictions, labels, weights=None, class_id=None):
predictions = ops.convert_to_tensor(predictions)
labels = ops.convert_to_tensor(labels)
......
......@@ -883,11 +883,7 @@ class MultiClassHeadTest(test.TestCase):
def _expected_eval_metrics(self, expected_loss):
return {
"accuracy": 0.,
"auc": 1. / 4,
"loss": expected_loss,
"auc/class0": 1.,
"auc/class1": 1.,
"auc/class2": 0.,
"labels/actual_label_mean/class0": 0. / 1,
"labels/actual_label_mean/class1": 0. / 1,
"labels/actual_label_mean/class2": 1. / 1,
......@@ -957,11 +953,7 @@ class MultiClassHeadTest(test.TestCase):
expected_loss = 1.0986123
_assert_metrics(self, expected_loss, {
"accuracy": 0.,
"auc": 2. / 4,
"loss": expected_loss,
"auc/class0": 1.,
"auc/class1": 1.,
"auc/class2": 0.,
"labels/actual_label_mean/class0": 0. / 1,
"labels/actual_label_mean/class1": 0. / 1,
"labels/actual_label_mean/class2": 1. / 1,
......@@ -1023,11 +1015,7 @@ class MultiClassHeadTest(test.TestCase):
expected_loss = 3.1698461
expected_eval_metrics = {
"accuracy": 0.,
"auc": 9.99999e-07,
"loss": expected_loss,
"auc/class0": 1.,
"auc/class1": 1.,
"auc/class2": 0.,
"labels/actual_label_mean/class0": 0. / 1,
"labels/actual_label_mean/class1": 0. / 1,
"labels/actual_label_mean/class2": 1. / 1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册