提交 23c93ebb 编写于 作者: K Karmel Allison 提交者: TensorFlower Gardener

Metrics tests: adding v2 decorators.

PiperOrigin-RevId: 225398873
上级 7dcebc86
......@@ -27,9 +27,10 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import layers
from tensorflow.python.keras import metrics
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
......@@ -285,19 +286,6 @@ class KerasAccuracyTest(test.TestCase):
metrics._assert_thresholds_range([None, 0.5])
def _get_simple_sequential_model(compile_metrics):
model = Sequential()
model.add(
layers.Dense(
3, activation='relu', input_dim=4, kernel_initializer='ones'))
model.add(layers.Dense(1, activation='sigmoid', kernel_initializer='ones'))
model.compile(
loss='mae',
metrics=compile_metrics,
optimizer=RMSPropOptimizer(learning_rate=0.001))
return model
@test_util.run_all_in_graph_and_eager_modes
class FalsePositivesTest(test.TestCase):
......@@ -366,16 +354,6 @@ class FalsePositivesTest(test.TestCase):
r'Threshold values must be in \[0, 1\]. Invalid values: \[-1, 2\]'):
metrics.FalsePositives(thresholds=[-1, 0.5, 2])
def test_reset_states(self):
fp_obj = metrics.FalsePositives()
model = _get_simple_sequential_model([fp_obj])
x = np.ones((100, 4))
y = np.zeros((100, 1))
model.evaluate(x, y)
self.assertEqual(self.evaluate(fp_obj.accumulator), 100.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(fp_obj.accumulator), 100.)
@test_util.run_all_in_graph_and_eager_modes
class FalseNegativesTest(test.TestCase):
......@@ -438,16 +416,6 @@ class FalseNegativesTest(test.TestCase):
result = fn_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose([4., 16., 23.], self.evaluate(result))
def test_reset_states(self):
fn_obj = metrics.FalseNegatives()
model = _get_simple_sequential_model([fn_obj])
x = np.zeros((100, 4))
y = np.ones((100, 1))
model.evaluate(x, y)
self.assertEqual(self.evaluate(fn_obj.accumulator), 100.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(fn_obj.accumulator), 100.)
@test_util.run_all_in_graph_and_eager_modes
class TrueNegativesTest(test.TestCase):
......@@ -510,16 +478,6 @@ class TrueNegativesTest(test.TestCase):
result = tn_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose([5., 15., 23.], self.evaluate(result))
def test_reset_states(self):
tn_obj = metrics.TrueNegatives()
model = _get_simple_sequential_model([tn_obj])
x = np.zeros((100, 4))
y = np.zeros((100, 1))
model.evaluate(x, y)
self.assertEqual(self.evaluate(tn_obj.accumulator), 100.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(tn_obj.accumulator), 100.)
@test_util.run_all_in_graph_and_eager_modes
class TruePositivesTest(test.TestCase):
......@@ -581,16 +539,6 @@ class TruePositivesTest(test.TestCase):
result = tp_obj(y_true, y_pred, sample_weight=37.)
self.assertAllClose([222., 111., 37.], self.evaluate(result))
def test_reset_states(self):
tp_obj = metrics.TruePositives()
model = _get_simple_sequential_model([tp_obj])
x = np.ones((100, 4))
y = np.ones((100, 1))
model.evaluate(x, y)
self.assertEqual(self.evaluate(tp_obj.accumulator), 100.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(tp_obj.accumulator), 100.)
@test_util.run_all_in_graph_and_eager_modes
class PrecisionTest(test.TestCase):
......@@ -703,18 +651,6 @@ class PrecisionTest(test.TestCase):
self.assertArrayNear([expected_precision, 0], self.evaluate(p_obj.result()),
1e-3)
def test_reset_states(self):
p_obj = metrics.Precision()
model = _get_simple_sequential_model([p_obj])
x = np.concatenate((np.ones((50, 4)), np.ones((50, 4))))
y = np.concatenate((np.ones((50, 1)), np.zeros((50, 1))))
model.evaluate(x, y)
self.assertEqual(self.evaluate(p_obj.tp), 50.)
self.assertEqual(self.evaluate(p_obj.fp), 50.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(p_obj.tp), 50.)
self.assertEqual(self.evaluate(p_obj.fp), 50.)
@test_util.run_all_in_graph_and_eager_modes
class RecallTest(test.TestCase):
......@@ -826,18 +762,6 @@ class RecallTest(test.TestCase):
self.assertArrayNear([expected_recall, 0], self.evaluate(r_obj.result()),
1e-3)
def test_reset_states(self):
r_obj = metrics.Recall()
model = _get_simple_sequential_model([r_obj])
x = np.concatenate((np.ones((50, 4)), np.zeros((50, 4))))
y = np.concatenate((np.ones((50, 1)), np.ones((50, 1))))
model.evaluate(x, y)
self.assertEqual(self.evaluate(r_obj.tp), 50.)
self.assertEqual(self.evaluate(r_obj.fn), 50.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(r_obj.tp), 50.)
self.assertEqual(self.evaluate(r_obj.fn), 50.)
@test_util.run_all_in_graph_and_eager_modes
class SensitivityAtSpecificityTest(test.TestCase, parameterized.TestCase):
......@@ -927,24 +851,6 @@ class SensitivityAtSpecificityTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, '`num_thresholds` must be > 0.'):
metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1)
def test_reset_states(self):
s_obj = metrics.SensitivityAtSpecificity(0.5, num_thresholds=1)
model = _get_simple_sequential_model([s_obj])
x = np.concatenate((np.ones((25, 4)), np.zeros((25, 4)), np.zeros((25, 4)),
np.ones((25, 4))))
y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones((25, 1)),
np.zeros((25, 1))))
model.evaluate(x, y)
self.assertEqual(self.evaluate(s_obj.tp), 25.)
self.assertEqual(self.evaluate(s_obj.fp), 25.)
self.assertEqual(self.evaluate(s_obj.fn), 25.)
self.assertEqual(self.evaluate(s_obj.tn), 25.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(s_obj.tp), 25.)
self.assertEqual(self.evaluate(s_obj.fp), 25.)
self.assertEqual(self.evaluate(s_obj.fn), 25.)
self.assertEqual(self.evaluate(s_obj.tn), 25.)
@test_util.run_all_in_graph_and_eager_modes
class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
......@@ -1034,24 +940,6 @@ class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, '`num_thresholds` must be > 0.'):
metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1)
def test_reset_states(self):
s_obj = metrics.SpecificityAtSensitivity(0.5, num_thresholds=1)
model = _get_simple_sequential_model([s_obj])
x = np.concatenate((np.ones((25, 4)), np.zeros((25, 4)), np.zeros((25, 4)),
np.ones((25, 4))))
y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones((25, 1)),
np.zeros((25, 1))))
model.evaluate(x, y)
self.assertEqual(self.evaluate(s_obj.tp), 25.)
self.assertEqual(self.evaluate(s_obj.fp), 25.)
self.assertEqual(self.evaluate(s_obj.fn), 25.)
self.assertEqual(self.evaluate(s_obj.tn), 25.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(s_obj.tp), 25.)
self.assertEqual(self.evaluate(s_obj.fp), 25.)
self.assertEqual(self.evaluate(s_obj.fn), 25.)
self.assertEqual(self.evaluate(s_obj.tn), 25.)
@test_util.run_all_in_graph_and_eager_modes
class CosineProximityTest(test.TestCase):
......@@ -1086,5 +974,125 @@ class CosineProximityTest(test.TestCase):
result = cosine_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(-0.59916, self.evaluate(result), atol=1e-5)
def _get_model(compile_metrics):
model_layers = [
layers.Dense(3, activation='relu', kernel_initializer='ones'),
layers.Dense(1, activation='sigmoid', kernel_initializer='ones')]
model = testing_utils.get_model_from_layers(model_layers, input_shape=(4,))
model.compile(
loss='mae',
metrics=compile_metrics,
optimizer=RMSPropOptimizer(learning_rate=0.001),
run_eagerly=testing_utils.should_run_eagerly())
return model
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
class ResetStatesTest(keras_parameterized.TestCase):
def test_reset_states_false_positives(self):
fp_obj = metrics.FalsePositives()
model = _get_model([fp_obj])
x = np.ones((100, 4))
y = np.zeros((100, 1))
model.evaluate(x, y)
self.assertEqual(self.evaluate(fp_obj.accumulator), 100.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(fp_obj.accumulator), 100.)
def test_reset_states_false_negatives(self):
fn_obj = metrics.FalseNegatives()
model = _get_model([fn_obj])
x = np.zeros((100, 4))
y = np.ones((100, 1))
model.evaluate(x, y)
self.assertEqual(self.evaluate(fn_obj.accumulator), 100.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(fn_obj.accumulator), 100.)
def test_reset_states_true_negatives(self):
tn_obj = metrics.TrueNegatives()
model = _get_model([tn_obj])
x = np.zeros((100, 4))
y = np.zeros((100, 1))
model.evaluate(x, y)
self.assertEqual(self.evaluate(tn_obj.accumulator), 100.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(tn_obj.accumulator), 100.)
def test_reset_states_true_positives(self):
tp_obj = metrics.TruePositives()
model = _get_model([tp_obj])
x = np.ones((100, 4))
y = np.ones((100, 1))
model.evaluate(x, y)
self.assertEqual(self.evaluate(tp_obj.accumulator), 100.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(tp_obj.accumulator), 100.)
def test_reset_states_precision(self):
p_obj = metrics.Precision()
model = _get_model([p_obj])
x = np.concatenate((np.ones((50, 4)), np.ones((50, 4))))
y = np.concatenate((np.ones((50, 1)), np.zeros((50, 1))))
model.evaluate(x, y)
self.assertEqual(self.evaluate(p_obj.tp), 50.)
self.assertEqual(self.evaluate(p_obj.fp), 50.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(p_obj.tp), 50.)
self.assertEqual(self.evaluate(p_obj.fp), 50.)
def test_reset_states_recall(self):
r_obj = metrics.Recall()
model = _get_model([r_obj])
x = np.concatenate((np.ones((50, 4)), np.zeros((50, 4))))
y = np.concatenate((np.ones((50, 1)), np.ones((50, 1))))
model.evaluate(x, y)
self.assertEqual(self.evaluate(r_obj.tp), 50.)
self.assertEqual(self.evaluate(r_obj.fn), 50.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(r_obj.tp), 50.)
self.assertEqual(self.evaluate(r_obj.fn), 50.)
def test_reset_states_sensitivity_at_specificity(self):
s_obj = metrics.SensitivityAtSpecificity(0.5, num_thresholds=1)
model = _get_model([s_obj])
x = np.concatenate((np.ones((25, 4)), np.zeros((25, 4)), np.zeros((25, 4)),
np.ones((25, 4))))
y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones((25, 1)),
np.zeros((25, 1))))
model.evaluate(x, y)
self.assertEqual(self.evaluate(s_obj.tp), 25.)
self.assertEqual(self.evaluate(s_obj.fp), 25.)
self.assertEqual(self.evaluate(s_obj.fn), 25.)
self.assertEqual(self.evaluate(s_obj.tn), 25.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(s_obj.tp), 25.)
self.assertEqual(self.evaluate(s_obj.fp), 25.)
self.assertEqual(self.evaluate(s_obj.fn), 25.)
self.assertEqual(self.evaluate(s_obj.tn), 25.)
def test_reset_states_specificity_at_sensitivity(self):
s_obj = metrics.SpecificityAtSensitivity(0.5, num_thresholds=1)
model = _get_model([s_obj])
x = np.concatenate((np.ones((25, 4)), np.zeros((25, 4)), np.zeros((25, 4)),
np.ones((25, 4))))
y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones((25, 1)),
np.zeros((25, 1))))
model.evaluate(x, y)
self.assertEqual(self.evaluate(s_obj.tp), 25.)
self.assertEqual(self.evaluate(s_obj.fp), 25.)
self.assertEqual(self.evaluate(s_obj.fn), 25.)
self.assertEqual(self.evaluate(s_obj.tn), 25.)
model.evaluate(x, y)
self.assertEqual(self.evaluate(s_obj.tp), 25.)
self.assertEqual(self.evaluate(s_obj.fp), 25.)
self.assertEqual(self.evaluate(s_obj.fn), 25.)
self.assertEqual(self.evaluate(s_obj.tn), 25.)
if __name__ == '__main__':
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册