提交 434d29d2 编写于 作者: C Chenkai Kuang 提交者: TensorFlower Gardener

Fix keras metric.result_state when the metric variables are sharded variable.

PiperOrigin-RevId: 381292911
上级 1232f05b
......@@ -108,22 +108,28 @@ class ShardedVariableTest(tf.test.TestCase):
def test_keras_metrics(self):
with self.strategy.scope():
fp = keras.metrics.FalsePositives(thresholds=[0.2, 0.5, 0.7, 0.8])
auc = keras.metrics.AUC(num_thresholds=10)
@tf.function
def update():
fp.update_state([0., 1., 0., 0.], [0., 0., 0.3, 0.9])
auc.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
@tf.function
def reset():
fp.reset_state()
auc.reset_state()
update()
self.assertEqual(auc.result(), 0.75)
self.assertAllEqual(fp.result(), [2., 1., 1., 1.])
reset()
self.assertEqual(auc.result(), 0.0)
self.assertAllEqual(fp.result(), [0., 0., 0., 0.])
self.assertTrue(hasattr(auc.true_positives, 'variables'))
self.assertTrue(hasattr(fp.accumulator, 'variables'))
def test_saved_model(self):
......
......@@ -1038,9 +1038,9 @@ class _ConfusionMatrixConditionCount(Metric):
return tf.convert_to_tensor(result)
def reset_state(self):
num_thresholds = len(to_list(self.thresholds))
backend.batch_set_value(
[(v, np.zeros((num_thresholds,))) for v in self.variables])
backend.batch_set_value([
(v, np.zeros(v.shape.as_list())) for v in self.variables
])
def get_config(self):
config = {'thresholds': self.init_thresholds}
......@@ -3175,8 +3175,9 @@ class MeanTensor(Metric):
def reset_state(self):
if self._built:
backend.batch_set_value(
[(v, np.zeros(self._shape.as_list())) for v in self.variables])
backend.batch_set_value([
(v, np.zeros(v.shape.as_list())) for v in self.variables
])
@keras_export('keras.metrics.BinaryCrossentropy')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册