From 434d29d277630a6b6ae2c405d4f7ea787db56ee9 Mon Sep 17 00:00:00 2001 From: Chenkai Kuang Date: Thu, 24 Jun 2021 11:01:22 -0700 Subject: [PATCH] Fix keras metric.result_state when the metric variables are sharded variable. PiperOrigin-RevId: 381292911 --- keras/distribute/sharded_variable_test.py | 6 ++++++ keras/metrics.py | 11 ++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/keras/distribute/sharded_variable_test.py b/keras/distribute/sharded_variable_test.py index 26a8bc40a..d0a945870 100644 --- a/keras/distribute/sharded_variable_test.py +++ b/keras/distribute/sharded_variable_test.py @@ -108,22 +108,28 @@ class ShardedVariableTest(tf.test.TestCase): def test_keras_metrics(self): with self.strategy.scope(): + fp = keras.metrics.FalsePositives(thresholds=[0.2, 0.5, 0.7, 0.8]) auc = keras.metrics.AUC(num_thresholds=10) @tf.function def update(): + fp.update_state([0., 1., 0., 0.], [0., 0., 0.3, 0.9]) auc.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) @tf.function def reset(): + fp.reset_state() auc.reset_state() update() self.assertEqual(auc.result(), 0.75) + self.assertAllEqual(fp.result(), [2., 1., 1., 1.]) reset() self.assertEqual(auc.result(), 0.0) + self.assertAllEqual(fp.result(), [0., 0., 0., 0.]) self.assertTrue(hasattr(auc.true_positives, 'variables')) + self.assertTrue(hasattr(fp.accumulator, 'variables')) def test_saved_model(self): diff --git a/keras/metrics.py b/keras/metrics.py index 49c4856b2..7e7b0c44d 100644 --- a/keras/metrics.py +++ b/keras/metrics.py @@ -1038,9 +1038,9 @@ class _ConfusionMatrixConditionCount(Metric): return tf.convert_to_tensor(result) def reset_state(self): - num_thresholds = len(to_list(self.thresholds)) - backend.batch_set_value( - [(v, np.zeros((num_thresholds,))) for v in self.variables]) + backend.batch_set_value([ + (v, np.zeros(v.shape.as_list())) for v in self.variables + ]) def get_config(self): config = {'thresholds': self.init_thresholds} @@ -3175,8 +3175,9 @@ class MeanTensor(Metric): def reset_state(self): if self._built: - backend.batch_set_value( - [(v, np.zeros(self._shape.as_list())) for v in self.variables]) + backend.batch_set_value([ + (v, np.zeros(v.shape.as_list())) for v in self.variables + ]) @keras_export('keras.metrics.BinaryCrossentropy') -- GitLab